diff --git a/libcst/codemod/commands/rename.py b/libcst/codemod/commands/rename.py index 290f1ac1..9d710cca 100644 --- a/libcst/codemod/commands/rename.py +++ b/libcst/codemod/commands/rename.py @@ -92,14 +92,38 @@ def __init__(self, context: CodemodContext, old_name: str, new_name: str) -> Non self.old_module: str = old_module self.old_mod_or_obj: str = old_mod_or_obj - self.as_name: Optional[Tuple[str, str]] = None - - # A set of nodes that have been renamed to help with the cleanup of now potentially unused - # imports, during import cleanup in `leave_Module`. - self.scheduled_removals: Set[cst.CSTNode] = set() - # If an import has been renamed while inside an `Import` or `ImportFrom` node, we want to flag - # this so that we do not end up with two of the same import. - self.bypass_import = False + @property + def as_name(self) -> Optional[Tuple[str, str]]: + if "as_name" not in self.context.scratch: + self.context.scratch["as_name"] = None + return self.context.scratch["as_name"] + + @as_name.setter + def as_name(self, value: Optional[Tuple[str, str]]) -> None: + self.context.scratch["as_name"] = value + + @property + def scheduled_removals(self) -> Set[cst.CSTNode]: + """A set of nodes that have been renamed to help with the cleanup of now potentially unused + imports, during import cleanup in `leave_Module`.""" + if "scheduled_removals" not in self.context.scratch: + self.context.scratch["scheduled_removals"] = set() + return self.context.scratch["scheduled_removals"] + + @scheduled_removals.setter + def scheduled_removals(self, value: Set[cst.CSTNode]) -> None: + self.context.scratch["scheduled_removals"] = value + + @property + def bypass_import(self) -> bool: + """A flag to indicate that an import has been renamed while inside an `Import` or `ImportFrom` node.""" + if "bypass_import" not in self.context.scratch: + self.context.scratch["bypass_import"] = False + return self.context.scratch["bypass_import"] + + @bypass_import.setter + def bypass_import(self, value: bool) -> None: + self.context.scratch["bypass_import"] = value def visit_Import(self, node: cst.Import) -> None: for import_alias in node.names: