Skip to content

Commit

Permalink
Add codemod RenameCommand (#308)
Browse files Browse the repository at this point in the history
Add codemod RenameCommand which renames all instances of a local or imported object.
  • Loading branch information
josieesh authored Jun 25, 2020
1 parent 73608ad commit 030df06
Show file tree
Hide file tree
Showing 2 changed files with 952 additions and 0 deletions.
359 changes: 359 additions & 0 deletions libcst/codemod/commands/rename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
import argparse
from typing import Callable, Optional, Sequence, Set, Tuple, Union

import libcst as cst
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.helpers import get_full_name_for_node
from libcst.metadata import QualifiedNameProvider


def leave_import_decorator(
method: Callable[..., Union[cst.Import, cst.ImportFrom]]
) -> Callable[..., Union[cst.Import, cst.ImportFrom]]:
# We want to record any 'as name' that is relevant but only after we leave the corresponding Import/ImportFrom node since
# we don't want the 'as name' to interfere with children 'Name' and 'Attribute' nodes.
def wrapper(
self: "RenameCommand",
original_node: Union[cst.Import, cst.ImportFrom],
updated_node: Union[cst.Import, cst.ImportFrom],
) -> Union[cst.Import, cst.ImportFrom]:
updated_node = method(self, original_node, updated_node)
if original_node != updated_node:
self.record_asname(original_node)
return updated_node

return wrapper


class RenameCommand(VisitorBasedCodemodCommand):
"""
Rename all instances of a local or imported object.
"""

DESCRIPTION: str = "Rename all instances of a local or imported object."

METADATA_DEPENDENCIES = (QualifiedNameProvider,)

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--old_name",
dest="old_name",
required=True,
help="Full dotted name of object to rename. Eg: `foo.bar.baz`",
)

parser.add_argument(
"--new_name",
dest="new_name",
required=True,
help=(
"Full dotted name of replacement object. You may provide a single-colon-delimited name to specify how you want the new import to be structured."
+ "\nEg: `foo:bar.baz` will be translated to `from foo import bar`."
+ "\nIf no ':' character is provided, the import statement will default to `from foo.bar import baz` for a `new_name` value of `foo.bar.baz`"
+ " or simply replace the old import on the spot if the old import is an exact match."
),
)

def __init__(self, context: CodemodContext, old_name: str, new_name: str) -> None:
super().__init__(context)

new_module, has_colon, new_mod_or_obj = new_name.rpartition(":")
# Exit early if improperly formatted args.
if ":" in new_module:
raise ValueError("Error: `new_name` should contain at most one colon.")
if ":" in old_name:
raise ValueError("Error: `old_name` should not contain any colons.")

if not has_colon or not new_module:
new_module, _, new_mod_or_obj = new_name.rpartition(".")

self.new_name: str = new_name.replace(":", ".").strip(".")
self.new_module: str = new_module.replace(":", ".").strip(".")
self.new_mod_or_obj: str = new_mod_or_obj

# If `new_name` contains a single colon at the end, then we assume the user wants the import
# to be structured as 'import new_name'. So both self.new_mod_or_obj and self.old_mod_or_obj
# will be empty in this case.
if not self.new_mod_or_obj:
old_module = old_name
old_mod_or_obj = ""
else:
old_module, _, old_mod_or_obj = old_name.rpartition(".")

self.old_name: str = old_name
self.old_module: str = old_module
self.old_mod_or_obj: str = old_mod_or_obj

self.as_name: Optional[Tuple[str, str]] = None

# A set of nodes that have been renamed to help with the cleanup of now potentially unused
# imports, during import cleanup in `leave_Module`.
self.scheduled_removals: Set[cst.CSTNode] = set()
# If an import has been renamed while inside an `Import` or `ImportFrom` node, we want to flag
# this so that we do not end up with two of the same import.
self.bypass_import = False

def visit_Import(self, node: cst.Import) -> None:
for import_alias in node.names:
alias_name = get_full_name_for_node(import_alias.name)
if alias_name is not None:
if alias_name == self.old_name or alias_name.startswith(
self.old_name + "."
):
# If the import statement is exactly equivalent to the old name, or we are renaming a top-level module of the import,
# it will be taken care of in `leave_Name` or `leave_Attribute` when visiting the Name and Attribute children of this Import.
self.bypass_import = True

@leave_import_decorator
def leave_Import(
self, original_node: cst.Import, updated_node: cst.Import
) -> cst.Import:
new_names = []
for import_alias in updated_node.names:
import_alias_name = import_alias.name
import_alias_full_name = get_full_name_for_node(import_alias_name)
if import_alias_full_name is None:
raise Exception("Could not parse full name for ImportAlias.name node.")

if isinstance(import_alias_name, cst.Name) and self.old_name.startswith(
import_alias_full_name + "."
):
# Might, be in use elsewhere in the code, so schedule a potential removal, and add another alias.
new_names.append(import_alias)
self.scheduled_removals.add(original_node)
new_names.append(
cst.ImportAlias(
name=cst.Name(
value=self.gen_replacement_module(import_alias_full_name)
)
)
)
self.bypass_import = True
elif isinstance(
import_alias_name, cst.Attribute
) and self.old_name.startswith(import_alias_full_name + "."):
# Same idea as above.
new_names.append(import_alias)
self.scheduled_removals.add(original_node)
new_name_node: Union[
cst.Attribute, cst.Name
] = self.gen_name_or_attr_node(
self.gen_replacement_module(import_alias_full_name)
)
new_names.append(cst.ImportAlias(name=new_name_node))
self.bypass_import = True
else:
new_names.append(import_alias)

return updated_node.with_changes(names=new_names)

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
module = node.module
if module is None:
return
imported_module_name = get_full_name_for_node(module)
if imported_module_name is None:
return
if imported_module_name == self.old_name or imported_module_name.startswith(
self.old_name + "."
):
# If the imported module is exactly equivalent to the old name or we are renaming a parent module of the current module,
# it will be taken care of in `leave_Name` or `leave_Attribute` when visiting the children of this ImportFrom.
self.bypass_import = True

