diff --git a/libcst/matchers/_visitors.py b/libcst/matchers/_visitors.py index ded6eb9dc..a314fc4d4 100644 --- a/libcst/matchers/_visitors.py +++ b/libcst/matchers/_visitors.py @@ -45,6 +45,15 @@ ) from libcst.matchers._return_types import TYPED_FUNCTION_RETURN_MAPPING +try: + # PEP 604 unions, in Python 3.10+ + from types import UnionType +except ImportError: + # We use this for isinstance; no annotation will be an instance of this + class UnionType: + pass + + CONCRETE_METHODS: Set[str] = { *{f"visit_{cls.__name__}" for cls in TYPED_FUNCTION_RETURN_MAPPING}, *{f"leave_{cls.__name__}" for cls in TYPED_FUNCTION_RETURN_MAPPING}, @@ -78,18 +87,15 @@ def _get_possible_match_classes(matcher: BaseMatcherNode) -> List[Type[cst.CSTNo return [getattr(cst, matcher.__class__.__name__)] -def _annotation_looks_like_union(annotation: object) -> bool: - if getattr(annotation, "__origin__", None) is Union: - return True - # support PEP-604 style unions introduced in Python 3.10 +def _annotation_is_union(annotation: object) -> bool: return ( - annotation.__class__.__name__ == "Union" - and annotation.__class__.__module__ == "types" + isinstance(annotation, UnionType) + or getattr(annotation, "__origin__", None) is Union ) def _get_possible_annotated_classes(annotation: object) -> List[Type[object]]: - if _annotation_looks_like_union(annotation): + if _annotation_is_union(annotation): return getattr(annotation, "__args__", []) else: return [cast(Type[object], annotation)] diff --git a/libcst/matchers/tests/test_decorators.py b/libcst/matchers/tests/test_decorators.py index 7486cee82..8b28657c3 100644 --- a/libcst/matchers/tests/test_decorators.py +++ b/libcst/matchers/tests/test_decorators.py @@ -3,10 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import sys from ast import literal_eval from textwrap import dedent from typing import List, Set -from unittest.mock import Mock +from unittest import skipIf import libcst as cst import libcst.matchers as m @@ -996,22 +997,14 @@ def bar() -> None: self.assertEqual(visitor.visits, ['"baz"']) -# This is meant to simulate `cst.ImportFrom | cst.RemovalSentinel` in py3.10 -FakeUnionClass: Mock = Mock() -setattr(FakeUnionClass, "__name__", "Union") -setattr(FakeUnionClass, "__module__", "types") -FakeUnion: Mock = Mock() -FakeUnion.__class__ = FakeUnionClass -FakeUnion.__args__ = [cst.ImportFrom, cst.RemovalSentinel] - - class MatchersUnionDecoratorsTest(UnitTest): + @skipIf(bool(sys.version_info < (3, 10)), "new union syntax not available") def test_init_with_new_union_annotation(self) -> None: class TransformerWithUnionReturnAnnotation(m.MatcherDecoratableTransformer): @m.leave(m.ImportFrom(module=m.Name(value="typing"))) def test( self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom - ) -> FakeUnion: + ) -> cst.ImportFrom | cst.RemovalSentinel: pass # assert that init (specifically _check_types on return annotation) passes