Skip to content

Commit

Permalink
rename: store state in scratch
Browse files Browse the repository at this point in the history
This PR changes RenameCodemod to store its per-module state in `self.context.scratch` which gets properly reset between files.
  • Loading branch information
zsol committed Nov 28, 2024
1 parent 08da127 commit 302f329
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions libcst/codemod/commands/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,38 @@ def __init__(self, context: CodemodContext, old_name: str, new_name: str) -> Non
self.old_module: str = old_module
self.old_mod_or_obj: str = old_mod_or_obj

self.as_name: Optional[Tuple[str, str]] = None

# A set of nodes that have been renamed to help with the cleanup of now potentially unused
# imports, during import cleanup in `leave_Module`.
self.scheduled_removals: Set[cst.CSTNode] = set()
# If an import has been renamed while inside an `Import` or `ImportFrom` node, we want to flag
# this so that we do not end up with two of the same import.
self.bypass_import = False
@property
def as_name(self) -> Optional[Tuple[str, str]]:
if "as_name" not in self.context.scratch:
self.context.scratch["as_name"] = None
return self.context.scratch["as_name"]

@as_name.setter
def as_name(self, value: Optional[Tuple[str, str]]) -> None:
self.context.scratch["as_name"] = value

@property
def scheduled_removals(self) -> Set[cst.CSTNode]:
"""A set of nodes that have been renamed to help with the cleanup of now potentially unused
imports, during import cleanup in `leave_Module`."""
if "scheduled_removals" not in self.context.scratch:
self.context.scratch["scheduled_removals"] = set()
return self.context.scratch["scheduled_removals"]

@scheduled_removals.setter
def scheduled_removals(self, value: Set[cst.CSTNode]) -> None:
self.context.scratch["scheduled_removals"] = value

@property
def bypass_import(self) -> bool:
"""A flag to indicate that an import has been renamed while inside an `Import` or `ImportFrom` node."""
if "bypass_import" not in self.context.scratch:
self.context.scratch["bypass_import"] = False
return self.context.scratch["bypass_import"]

@bypass_import.setter
def bypass_import(self, value: bool) -> None:
self.context.scratch["bypass_import"] = value

def visit_Import(self, node: cst.Import) -> None:
for import_alias in node.names:
Expand Down

0 comments on commit 302f329

Please sign in to comment.