Skip to content

Commit

Permalink
tie accesses from string annotation to the string node (#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol authored May 12, 2021
1 parent d1606b7 commit 4d2ccc5
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 7 deletions.
40 changes: 33 additions & 7 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def __new__(cls) -> "Tree":

#: The node of the access. A name is an access when the expression context is
#: :attr:`ExpressionContext.LOAD`. This is usually the name node representing the
#: access, except for dotted imports, when it might be the attribute that
#: represents the most specific part of the imported symbol.
node: Union[cst.Name, cst.Attribute]
#: access, except for: 1) dotted imports, when it might be the attribute that
#: represents the most specific part of the imported symbol; and 2) string
#: annotations, when it is the entire string literal
node: Union[cst.Name, cst.Attribute, cst.BaseString]

#: The scope of the access. Note that a access could be in a child scope of its
#: assignment.
Expand Down Expand Up @@ -422,7 +423,7 @@ def _record_assignment_as_parent(self, name: str, node: cst.CSTNode) -> None:

@abc.abstractmethod
def __contains__(self, name: str) -> bool:
""" Check if the name str exist in current scope by ``name in scope``. """
"""Check if the name str exist in current scope by ``name in scope``."""
...

@abc.abstractmethod
Expand Down Expand Up @@ -775,18 +776,26 @@ def _is_assignment(node: cst.CSTNode, assignment_node: cst.CSTNode) -> bool:
return False


@dataclass(frozen=True)
class DeferredAccess:
access: Access
enclosing_attribute: Optional[cst.Attribute]
enclosing_string_annotation: Optional[cst.BaseString]


class ScopeVisitor(cst.CSTVisitor):
# since it's probably not useful. That can makes this visitor cleaner.
def __init__(self, provider: "ScopeProvider") -> None:
self.provider: ScopeProvider = provider
self.scope: Scope = GlobalScope()
self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = []
self.__deferred_accesses: List[DeferredAccess] = []
self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None]
self.__in_annotation: Set[
Union[cst.Call, cst.Annotation, cst.Subscript]
] = set()
self.__in_type_hint: Set[Union[cst.Call, cst.Annotation, cst.Subscript]] = set()
self.__in_ignored_subscript: Set[cst.Subscript] = set()
self.__last_string_annotation: Optional[cst.BaseString] = None
self.__ignore_annotation: int = 0

@contextmanager
Expand Down Expand Up @@ -887,8 +896,13 @@ def _handle_string_annotation(
) and not self.__in_ignored_subscript:
value = node.evaluated_value
if value:
top_level_annotation = self.__last_string_annotation is None
if top_level_annotation:
self.__last_string_annotation = node
mod = cst.parse_module(value)
mod.visit(self)
if top_level_annotation:
self.__last_string_annotation = None
return True
return False

Expand Down Expand Up @@ -920,7 +934,11 @@ def visit_Name(self, node: cst.Name) -> Optional[bool]:
is_type_hint=bool(self.__in_type_hint),
)
self.__deferred_accesses.append(
(access, self.__top_level_attribute_stack[-1])
DeferredAccess(
access=access,
enclosing_attribute=self.__top_level_attribute_stack[-1],
enclosing_string_annotation=self.__last_string_annotation,
)
)

def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
Expand Down Expand Up @@ -1074,7 +1092,12 @@ def infer_accesses(self) -> None:
# In worst case, all accesses (m) and assignments (n) refer to the same name,
# the time complexity is O(m x n), this optimizes it as O(m + n).
scope_name_accesses = defaultdict(set)
for (access, enclosing_attribute) in self.__deferred_accesses:
for def_access in self.__deferred_accesses:
access, enclosing_attribute, enclosing_string_annotation = (
def_access.access,
def_access.enclosing_attribute,
def_access.enclosing_string_annotation,
)
name = ensure_type(access.node, cst.Name).value
if enclosing_attribute is not None:
# if _gen_dotted_names doesn't generate any values, fall back to
Expand All @@ -1085,6 +1108,9 @@ def infer_accesses(self) -> None:
name = attr_name
break

if enclosing_string_annotation is not None:
access.node = enclosing_string_annotation

scope_name_accesses[(access.scope, name)].add(access)
access.record_assignments(name)
access.scope.record_access(name, access)
Expand Down
72 changes: 72 additions & 0 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,10 @@ class Test(Generic[J]):
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "B")

assignment = list(scope["C"])[0]
self.assertIsInstance(assignment, Assignment)
Expand All @@ -1104,6 +1108,10 @@ class Test(Generic[J]):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "E")

assignment = list(scope["E2"])[0]
self.assertIsInstance(assignment, Assignment)
Expand All @@ -1119,6 +1127,10 @@ class Test(Generic[J]):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "Optional[G]")

assignment = list(scope["G2"])[0]
self.assertIsInstance(assignment, Assignment)
Expand All @@ -1130,6 +1142,10 @@ class Test(Generic[J]):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "H")

assignment = list(scope["I"])[0]
self.assertIsInstance(assignment, Assignment)
Expand All @@ -1148,6 +1164,10 @@ class Test(Generic[J]):
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "K")

assignment = list(scope["K2"])[0]
self.assertIsInstance(assignment, Assignment)
Expand All @@ -1157,12 +1177,64 @@ class Test(Generic[J]):
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
reference_node = references[0].node
self.assertIsInstance(reference_node, cst.SimpleString)
if isinstance(reference_node, cst.SimpleString):
self.assertEqual(reference_node.evaluated_value, "L")

assignment = list(scope["M"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)

def test_insane_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
r"""
from typing import TypeVar
from a import G
TypeVar("G2", bound="Optional[\"G\"]")
"""
)
imp = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
call = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)
bound = call.args[1].value
scope = scopes[imp]
assignment = next(iter(scope["G"]))
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
self.assertEqual(list(assignment.references)[0].node, bound)

def test_dotted_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
r"""
from typing import TypeVar
import a.G
TypeVar("G2", bound="a.G")
"""
)
imp = ensure_type(
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Import
)
call = ensure_type(
ensure_type(
ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr
).value,
cst.Call,
)
bound = call.args[1].value
scope = scopes[imp]
assignment = next(iter(scope["a.G"]))
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
self.assertEqual(list(assignment.references)[0].node, bound)

def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
Expand Down

0 comments on commit 4d2ccc5

Please sign in to comment.