Skip to content

Commit

Permalink
fix PEP 604 union annotations in decorators (#828)
Browse files Browse the repository at this point in the history
  • Loading branch information
Carl Meyer authored Nov 29, 2022
1 parent 987aff6 commit f668e88
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
20 changes: 13 additions & 7 deletions libcst/matchers/_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@
)
from libcst.matchers._return_types import TYPED_FUNCTION_RETURN_MAPPING

try:
# PEP 604 unions, in Python 3.10+
from types import UnionType
except ImportError:
# We use this for isinstance; no annotation will be an instance of this
class UnionType:
pass


CONCRETE_METHODS: Set[str] = {
*{f"visit_{cls.__name__}" for cls in TYPED_FUNCTION_RETURN_MAPPING},
*{f"leave_{cls.__name__}" for cls in TYPED_FUNCTION_RETURN_MAPPING},
Expand Down Expand Up @@ -78,18 +87,15 @@ def _get_possible_match_classes(matcher: BaseMatcherNode) -> List[Type[cst.CSTNo
return [getattr(cst, matcher.__class__.__name__)]


def _annotation_looks_like_union(annotation: object) -> bool:
if getattr(annotation, "__origin__", None) is Union:
return True
# support PEP-604 style unions introduced in Python 3.10
def _annotation_is_union(annotation: object) -> bool:
return (
annotation.__class__.__name__ == "Union"
and annotation.__class__.__module__ == "types"
isinstance(annotation, UnionType)
or getattr(annotation, "__origin__", None) is Union
)


def _get_possible_annotated_classes(annotation: object) -> List[Type[object]]:
if _annotation_looks_like_union(annotation):
if _annotation_is_union(annotation):
return getattr(annotation, "__args__", [])
else:
return [cast(Type[object], annotation)]
Expand Down
15 changes: 4 additions & 11 deletions libcst/matchers/tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
from ast import literal_eval
from textwrap import dedent
from typing import List, Set
from unittest.mock import Mock
from unittest import skipIf

import libcst as cst
import libcst.matchers as m
Expand Down Expand Up @@ -996,22 +997,14 @@ def bar() -> None:
self.assertEqual(visitor.visits, ['"baz"'])


# This is meant to simulate `cst.ImportFrom | cst.RemovalSentinel` in py3.10
FakeUnionClass: Mock = Mock()
setattr(FakeUnionClass, "__name__", "Union")
setattr(FakeUnionClass, "__module__", "types")
FakeUnion: Mock = Mock()
FakeUnion.__class__ = FakeUnionClass
FakeUnion.__args__ = [cst.ImportFrom, cst.RemovalSentinel]


class MatchersUnionDecoratorsTest(UnitTest):
@skipIf(bool(sys.version_info < (3, 10)), "new union syntax not available")
def test_init_with_new_union_annotation(self) -> None:
class TransformerWithUnionReturnAnnotation(m.MatcherDecoratableTransformer):
@m.leave(m.ImportFrom(module=m.Name(value="typing")))
def test(
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> FakeUnion:
) -> cst.ImportFrom | cst.RemovalSentinel:
pass

# assert that init (specifically _check_types on return annotation) passes
Expand Down

0 comments on commit f668e88

Please sign in to comment.