From 1c3a85da7f922141823e29dcf6628b94b7357d3b Mon Sep 17 00:00:00 2001 From: Matthew Shaer Date: Tue, 8 Nov 2022 12:24:20 +0000 Subject: [PATCH 1/2] Adding a provider which can tell what accessor to use to go from the parent to that child node ufmt --- libcst/metadata/__init__.py | 2 + libcst/metadata/accessor_provider.py | 19 ++++++ .../metadata/tests/test_accessor_provider.py | 68 +++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 libcst/metadata/accessor_provider.py create mode 100644 libcst/metadata/tests/test_accessor_provider.py diff --git a/libcst/metadata/__init__.py b/libcst/metadata/__init__.py index 75e382292..66e7e5251 100644 --- a/libcst/metadata/__init__.py +++ b/libcst/metadata/__init__.py @@ -5,6 +5,7 @@ from libcst._position import CodePosition, CodeRange +from libcst.metadata.accessor_provider import AccessorProvider from libcst.metadata.base_provider import ( BaseMetadataProvider, BatchableMetadataProvider, @@ -86,6 +87,7 @@ "Accesses", "TypeInferenceProvider", "FullRepoManager", + "AccessorProvider", # Experimental APIs: "ExperimentalReentrantCodegenProvider", "CodegenPartial", diff --git a/libcst/metadata/accessor_provider.py b/libcst/metadata/accessor_provider.py new file mode 100644 index 000000000..5d4f22e42 --- /dev/null +++ b/libcst/metadata/accessor_provider.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import dataclasses + +import libcst as cst + +from libcst.metadata.base_provider import VisitorMetadataProvider + + +class AccessorProvider(VisitorMetadataProvider[str]): + def on_visit(self, node: cst.CSTNode) -> bool: + for f in dataclasses.fields(node): + child = getattr(node, f.name) + self.set_metadata(child, f.name) + return True diff --git a/libcst/metadata/tests/test_accessor_provider.py b/libcst/metadata/tests/test_accessor_provider.py new file mode 100644 index 000000000..6ccfad5ee --- /dev/null +++ b/libcst/metadata/tests/test_accessor_provider.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses + +from textwrap import dedent + +import libcst as cst +from libcst.metadata import AccessorProvider, MetadataWrapper +from libcst.testing.utils import data_provider, UnitTest + + +class DependentVisitor(cst.CSTVisitor): + METADATA_DEPENDENCIES = (AccessorProvider,) + + def __init__(self, *, test: UnitTest) -> None: + self.test = test + + def on_visit(self, node: cst.CSTNode) -> bool: + for f in dataclasses.fields(node): + child = getattr(node, f.name) + if type(child) is cst.CSTNode: + accessor = self.get_metadata(AccessorProvider, child) + self.test.assertEqual(accessor, f.name) + + return True + + +class AccessorProviderTest(UnitTest): + @data_provider( + ( + ( + """ + foo = 'toplevel' + fn1(foo) + fn2(foo) + def fn_def(): + foo = 'shadow' + fn3(foo) + """, + ), + ( + """ + global_var = None + @cls_attr + class Cls(cls_attr, kwarg=cls_attr): + cls_attr = 5 + def f(): + pass + """, + ), + ( + """ + iterator = None + condition = None + [elt for target in iterator if condition] + {elt for target in iterator if condition} + {elt: target for target in iterator if condition} + (elt for target in iterator if condition) + """, + ), + ) + ) + def test_accessor_provier(self, code: str) -> None: + wrapper = MetadataWrapper(cst.parse_module(dedent(code))) + wrapper.visit(DependentVisitor(test=self)) From 3061e90862c7d7fd21025150c7bb27d94c8dc0d0 Mon Sep 17 00:00:00 2001 From: Matthew Shaer Date: Fri, 18 Nov 2022 17:45:44 +0000 Subject: [PATCH 2/2] Handling accessors for Sequences --- libcst/metadata/accessor_provider.py | 8 +++++++- libcst/metadata/tests/test_accessor_provider.py | 9 ++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/libcst/metadata/accessor_provider.py b/libcst/metadata/accessor_provider.py index 5d4f22e42..8fe506090 100644 --- a/libcst/metadata/accessor_provider.py +++ b/libcst/metadata/accessor_provider.py @@ -6,6 +6,8 @@ import dataclasses +from typing import Sequence + import libcst as cst from libcst.metadata.base_provider import VisitorMetadataProvider @@ -15,5 +17,9 @@ class AccessorProvider(VisitorMetadataProvider[str]): def on_visit(self, node: cst.CSTNode) -> bool: for f in dataclasses.fields(node): child = getattr(node, f.name) - self.set_metadata(child, f.name) + if isinstance(child, cst.CSTNode): + self.set_metadata(child, f.name) + elif isinstance(child, Sequence): + for idx, subchild in enumerate(child): + self.set_metadata(subchild, f.name + "[" + str(idx) + "]") return True diff --git a/libcst/metadata/tests/test_accessor_provider.py b/libcst/metadata/tests/test_accessor_provider.py index 6ccfad5ee..8905bcc22 100644 --- a/libcst/metadata/tests/test_accessor_provider.py +++ b/libcst/metadata/tests/test_accessor_provider.py @@ -7,6 +7,8 @@ from textwrap import dedent +from typing import Sequence + import libcst as cst from libcst.metadata import AccessorProvider, MetadataWrapper from libcst.testing.utils import data_provider, UnitTest @@ -22,8 +24,13 @@ def on_visit(self, node: cst.CSTNode) -> bool: for f in dataclasses.fields(node): child = getattr(node, f.name) if type(child) is cst.CSTNode: - accessor = self.get_metadata(AccessorProvider, child) + accessor = self.get_metadata(AccessorProvider, child, None) self.test.assertEqual(accessor, f.name) + elif isinstance(child, Sequence): + for idx, subchild in enumerate(child): + if type(subchild) is cst.CSTNode: + accessor = self.get_metadata(AccessorProvider, subchild, None) + self.test.assertEqual(accessor, f.name + "[" + str(idx) + "]") return True