Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scope provider changes for type annotations #1014

Merged
merged 1 commit into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libcst/helpers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Type
from typing import Type, TypeVar

from libcst._types import CSTNodeT
T = TypeVar("T")


def ensure_type(node: object, nodetype: Type[CSTNodeT]) -> CSTNodeT:
def ensure_type(node: object, nodetype: Type[T]) -> T:
"""
Takes any python object, and a LibCST :class:`~libcst.CSTNode` subclass and
refines the type of the python object. This is most useful when you already
Expand Down
146 changes: 111 additions & 35 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import abc
import builtins
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
from enum import auto, Enum
from typing import (
Expand Down Expand Up @@ -51,6 +51,10 @@
cst.Nonlocal,
cst.Parameters,
cst.WithItem,
cst.TypeVar,
cst.TypeAlias,
cst.TypeVarTuple,
cst.ParamSpec,
)


Expand Down Expand Up @@ -116,15 +120,17 @@ def record_assignment(self, assignment: "BaseAssignment") -> None:
self.__assignments.add(assignment)

def record_assignments(self, name: str) -> None:
assignments = self.scope[name]
assignments = self.scope._resolve_scope_for_access(name, self.scope)
# filter out assignments that happened later than this access
previous_assignments = {
assignment
for assignment in assignments
if assignment.scope != self.scope or assignment._index < self.__index
}
if not previous_assignments and assignments and self.scope.parent != self.scope:
previous_assignments = self.scope.parent[name]
previous_assignments = self.scope.parent._resolve_scope_for_access(
name, self.scope
)
self.__assignments |= previous_assignments


Expand Down Expand Up @@ -440,7 +446,7 @@ def record_access(self, name: str, access: Access) -> None:
self._accesses_by_name[name].add(access)
self._accesses_by_node[access.node].add(access)

def _is_visible_from_children(self) -> bool:
def _is_visible_from_children(self, from_scope: "Scope") -> bool:
"""Returns if the assignments in this scope can be accessed from children.

This is normally True, except for class scopes::
Expand All @@ -459,9 +465,11 @@ def inner_fn():
"""
return True

def _next_visible_parent(self, first: Optional["Scope"] = None) -> "Scope":
def _next_visible_parent(
self, from_scope: "Scope", first: Optional["Scope"] = None
) -> "Scope":
parent = first if first is not None else self.parent
while not parent._is_visible_from_children():
while not parent._is_visible_from_children(from_scope):
parent = parent.parent
return parent

Expand All @@ -470,7 +478,6 @@ def __contains__(self, name: str) -> bool:
"""Check if the name str exist in current scope by ``name in scope``."""
...

@abc.abstractmethod
def __getitem__(self, name: str) -> Set[BaseAssignment]:
"""
Get assignments given a name str by ``scope[name]``.
Expand Down Expand Up @@ -508,6 +515,12 @@ def __getitem__(self, name: str) -> Set[BaseAssignment]:
defined a given name by the time a piece of code is executed.
For the above example, value would resolve to a set of both assignments.
"""
return self._resolve_scope_for_access(name, self)

@abc.abstractmethod
def _resolve_scope_for_access(
self, name: str, from_scope: "Scope"
) -> Set[BaseAssignment]:
...

def __hash__(self) -> int:
Expand Down Expand Up @@ -612,7 +625,9 @@ def __init__(self, globals: Scope) -> None:
def __contains__(self, name: str) -> bool:
return hasattr(builtins, name)

def __getitem__(self, name: str) -> Set[BaseAssignment]:
def _resolve_scope_for_access(
self, name: str, from_scope: "Scope"
) -> Set[BaseAssignment]:
if name in self._assignments:
return self._assignments[name]
if hasattr(builtins, name):
Expand Down Expand Up @@ -644,13 +659,15 @@ def __init__(self) -> None:
def __contains__(self, name: str) -> bool:
if name in self._assignments:
return len(self._assignments[name]) > 0
return name in self._next_visible_parent()
return name in self._next_visible_parent(self)

def __getitem__(self, name: str) -> Set[BaseAssignment]:
def _resolve_scope_for_access(
self, name: str, from_scope: "Scope"
) -> Set[BaseAssignment]:
if name in self._assignments:
return self._assignments[name]

parent = self._next_visible_parent()
parent = self._next_visible_parent(from_scope)
return parent[name]

def record_global_overwrite(self, name: str) -> None:
Expand Down Expand Up @@ -688,7 +705,7 @@ def record_nonlocal_overwrite(self, name: str) -> None:
def _find_assignment_target(self, name: str) -> "Scope":
if name in self._scope_overwrites:
scope = self._scope_overwrites[name]
return self._next_visible_parent(scope)._find_assignment_target(name)
return self._next_visible_parent(self, scope)._find_assignment_target(name)
else:
return super()._find_assignment_target(name)

Expand All @@ -697,16 +714,22 @@ def __contains__(self, name: str) -> bool:
return name in self._scope_overwrites[name]
if name in self._assignments:
return len(self._assignments[name]) > 0
return name in self._next_visible_parent()
return name in self._next_visible_parent(self)

def __getitem__(self, name: str) -> Set[BaseAssignment]:
def _resolve_scope_for_access(
self, name: str, from_scope: "Scope"
) -> Set[BaseAssignment]:
if name in self._scope_overwrites:
scope = self._scope_overwrites[name]
return self._next_visible_parent(scope)[name]
return self._next_visible_parent(
from_scope, scope
)._resolve_scope_for_access(name, from_scope)
if name in self._assignments:
return self._assignments[name]
else:
return self._next_visible_parent()[name]
return self._next_visible_parent(from_scope)._resolve_scope_for_access(
name, from_scope
)

def _make_name_prefix(self) -> str:
# filter falsey strings out
Expand All @@ -728,8 +751,8 @@ class ClassScope(LocalScope):
When a class is defined, it creates a ClassScope.
"""

