From 6a4fecf8bea39cbced03b0b07da490046fefdc87 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Sat, 9 Sep 2023 21:30:47 +0100 Subject: [PATCH] Scope provider changes for type annotations --- libcst/helpers/common.py | 6 +- libcst/metadata/scope_provider.py | 96 +++++++-- libcst/metadata/tests/test_scope_provider.py | 209 +++++++++++++++++++ 3 files changed, 291 insertions(+), 20 deletions(-) diff --git a/libcst/helpers/common.py b/libcst/helpers/common.py index 0965abebb..16c77669a 100644 --- a/libcst/helpers/common.py +++ b/libcst/helpers/common.py @@ -3,12 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # -from typing import Type +from typing import Type, TypeVar -from libcst._types import CSTNodeT +T = TypeVar("T") -def ensure_type(node: object, nodetype: Type[CSTNodeT]) -> CSTNodeT: +def ensure_type(node: object, nodetype: Type[T]) -> T: """ Takes any python object, and a LibCST :class:`~libcst.CSTNode` subclass and refines the type of the python object. This is most useful when you already diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 4268c5d4d..e3327adb3 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -7,7 +7,7 @@ import abc import builtins from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, ExitStack from dataclasses import dataclass from enum import auto, Enum from typing import ( @@ -33,6 +33,7 @@ ExpressionContext, ExpressionContextProvider, ) +from libcst.metadata.position_provider import PositionProvider # Comprehensions are handled separately in _visit_comp_alike due to # the complexity of the semantics @@ -51,6 +52,10 @@ cst.Nonlocal, cst.Parameters, cst.WithItem, + cst.TypeVar, + cst.TypeAlias, + cst.TypeVarTuple, + cst.ParamSpec, ) @@ -755,6 +760,23 @@ def _make_name_prefix(self) -> str: return ".".join(filter(None, [self.parent._name_prefix, ""])) +class AnnotationScope(LocalScope): + """ + Scopes used for type aliases and type parameters as defined by PEP-695. + + These scopes are created for type parameters using the special syntax, as well as + type aliases. See https://peps.python.org/pep-0695/#scoping-behavior for more. + """ + + def _make_name_prefix(self) -> str: + # these scopes are transparent for the purposes of qualified names + return self.parent._name_prefix + + def _next_visible_parent(self, first: Optional[Scope] = None) -> "Scope": + # ignore _is_visible_from_children explicitly + return first if first is not None else self.parent + + # Generates dotted names from an Attribute or Name node: # Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a") # each string has the corresponding CSTNode attached to it @@ -822,6 +844,7 @@ class DeferredAccess: class ScopeVisitor(cst.CSTVisitor): # since it's probably not useful. That can makes this visitor cleaner. def __init__(self, provider: "ScopeProvider") -> None: + super().__init__() self.provider: ScopeProvider = provider self.scope: Scope = GlobalScope() self.__deferred_accesses: List[DeferredAccess] = [] @@ -992,15 +1015,22 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: self.scope.record_assignment(node.name.value, node) self.provider.set_metadata(node.name, self.scope) - with self._new_scope(FunctionScope, node, get_full_name_for_node(node.name)): - node.params.visit(self) - node.body.visit(self) + with ExitStack() as stack: + if node.type_parameters: + stack.enter_context(self._new_scope(AnnotationScope, node, None)) + node.type_parameters.visit(self) - for decorator in node.decorators: - decorator.visit(self) - returns = node.returns - if returns: - returns.visit(self) + with self._new_scope( + FunctionScope, node, get_full_name_for_node(node.name) + ): + node.params.visit(self) + node.body.visit(self) + + for decorator in node.decorators: + decorator.visit(self) + returns = node.returns + if returns: + returns.visit(self) return False @@ -1032,14 +1062,20 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: self.provider.set_metadata(node.name, self.scope) for decorator in node.decorators: decorator.visit(self) - for base in node.bases: - base.visit(self) - for keyword in node.keywords: - keyword.visit(self) - - with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)): - for statement in node.body.body: - statement.visit(self) + + with ExitStack() as stack: + if node.type_parameters: + stack.enter_context(self._new_scope(AnnotationScope, node, None)) + node.type_parameters.visit(self) + + for base in node.bases: + base.visit(self) + for keyword in node.keywords: + keyword.visit(self) + + with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)): + for statement in node.body.body: + statement.visit(self) return False def visit_ClassDef_bases(self, node: cst.ClassDef) -> None: @@ -1174,6 +1210,32 @@ def on_leave(self, original_node: cst.CSTNode) -> None: self.scope._assignment_count += 1 super().on_leave(original_node) + def visit_TypeAlias(self, node: cst.TypeAlias) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + + with self._new_scope(AnnotationScope, node, None): + if node.type_parameters is not None: + node.type_parameters.visit(self) + node.value.visit(self) + + return False + + def visit_TypeVar(self, node: cst.TypeVar) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + + if node.bound is not None: + node.bound.visit(self) + + return False + + def visit_TypeVarTuple(self, node: cst.TypeVarTuple) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + return False + + def visit_ParamSpec(self, node: cst.ParamSpec) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + return False + class ScopeProvider(BatchableMetadataProvider[Optional[Scope]]): """ diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 9908cb4cd..129b9832a 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -12,8 +12,10 @@ import libcst as cst from libcst import ensure_type from libcst.metadata import MetadataWrapper +from libcst.metadata.position_provider import PositionProvider from libcst.metadata.scope_provider import ( _gen_dotted_names, + AnnotationScope, Assignment, BuiltinAssignment, BuiltinScope, @@ -1982,3 +1984,210 @@ def something(): scope.get_qualified_names_for(cst.Name("something_else")), set(), ) + + def test_type_alias_scope(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + type A = C + lol: A + """ + ) + alias = ensure_type( + ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.TypeAlias + ) + self.assertIsInstance(scopes[alias], GlobalScope) + a_assignments = list(scopes[alias]["A"]) + self.assertEqual(len(a_assignments), 1) + lol = ensure_type( + ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.AnnAssign + ) + self.assertEqual(len(a_references := list(a_assignments[0].references)), 1) + self.assertEqual(a_references[0].node, lol.annotation.annotation) + + self.assertIsInstance(scopes[alias.value], AnnotationScope) + + def test_type_alias_param(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + B = int + type A[T: B] = T + lol: T + """ + ) + alias = ensure_type( + ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.TypeAlias + ) + assert alias.type_parameters + param_scope = scopes[alias.type_parameters] + self.assertEqual(len(t_assignments := list(param_scope["T"])), 1) + self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1) + self.assertEqual(t_refs[0].node, alias.value) + + b = ( + ensure_type( + ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Assign + ) + .targets[0] + .target + ) + b_assignment = list(scopes[b]["B"])[0] + self.assertEqual( + {ref.node for ref in b_assignment.references}, + {ensure_type(alias.type_parameters.params[0].param, cst.TypeVar).bound}, + ) + + def test_type_alias_tuple_and_paramspec(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + type A[*T] = T + lol: T + type A[**T] = T + lol: T + """ + ) + alias_tuple = ensure_type( + ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.TypeAlias + ) + assert alias_tuple.type_parameters + param_scope = scopes[alias_tuple.type_parameters] + self.assertEqual(len(t_assignments := list(param_scope["T"])), 1) + self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1) + self.assertEqual(t_refs[0].node, alias_tuple.value) + + alias_paramspec = ensure_type( + ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.TypeAlias + ) + assert alias_paramspec.type_parameters + param_scope = scopes[alias_paramspec.type_parameters] + self.assertEqual(len(t_assignments := list(param_scope["T"])), 1) + self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1) + self.assertEqual(t_refs[0].node, alias_paramspec.value) + + def test_class_type_params(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + class W[T]: + def f() -> T: pass + def g[T]() -> T: pass + """ + ) + cls = ensure_type(m.body[0], cst.ClassDef) + cls_scope = scopes[cls.body.body[0]] + self.assertEqual(len(t_assignments_in_cls := list(cls_scope["T"])), 1) + assert cls.type_parameters + self.assertEqual( + ensure_type(t_assignments_in_cls[0], Assignment).node, + cls.type_parameters.params[0].param, + ) + self.assertEqual( + len(t_refs_in_cls := list(t_assignments_in_cls[0].references)), 1 + ) + f = ensure_type(cls.body.body[0], cst.FunctionDef) + assert f.returns + self.assertEqual(t_refs_in_cls[0].node, f.returns.annotation) + + g = ensure_type(cls.body.body[1], cst.FunctionDef) + assert g.type_parameters + assert g.returns + self.assertEqual(len(t_assignments_in_g := list(scopes[g.body]["T"])), 1) + self.assertEqual( + ensure_type(t_assignments_in_g[0], Assignment).node, + g.type_parameters.params[0].param, + ) + self.assertEqual(len(t_refs_in_g := list(t_assignments_in_g[0].references)), 1) + self.assertEqual(t_refs_in_g[0].node, g.returns.annotation) + + def test_nested_class_type_params(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + class Outer: + class Nested[T: Outer]: pass + """ + ) + outer = ensure_type(m.body[0], cst.ClassDef) + outer_refs = list(list(scopes[outer]["Outer"])[0].references) + self.assertEqual(len(outer_refs), 1) + inner = ensure_type(outer.body.body[0], cst.ClassDef) + assert inner.type_parameters + self.assertEqual( + outer_refs[0].node, + ensure_type(inner.type_parameters.params[0].param, cst.TypeVar).bound, + ) + + def test_annotation_refers_to_nested_class(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + class Outer: + class Nested: + pass + + type Alias = Nested + + def meth1[T: Nested](self): pass + def meth2[T](self, arg: Nested): pass + """ + ) + outer = ensure_type(m.body[0], cst.ClassDef) + nested = ensure_type(outer.body.body[0], cst.ClassDef) + alias = ensure_type( + ensure_type(outer.body.body[1], cst.SimpleStatementLine).body[0], + cst.TypeAlias, + ) + self.assertIsInstance(scopes[alias.value], AnnotationScope) + nested_refs_within_alias = list(scopes[alias.value].accesses["Nested"]) + self.assertEqual(len(nested_refs_within_alias), 1) + self.assertEqual( + { + ensure_type(ref, Assignment).node + for ref in nested_refs_within_alias[0].referents + }, + {nested}, + ) + + meth1 = ensure_type(outer.body.body[2], cst.FunctionDef) + self.assertIsInstance(scopes[meth1], ClassScope) + assert meth1.type_parameters + meth1_typevar = ensure_type(meth1.type_parameters.params[0].param, cst.TypeVar) + meth1_typevar_scope = scopes[meth1_typevar] + self.assertIsInstance(meth1_typevar_scope, AnnotationScope) + nested_refs_within_meth1 = list(meth1_typevar_scope.accesses["Nested"]) + self.assertEqual(len(nested_refs_within_meth1), 1) + self.assertEqual( + { + ensure_type(ref, Assignment).node + for ref in nested_refs_within_meth1[0].referents + }, + {nested}, + ) + + meth2 = ensure_type(outer.body.body[3], cst.FunctionDef) + meth2_annotation = meth2.params.params[1].annotation + assert meth2_annotation + nested_refs_within_meth2 = list(scopes[meth2_annotation].accesses["Nested"]) + self.assertEqual(len(nested_refs_within_meth2), 1) + self.assertEqual( + { + ensure_type(ref, Assignment).node + for ref in nested_refs_within_meth2[0].referents + }, + {nested}, + ) + + def test_body_isnt_subject_to_special_annotation_rule(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + class Outer: + class Inner: pass + def f[T: Inner](self): Inner() + """ + ) + outer = ensure_type(m.body[0], cst.ClassDef) + inner = ensure_type(outer.body.body[0], cst.ClassDef) + # note: this is different from global scope + outer_scope = scopes[outer.body.body[0]] + inner_assignment = list(outer_scope["Inner"])[0] + self.skipTest( + "f's body is subject to the annotation scoping rules" + " created by its type parameters, but it shouldn't be" + ) + self.assertEqual(len(inner_assignment.references), 1)