Skip to content

Commit

Permalink
ScopeProvider: Record Access for Attributes and Decorators (#1019)
Browse files Browse the repository at this point in the history
* Support for Attributes and Decorators in _NameUtil

* Replaced _NameUtil with get_full_name_for_node

* Added tests
  • Loading branch information
andrecsilva authored Oct 1, 2023
1 parent e1da64b commit 552af63
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
21 changes: 2 additions & 19 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def __iter__(self) -> Iterator[BaseAssignment]:

def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[BaseAssignment]:
"""Get assignments given a name str or :class:`~libcst.CSTNode` by ``scope.assignments[node]``"""
name = _NameUtil.get_name_for(node)
name = get_full_name_for_node(node)
return set(self._assignments[name]) if name in self._assignments else set()

def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
Expand All @@ -352,31 +352,14 @@ def __iter__(self) -> Iterator[Access]:

def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[Access]:
"""Get accesses given a name str or :class:`~libcst.CSTNode` by ``scope.accesses[node]``"""
name = _NameUtil.get_name_for(node)
name = get_full_name_for_node(node)
return self._accesses[name] if name in self._accesses else set()

def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
"""Check if a name str or :class:`~libcst.CSTNode` has any access by ``node in scope.accesses``"""
return len(self[node]) > 0


class _NameUtil:
@staticmethod
def get_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]:
"""A helper function to retrieve simple name str from a CSTNode or str"""
if isinstance(node, cst.Name):
return node.value
elif isinstance(node, str):
return node
elif isinstance(node, cst.Call):
return _NameUtil.get_name_for(node.func)
elif isinstance(node, cst.Subscript):
return _NameUtil.get_name_for(node.value)
elif isinstance(node, (cst.FunctionDef, cst.ClassDef)):
return _NameUtil.get_name_for(node.name)
return None


class Scope(abc.ABC):
"""
Base class of all scope classes. Scope object stores assignments from imports,
Expand Down
39 changes: 39 additions & 0 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,45 @@ def test_dotted_import_access(self) -> None:
self.assertEqual(list(scope_of_module["x.y"])[0].references, set())
self.assertEqual(scope_of_module.accesses["x.y"], set())

def test_dotted_import_access_reference_by_node(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import a.b.c
a.b.c()
"""
)
scope_of_module = scopes[m]
first_statement = ensure_type(m.body[1], cst.SimpleStatementLine)
call = ensure_type(
ensure_type(first_statement.body[0], cst.Expr).value, cst.Call
)

a_b_c_assignment = cast(ImportAssignment, list(scope_of_module["a.b.c"])[0])
a_b_c_access = list(a_b_c_assignment.references)[0]
self.assertEqual(scope_of_module.accesses[call], {a_b_c_access})
self.assertEqual(a_b_c_access.node, call.func)

def test_decorator_access_reference_by_node(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import decorator
@decorator
def f():
pass
"""
)
scope_of_module = scopes[m]
function_def = ensure_type(m.body[1], cst.FunctionDef)
decorator = function_def.decorators[0]
self.assertTrue("decorator" in scope_of_module)

decorator_assignment = cast(
ImportAssignment, list(scope_of_module["decorator"])[0]
)
decorator_access = list(decorator_assignment.references)[0]
self.assertEqual(scope_of_module.accesses[decorator], {decorator_access})

def test_dotted_import_with_call_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
Expand Down

0 comments on commit 552af63

Please sign in to comment.