From 552af63d2923d390c678d9d1ec2123e21e7f21a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Silva?= <12188364+andrecsilva@users.noreply.github.com> Date: Sun, 1 Oct 2023 14:34:42 -0300 Subject: [PATCH] ScopeProvider: Record Access for Attributes and Decorators (#1019) * Support for Attributes and Decorators in _NameUtil * Replaced _NameUtil with get_full_name_for_node * Added tests --- libcst/metadata/scope_provider.py | 21 +---------- libcst/metadata/tests/test_scope_provider.py | 39 ++++++++++++++++++++ 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 73bb61f56..75f37a06e 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -330,7 +330,7 @@ def __iter__(self) -> Iterator[BaseAssignment]: def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[BaseAssignment]: """Get assignments given a name str or :class:`~libcst.CSTNode` by ``scope.assignments[node]``""" - name = _NameUtil.get_name_for(node) + name = get_full_name_for_node(node) return set(self._assignments[name]) if name in self._assignments else set() def __contains__(self, node: Union[str, cst.CSTNode]) -> bool: @@ -352,7 +352,7 @@ def __iter__(self) -> Iterator[Access]: def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[Access]: """Get accesses given a name str or :class:`~libcst.CSTNode` by ``scope.accesses[node]``""" - name = _NameUtil.get_name_for(node) + name = get_full_name_for_node(node) return self._accesses[name] if name in self._accesses else set() def __contains__(self, node: Union[str, cst.CSTNode]) -> bool: @@ -360,23 +360,6 @@ def __contains__(self, node: Union[str, cst.CSTNode]) -> bool: return len(self[node]) > 0 -class _NameUtil: - @staticmethod - def get_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]: - """A helper function to retrieve simple name str from a CSTNode or str""" - if isinstance(node, cst.Name): - return node.value - elif isinstance(node, str): - return node - elif isinstance(node, cst.Call): - return _NameUtil.get_name_for(node.func) - elif isinstance(node, cst.Subscript): - return _NameUtil.get_name_for(node.value) - elif isinstance(node, (cst.FunctionDef, cst.ClassDef)): - return _NameUtil.get_name_for(node.name) - return None - - class Scope(abc.ABC): """ Base class of all scope classes. Scope object stores assignments from imports, diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 5f6d485bb..a2087645c 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -253,6 +253,45 @@ def test_dotted_import_access(self) -> None: self.assertEqual(list(scope_of_module["x.y"])[0].references, set()) self.assertEqual(scope_of_module.accesses["x.y"], set()) + def test_dotted_import_access_reference_by_node(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + import a.b.c + a.b.c() + """ + ) + scope_of_module = scopes[m] + first_statement = ensure_type(m.body[1], cst.SimpleStatementLine) + call = ensure_type( + ensure_type(first_statement.body[0], cst.Expr).value, cst.Call + ) + + a_b_c_assignment = cast(ImportAssignment, list(scope_of_module["a.b.c"])[0]) + a_b_c_access = list(a_b_c_assignment.references)[0] + self.assertEqual(scope_of_module.accesses[call], {a_b_c_access}) + self.assertEqual(a_b_c_access.node, call.func) + + def test_decorator_access_reference_by_node(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + import decorator + + @decorator + def f(): + pass + """ + ) + scope_of_module = scopes[m] + function_def = ensure_type(m.body[1], cst.FunctionDef) + decorator = function_def.decorators[0] + self.assertTrue("decorator" in scope_of_module) + + decorator_assignment = cast( + ImportAssignment, list(scope_of_module["decorator"])[0] + ) + decorator_access = list(decorator_assignment.references)[0] + self.assertEqual(scope_of_module.accesses[decorator], {decorator_access}) + def test_dotted_import_with_call_access(self) -> None: m, scopes = get_scope_metadata_provider( """