Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename: handle imports via a parent module #1251

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 56 additions & 32 deletions libcst/codemod/commands/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,38 @@ def __init__(self, context: CodemodContext, old_name: str, new_name: str) -> Non
self.old_module: str = old_module
self.old_mod_or_obj: str = old_mod_or_obj

self.as_name: Optional[Tuple[str, str]] = None

# A set of nodes that have been renamed to help with the cleanup of now potentially unused
# imports, during import cleanup in `leave_Module`.
self.scheduled_removals: Set[cst.CSTNode] = set()
# If an import has been renamed while inside an `Import` or `ImportFrom` node, we want to flag
# this so that we do not end up with two of the same import.
self.bypass_import = False
@property
def as_name(self) -> Optional[Tuple[str, str]]:
if "as_name" not in self.context.scratch:
self.context.scratch["as_name"] = None
return self.context.scratch["as_name"]

@as_name.setter
def as_name(self, value: Optional[Tuple[str, str]]) -> None:
self.context.scratch["as_name"] = value

@property
def scheduled_removals(self) -> Set[cst.CSTNode]:
"""A set of nodes that have been renamed to help with the cleanup of now potentially unused
imports, during import cleanup in `leave_Module`."""
if "scheduled_removals" not in self.context.scratch:
self.context.scratch["scheduled_removals"] = set()
return self.context.scratch["scheduled_removals"]

@scheduled_removals.setter
def scheduled_removals(self, value: Set[cst.CSTNode]) -> None:
self.context.scratch["scheduled_removals"] = value

@property
def bypass_import(self) -> bool:
"""A flag to indicate that an import has been renamed while inside an `Import` or `ImportFrom` node."""
if "bypass_import" not in self.context.scratch:
self.context.scratch["bypass_import"] = False
return self.context.scratch["bypass_import"]

@bypass_import.setter
def bypass_import(self, value: bool) -> None:
self.context.scratch["bypass_import"] = value

def visit_Import(self, node: cst.Import) -> None:
for import_alias in node.names:
Expand All @@ -118,38 +142,30 @@ def leave_Import(
) -> cst.Import:
new_names = []
for import_alias in updated_node.names:
# We keep the original import_alias here in case it's used by other symbols.
# It will be removed later in RemoveImportsVisitor if it's unused.
new_names.append(import_alias)
import_alias_name = import_alias.name
import_alias_full_name = get_full_name_for_node(import_alias_name)
if import_alias_full_name is None:
raise Exception("Could not parse full name for ImportAlias.name node.")

if isinstance(import_alias_name, cst.Name) and self.old_name.startswith(
import_alias_full_name + "."
):
# Might, be in use elsewhere in the code, so schedule a potential removal, and add another alias.
new_names.append(import_alias)
replacement_module = self.gen_replacement_module(import_alias_full_name)
self.bypass_import = True
if replacement_module != import_alias_name.value:
self.scheduled_removals.add(original_node)
new_names.append(
cst.ImportAlias(name=cst.Name(value=replacement_module))
)
elif isinstance(
import_alias_name, cst.Attribute
if isinstance(
import_alias_name, (cst.Name, cst.Attribute)
) and self.old_name.startswith(import_alias_full_name + "."):
# Same idea as above.
new_names.append(import_alias)
replacement_module = self.gen_replacement_module(import_alias_full_name)
if not replacement_module:
# here import_alias_full_name isn't an exact match for old_name
# don't add an import here, it will be handled either in more
# specific import aliases or at the very end
continue
self.bypass_import = True
if replacement_module != import_alias_full_name:
self.scheduled_removals.add(original_node)
new_name_node: Union[cst.Attribute, cst.Name] = (
self.gen_name_or_attr_node(replacement_module)
)
new_names.append(cst.ImportAlias(name=new_name_node))
else:
new_names.append(import_alias)

return updated_node.with_changes(names=new_names)

Expand Down Expand Up @@ -265,10 +281,14 @@ def leave_Attribute(
if not inside_import_statement:
self.scheduled_removals.add(original_node.value)
if full_replacement_name == self.new_name:
return updated_node.with_changes(
value=cst.parse_expression(new_value),
attr=cst.Name(value=new_attr.rstrip(".")),
)
value = cst.parse_expression(new_value)
if new_attr:
return updated_node.with_changes(
value=value,
attr=cst.Name(value=new_attr.rstrip(".")),
)
assert isinstance(value, (cst.Name, cst.Attribute))
return value

return self.gen_name_or_attr_node(new_attr)

Expand Down Expand Up @@ -305,8 +325,12 @@ def gen_replacement(self, original_name: str) -> str:

if original_name == self.old_mod_or_obj:
return self.new_mod_or_obj
elif original_name == ".".join([self.old_module, self.old_mod_or_obj]):
return self.new_name
elif original_name == self.old_name:
return (
self.new_mod_or_obj
if (not self.bypass_import and self.new_mod_or_obj)
else self.new_name
)
elif original_name.endswith("." + self.old_mod_or_obj):
return self.new_mod_or_obj
else:
Expand Down
36 changes: 36 additions & 0 deletions libcst/codemod/commands/tests/test_rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,3 +705,39 @@ def test_rename_single_with_colon(self) -> None:
old_name="a.b.qux",
new_name="a:b.qux",
)

def test_import_parent_module(self) -> None:
before = """
import a
a.b.c(a.b.c.d)
"""
after = """
from z import c

c(c.d)
"""
self.assertCodemod(before, after, old_name="a.b.c", new_name="z.c")

def test_import_parent_module_2(self) -> None:
before = """
import a.b
a.b.c.d(a.b.c.d.x)
"""
after = """
from z import c

c(c.x)
"""
self.assertCodemod(before, after, old_name="a.b.c.d", new_name="z.c")

def test_import_parent_module_3(self) -> None:
before = """
import a
a.b.c(a.b.c.d)
"""
after = """
import z.c

z.c(z.c.d)
"""
self.assertCodemod(before, after, old_name="a.b.c", new_name="z.c:")
Loading