Skip to content

Commit

Permalink
Scope provider changes for type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol committed Sep 16, 2023
1 parent 7406106 commit 6a4fecf
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 20 deletions.
6 changes: 3 additions & 3 deletions libcst/helpers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Type
from typing import Type, TypeVar

from libcst._types import CSTNodeT
T = TypeVar("T")


def ensure_type(node: object, nodetype: Type[CSTNodeT]) -> CSTNodeT:
def ensure_type(node: object, nodetype: Type[T]) -> T:
"""
Takes any python object, and a LibCST :class:`~libcst.CSTNode` subclass and
refines the type of the python object. This is most useful when you already
Expand Down
96 changes: 79 additions & 17 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import abc
import builtins
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
from enum import auto, Enum
from typing import (
Expand All @@ -33,6 +33,7 @@
ExpressionContext,
ExpressionContextProvider,
)
from libcst.metadata.position_provider import PositionProvider

# Comprehensions are handled separately in _visit_comp_alike due to
# the complexity of the semantics
Expand All @@ -51,6 +52,10 @@
cst.Nonlocal,
cst.Parameters,
cst.WithItem,
cst.TypeVar,
cst.TypeAlias,
cst.TypeVarTuple,
cst.ParamSpec,
)


Expand Down Expand Up @@ -755,6 +760,23 @@ def _make_name_prefix(self) -> str:
return ".".join(filter(None, [self.parent._name_prefix, "<comprehension>"]))


class AnnotationScope(LocalScope):
"""
Scopes used for type aliases and type parameters as defined by PEP-695.
These scopes are created for type parameters using the special syntax, as well as
type aliases. See https://peps.python.org/pep-0695/#scoping-behavior for more.
"""

def _make_name_prefix(self) -> str:
# these scopes are transparent for the purposes of qualified names
return self.parent._name_prefix

def _next_visible_parent(self, first: Optional[Scope] = None) -> "Scope":
# ignore _is_visible_from_children explicitly
return first if first is not None else self.parent


# Generates dotted names from an Attribute or Name node:
# Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a")
# each string has the corresponding CSTNode attached to it
Expand Down Expand Up @@ -822,6 +844,7 @@ class DeferredAccess:
class ScopeVisitor(cst.CSTVisitor):
# since it's probably not useful. That can makes this visitor cleaner.
def __init__(self, provider: "ScopeProvider") -> None:
super().__init__()
self.provider: ScopeProvider = provider
self.scope: Scope = GlobalScope()
self.__deferred_accesses: List[DeferredAccess] = []
Expand Down Expand Up @@ -992,15 +1015,22 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
self.provider.set_metadata(node.name, self.scope)

with self._new_scope(FunctionScope, node, get_full_name_for_node(node.name)):
node.params.visit(self)
node.body.visit(self)
with ExitStack() as stack:
if node.type_parameters:
stack.enter_context(self._new_scope(AnnotationScope, node, None))
node.type_parameters.visit(self)

for decorator in node.decorators:
decorator.visit(self)
returns = node.returns
if returns:
returns.visit(self)
with self._new_scope(
FunctionScope, node, get_full_name_for_node(node.name)
):
node.params.visit(self)
node.body.visit(self)

for decorator in node.decorators:
decorator.visit(self)
returns = node.returns
if returns:
returns.visit(self)

return False

Expand Down Expand Up @@ -1032,14 +1062,20 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.provider.set_metadata(node.name, self.scope)
for decorator in node.decorators:
decorator.visit(self)
for base in node.bases:
base.visit(self)
for keyword in node.keywords:
keyword.visit(self)

with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)):
for statement in node.body.body:
statement.visit(self)

with ExitStack() as stack:
if node.type_parameters:
stack.enter_context(self._new_scope(AnnotationScope, node, None))
node.type_parameters.visit(self)

for base in node.bases:
base.visit(self)
for keyword in node.keywords:
keyword.visit(self)

with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)):
for statement in node.body.body:
statement.visit(self)
return False

def visit_ClassDef_bases(self, node: cst.ClassDef) -> None:
Expand Down Expand Up @@ -1174,6 +1210,32 @@ def on_leave(self, original_node: cst.CSTNode) -> None:
self.scope._assignment_count += 1
super().on_leave(original_node)

def visit_TypeAlias(self, node: cst.TypeAlias) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)

with self._new_scope(AnnotationScope, node, None):
if node.type_parameters is not None:
node.type_parameters.visit(self)
node.value.visit(self)

return False

def visit_TypeVar(self, node: cst.TypeVar) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)

if node.bound is not None:
node.bound.visit(self)

return False

def visit_TypeVarTuple(self, node: cst.TypeVarTuple) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
return False

def visit_ParamSpec(self, node: cst.ParamSpec) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
return False


class ScopeProvider(BatchableMetadataProvider[Optional[Scope]]):
"""
Expand Down
209 changes: 209 additions & 0 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import libcst as cst
from libcst import ensure_type
from libcst.metadata import MetadataWrapper
from libcst.metadata.position_provider import PositionProvider
from libcst.metadata.scope_provider import (
_gen_dotted_names,
AnnotationScope,
Assignment,
BuiltinAssignment,
BuiltinScope,
Expand Down Expand Up @@ -1982,3 +1984,210 @@ def something():
scope.get_qualified_names_for(cst.Name("something_else")),
set(),
)

def test_type_alias_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
type A = C
lol: A
"""
)
alias = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.TypeAlias
)
self.assertIsInstance(scopes[alias], GlobalScope)
a_assignments = list(scopes[alias]["A"])
self.assertEqual(len(a_assignments), 1)
lol = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.AnnAssign
)
self.assertEqual(len(a_references := list(a_assignments[0].references)), 1)
self.assertEqual(a_references[0].node, lol.annotation.annotation)

self.assertIsInstance(scopes[alias.value], AnnotationScope)

def test_type_alias_param(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
B = int
type A[T: B] = T
lol: T
"""
)
alias = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.TypeAlias
)
assert alias.type_parameters
param_scope = scopes[alias.type_parameters]
self.assertEqual(len(t_assignments := list(param_scope["T"])), 1)
self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1)
self.assertEqual(t_refs[0].node, alias.value)

b = (
ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
b_assignment = list(scopes[b]["B"])[0]
self.assertEqual(
{ref.node for ref in b_assignment.references},
{ensure_type(alias.type_parameters.params[0].param, cst.TypeVar).bound},
)

def test_type_alias_tuple_and_paramspec(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
type A[*T] = T
lol: T
type A[**T] = T
lol: T
"""
)
alias_tuple = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.TypeAlias
)
assert alias_tuple.type_parameters
param_scope = scopes[alias_tuple.type_parameters]
self.assertEqual(len(t_assignments := list(param_scope["T"])), 1)
self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1)
self.assertEqual(t_refs[0].node, alias_tuple.value)

alias_paramspec = ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.TypeAlias
)
assert alias_paramspec.type_parameters
param_scope = scopes[alias_paramspec.type_parameters]
self.assertEqual(len(t_assignments := list(param_scope["T"])), 1)
self.assertEqual(len(t_refs := list(t_assignments[0].references)), 1)
self.assertEqual(t_refs[0].node, alias_paramspec.value)

def test_class_type_params(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
class W[T]:
def f() -> T: pass
def g[T]() -> T: pass
"""
)
cls = ensure_type(m.body[0], cst.ClassDef)
cls_scope = scopes[cls.body.body[0]]
self.assertEqual(len(t_assignments_in_cls := list(cls_scope["T"])), 1)
assert cls.type_parameters
self.assertEqual(
ensure_type(t_assignments_in_cls[0], Assignment).node,
cls.type_parameters.params[0].param,
)
self.assertEqual(
len(t_refs_in_cls := list(t_assignments_in_cls[0].references)), 1
)
f = ensure_type(cls.body.body[0], cst.FunctionDef)
assert f.returns
self.assertEqual(t_refs_in_cls[0].node, f.returns.annotation)

g = ensure_type(cls.body.body[1], cst.FunctionDef)
assert g.type_parameters
assert g.returns
self.assertEqual(len(t_assignments_in_g := list(scopes[g.body]["T"])), 1)
self.assertEqual(
ensure_type(t_assignments_in_g[0], Assignment).node,
g.type_parameters.params[0].param,
)
self.assertEqual(len(t_refs_in_g := list(t_assignments_in_g[0].references)), 1)
self.assertEqual(t_refs_in_g[0].node, g.returns.annotation)

def test_nested_class_type_params(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
class Outer:
class Nested[T: Outer]: pass
"""
)
outer = ensure_type(m.body[0], cst.ClassDef)
outer_refs = list(list(scopes[outer]["Outer"])[0].references)
self.assertEqual(len(outer_refs), 1)
inner = ensure_type(outer.body.body[0], cst.ClassDef)
assert inner.type_parameters
self.assertEqual(
outer_refs[0].node,
ensure_type(inner.type_parameters.params[0].param, cst.TypeVar).bound,
)

def test_annotation_refers_to_nested_class(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
class Outer:
class Nested:
pass
type Alias = Nested
def meth1[T: Nested](self): pass
def meth2[T](self, arg: Nested): pass
"""
)
outer = ensure_type(m.body[0], cst.ClassDef)
nested = ensure_type(outer.body.body[0], cst.ClassDef)
alias = ensure_type(
ensure_type(outer.body.body[1], cst.SimpleStatementLine).body[0],
cst.TypeAlias,
)
self.assertIsInstance(scopes[alias.value], AnnotationScope)
nested_refs_within_alias = list(scopes[alias.value].accesses["Nested"])
self.assertEqual(len(nested_refs_within_alias), 1)
self.assertEqual(
{
ensure_type(ref, Assignment).node
for ref in nested_refs_within_alias[0].referents
},
{nested},
)

meth1 = ensure_type(outer.body.body[2], cst.FunctionDef)
self.assertIsInstance(scopes[meth1], ClassScope)
assert meth1.type_parameters
meth1_typevar = ensure_type(meth1.type_parameters.params[0].param, cst.TypeVar)
meth1_typevar_scope = scopes[meth1_typevar]
self.assertIsInstance(meth1_typevar_scope, AnnotationScope)
nested_refs_within_meth1 = list(meth1_typevar_scope.accesses["Nested"])
self.assertEqual(len(nested_refs_within_meth1), 1)
self.assertEqual(
{
ensure_type(ref, Assignment).node
for ref in nested_refs_within_meth1[0].referents
},
{nested},
)

meth2 = ensure_type(outer.body.body[3], cst.FunctionDef)
meth2_annotation = meth2.params.params[1].annotation
assert meth2_annotation
nested_refs_within_meth2 = list(scopes[meth2_annotation].accesses["Nested"])
self.assertEqual(len(nested_refs_within_meth2), 1)
self.assertEqual(
{
ensure_type(ref, Assignment).node
for ref in nested_refs_within_meth2[0].referents
},
{nested},
)

def test_body_isnt_subject_to_special_annotation_rule(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
class Outer:
class Inner: pass
def f[T: Inner](self): Inner()
"""
)
outer = ensure_type(m.body[0], cst.ClassDef)
inner = ensure_type(outer.body.body[0], cst.ClassDef)
# note: this is different from global scope
outer_scope = scopes[outer.body.body[0]]
inner_assignment = list(outer_scope["Inner"])[0]
self.skipTest(
"f's body is subject to the annotation scoping rules"
" created by its type parameters, but it shouldn't be"
)
self.assertEqual(len(inner_assignment.references), 1)

0 comments on commit 6a4fecf

Please sign in to comment.