diff --git a/libcst/helpers/common.py b/libcst/helpers/common.py index 0965abebb..16c77669a 100644 --- a/libcst/helpers/common.py +++ b/libcst/helpers/common.py @@ -3,12 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # -from typing import Type +from typing import Type, TypeVar -from libcst._types import CSTNodeT +T = TypeVar("T") -def ensure_type(node: object, nodetype: Type[CSTNodeT]) -> CSTNodeT: +def ensure_type(node: object, nodetype: Type[T]) -> T: """ Takes any python object, and a LibCST :class:`~libcst.CSTNode` subclass and refines the type of the python object. This is most useful when you already diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 4268c5d4d..73bb61f56 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -7,7 +7,7 @@ import abc import builtins from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, ExitStack from dataclasses import dataclass from enum import auto, Enum from typing import ( @@ -51,6 +51,10 @@ cst.Nonlocal, cst.Parameters, cst.WithItem, + cst.TypeVar, + cst.TypeAlias, + cst.TypeVarTuple, + cst.ParamSpec, ) @@ -116,7 +120,7 @@ def record_assignment(self, assignment: "BaseAssignment") -> None: self.__assignments.add(assignment) def record_assignments(self, name: str) -> None: - assignments = self.scope[name] + assignments = self.scope._resolve_scope_for_access(name, self.scope) # filter out assignments that happened later than this access previous_assignments = { assignment @@ -124,7 +128,9 @@ def record_assignments(self, name: str) -> None: if assignment.scope != self.scope or assignment._index < self.__index } if not previous_assignments and assignments and self.scope.parent != self.scope: - previous_assignments = self.scope.parent[name] + previous_assignments = self.scope.parent._resolve_scope_for_access( + name, self.scope + ) self.__assignments |= previous_assignments @@ -440,7 +446,7 @@ def record_access(self, name: str, access: Access) -> None: self._accesses_by_name[name].add(access) self._accesses_by_node[access.node].add(access) - def _is_visible_from_children(self) -> bool: + def _is_visible_from_children(self, from_scope: "Scope") -> bool: """Returns if the assignments in this scope can be accessed from children. This is normally True, except for class scopes:: @@ -459,9 +465,11 @@ def inner_fn(): """ return True - def _next_visible_parent(self, first: Optional["Scope"] = None) -> "Scope": + def _next_visible_parent( + self, from_scope: "Scope", first: Optional["Scope"] = None + ) -> "Scope": parent = first if first is not None else self.parent - while not parent._is_visible_from_children(): + while not parent._is_visible_from_children(from_scope): parent = parent.parent return parent @@ -470,7 +478,6 @@ def __contains__(self, name: str) -> bool: """Check if the name str exist in current scope by ``name in scope``.""" ... - @abc.abstractmethod def __getitem__(self, name: str) -> Set[BaseAssignment]: """ Get assignments given a name str by ``scope[name]``. @@ -508,6 +515,12 @@ def __getitem__(self, name: str) -> Set[BaseAssignment]: defined a given name by the time a piece of code is executed. For the above example, value would resolve to a set of both assignments. """ + return self._resolve_scope_for_access(name, self) + + @abc.abstractmethod + def _resolve_scope_for_access( + self, name: str, from_scope: "Scope" + ) -> Set[BaseAssignment]: ... def __hash__(self) -> int: @@ -612,7 +625,9 @@ def __init__(self, globals: Scope) -> None: def __contains__(self, name: str) -> bool: return hasattr(builtins, name) - def __getitem__(self, name: str) -> Set[BaseAssignment]: + def _resolve_scope_for_access( + self, name: str, from_scope: "Scope" + ) -> Set[BaseAssignment]: if name in self._assignments: return self._assignments[name] if hasattr(builtins, name): @@ -644,13 +659,15 @@ def __init__(self) -> None: def __contains__(self, name: str) -> bool: if name in self._assignments: return len(self._assignments[name]) > 0 - return name in self._next_visible_parent() + return name in self._next_visible_parent(self) - def __getitem__(self, name: str) -> Set[BaseAssignment]: + def _resolve_scope_for_access( + self, name: str, from_scope: "Scope" + ) -> Set[BaseAssignment]: if name in self._assignments: return self._assignments[name] - parent = self._next_visible_parent() + parent = self._next_visible_parent(from_scope) return parent[name] def record_global_overwrite(self, name: str) -> None: @@ -688,7 +705,7 @@ def record_nonlocal_overwrite(self, name: str) -> None: def _find_assignment_target(self, name: str) -> "Scope": if name in self._scope_overwrites: scope = self._scope_overwrites[name] - return self._next_visible_parent(scope)._find_assignment_target(name) + return self._next_visible_parent(self, scope)._find_assignment_target(name) else: return super()._find_assignment_target(name) @@ -697,16 +714,22 @@ def __contains__(self, name: str) -> bool: return name in self._scope_overwrites[name] if name in self._assignments: return len(self._assignments[name]) > 0 - return name in self._next_visible_parent() + return name in self._next_visible_parent(self) - def __getitem__(self, name: str) -> Set[BaseAssignment]: + def _resolve_scope_for_access( + self, name: str, from_scope: "Scope" + ) -> Set[BaseAssignment]: if name in self._scope_overwrites: scope = self._scope_overwrites[name] - return self._next_visible_parent(scope)[name] + return self._next_visible_parent( + from_scope, scope + )._resolve_scope_for_access(name, from_scope) if name in self._assignments: return self._assignments[name] else: - return self._next_visible_parent()[name] + return self._next_visible_parent(from_scope)._resolve_scope_for_access( + name, from_scope + ) def _make_name_prefix(self) -> str: # filter falsey strings out @@ -728,8 +751,8 @@ class ClassScope(LocalScope): When a class is defined, it creates a ClassScope. """ - def _is_visible_from_children(self) -> bool: - return False + def _is_visible_from_children(self, from_scope: "Scope") -> bool: + return from_scope.parent is self and isinstance(from_scope, AnnotationScope) def _make_name_prefix(self) -> str: # filter falsey strings out @@ -755,6 +778,19 @@ def _make_name_prefix(self) -> str: return ".".join(filter(None, [self.parent._name_prefix, ""])) +class AnnotationScope(LocalScope): + """ + Scopes used for type aliases and type parameters as defined by PEP-695. + + These scopes are created for type parameters using the special syntax, as well as + type aliases. See https://peps.python.org/pep-0695/#scoping-behavior for more. + """ + + def _make_name_prefix(self) -> str: + # these scopes are transparent for the purposes of qualified names + return self.parent._name_prefix + + # Generates dotted names from an Attribute or Name node: # Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a") # each string has the corresponding CSTNode attached to it @@ -822,6 +858,7 @@ class DeferredAccess: class ScopeVisitor(cst.CSTVisitor): # since it's probably not useful. That can makes this visitor cleaner. def __init__(self, provider: "ScopeProvider") -> None: + super().__init__() self.provider: ScopeProvider = provider self.scope: Scope = GlobalScope() self.__deferred_accesses: List[DeferredAccess] = [] @@ -992,15 +1029,22 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: self.scope.record_assignment(node.name.value, node) self.provider.set_metadata(node.name, self.scope) - with self._new_scope(FunctionScope, node, get_full_name_for_node(node.name)): - node.params.visit(self) - node.body.visit(self) + with ExitStack() as stack: + if node.type_parameters: + stack.enter_context(self._new_scope(AnnotationScope, node, None)) + node.type_parameters.visit(self) - for decorator in node.decorators: - decorator.visit(self) - returns = node.returns - if returns: - returns.visit(self) + with self._new_scope( + FunctionScope, node, get_full_name_for_node(node.name) + ): + node.params.visit(self) + node.body.visit(self) + + for decorator in node.decorators: + decorator.visit(self) + returns = node.returns + if returns: + returns.visit(self) return False @@ -1032,14 +1076,20 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: self.provider.set_metadata(node.name, self.scope) for decorator in node.decorators: decorator.visit(self) - for base in node.bases: - base.visit(self) - for keyword in node.keywords: - keyword.visit(self) - - with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)): - for statement in node.body.body: - statement.visit(self) + + with ExitStack() as stack: + if node.type_parameters: + stack.enter_context(self._new_scope(AnnotationScope, node, None)) + node.type_parameters.visit(self) + + for base in node.bases: + base.visit(self) + for keyword in node.keywords: + keyword.visit(self) + + with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)): + for statement in node.body.body: + statement.visit(self) return False def visit_ClassDef_bases(self, node: cst.ClassDef) -> None: @@ -1163,7 +1213,7 @@ def infer_accesses(self) -> None: access.scope.record_access(name, access) for (scope, name), accesses in scope_name_accesses.items(): - for assignment in scope[name]: + for assignment in scope._resolve_scope_for_access(name, scope): assignment.record_accesses(accesses) self.__deferred_accesses = [] @@ -1174,6 +1224,32 @@ def on_leave(self, original_node: cst.CSTNode) -> None: self.scope._assignment_count += 1 super().on_leave(original_node) + def visit_TypeAlias(self, node: cst.TypeAlias) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + + with self._new_scope(AnnotationScope, node, None): + if node.type_parameters is not None: + node.type_parameters.visit(self) + node.value.visit(self) + + return False + + def visit_TypeVar(self, node: cst.TypeVar) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + + if node.bound is not None: + node.bound.visit(self) + + return False + + def visit_TypeVarTuple(self, node: cst.TypeVarTuple) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + return False + + def visit_ParamSpec(self, node: cst.ParamSpec) -> Optional[bool]: + self.scope.record_assignment(node.name.value, node) + return False + class ScopeProvider(BatchableMetadataProvider[Optional[Scope]]): """ diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 9908cb4cd..5f6d485bb 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -11,9 +11,11 @@ import libcst as cst from libcst import ensure_type +from libcst._parser.entrypoints import is_native from libcst.metadata import MetadataWrapper from libcst.metadata.scope_provider import ( _gen_dotted_names, + AnnotationScope, Assignment, BuiltinAssignment, BuiltinScope, @@ -1982,3 +1984,228 @@ def something(): scope.get_qualified_names_for(cst.Name("something_else")), set(), ) + + def test_type_alias_scope(self) -> None: + if not is_native(): + self.skipTest("type aliases are only supported in the native parser") + m, scopes = get_scope_metadata_provider( + """ + type A = C + lol: A + """ + ) + alias = ensure_type( + ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.TypeAlias + ) + self.assertIsInstance(scopes[alias], GlobalScope) + a_assignments = list(scopes[alias]["A"]) + self.assertEqual(len(a_assignments), 1) + lol = ensure_type( + ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.AnnAssign + ) + self.assertEqual(len(a_references := list(a_assignments[0].references)), 1) + self.assertEqual(a_references[0].node, lol.annotation.annotation) + + self.assertIsInstance(scopes[alias.value], AnnotationScope) + + def test_type_alias_param(self) -> None: + if not is_native(): + self.skipTest("type parameters are only supported in the native parser") + m, scopes = get_scope_metadata_provider( + """ + B = int + type A[T: B] = T + lol: T + """ + ) + alias = ensure_type( + ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.TypeAlias + ) + assert alias.type_parameters + param_scope = scopes[alias.type_parameters] + self.assertEqual(len(t_assignments := list(param_scope["T"])), 1) + self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1) + self.assertEqual(t_refs[0].node, alias.value) + + b = ( + ensure_type( + ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Assign + ) + .targets[0] + .target + ) + b_assignment = list(scopes[b]["B"])[0] + self.assertEqual( + {ref.node for ref in b_assignment.references}, + {ensure_type(alias.type_parameters.params[0].param, cst.TypeVar).bound}, + ) + + def test_type_alias_tuple_and_paramspec(self) -> None: + if not is_native(): + self.skipTest("type parameters are only supported in the native parser") + m, scopes = get_scope_metadata_provider( + """ + type A[*T] = T + lol: T + type A[**T] = T + lol: T + """ + ) + alias_tuple = ensure_type( + ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.TypeAlias + ) + assert alias_tuple.type_parameters + param_scope = scopes[alias_tuple.type_parameters] + self.assertEqual(len(t_assignments := list(param_scope["T"])), 1) + self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1) + self.assertEqual(t_refs[0].node, alias_tuple.value) + + alias_paramspec = ensure_type( + ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.TypeAlias + ) + assert alias_paramspec.type_parameters + param_scope = scopes[alias_paramspec.type_parameters] + self.assertEqual(len(t_assignments := list(param_scope["T"])), 1) + self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1) + self.assertEqual(t_refs[0].node, alias_paramspec.value) + + def test_class_type_params(self) -> None: + if not is_native(): + self.skipTest("type parameters are only supported in the native parser") + m, scopes = get_scope_metadata_provider( + """ + class W[T]: + def f() -> T: pass + def g[T]() -> T: pass + """ + ) + cls = ensure_type(m.body[0], cst.ClassDef) + cls_scope = scopes[cls.body.body[0]] + self.assertEqual(len(t_assignments_in_cls := list(cls_scope["T"])), 1) + assert cls.type_parameters + self.assertEqual( + ensure_type(t_assignments_in_cls[0], Assignment).node, + cls.type_parameters.params[0].param, + ) + self.assertEqual( + len(t_refs_in_cls := list(t_assignments_in_cls[0].references)), 1 + ) + f = ensure_type(cls.body.body[0], cst.FunctionDef) + assert f.returns + self.assertEqual(t_refs_in_cls[0].node, f.returns.annotation) + + g = ensure_type(cls.body.body[1], cst.FunctionDef) + assert g.type_parameters + assert g.returns + self.assertEqual(len(t_assignments_in_g := list(scopes[g.body]["T"])), 1) + self.assertEqual( + ensure_type(t_assignments_in_g[0], Assignment).node, + g.type_parameters.params[0].param, + ) + self.assertEqual(len(t_refs_in_g := list(t_assignments_in_g[0].references)), 1) + self.assertEqual(t_refs_in_g[0].node, g.returns.annotation) + + def test_nested_class_type_params(self) -> None: + if not is_native(): + self.skipTest("type parameters are only supported in the native parser") + m, scopes = get_scope_metadata_provider( + """ + class Outer: + class Nested[T: Outer]: pass + """ + ) + outer = ensure_type(m.body[0], cst.ClassDef) + outer_refs = list(list(scopes[outer]["Outer"])[0].references) + self.assertEqual(len(outer_refs), 1) + inner = ensure_type(outer.body.body[0], cst.ClassDef) + assert inner.type_parameters + self.assertEqual( + outer_refs[0].node, + ensure_type(inner.type_parameters.params[0].param, cst.TypeVar).bound, + ) + + def test_annotation_refers_to_nested_class(self) -> None: + if not is_native(): + self.skipTest("type parameters are only supported in the native parser") + m, scopes = get_scope_metadata_provider( + """ + class Outer: + class Nested: + pass + + type Alias = Nested + + def meth1[T: Nested](self): pass + def meth2[T](self, arg: Nested): pass + """ + ) + outer = ensure_type(m.body[0], cst.ClassDef) + nested = ensure_type(outer.body.body[0], cst.ClassDef) + alias = ensure_type( + ensure_type(outer.body.body[1], cst.SimpleStatementLine).body[0], + cst.TypeAlias, + ) + self.assertIsInstance(scopes[alias.value], AnnotationScope) + nested_refs_within_alias = list(scopes[alias.value].accesses["Nested"]) + self.assertEqual(len(nested_refs_within_alias), 1) + self.assertEqual( + { + ensure_type(ref, Assignment).node + for ref in nested_refs_within_alias[0].referents + }, + {nested}, + ) + + meth1 = ensure_type(outer.body.body[2], cst.FunctionDef) + self.assertIsInstance(scopes[meth1], ClassScope) + assert meth1.type_parameters + meth1_typevar = ensure_type(meth1.type_parameters.params[0].param, cst.TypeVar) + meth1_typevar_scope = scopes[meth1_typevar] + self.assertIsInstance(meth1_typevar_scope, AnnotationScope) + nested_refs_within_meth1 = list(meth1_typevar_scope.accesses["Nested"]) + self.assertEqual(len(nested_refs_within_meth1), 1) + self.assertEqual( + { + ensure_type(ref, Assignment).node + for ref in nested_refs_within_meth1[0].referents + }, + {nested}, + ) + + meth2 = ensure_type(outer.body.body[3], cst.FunctionDef) + meth2_annotation = meth2.params.params[1].annotation + assert meth2_annotation + nested_refs_within_meth2 = list(scopes[meth2_annotation].accesses["Nested"]) + self.assertEqual(len(nested_refs_within_meth2), 1) + self.assertEqual( + { + ensure_type(ref, Assignment).node + for ref in nested_refs_within_meth2[0].referents + }, + {nested}, + ) + + def test_body_isnt_subject_to_special_annotation_rule(self) -> None: + if not is_native(): + self.skipTest("type parameters are only supported in the native parser") + m, scopes = get_scope_metadata_provider( + """ + class Outer: + class Inner: pass + def f[T: Inner](self): Inner + """ + ) + outer = ensure_type(m.body[0], cst.ClassDef) + # note: this is different from global scope + outer_scope = scopes[outer.body.body[0]] + inner_assignment = list(outer_scope["Inner"])[0] + self.assertEqual(len(inner_assignment.references), 1) + f = ensure_type(outer.body.body[1], cst.FunctionDef) + assert f.type_parameters + T = ensure_type(f.type_parameters.params[0].param, cst.TypeVar) + self.assertIs(list(inner_assignment.references)[0].node, T.bound) + + inner_in_func_body = ensure_type(f.body.body[0], cst.Expr) + f_scope = scopes[inner_in_func_body] + self.assertIn(inner_in_func_body.value, f_scope.accesses) + self.assertEqual(list(f_scope.accesses)[0].referents, set())