Skip to content

Commit

Permalink
rename: handle imports via a parent module (#1251)
Browse files Browse the repository at this point in the history
When requesting a rename for `a.b.c`, we want to act on `import a` when it's used to access `a.b.c`
  • Loading branch information
zsol authored Nov 28, 2024
1 parent 6fdca74 commit 28e0f39
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 24 deletions.
48 changes: 24 additions & 24 deletions libcst/codemod/commands/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,38 +142,30 @@ def leave_Import(
) -> cst.Import:
new_names = []
for import_alias in updated_node.names:
# We keep the original import_alias here in case it's used by other symbols.
# It will be removed later in RemoveImportsVisitor if it's unused.
new_names.append(import_alias)
import_alias_name = import_alias.name
import_alias_full_name = get_full_name_for_node(import_alias_name)
if import_alias_full_name is None:
raise Exception("Could not parse full name for ImportAlias.name node.")

if isinstance(import_alias_name, cst.Name) and self.old_name.startswith(
import_alias_full_name + "."
):
# Might, be in use elsewhere in the code, so schedule a potential removal, and add another alias.
new_names.append(import_alias)
replacement_module = self.gen_replacement_module(import_alias_full_name)
self.bypass_import = True
if replacement_module != import_alias_name.value:
self.scheduled_removals.add(original_node)
new_names.append(
cst.ImportAlias(name=cst.Name(value=replacement_module))
)
elif isinstance(
import_alias_name, cst.Attribute
if isinstance(
import_alias_name, (cst.Name, cst.Attribute)
) and self.old_name.startswith(import_alias_full_name + "."):
# Same idea as above.
new_names.append(import_alias)
replacement_module = self.gen_replacement_module(import_alias_full_name)
if not replacement_module:
# here import_alias_full_name isn't an exact match for old_name
# don't add an import here, it will be handled either in more
# specific import aliases or at the very end
continue
self.bypass_import = True
if replacement_module != import_alias_full_name:
self.scheduled_removals.add(original_node)
new_name_node: Union[cst.Attribute, cst.Name] = (
self.gen_name_or_attr_node(replacement_module)
)
new_names.append(cst.ImportAlias(name=new_name_node))
else:
new_names.append(import_alias)

return updated_node.with_changes(names=new_names)

Expand Down Expand Up @@ -289,10 +281,14 @@ def leave_Attribute(
if not inside_import_statement:
self.scheduled_removals.add(original_node.value)
if full_replacement_name == self.new_name:
return updated_node.with_changes(
value=cst.parse_expression(new_value),
attr=cst.Name(value=new_attr.rstrip(".")),
)
value = cst.parse_expression(new_value)
if new_attr:
return updated_node.with_changes(
value=value,
attr=cst.Name(value=new_attr.rstrip(".")),
)
assert isinstance(value, (cst.Name, cst.Attribute))
return value

return self.gen_name_or_attr_node(new_attr)

Expand Down Expand Up @@ -329,8 +325,12 @@ def gen_replacement(self, original_name: str) -> str:

if original_name == self.old_mod_or_obj:
return self.new_mod_or_obj
elif original_name == ".".join([self.old_module, self.old_mod_or_obj]):
return self.new_name
elif original_name == self.old_name:
return (
self.new_mod_or_obj
if (not self.bypass_import and self.new_mod_or_obj)
else self.new_name
)
elif original_name.endswith("." + self.old_mod_or_obj):
return self.new_mod_or_obj
else:
Expand Down
36 changes: 36 additions & 0 deletions libcst/codemod/commands/tests/test_rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,3 +705,39 @@ def test_rename_single_with_colon(self) -> None:
old_name="a.b.qux",
new_name="a:b.qux",
)

def test_import_parent_module(self) -> None:
before = """
import a
a.b.c(a.b.c.d)
"""
after = """
from z import c
c(c.d)
"""
self.assertCodemod(before, after, old_name="a.b.c", new_name="z.c")

def test_import_parent_module_2(self) -> None:
before = """
import a.b
a.b.c.d(a.b.c.d.x)
"""
after = """
from z import c
c(c.x)
"""
self.assertCodemod(before, after, old_name="a.b.c.d", new_name="z.c")

def test_import_parent_module_3(self) -> None:
before = """
import a
a.b.c(a.b.c.d)
"""
after = """
import z.c
z.c(z.c.d)
"""
self.assertCodemod(before, after, old_name="a.b.c", new_name="z.c:")

0 comments on commit 28e0f39

Please sign in to comment.