diff --git a/libcst/codemod/visitors/_remove_imports.py b/libcst/codemod/visitors/_remove_imports.py index 67e42fd72..629fc021a 100644 --- a/libcst/codemod/visitors/_remove_imports.py +++ b/libcst/codemod/visitors/_remove_imports.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # -from typing import Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union import libcst as cst from libcst.codemod._context import CodemodContext @@ -337,24 +337,15 @@ def leave_Import( ] return updated_node.with_changes(names=names_to_keep) - def leave_ImportFrom( - self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom - ) -> Union[cst.ImportFrom, cst.RemovalSentinel]: - names = original_node.names - if isinstance(names, cst.ImportStar): - # This is a star import, so we won't remove it. - return updated_node - - # Make sure we actually know the absolute module. - module_name = get_absolute_module_for_import( - self.context.full_module_name, updated_node - ) - if module_name is None or module_name not in self.unused_obj_imports: - # This node isn't on our list of todos, so let's bail. - return updated_node - objects_to_remove = self.unused_obj_imports[module_name] - + def _process_importfrom_aliases( + self, + updated_node: cst.ImportFrom, + names: Iterable[cst.ImportAlias], + module_name: str, + ) -> Dict[str, Any]: + updates = {} names_to_keep = [] + objects_to_remove = self.unused_obj_imports[module_name] for import_alias in names: # Figure out if it is in our list of things to kill for name, alias in objects_to_remove: @@ -374,6 +365,56 @@ def leave_ImportFrom( names_to_keep.append(import_alias) continue + # We are about to remove `import_alias`. Check if there are any + # trailing comments and reparent them to the previous import. + # We only do this in case there's a trailing comma, otherwise the + # entire import statement is going to be removed anyway. + comma = import_alias.comma + if isinstance(comma, cst.Comma): + if len(names_to_keep) != 0: + # there is a previous import alias + prev = names_to_keep[-1] + names_to_keep[-1] = prev.with_deep_changes( + whitespace_after=_merge_whitespace_after( + prev.comma.whitespace_after, + comma.whitespace_after, + ) + ) + else: + # No previous import alias, need to attach comment to `ImportFrom`. + # We can only do this if there was a leftparen on the import + # statement. Otherwise there can't be any standalone comments + # anyway, so it's fine to skip this logic. + lpar = updated_node.lpar + if isinstance(lpar, cst.LeftParen): + updates["lpar"] = lpar.with_changes( + whitespace_after=_merge_whitespace_after( + lpar.whitespace_after, + comma.whitespace_after, + ) + ) + updates["names"] = names_to_keep + return updates + + def leave_ImportFrom( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> Union[cst.ImportFrom, cst.RemovalSentinel]: + names = original_node.names + if isinstance(names, cst.ImportStar): + # This is a star import, so we won't remove it. + return updated_node + + # Make sure we actually know the absolute module. + module_name = get_absolute_module_for_import( + self.context.full_module_name, updated_node + ) + if module_name is None or module_name not in self.unused_obj_imports: + # This node isn't on our list of todos, so let's bail. + return updated_node + + updates = self._process_importfrom_aliases(updated_node, names, module_name) + names_to_keep = updates["names"] + # no changes if names_to_keep == names: return updated_node @@ -389,4 +430,20 @@ def leave_ImportFrom( *names_to_keep[:-1], names_to_keep[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT), ] - return updated_node.with_changes(names=names_to_keep) + updates["names"] = names_to_keep + return updated_node.with_changes(**updates) + + +def _merge_whitespace_after( + left: cst.BaseParenthesizableWhitespace, right: cst.BaseParenthesizableWhitespace +) -> cst.BaseParenthesizableWhitespace: + if not isinstance(right, cst.ParenthesizedWhitespace): + return left + if not isinstance(left, cst.ParenthesizedWhitespace): + return right + + return left.with_changes( + empty_lines=tuple( + line for line in right.empty_lines if line.comment is not None + ), + ) diff --git a/libcst/codemod/visitors/tests/test_remove_imports.py b/libcst/codemod/visitors/tests/test_remove_imports.py index d8d0e1863..f98aacbb6 100644 --- a/libcst/codemod/visitors/tests/test_remove_imports.py +++ b/libcst/codemod/visitors/tests/test_remove_imports.py @@ -57,6 +57,95 @@ def foo() -> None: self.assertCodemod(before, after, [("baz", None, None)]) + def test_remove_fromimport_simple(self) -> None: + before = "from a import b, c" + after = "from a import c" + self.assertCodemod(before, after, [("a", "b", None)]) + + def test_remove_fromimport_keeping_standalone_comment(self) -> None: + before = """ + from foo import ( + bar, + # comment + baz, + ) + from loooong import ( + bar, + # comment + short, + this_stays + ) + from third import ( + # comment + short, + this_stays_too + ) + """ + after = """ + from foo import ( + # comment + baz, + ) + from loooong import ( + this_stays + ) + from third import ( + this_stays_too + ) + """ + self.assertCodemod( + before, + after, + [ + ("foo", "bar", None), + ("loooong", "short", None), + ("loooong", "bar", None), + ("third", "short", None), + ], + ) + + def test_remove_fromimport_keeping_inline_comment(self) -> None: + before = """ + from foo import ( # comment + bar, + # comment2 + baz, + ) + from loooong import ( + bar, + short, # comment + # comment2 + this_stays + ) + from third import ( + short, # comment + this_stays_too # comment2 + ) + """ + after = """ + from foo import ( # comment + # comment2 + baz, + ) + from loooong import ( + # comment2 + this_stays + ) + from third import ( + this_stays_too # comment2 + ) + """ + self.assertCodemod( + before, + after, + [ + ("foo", "bar", None), + ("loooong", "short", None), + ("loooong", "bar", None), + ("third", "short", None), + ], + ) + def test_remove_import_alias_simple(self) -> None: """ Should remove aliased module as import