From 468ccc5c92fed95a6775eaf8d667cf6214265b15 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Sat, 9 Sep 2023 21:30:47 +0100 Subject: [PATCH] Scope provider changes for type annotations --- libcst/metadata/scope_provider.py | 40 ++++++++++++++++++++ libcst/metadata/tests/test_scope_provider.py | 26 +++++++++++++ 2 files changed, 66 insertions(+) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 4268c5d4d..b5ce73a89 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -33,6 +33,7 @@ ExpressionContext, ExpressionContextProvider, ) +from libcst.metadata.position_provider import PositionProvider # Comprehensions are handled separately in _visit_comp_alike due to # the complexity of the semantics @@ -754,6 +755,21 @@ def _make_name_prefix(self) -> str: # filter falsey strings out return ".".join(filter(None, [self.parent._name_prefix, ""])) +class AnnotationScope(LocalScope): + """ + Scopes used for type aliases and type parameters as defined by PEP-695. + + These scopes are created for type parameters using the special syntax, as well as + type aliases. See https://peps.python.org/pep-0695/#scoping-behavior for more. + """ + + def _make_name_prefix(self) -> str: + # these scopes are transparent for the purposes of qualified names + return self.parent._name_prefix + + def _next_visible_parent(self, first: Optional[Scope] = None) -> "Scope": + # ignore _is_visible_from_children explicitly + return first if first is not None else self.parent # Generates dotted names from an Attribute or Name node: # Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a") @@ -820,6 +836,7 @@ class DeferredAccess: class ScopeVisitor(cst.CSTVisitor): + METADATA_DEPENDENCIES = (PositionProvider, ) # since it's probably not useful. That can makes this visitor cleaner. def __init__(self, provider: "ScopeProvider") -> None: self.provider: ScopeProvider = provider @@ -1146,6 +1163,8 @@ def infer_accesses(self) -> None: def_access.enclosing_string_annotation, ) name = ensure_type(access.node, cst.Name).value + if name == "Nested": + breakpoint() if enclosing_attribute is not None: # if _gen_dotted_names doesn't generate any values, fall back to # the original name node above @@ -1174,6 +1193,27 @@ def on_leave(self, original_node: cst.CSTNode) -> None: self.scope._assignment_count += 1 super().on_leave(original_node) + def visit_TypeAlias(self, node: cst.TypeAlias) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + + with self._new_scope(AnnotationScope, node, None): + if node.type_parameters is not None: + node.type_parameters.visit(self) + node.value.visit(self) + + return False + + def visit_TypeVar(self, node: cst.TypeVar) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + + if node.bound is not None: + with self._new_scope(AnnotationScope, node, None): + node.bound.visit(self) + + return False + + + class ScopeProvider(BatchableMetadataProvider[Optional[Scope]]): """ diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 9908cb4cd..51b6f2aaf 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -27,6 +27,7 @@ QualifiedNameSource, Scope, ScopeProvider, + AnnotationScope, ) from libcst.testing.utils import data_provider, UnitTest @@ -1982,3 +1983,28 @@ def something(): scope.get_qualified_names_for(cst.Name("something_else")), set(), ) + + def test_annotation_refers_to_nested_class(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + class Outer: + class Nested: + pass + + type Alias = Nested + + def meth1[T: Nested](self): pass + def meth2[T](self, arg: Nested): pass + """ + ) + outer = ensure_type(m.body[0], cst.ClassDef) + nested = ensure_type(outer.body.body[0], cst.ClassDef) + alias = ensure_type(ensure_type(outer.body.body[1], cst.SimpleStatementLine).body[0], cst.TypeAlias) + self.assertIsInstance(scopes[alias.value], AnnotationScope) + nested_refs_within_alias = list(scopes[alias.value].accesses["Nested"]) + self.assertEqual(len(nested_refs_within_alias), 1) + self.assertEqual(nested_refs_within_alias[0].referents, {nested.name}) + + meth1 = ensure_type(outer.body.body[2], cst.FunctionDef) + meth1_scope = scopes[meth1] + self.assertIsInstance(meth1_scope, AnnotationScope)