def _is_visible_from_children(self) -> bool:
return False
def _is_visible_from_children(self, from_scope: "Scope") -> bool:
return from_scope.parent is self and isinstance(from_scope, AnnotationScope)

def _make_name_prefix(self) -> str:
# filter falsey strings out
Expand All @@ -755,6 +778,19 @@ def _make_name_prefix(self) -> str:
return ".".join(filter(None, [self.parent._name_prefix, "<comprehension>"]))


class AnnotationScope(LocalScope):
"""
Scopes used for type aliases and type parameters as defined by PEP-695.

These scopes are created for type parameters using the special syntax, as well as
type aliases. See https://peps.python.org/pep-0695/#scoping-behavior for more.
"""

def _make_name_prefix(self) -> str:
# these scopes are transparent for the purposes of qualified names
return self.parent._name_prefix


# Generates dotted names from an Attribute or Name node:
# Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a")
# each string has the corresponding CSTNode attached to it
Expand Down Expand Up @@ -822,6 +858,7 @@ class DeferredAccess:
class ScopeVisitor(cst.CSTVisitor):
# since it's probably not useful. That can makes this visitor cleaner.
def __init__(self, provider: "ScopeProvider") -> None:
super().__init__()
self.provider: ScopeProvider = provider
self.scope: Scope = GlobalScope()
self.__deferred_accesses: List[DeferredAccess] = []
Expand Down Expand Up @@ -992,15 +1029,22 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
self.provider.set_metadata(node.name, self.scope)

with self._new_scope(FunctionScope, node, get_full_name_for_node(node.name)):
node.params.visit(self)
node.body.visit(self)
with ExitStack() as stack:
if node.type_parameters:
stack.enter_context(self._new_scope(AnnotationScope, node, None))
node.type_parameters.visit(self)

for decorator in node.decorators:
decorator.visit(self)
returns = node.returns
if returns:
returns.visit(self)
with self._new_scope(
FunctionScope, node, get_full_name_for_node(node.name)
):
node.params.visit(self)
node.body.visit(self)

for decorator in node.decorators:
decorator.visit(self)
returns = node.returns
if returns:
returns.visit(self)

return False

Expand Down Expand Up @@ -1032,14 +1076,20 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.provider.set_metadata(node.name, self.scope)
for decorator in node.decorators:
decorator.visit(self)
for base in node.bases:
base.visit(self)
for keyword in node.keywords:
keyword.visit(self)

with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)):
for statement in node.body.body:
statement.visit(self)

with ExitStack() as stack:
if node.type_parameters:
stack.enter_context(self._new_scope(AnnotationScope, node, None))
node.type_parameters.visit(self)

for base in node.bases:
base.visit(self)
for keyword in node.keywords:
keyword.visit(self)

with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)):
for statement in node.body.body:
statement.visit(self)
return False

def visit_ClassDef_bases(self, node: cst.ClassDef) -> None:
Expand Down Expand Up @@ -1163,7 +1213,7 @@ def infer_accesses(self) -> None:
access.scope.record_access(name, access)

for (scope, name), accesses in scope_name_accesses.items():
for assignment in scope[name]:
for assignment in scope._resolve_scope_for_access(name, scope):
assignment.record_accesses(accesses)

self.__deferred_accesses = []
Expand All @@ -1174,6 +1224,32 @@ def on_leave(self, original_node: cst.CSTNode) -> None:
self.scope._assignment_count += 1
super().on_leave(original_node)

def visit_TypeAlias(self, node: cst.TypeAlias) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)

with self._new_scope(AnnotationScope, node, None):
if node.type_parameters is not None:
node.type_parameters.visit(self)
node.value.visit(self)

return False

def visit_TypeVar(self, node: cst.TypeVar) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)

if node.bound is not None:
node.bound.visit(self)

return False

def visit_TypeVarTuple(self, node: cst.TypeVarTuple) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
return False

def visit_ParamSpec(self, node: cst.ParamSpec) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
return False


class ScopeProvider(BatchableMetadataProvider[Optional[Scope]]):
"""
Expand Down
Loading
Loading