From 302f329efd49f3973ba30954a42ba87c789a57d9 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Thu, 28 Nov 2024 13:25:24 +0000 Subject: [PATCH 1/2] rename: store state in scratch This PR changes RenameCodemod to store its per-module state in `self.context.scratch` which gets properly reset between files. --- libcst/codemod/commands/rename.py | 40 ++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/libcst/codemod/commands/rename.py b/libcst/codemod/commands/rename.py index 290f1ac1..9d710cca 100644 --- a/libcst/codemod/commands/rename.py +++ b/libcst/codemod/commands/rename.py @@ -92,14 +92,38 @@ def __init__(self, context: CodemodContext, old_name: str, new_name: str) -> Non self.old_module: str = old_module self.old_mod_or_obj: str = old_mod_or_obj - self.as_name: Optional[Tuple[str, str]] = None - - # A set of nodes that have been renamed to help with the cleanup of now potentially unused - # imports, during import cleanup in `leave_Module`. - self.scheduled_removals: Set[cst.CSTNode] = set() - # If an import has been renamed while inside an `Import` or `ImportFrom` node, we want to flag - # this so that we do not end up with two of the same import. - self.bypass_import = False + @property + def as_name(self) -> Optional[Tuple[str, str]]: + if "as_name" not in self.context.scratch: + self.context.scratch["as_name"] = None + return self.context.scratch["as_name"] + + @as_name.setter + def as_name(self, value: Optional[Tuple[str, str]]) -> None: + self.context.scratch["as_name"] = value + + @property + def scheduled_removals(self) -> Set[cst.CSTNode]: + """A set of nodes that have been renamed to help with the cleanup of now potentially unused + imports, during import cleanup in `leave_Module`.""" + if "scheduled_removals" not in self.context.scratch: + self.context.scratch["scheduled_removals"] = set() + return self.context.scratch["scheduled_removals"] + + @scheduled_removals.setter + def scheduled_removals(self, value: Set[cst.CSTNode]) -> None: + self.context.scratch["scheduled_removals"] = value + + @property + def bypass_import(self) -> bool: + """A flag to indicate that an import has been renamed while inside an `Import` or `ImportFrom` node.""" + if "bypass_import" not in self.context.scratch: + self.context.scratch["bypass_import"] = False + return self.context.scratch["bypass_import"] + + @bypass_import.setter + def bypass_import(self, value: bool) -> None: + self.context.scratch["bypass_import"] = value def visit_Import(self, node: cst.Import) -> None: for import_alias in node.names: From c7001957ef91a43f9d97a15c02cfe7bcca1ee0c1 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Thu, 28 Nov 2024 14:47:00 +0000 Subject: [PATCH 2/2] rename: handle imports via a parent module When requesting a rename for `a.b.c`, we want to act on `import a` when it's used to access `a.b.c` --- libcst/codemod/commands/rename.py | 48 ++++++++++---------- libcst/codemod/commands/tests/test_rename.py | 36 +++++++++++++++ 2 files changed, 60 insertions(+), 24 deletions(-) diff --git a/libcst/codemod/commands/rename.py b/libcst/codemod/commands/rename.py index 9d710cca..ae7138c8 100644 --- a/libcst/codemod/commands/rename.py +++ b/libcst/codemod/commands/rename.py @@ -142,29 +142,23 @@ def leave_Import( ) -> cst.Import: new_names = [] for import_alias in updated_node.names: + # We keep the original import_alias here in case it's used by other symbols. + # It will be removed later in RemoveImportsVisitor if it's unused. + new_names.append(import_alias) import_alias_name = import_alias.name import_alias_full_name = get_full_name_for_node(import_alias_name) if import_alias_full_name is None: raise Exception("Could not parse full name for ImportAlias.name node.") - if isinstance(import_alias_name, cst.Name) and self.old_name.startswith( - import_alias_full_name + "." - ): - # Might, be in use elsewhere in the code, so schedule a potential removal, and add another alias. - new_names.append(import_alias) - replacement_module = self.gen_replacement_module(import_alias_full_name) - self.bypass_import = True - if replacement_module != import_alias_name.value: - self.scheduled_removals.add(original_node) - new_names.append( - cst.ImportAlias(name=cst.Name(value=replacement_module)) - ) - elif isinstance( - import_alias_name, cst.Attribute + if isinstance( + import_alias_name, (cst.Name, cst.Attribute) ) and self.old_name.startswith(import_alias_full_name + "."): - # Same idea as above. - new_names.append(import_alias) replacement_module = self.gen_replacement_module(import_alias_full_name) + if not replacement_module: + # here import_alias_full_name isn't an exact match for old_name + # don't add an import here, it will be handled either in more + # specific import aliases or at the very end + continue self.bypass_import = True if replacement_module != import_alias_full_name: self.scheduled_removals.add(original_node) @@ -172,8 +166,6 @@ def leave_Import( self.gen_name_or_attr_node(replacement_module) ) new_names.append(cst.ImportAlias(name=new_name_node)) - else: - new_names.append(import_alias) return updated_node.with_changes(names=new_names) @@ -289,10 +281,14 @@ def leave_Attribute( if not inside_import_statement: self.scheduled_removals.add(original_node.value) if full_replacement_name == self.new_name: - return updated_node.with_changes( - value=cst.parse_expression(new_value), - attr=cst.Name(value=new_attr.rstrip(".")), - ) + value = cst.parse_expression(new_value) + if new_attr: + return updated_node.with_changes( + value=value, + attr=cst.Name(value=new_attr.rstrip(".")), + ) + assert isinstance(value, (cst.Name, cst.Attribute)) + return value return self.gen_name_or_attr_node(new_attr) @@ -329,8 +325,12 @@ def gen_replacement(self, original_name: str) -> str: if original_name == self.old_mod_or_obj: return self.new_mod_or_obj - elif original_name == ".".join([self.old_module, self.old_mod_or_obj]): - return self.new_name + elif original_name == self.old_name: + return ( + self.new_mod_or_obj + if (not self.bypass_import and self.new_mod_or_obj) + else self.new_name + ) elif original_name.endswith("." + self.old_mod_or_obj): return self.new_mod_or_obj else: diff --git a/libcst/codemod/commands/tests/test_rename.py b/libcst/codemod/commands/tests/test_rename.py index 2f897119..efcfbc6e 100644 --- a/libcst/codemod/commands/tests/test_rename.py +++ b/libcst/codemod/commands/tests/test_rename.py @@ -705,3 +705,39 @@ def test_rename_single_with_colon(self) -> None: old_name="a.b.qux", new_name="a:b.qux", ) + + def test_import_parent_module(self) -> None: + before = """ + import a + a.b.c(a.b.c.d) + """ + after = """ + from z import c + + c(c.d) + """ + self.assertCodemod(before, after, old_name="a.b.c", new_name="z.c") + + def test_import_parent_module_2(self) -> None: + before = """ + import a.b + a.b.c.d(a.b.c.d.x) + """ + after = """ + from z import c + + c(c.x) + """ + self.assertCodemod(before, after, old_name="a.b.c.d", new_name="z.c") + + def test_import_parent_module_3(self) -> None: + before = """ + import a + a.b.c(a.b.c.d) + """ + after = """ + import z.c + + z.c(z.c.d) + """ + self.assertCodemod(before, after, old_name="a.b.c", new_name="z.c:")