From 28e0f397b278f061f3c6cef9bf80a0422b7b447e Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Thu, 28 Nov 2024 20:02:23 +0000 Subject: [PATCH] rename: handle imports via a parent module (#1251) When requesting a rename for `a.b.c`, we want to act on `import a` when it's used to access `a.b.c` --- libcst/codemod/commands/rename.py | 48 ++++++++++---------- libcst/codemod/commands/tests/test_rename.py | 36 +++++++++++++++ 2 files changed, 60 insertions(+), 24 deletions(-) diff --git a/libcst/codemod/commands/rename.py b/libcst/codemod/commands/rename.py index 9d710cca2..ae7138c8f 100644 --- a/libcst/codemod/commands/rename.py +++ b/libcst/codemod/commands/rename.py @@ -142,29 +142,23 @@ def leave_Import( ) -> cst.Import: new_names = [] for import_alias in updated_node.names: + # We keep the original import_alias here in case it's used by other symbols. + # It will be removed later in RemoveImportsVisitor if it's unused. + new_names.append(import_alias) import_alias_name = import_alias.name import_alias_full_name = get_full_name_for_node(import_alias_name) if import_alias_full_name is None: raise Exception("Could not parse full name for ImportAlias.name node.") - if isinstance(import_alias_name, cst.Name) and self.old_name.startswith( - import_alias_full_name + "." - ): - # Might, be in use elsewhere in the code, so schedule a potential removal, and add another alias. - new_names.append(import_alias) - replacement_module = self.gen_replacement_module(import_alias_full_name) - self.bypass_import = True - if replacement_module != import_alias_name.value: - self.scheduled_removals.add(original_node) - new_names.append( - cst.ImportAlias(name=cst.Name(value=replacement_module)) - ) - elif isinstance( - import_alias_name, cst.Attribute + if isinstance( + import_alias_name, (cst.Name, cst.Attribute) ) and self.old_name.startswith(import_alias_full_name + "."): - # Same idea as above. - new_names.append(import_alias) replacement_module = self.gen_replacement_module(import_alias_full_name) + if not replacement_module: + # here import_alias_full_name isn't an exact match for old_name + # don't add an import here, it will be handled either in more + # specific import aliases or at the very end + continue self.bypass_import = True if replacement_module != import_alias_full_name: self.scheduled_removals.add(original_node) @@ -172,8 +166,6 @@ def leave_Import( self.gen_name_or_attr_node(replacement_module) ) new_names.append(cst.ImportAlias(name=new_name_node)) - else: - new_names.append(import_alias) return updated_node.with_changes(names=new_names) @@ -289,10 +281,14 @@ def leave_Attribute( if not inside_import_statement: self.scheduled_removals.add(original_node.value) if full_replacement_name == self.new_name: - return updated_node.with_changes( - value=cst.parse_expression(new_value), - attr=cst.Name(value=new_attr.rstrip(".")), - ) + value = cst.parse_expression(new_value) + if new_attr: + return updated_node.with_changes( + value=value, + attr=cst.Name(value=new_attr.rstrip(".")), + ) + assert isinstance(value, (cst.Name, cst.Attribute)) + return value return self.gen_name_or_attr_node(new_attr) @@ -329,8 +325,12 @@ def gen_replacement(self, original_name: str) -> str: if original_name == self.old_mod_or_obj: return self.new_mod_or_obj - elif original_name == ".".join([self.old_module, self.old_mod_or_obj]): - return self.new_name + elif original_name == self.old_name: + return ( + self.new_mod_or_obj + if (not self.bypass_import and self.new_mod_or_obj) + else self.new_name + ) elif original_name.endswith("." + self.old_mod_or_obj): return self.new_mod_or_obj else: diff --git a/libcst/codemod/commands/tests/test_rename.py b/libcst/codemod/commands/tests/test_rename.py index 2f8971192..efcfbc6ef 100644 --- a/libcst/codemod/commands/tests/test_rename.py +++ b/libcst/codemod/commands/tests/test_rename.py @@ -705,3 +705,39 @@ def test_rename_single_with_colon(self) -> None: old_name="a.b.qux", new_name="a:b.qux", ) + + def test_import_parent_module(self) -> None: + before = """ + import a + a.b.c(a.b.c.d) + """ + after = """ + from z import c + + c(c.d) + """ + self.assertCodemod(before, after, old_name="a.b.c", new_name="z.c") + + def test_import_parent_module_2(self) -> None: + before = """ + import a.b + a.b.c.d(a.b.c.d.x) + """ + after = """ + from z import c + + c(c.x) + """ + self.assertCodemod(before, after, old_name="a.b.c.d", new_name="z.c") + + def test_import_parent_module_3(self) -> None: + before = """ + import a + a.b.c(a.b.c.d) + """ + after = """ + import z.c + + z.c(z.c.d) + """ + self.assertCodemod(before, after, old_name="a.b.c", new_name="z.c:")