@leave_import_decorator
def leave_ImportFrom(
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
module = updated_node.module
if module is None:
return updated_node
imported_module_name = get_full_name_for_node(module)
names = original_node.names

if imported_module_name is None or not isinstance(names, Sequence):
return updated_node

else:
new_names = []
for import_alias in names:
alias_name = get_full_name_for_node(import_alias.name)
if alias_name is not None:
qual_name = f"{imported_module_name}.{alias_name}"
if self.old_name == qual_name:

replacement_module = self.gen_replacement_module(
imported_module_name
)
replacement_obj = self.gen_replacement(alias_name)
if not replacement_obj:
# The user has requested an `import` statement rather than an `from ... import`.
# This will be taken care of in `leave_Module`, in the meantime, schedule for potential removal.
new_names.append(import_alias)
self.scheduled_removals.add(original_node)
continue

new_import_alias_name: Union[
cst.Attribute, cst.Name
] = self.gen_name_or_attr_node(replacement_obj)
# Rename on the spot only if this is the only imported name under the module.
if len(names) == 1:
self.bypass_import = True
return updated_node.with_changes(
module=cst.parse_expression(replacement_module),
names=(cst.ImportAlias(name=new_import_alias_name),),
)
# Or if the module name is to stay the same.
elif replacement_module == imported_module_name:
self.bypass_import = True
new_names.append(
cst.ImportAlias(name=new_import_alias_name)
)
else:
if self.old_name.startswith(qual_name + "."):
# This import might be in use elsewhere in the code, so schedule a potential removal.
self.scheduled_removals.add(original_node)
new_names.append(import_alias)

return updated_node.with_changes(names=new_names)
return updated_node

def leave_Name(
self, original_node: cst.Name, updated_node: cst.Name
) -> Union[cst.Attribute, cst.Name]:
full_name_for_node: str = original_node.value
full_replacement_name = self.gen_replacement(full_name_for_node)

# If a node has no associated QualifiedName, we are still inside an import statement.
inside_import_statement: bool = not self.get_metadata(
QualifiedNameProvider, original_node, set()
)
if QualifiedNameProvider.has_name(self, original_node, self.old_name) or (
inside_import_statement and full_replacement_name == self.new_name
):
if not full_replacement_name:
full_replacement_name = self.new_name
if not inside_import_statement:
self.scheduled_removals.add(original_node)
return self.gen_name_or_attr_node(full_replacement_name)

return updated_node

def leave_Attribute(
self, original_node: cst.Attribute, updated_node: cst.Attribute
) -> Union[cst.Name, cst.Attribute]:
full_name_for_node = get_full_name_for_node(original_node)
if full_name_for_node is None:
raise Exception("Could not parse full name for Attribute node.")
full_replacement_name = self.gen_replacement(full_name_for_node)

# If a node has no associated QualifiedName, we are still inside an import statement.
inside_import_statement: bool = not self.get_metadata(
QualifiedNameProvider, original_node, set()
)
if QualifiedNameProvider.has_name(self, original_node, self.old_name,) or (
inside_import_statement and full_replacement_name == self.new_name
):
new_value, new_attr = self.new_module, self.new_mod_or_obj
if not inside_import_statement:
self.scheduled_removals.add(original_node.value)
if full_replacement_name == self.new_name:
return updated_node.with_changes(
value=cst.parse_expression(new_value),
attr=cst.Name(value=new_attr.rstrip(".")),
)

return self.gen_name_or_attr_node(new_attr)

return updated_node

def leave_Module(
self, original_node: cst.Module, updated_node: cst.Module
) -> cst.Module:
for removal_node in self.scheduled_removals:
RemoveImportsVisitor.remove_unused_import_by_node(
self.context, removal_node
)
# If bypass_import is False, we know that no import statements were directly renamed, and the fact
# that we have any `self.scheduled_removals` tells us we encountered a matching `old_name` in the code.
if not self.bypass_import and self.scheduled_removals:
if self.new_module:
new_obj: Optional[str] = self.new_mod_or_obj.split(".")[
0
] if self.new_mod_or_obj else None
AddImportsVisitor.add_needed_import(
self.context, module=self.new_module, obj=new_obj
)
return updated_node

def gen_replacement(self, original_name: str) -> str:
module_as_name = self.as_name
if module_as_name is not None:
if original_name == module_as_name[0]:
original_name = module_as_name[1]
elif original_name.startswith(module_as_name[0] + "."):
original_name = original_name.replace(
module_as_name[0] + ".", module_as_name[1] + ".", 1
)

if original_name == self.old_mod_or_obj:
return self.new_mod_or_obj
elif original_name == ".".join([self.old_module, self.old_mod_or_obj]):
return self.new_name
elif original_name.endswith("." + self.old_mod_or_obj):
return self.new_mod_or_obj
else:
return self.gen_replacement_module(original_name)

def gen_replacement_module(self, original_module: str) -> str:
return self.new_module if original_module == self.old_module else ""

def gen_name_or_attr_node(
self, dotted_expression: str
) -> Union[cst.Attribute, cst.Name]:
name_or_attr_node: cst.BaseExpression = cst.parse_expression(dotted_expression)
if not isinstance(name_or_attr_node, (cst.Name, cst.Attribute)):
raise Exception(
"`parse_expression()` on dotted path returned non-Attribute-or-Name."
)
return name_or_attr_node

def record_asname(self, original_node: Union[cst.Import, cst.ImportFrom]) -> None:
# Record the import's `as` name if it has one, and set the attribute mapping.
names = original_node.names
if not isinstance(names, Sequence):
return
for import_alias in names:
alias_name = get_full_name_for_node(import_alias.name)
if isinstance(original_node, cst.ImportFrom):
module = original_node.module
if module is None:
return
module_name = get_full_name_for_node(module)
if module_name is None:
return
qual_name = f"{module_name}.{alias_name}"
else:
qual_name = alias_name
if qual_name is not None and alias_name is not None:
if qual_name == self.old_name or self.old_name.startswith(
qual_name + "."
):
as_name_optional = import_alias.asname
as_name_node = (
as_name_optional.name if as_name_optional is not None else None
)
if as_name_node is not None and isinstance(
as_name_node, (cst.Name, cst.Attribute)
):
full_as_name = get_full_name_for_node(as_name_node)
if full_as_name is not None:
self.as_name = (full_as_name, alias_name)
Loading

0 comments on commit 030df06

Please sign in to comment.