diff --git a/libcst/codemod/commands/rename.py b/libcst/codemod/commands/rename.py new file mode 100644 index 000000000..ce50cc123 --- /dev/null +++ b/libcst/codemod/commands/rename.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# pyre-strict +import argparse +from typing import Callable, Optional, Sequence, Set, Tuple, Union + +import libcst as cst +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor +from libcst.helpers import get_full_name_for_node +from libcst.metadata import QualifiedNameProvider + + +def leave_import_decorator( + method: Callable[..., Union[cst.Import, cst.ImportFrom]] +) -> Callable[..., Union[cst.Import, cst.ImportFrom]]: + # We want to record any 'as name' that is relevant but only after we leave the corresponding Import/ImportFrom node since + # we don't want the 'as name' to interfere with children 'Name' and 'Attribute' nodes. + def wrapper( + self: "RenameCommand", + original_node: Union[cst.Import, cst.ImportFrom], + updated_node: Union[cst.Import, cst.ImportFrom], + ) -> Union[cst.Import, cst.ImportFrom]: + updated_node = method(self, original_node, updated_node) + if original_node != updated_node: + self.record_asname(original_node) + return updated_node + + return wrapper + + +class RenameCommand(VisitorBasedCodemodCommand): + """ + Rename all instances of a local or imported object. + """ + + DESCRIPTION: str = "Rename all instances of a local or imported object." + + METADATA_DEPENDENCIES = (QualifiedNameProvider,) + + @staticmethod + def add_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--old_name", + dest="old_name", + required=True, + help="Full dotted name of object to rename. Eg: `foo.bar.baz`", + ) + + parser.add_argument( + "--new_name", + dest="new_name", + required=True, + help=( + "Full dotted name of replacement object. You may provide a single-colon-delimited name to specify how you want the new import to be structured." + + "\nEg: `foo:bar.baz` will be translated to `from foo import bar`." + + "\nIf no ':' character is provided, the import statement will default to `from foo.bar import baz` for a `new_name` value of `foo.bar.baz`" + + " or simply replace the old import on the spot if the old import is an exact match." + ), + ) + + def __init__(self, context: CodemodContext, old_name: str, new_name: str) -> None: + super().__init__(context) + + new_module, has_colon, new_mod_or_obj = new_name.rpartition(":") + # Exit early if improperly formatted args. + if ":" in new_module: + raise ValueError("Error: `new_name` should contain at most one colon.") + if ":" in old_name: + raise ValueError("Error: `old_name` should not contain any colons.") + + if not has_colon or not new_module: + new_module, _, new_mod_or_obj = new_name.rpartition(".") + + self.new_name: str = new_name.replace(":", ".").strip(".") + self.new_module: str = new_module.replace(":", ".").strip(".") + self.new_mod_or_obj: str = new_mod_or_obj + + # If `new_name` contains a single colon at the end, then we assume the user wants the import + # to be structured as 'import new_name'. So both self.new_mod_or_obj and self.old_mod_or_obj + # will be empty in this case. + if not self.new_mod_or_obj: + old_module = old_name + old_mod_or_obj = "" + else: + old_module, _, old_mod_or_obj = old_name.rpartition(".") + + self.old_name: str = old_name + self.old_module: str = old_module + self.old_mod_or_obj: str = old_mod_or_obj + + self.as_name: Optional[Tuple[str, str]] = None + + # A set of nodes that have been renamed to help with the cleanup of now potentially unused + # imports, during import cleanup in `leave_Module`. + self.scheduled_removals: Set[cst.CSTNode] = set() + # If an import has been renamed while inside an `Import` or `ImportFrom` node, we want to flag + # this so that we do not end up with two of the same import. + self.bypass_import = False + + def visit_Import(self, node: cst.Import) -> None: + for import_alias in node.names: + alias_name = get_full_name_for_node(import_alias.name) + if alias_name is not None: + if alias_name == self.old_name or alias_name.startswith( + self.old_name + "." + ): + # If the import statement is exactly equivalent to the old name, or we are renaming a top-level module of the import, + # it will be taken care of in `leave_Name` or `leave_Attribute` when visiting the Name and Attribute children of this Import. + self.bypass_import = True + + @leave_import_decorator + def leave_Import( + self, original_node: cst.Import, updated_node: cst.Import + ) -> cst.Import: + new_names = [] + for import_alias in updated_node.names: + import_alias_name = import_alias.name + import_alias_full_name = get_full_name_for_node(import_alias_name) + if import_alias_full_name is None: + raise Exception("Could not parse full name for ImportAlias.name node.") + + if isinstance(import_alias_name, cst.Name) and self.old_name.startswith( + import_alias_full_name + "." + ): + # Might, be in use elsewhere in the code, so schedule a potential removal, and add another alias. + new_names.append(import_alias) + self.scheduled_removals.add(original_node) + new_names.append( + cst.ImportAlias( + name=cst.Name( + value=self.gen_replacement_module(import_alias_full_name) + ) + ) + ) + self.bypass_import = True + elif isinstance( + import_alias_name, cst.Attribute + ) and self.old_name.startswith(import_alias_full_name + "."): + # Same idea as above. + new_names.append(import_alias) + self.scheduled_removals.add(original_node) + new_name_node: Union[ + cst.Attribute, cst.Name + ] = self.gen_name_or_attr_node( + self.gen_replacement_module(import_alias_full_name) + ) + new_names.append(cst.ImportAlias(name=new_name_node)) + self.bypass_import = True + else: + new_names.append(import_alias) + + return updated_node.with_changes(names=new_names) + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + module = node.module + if module is None: + return + imported_module_name = get_full_name_for_node(module) + if imported_module_name is None: + return + if imported_module_name == self.old_name or imported_module_name.startswith( + self.old_name + "." + ): + # If the imported module is exactly equivalent to the old name or we are renaming a parent module of the current module, + # it will be taken care of in `leave_Name` or `leave_Attribute` when visiting the children of this ImportFrom. + self.bypass_import = True + + @leave_import_decorator + def leave_ImportFrom( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> cst.ImportFrom: + module = updated_node.module + if module is None: + return updated_node + imported_module_name = get_full_name_for_node(module) + names = original_node.names + + if imported_module_name is None or not isinstance(names, Sequence): + return updated_node + + else: + new_names = [] + for import_alias in names: + alias_name = get_full_name_for_node(import_alias.name) + if alias_name is not None: + qual_name = f"{imported_module_name}.{alias_name}" + if self.old_name == qual_name: + + replacement_module = self.gen_replacement_module( + imported_module_name + ) + replacement_obj = self.gen_replacement(alias_name) + if not replacement_obj: + # The user has requested an `import` statement rather than an `from ... import`. + # This will be taken care of in `leave_Module`, in the meantime, schedule for potential removal. + new_names.append(import_alias) + self.scheduled_removals.add(original_node) + continue + + new_import_alias_name: Union[ + cst.Attribute, cst.Name + ] = self.gen_name_or_attr_node(replacement_obj) + # Rename on the spot only if this is the only imported name under the module. + if len(names) == 1: + self.bypass_import = True + return updated_node.with_changes( + module=cst.parse_expression(replacement_module), + names=(cst.ImportAlias(name=new_import_alias_name),), + ) + # Or if the module name is to stay the same. + elif replacement_module == imported_module_name: + self.bypass_import = True + new_names.append( + cst.ImportAlias(name=new_import_alias_name) + ) + else: + if self.old_name.startswith(qual_name + "."): + # This import might be in use elsewhere in the code, so schedule a potential removal. + self.scheduled_removals.add(original_node) + new_names.append(import_alias) + + return updated_node.with_changes(names=new_names) + return updated_node + + def leave_Name( + self, original_node: cst.Name, updated_node: cst.Name + ) -> Union[cst.Attribute, cst.Name]: + full_name_for_node: str = original_node.value + full_replacement_name = self.gen_replacement(full_name_for_node) + + # If a node has no associated QualifiedName, we are still inside an import statement. + inside_import_statement: bool = not self.get_metadata( + QualifiedNameProvider, original_node, set() + ) + if QualifiedNameProvider.has_name(self, original_node, self.old_name) or ( + inside_import_statement and full_replacement_name == self.new_name + ): + if not full_replacement_name: + full_replacement_name = self.new_name + if not inside_import_statement: + self.scheduled_removals.add(original_node) + return self.gen_name_or_attr_node(full_replacement_name) + + return updated_node + + def leave_Attribute( + self, original_node: cst.Attribute, updated_node: cst.Attribute + ) -> Union[cst.Name, cst.Attribute]: + full_name_for_node = get_full_name_for_node(original_node) + if full_name_for_node is None: + raise Exception("Could not parse full name for Attribute node.") + full_replacement_name = self.gen_replacement(full_name_for_node) + + # If a node has no associated QualifiedName, we are still inside an import statement. + inside_import_statement: bool = not self.get_metadata( + QualifiedNameProvider, original_node, set() + ) + if QualifiedNameProvider.has_name(self, original_node, self.old_name,) or ( + inside_import_statement and full_replacement_name == self.new_name + ): + new_value, new_attr = self.new_module, self.new_mod_or_obj + if not inside_import_statement: + self.scheduled_removals.add(original_node.value) + if full_replacement_name == self.new_name: + return updated_node.with_changes( + value=cst.parse_expression(new_value), + attr=cst.Name(value=new_attr.rstrip(".")), + ) + + return self.gen_name_or_attr_node(new_attr) + + return updated_node + + def leave_Module( + self, original_node: cst.Module, updated_node: cst.Module + ) -> cst.Module: + for removal_node in self.scheduled_removals: + RemoveImportsVisitor.remove_unused_import_by_node( + self.context, removal_node + ) + # If bypass_import is False, we know that no import statements were directly renamed, and the fact + # that we have any `self.scheduled_removals` tells us we encountered a matching `old_name` in the code. + if not self.bypass_import and self.scheduled_removals: + if self.new_module: + new_obj: Optional[str] = self.new_mod_or_obj.split(".")[ + 0 + ] if self.new_mod_or_obj else None + AddImportsVisitor.add_needed_import( + self.context, module=self.new_module, obj=new_obj + ) + return updated_node + + def gen_replacement(self, original_name: str) -> str: + module_as_name = self.as_name + if module_as_name is not None: + if original_name == module_as_name[0]: + original_name = module_as_name[1] + elif original_name.startswith(module_as_name[0] + "."): + original_name = original_name.replace( + module_as_name[0] + ".", module_as_name[1] + ".", 1 + ) + + if original_name == self.old_mod_or_obj: + return self.new_mod_or_obj + elif original_name == ".".join([self.old_module, self.old_mod_or_obj]): + return self.new_name + elif original_name.endswith("." + self.old_mod_or_obj): + return self.new_mod_or_obj + else: + return self.gen_replacement_module(original_name) + + def gen_replacement_module(self, original_module: str) -> str: + return self.new_module if original_module == self.old_module else "" + + def gen_name_or_attr_node( + self, dotted_expression: str + ) -> Union[cst.Attribute, cst.Name]: + name_or_attr_node: cst.BaseExpression = cst.parse_expression(dotted_expression) + if not isinstance(name_or_attr_node, (cst.Name, cst.Attribute)): + raise Exception( + "`parse_expression()` on dotted path returned non-Attribute-or-Name." + ) + return name_or_attr_node + + def record_asname(self, original_node: Union[cst.Import, cst.ImportFrom]) -> None: + # Record the import's `as` name if it has one, and set the attribute mapping. + names = original_node.names + if not isinstance(names, Sequence): + return + for import_alias in names: + alias_name = get_full_name_for_node(import_alias.name) + if isinstance(original_node, cst.ImportFrom): + module = original_node.module + if module is None: + return + module_name = get_full_name_for_node(module) + if module_name is None: + return + qual_name = f"{module_name}.{alias_name}" + else: + qual_name = alias_name + if qual_name is not None and alias_name is not None: + if qual_name == self.old_name or self.old_name.startswith( + qual_name + "." + ): + as_name_optional = import_alias.asname + as_name_node = ( + as_name_optional.name if as_name_optional is not None else None + ) + if as_name_node is not None and isinstance( + as_name_node, (cst.Name, cst.Attribute) + ): + full_as_name = get_full_name_for_node(as_name_node) + if full_as_name is not None: + self.as_name = (full_as_name, alias_name) diff --git a/libcst/codemod/commands/tests/test_rename.py b/libcst/codemod/commands/tests/test_rename.py new file mode 100644 index 000000000..b5427280f --- /dev/null +++ b/libcst/codemod/commands/tests/test_rename.py @@ -0,0 +1,593 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# pyre-strict + +from libcst.codemod import CodemodTest +from libcst.codemod.commands.rename import RenameCommand + + +class TestRenameCommand(CodemodTest): + + TRANSFORM = RenameCommand + + def test_rename_name(self) -> None: + + before = """ + from foo import bar + + def test() -> None: + bar(5) + """ + after = """ + from baz import qux + + def test() -> None: + qux(5) + """ + + self.assertCodemod(before, after, old_name="foo.bar", new_name="baz.qux") + + def test_rename_name_asname(self) -> None: + + before = """ + from foo import bar as bla + + def test() -> None: + bla(5) + """ + after = """ + from baz import qux + + def test() -> None: + qux(5) + """ + + self.assertCodemod( + before, after, old_name="foo.bar", new_name="baz.qux", + ) + + def test_rename_repeated_name_with_asname(self) -> None: + before = """ + from foo import foo as bla + + def test() -> None: + bla.bla(5) + """ + after = """ + from baz import qux + + def test() -> None: + qux.bla(5) + """ + self.assertCodemod( + before, after, old_name="foo.foo", new_name="baz.qux", + ) + + def test_rename_attr(self) -> None: + + before = """ + import a.b + + def test() -> None: + a.b.c(5) + """ + after = """ + import d.e + + def test() -> None: + d.e.f(5) + """ + + self.assertCodemod( + before, after, old_name="a.b.c", new_name="d.e.f", + ) + + def test_rename_attr_asname(self) -> None: + + before = """ + import foo as bar + + def test() -> None: + bar.qux(5) + """ + after = """ + import baz + + def test() -> None: + baz.quux(5) + """ + + self.assertCodemod( + before, after, old_name="foo.qux", new_name="baz.quux", + ) + + def test_rename_module_import(self) -> None: + before = """ + import a.b + + class Foo(a.b.C): + pass + """ + after = """ + import c.b + + class Foo(c.b.C): + pass + """ + + self.assertCodemod( + before, after, old_name="a.b", new_name="c.b", + ) + + def test_rename_module_import_2(self) -> None: + before = """ + import a.b + + class Foo(a.b.C): + pass + """ + after = """ + import c.b + + class Foo(c.b.C): + pass + """ + + self.assertCodemod( + before, after, old_name="a", new_name="c", + ) + + def test_rename_module_import_no_change(self) -> None: + # Full qualified names don't match, so don't codemod + before = """ + import a.b + + class Foo(a.b.C): + pass + """ + self.assertCodemod( + before, before, old_name="b", new_name="c.b", + ) + + def test_rename_module_import_from(self) -> None: + before = """ + from a import b + + class Foo(b.C): + pass + """ + after = """ + from c import b + + class Foo(b.C): + pass + """ + + self.assertCodemod( + before, after, old_name="a.b", new_name="c.b", + ) + + def test_rename_module_import_from_2(self) -> None: + before = """ + from a import b + + class Foo(b.C): + pass + """ + after = """ + from c import b + + class Foo(b.C): + pass + """ + + self.assertCodemod( + before, after, old_name="a", new_name="c", + ) + + def test_rename_class(self) -> None: + before = """ + from a.b import some_class + + class Foo(some_class): + pass + """ + after = """ + from c.b import some_class + + class Foo(some_class): + pass + """ + self.assertCodemod( + before, after, old_name="a.b.some_class", new_name="c.b.some_class", + ) + + def test_rename_importfrom_same_module(self) -> None: + before = """ + from a.b import Class_1, Class_2 + + class Foo(Class_1): + pass + """ + after = """ + from a.b import Class_3, Class_2 + + class Foo(Class_3): + pass + """ + self.assertCodemod( + before, after, old_name="a.b.Class_1", new_name="a.b.Class_3", + ) + + def test_rename_importfrom_same_module_2(self) -> None: + before = """ + from a.b import module_1, module_2 + + class Foo(module_1.Class_1): + pass + class Fooo(module_2.Class_2): + pass + """ + after = """ + from a.b import module_2 + from a.b.module_3 import Class_3 + + class Foo(Class_3): + pass + class Fooo(module_2.Class_2): + pass + """ + self.assertCodemod( + before, + after, + old_name="a.b.module_1.Class_1", + new_name="a.b.module_3.Class_3", + ) + + def test_rename_local_variable(self) -> None: + before = """ + x = 5 + y = 5 + x + """ + after = """ + z = 5 + y = 5 + z + """ + + self.assertCodemod( + before, after, old_name="x", new_name="z", + ) + + def test_module_does_not_change(self) -> None: + before = """ + from a import b + + class Foo(b): + pass + """ + after = """ + from a import c + + class Foo(c): + pass + """ + self.assertCodemod(before, after, old_name="a.b", new_name="a.c") + + def test_other_imports_untouched(self) -> None: + before = """ + import a, b, c + + class Foo(a.z): + bar: b.bar + baz: c.baz + """ + after = """ + import d, b, c + + class Foo(d.z): + bar: b.bar + baz: c.baz + """ + self.assertCodemod( + before, after, old_name="a.z", new_name="d.z", + ) + + def test_other_import_froms_untouched(self) -> None: + before = """ + from a import b, c, d + + class Foo(b): + bar: c.bar + baz: d.baz + """ + after = """ + from a import c, d + from f import b + + class Foo(b): + bar: c.bar + baz: d.baz + """ + self.assertCodemod( + before, after, old_name="a.b", new_name="f.b", + ) + + def test_no_removal_of_import_in_use(self) -> None: + before = """ + import a + + class Foo(a.b): + pass + class Foo2(a.c): + pass + """ + after = """ + import a, z + + class Foo(z.b): + pass + class Foo2(a.c): + pass + """ + self.assertCodemod( + before, after, old_name="a.b", new_name="z.b", + ) + + def test_no_removal_of_dotted_import_in_use(self) -> None: + before = """ + import a.b + + class Foo(a.b.c): + pass + class Foo2(a.b.d): + pass + """ + after = """ + import a.b, z.b + + class Foo(z.b.c): + pass + class Foo2(a.b.d): + pass + """ + self.assertCodemod( + before, after, old_name="a.b.c", new_name="z.b.c", + ) + + def test_no_removal_of_import_from_in_use(self) -> None: + before = """ + from a import b + + class Foo(b.some_class): + bar: b.some_other_class + """ + after = """ + from a import b + from blah import some_class + + class Foo(some_class): + bar: b.some_other_class + """ + self.assertCodemod( + before, after, old_name="a.b.some_class", new_name="blah.some_class", + ) + + def test_other_unused_imports_untouched(self) -> None: + before = """ + import a + import b + + class Foo(a.obj): + pass + """ + after = """ + import c + import b + + class Foo(c.obj): + pass + """ + self.assertCodemod( + before, after, old_name="a.obj", new_name="c.obj", + ) + + def test_complex_module_rename(self) -> None: + before = """ + from a.b.c import d + + class Foo(d.e.f): + pass + """ + after = """ + from g.h.i import j + + class Foo(j): + pass + """ + self.assertCodemod(before, after, old_name="a.b.c.d.e.f", new_name="g.h.i.j") + + def test_complex_module_rename_with_asname(self) -> None: + before = """ + from a.b.c import d as ddd + + class Foo(ddd.e.f): + pass + """ + after = """ + from g.h.i import j + + class Foo(j): + pass + """ + self.assertCodemod(before, after, old_name="a.b.c.d.e.f", new_name="g.h.i.j") + + def test_names_with_repeated_substrings(self) -> None: + before = """ + from aa import aaaa + + class Foo(aaaa.Bar): + pass + """ + after = """ + from b import c + + class Foo(c.Bar): + pass + """ + self.assertCodemod( + before, after, old_name="aa.aaaa", new_name="b.c", + ) + + def test_repeated_name(self) -> None: + before = """ + from foo import foo + + def bar(): + foo(5) + """ + after = """ + from qux import qux + + def bar(): + qux(5) + """ + self.assertCodemod( + before, after, old_name="foo.foo", new_name="qux.qux", + ) + + def test_no_codemod(self) -> None: + before = """ + from foo import bar + + def baz(): + bar(5) + """ + self.assertCodemod( + before, before, old_name="bar", new_name="qux", + ) + + def test_rename_import_prefix(self) -> None: + before = """ + import a.b.c.d + """ + after = """ + import x.y.c.d + """ + self.assertCodemod( + before, after, old_name="a.b", new_name="x.y", + ) + + def test_rename_import_from_prefix(self) -> None: + before = """ + from a.b.c.d import foo + """ + after = """ + from x.y.c.d import foo + """ + self.assertCodemod( + before, after, old_name="a.b", new_name="x.y", + ) + + def test_rename_multiple_occurrences(self) -> None: + before = """ + from a import b + + class Foo(b.some_class): + pass + class Foobar(b.some_class): + pass + """ + after = """ + from c.d import some_class + + class Foo(some_class): + pass + class Foobar(some_class): + pass + """ + self.assertCodemod( + before, after, old_name="a.b.some_class", new_name="c.d.some_class" + ) + + def test_rename_multiple_imports(self) -> None: + before = """ + import a + from a import b + from a.c import d + + class Foo(d): + pass + class Fooo(b.some_class): + pass + class Foooo(a.some_class): + pass + """ + after = """ + import z + from z import b + from z.c import d + + class Foo(d): + pass + class Fooo(b.some_class): + pass + class Foooo(z.some_class): + pass + """ + self.assertCodemod(before, after, old_name="a", new_name="z") + + def test_input_with_colon_sep(self) -> None: + before = """ + from a.b.c import d + + class Foo(d.e.f): + pass + """ + after = """ + from g.h import i + + class Foo(i.j): + pass + """ + self.assertCodemod(before, after, old_name="a.b.c.d.e.f", new_name="g.h:i.j") + + def test_input_with_colon_sep_at_the_end(self) -> None: + before = """ + from a.b.c import d + + class Foo(d.e): + pass + """ + after = """ + import g.h.i.j + + class Foo(g.h.i.j.e): + pass + """ + self.assertCodemod(before, after, old_name="a.b.c.d", new_name="g.h.i.j:") + + def test_input_with_colon_sep_at_the_front(self) -> None: + # This case should treat it as if no colon separator. + before = """ + from a.b.c import d + + class Foo(d.e): + pass + """ + after = """ + from g.h.i import j + + class Foo(j.e): + pass + """ + self.assertCodemod(before, after, old_name="a.b.c.d", new_name=":g.h.i.j") + + def test_no_change_because_no_match_was_found(self) -> None: + before = """ + from foo import bar + bar(42) + """ + self.assertCodemod(before, before, old_name="baz.bar", new_name="qux.bar")