From 21d37b94b253876511a8e0be65b9088129c1f787 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20M=C3=A9ndez=20Bravo?= Date: Mon, 12 Oct 2020 11:12:28 -0700 Subject: [PATCH] Support string annotations for type aliases (#401) --- libcst/_parser/types/tests/test_config.py | 2 +- libcst/metadata/scope_provider.py | 26 ++++++++++++-------- libcst/metadata/tests/test_scope_provider.py | 18 ++++++++++++-- libcst/tests/test_type_enforce.py | 1 + 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/libcst/_parser/types/tests/test_config.py b/libcst/_parser/types/tests/test_config.py index 6c0c0d0b7..8b68bd182 100644 --- a/libcst/_parser/types/tests/test_config.py +++ b/libcst/_parser/types/tests/test_config.py @@ -12,7 +12,7 @@ class TestConfig(UnitTest): @data_provider( { - "empty": (lambda: PartialParserConfig(),), + "empty": (PartialParserConfig,), "python_version_a": (lambda: PartialParserConfig(python_version="3.7"),), "python_version_b": (lambda: PartialParserConfig(python_version="3.7.1"),), "encoding": (lambda: PartialParserConfig(encoding="latin-1"),), diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index e1c7f1961..f0cde1dc1 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -643,7 +643,10 @@ def __init__(self, provider: "ScopeProvider") -> None: self.scope: Scope = GlobalScope() self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = [] self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None] - self.__in_annotation: Set[Union[cst.Call, cst.Annotation]] = set() + self.__in_annotation: Set[ + Union[cst.Call, cst.Annotation, cst.Subscript] + ] = set() + self.__in_ignored_subscript: Set[cst.Subscript] = set() @contextmanager def _new_scope( @@ -699,10 +702,8 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: def visit_Call(self, node: cst.Call) -> Optional[bool]: self.__top_level_attribute_stack.append(None) - if any( - qn.name == "typing.TypeVar" - for qn in self.scope.get_qualified_names_for(node) - ): + qnames = self.scope.get_qualified_names_for(node) + if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames): node.func.visit(self) self.__in_annotation.add(node) for arg in node.args[1:]: @@ -731,21 +732,26 @@ def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> Optional[boo def _handle_string_annotation( self, node: Union[cst.SimpleString, cst.ConcatenatedString] ) -> None: - if self.__in_annotation: + if self.__in_annotation and not self.__in_ignored_subscript: value = node.evaluated_value if value: mod = cst.parse_module(value) mod.visit(self) def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]: + qnames = self.scope.get_qualified_names_for(node.value) + if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames): + self.__in_annotation.add(node) if any( - qn.name in ("typing.Literal", "typing_extensions.Literal") - for qn in self.scope.get_qualified_names_for(node.value) + qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames ): - node.value.visit(self) - return False + self.__in_ignored_subscript.add(node) return True + def leave_Subscript(self, original_node: cst.Subscript) -> None: + self.__in_annotation.discard(original_node) + self.__in_ignored_subscript.discard(original_node) + def visit_Name(self, node: cst.Name) -> Optional[bool]: # not all Name have ExpressionContext context = self.provider.get_metadata(ExpressionContextProvider, node, None) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index d6337f3aa..228fb2768 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -1018,8 +1018,8 @@ def g(): def test_annotation_access(self) -> None: m, scopes = get_scope_metadata_provider( """ - from typing import Literal, TypeVar - from a import A, B, C, D, E, F + from typing import Literal, NewType, Optional, TypeVar + from a import A, B, C, D, E, F, G, H def x(a: A): pass def y(b: "B"): @@ -1029,6 +1029,8 @@ def z(c: Literal["C"]): DType = TypeVar("DType", bound=D) EType = TypeVar("EType", bound="E") FType = TypeVar("F") + GType = NewType("GType", "Optional[G]") + HType = Optional["H"] """ ) imp = ensure_type( @@ -1068,6 +1070,18 @@ def z(c: Literal["C"]): self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 0) + assignment = list(scope["G"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + references = list(assignment.references) + self.assertTrue(references[0].is_annotation) + + assignment = list(scope["H"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + references = list(assignment.references) + self.assertTrue(references[0].is_annotation) + def test_node_of_scopes(self) -> None: m, scopes = get_scope_metadata_provider( """ diff --git a/libcst/tests/test_type_enforce.py b/libcst/tests/test_type_enforce.py index 0779ec370..edc283e5b 100644 --- a/libcst/tests/test_type_enforce.py +++ b/libcst/tests/test_type_enforce.py @@ -53,6 +53,7 @@ class MyExampleClassWithMetaclass(metaclass=MyExampleMetaclass): pass +# lint-ignore: NoNamedTupleRule class NamedTupleSubclass(NamedTuple): a: str b: int