From 4d2ccc54b2b68a9059bc0670a472eded1cb767f9 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Wed, 12 May 2021 14:50:15 +0100 Subject: [PATCH] tie accesses from string annotation to the string node (#483) --- libcst/metadata/scope_provider.py | 40 +++++++++-- libcst/metadata/tests/test_scope_provider.py | 72 ++++++++++++++++++++ 2 files changed, 105 insertions(+), 7 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 21e7a9e52..3b5d380e9 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -74,9 +74,10 @@ def __new__(cls) -> "Tree": #: The node of the access. A name is an access when the expression context is #: :attr:`ExpressionContext.LOAD`. This is usually the name node representing the - #: access, except for dotted imports, when it might be the attribute that - #: represents the most specific part of the imported symbol. - node: Union[cst.Name, cst.Attribute] + #: access, except for: 1) dotted imports, when it might be the attribute that + #: represents the most specific part of the imported symbol; and 2) string + #: annotations, when it is the entire string literal + node: Union[cst.Name, cst.Attribute, cst.BaseString] #: The scope of the access. Note that a access could be in a child scope of its #: assignment. @@ -422,7 +423,7 @@ def _record_assignment_as_parent(self, name: str, node: cst.CSTNode) -> None: @abc.abstractmethod def __contains__(self, name: str) -> bool: - """ Check if the name str exist in current scope by ``name in scope``. """ + """Check if the name str exist in current scope by ``name in scope``.""" ... @abc.abstractmethod @@ -775,18 +776,26 @@ def _is_assignment(node: cst.CSTNode, assignment_node: cst.CSTNode) -> bool: return False +@dataclass(frozen=True) +class DeferredAccess: + access: Access + enclosing_attribute: Optional[cst.Attribute] + enclosing_string_annotation: Optional[cst.BaseString] + + class ScopeVisitor(cst.CSTVisitor): # since it's probably not useful. That can makes this visitor cleaner. def __init__(self, provider: "ScopeProvider") -> None: self.provider: ScopeProvider = provider self.scope: Scope = GlobalScope() - self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = [] + self.__deferred_accesses: List[DeferredAccess] = [] self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None] self.__in_annotation: Set[ Union[cst.Call, cst.Annotation, cst.Subscript] ] = set() self.__in_type_hint: Set[Union[cst.Call, cst.Annotation, cst.Subscript]] = set() self.__in_ignored_subscript: Set[cst.Subscript] = set() + self.__last_string_annotation: Optional[cst.BaseString] = None self.__ignore_annotation: int = 0 @contextmanager @@ -887,8 +896,13 @@ def _handle_string_annotation( ) and not self.__in_ignored_subscript: value = node.evaluated_value if value: + top_level_annotation = self.__last_string_annotation is None + if top_level_annotation: + self.__last_string_annotation = node mod = cst.parse_module(value) mod.visit(self) + if top_level_annotation: + self.__last_string_annotation = None return True return False @@ -920,7 +934,11 @@ def visit_Name(self, node: cst.Name) -> Optional[bool]: is_type_hint=bool(self.__in_type_hint), ) self.__deferred_accesses.append( - (access, self.__top_level_attribute_stack[-1]) + DeferredAccess( + access=access, + enclosing_attribute=self.__top_level_attribute_stack[-1], + enclosing_string_annotation=self.__last_string_annotation, + ) ) def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: @@ -1074,7 +1092,12 @@ def infer_accesses(self) -> None: # In worst case, all accesses (m) and assignments (n) refer to the same name, # the time complexity is O(m x n), this optimizes it as O(m + n). scope_name_accesses = defaultdict(set) - for (access, enclosing_attribute) in self.__deferred_accesses: + for def_access in self.__deferred_accesses: + access, enclosing_attribute, enclosing_string_annotation = ( + def_access.access, + def_access.enclosing_attribute, + def_access.enclosing_string_annotation, + ) name = ensure_type(access.node, cst.Name).value if enclosing_attribute is not None: # if _gen_dotted_names doesn't generate any values, fall back to @@ -1085,6 +1108,9 @@ def infer_accesses(self) -> None: name = attr_name break + if enclosing_string_annotation is not None: + access.node = enclosing_string_annotation + scope_name_accesses[(access.scope, name)].add(access) access.record_assignments(name) access.scope.record_access(name, access) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 59a20aec7..8a1bf4b3f 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -1082,6 +1082,10 @@ class Test(Generic[J]): self.assertEqual(len(assignment.references), 1) references = list(assignment.references) self.assertTrue(references[0].is_annotation) + reference_node = references[0].node + self.assertIsInstance(reference_node, cst.SimpleString) + if isinstance(reference_node, cst.SimpleString): + self.assertEqual(reference_node.evaluated_value, "B") assignment = list(scope["C"])[0] self.assertIsInstance(assignment, Assignment) @@ -1104,6 +1108,10 @@ class Test(Generic[J]): references = list(assignment.references) self.assertFalse(references[0].is_annotation) self.assertTrue(references[0].is_type_hint) + reference_node = references[0].node + self.assertIsInstance(reference_node, cst.SimpleString) + if isinstance(reference_node, cst.SimpleString): + self.assertEqual(reference_node.evaluated_value, "E") assignment = list(scope["E2"])[0] self.assertIsInstance(assignment, Assignment) @@ -1119,6 +1127,10 @@ class Test(Generic[J]): references = list(assignment.references) self.assertFalse(references[0].is_annotation) self.assertTrue(references[0].is_type_hint) + reference_node = references[0].node + self.assertIsInstance(reference_node, cst.SimpleString) + if isinstance(reference_node, cst.SimpleString): + self.assertEqual(reference_node.evaluated_value, "Optional[G]") assignment = list(scope["G2"])[0] self.assertIsInstance(assignment, Assignment) @@ -1130,6 +1142,10 @@ class Test(Generic[J]): references = list(assignment.references) self.assertFalse(references[0].is_annotation) self.assertTrue(references[0].is_type_hint) + reference_node = references[0].node + self.assertIsInstance(reference_node, cst.SimpleString) + if isinstance(reference_node, cst.SimpleString): + self.assertEqual(reference_node.evaluated_value, "H") assignment = list(scope["I"])[0] self.assertIsInstance(assignment, Assignment) @@ -1148,6 +1164,10 @@ class Test(Generic[J]): self.assertEqual(len(assignment.references), 1) references = list(assignment.references) self.assertFalse(references[0].is_annotation) + reference_node = references[0].node + self.assertIsInstance(reference_node, cst.SimpleString) + if isinstance(reference_node, cst.SimpleString): + self.assertEqual(reference_node.evaluated_value, "K") assignment = list(scope["K2"])[0] self.assertIsInstance(assignment, Assignment) @@ -1157,12 +1177,64 @@ class Test(Generic[J]): self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) references = list(assignment.references) + reference_node = references[0].node + self.assertIsInstance(reference_node, cst.SimpleString) + if isinstance(reference_node, cst.SimpleString): + self.assertEqual(reference_node.evaluated_value, "L") assignment = list(scope["M"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) references = list(assignment.references) + def test_insane_annotation_access(self) -> None: + m, scopes = get_scope_metadata_provider( + r""" + from typing import TypeVar + from a import G + TypeVar("G2", bound="Optional[\"G\"]") + """ + ) + imp = ensure_type( + ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom + ) + call = ensure_type( + ensure_type( + ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr + ).value, + cst.Call, + ) + bound = call.args[1].value + scope = scopes[imp] + assignment = next(iter(scope["G"])) + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + self.assertEqual(list(assignment.references)[0].node, bound) + + def test_dotted_annotation_access(self) -> None: + m, scopes = get_scope_metadata_provider( + r""" + from typing import TypeVar + import a.G + TypeVar("G2", bound="a.G") + """ + ) + imp = ensure_type( + ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Import + ) + call = ensure_type( + ensure_type( + ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr + ).value, + cst.Call, + ) + bound = call.args[1].value + scope = scopes[imp] + assignment = next(iter(scope["a.G"])) + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + self.assertEqual(list(assignment.references)[0].node, bound) + def test_node_of_scopes(self) -> None: m, scopes = get_scope_metadata_provider( """