From bee920d43483b654d23a3239576f637d864704d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Silva?= <12188364+andrecsilva@users.noreply.github.com> Date: Sun, 1 Oct 2023 10:38:33 -0300 Subject: [PATCH] AddImportsVisitor: add imports before the first non-import statement (#1024) * AddImportsVisitor will now only add at the top of module - Also added new tests to cover these cases * Fixed an issue with from imports * Added a couple tests for AddImportsVisitor * Refactoring of GatherImportsVisitor * Refactors, typos and typing changes --- libcst/codemod/visitors/_add_imports.py | 120 +++++++++++++----- libcst/codemod/visitors/_gather_imports.py | 105 ++++++++------- .../visitors/tests/test_add_imports.py | 103 +++++++++++++++ 3 files changed, 250 insertions(+), 78 deletions(-) diff --git a/libcst/codemod/visitors/_add_imports.py b/libcst/codemod/visitors/_add_imports.py index 8081adf9f..f734af5cd 100644 --- a/libcst/codemod/visitors/_add_imports.py +++ b/libcst/codemod/visitors/_add_imports.py @@ -8,11 +8,51 @@ import libcst from libcst import matchers as m, parse_statement +from libcst._nodes.statement import Import, ImportFrom, SimpleStatementLine from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareTransformer -from libcst.codemod.visitors._gather_imports import GatherImportsVisitor +from libcst.codemod.visitors._gather_imports import _GatherImportsMixin from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_absolute_module_from_package_for_import +from libcst.helpers.common import ensure_type + + +class _GatherTopImportsBeforeStatements(_GatherImportsMixin): + """ + Works similarly to GatherImportsVisitor, but only considers imports + declared before any other statements of the module with the exception + of docstrings and __strict__ flag. + """ + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + # Track all of the imports found in this transform + self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] + + def leave_Module(self, original_node: libcst.Module) -> None: + start = 1 if _skip_first(original_node) else 0 + for stmt in original_node.body[start:]: + if m.matches( + stmt, + m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]), + ): + stmt = ensure_type(stmt, SimpleStatementLine) + # Workaround for python 3.8 and 3.9, won't accept Union for isinstance + if m.matches(stmt.body[0], m.ImportFrom()): + imp = ensure_type(stmt.body[0], ImportFrom) + self.all_imports.append(imp) + if m.matches(stmt.body[0], m.Import()): + imp = ensure_type(stmt.body[0], Import) + self.all_imports.append(imp) + else: + break + for imp in self.all_imports: + if m.matches(imp, m.Import()): + imp = ensure_type(imp, Import) + self._handle_Import(imp) + else: + imp = ensure_type(imp, ImportFrom) + self._handle_ImportFrom(imp) class AddImportsVisitor(ContextAwareTransformer): @@ -169,12 +209,12 @@ def __init__( for module in sorted(from_imports_aliases) } - # Track the list of imports found in the file + # Track the list of imports found at the top of the file self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] def visit_Module(self, node: libcst.Module) -> None: - # Do a preliminary pass to gather the imports we already have - gatherer = GatherImportsVisitor(self.context) + # Do a preliminary pass to gather the imports we already have at the top + gatherer = _GatherTopImportsBeforeStatements(self.context) node.visit(gatherer) self.all_imports = gatherer.all_imports @@ -213,6 +253,10 @@ def leave_ImportFrom( # There's nothing to do here! return updated_node + # Ensure this is one of the imports at the top + if original_node not in self.all_imports: + return updated_node + # Get the module we're importing as a string, see if we have work to do. module = get_absolute_module_from_package_for_import( self.context.full_package_name, updated_node @@ -260,39 +304,26 @@ def _split_module( statement_before_import_location = 0 import_add_location = 0 - # never insert an import before initial __strict__ flag - if m.matches( - orig_module, - m.Module( - body=[ - m.SimpleStatementLine( - body=[ - m.Assign( - targets=[m.AssignTarget(target=m.Name("__strict__"))] - ) - ] - ), - m.ZeroOrMore(), - ] - ), - ): - statement_before_import_location = import_add_location = 1 - # This works under the principle that while we might modify node contents, # we have yet to modify the number of statements. So we can match on the # original tree but break up the statements of the modified tree. If we # change this assumption in this visitor, we will have to change this code. - for i, statement in enumerate(orig_module.body): - if i == 0 and m.matches( - statement, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]) + + # Finds the location to add imports. It is the end of the first import block that occurs before any other statement (save for docstrings) + + # Never insert an import before initial __strict__ flag or docstring + if _skip_first(orig_module): + statement_before_import_location = import_add_location = 1 + + for i, statement in enumerate( + orig_module.body[statement_before_import_location:] + ): + if m.matches( + statement, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]) ): - statement_before_import_location = import_add_location = 1 - elif isinstance(statement, libcst.SimpleStatementLine): - for possible_import in statement.body: - for last_import in self.all_imports: - if possible_import is last_import: - import_add_location = i + 1 - break + import_add_location = i + statement_before_import_location + 1 + else: + break return ( list(updated_module.body[:statement_before_import_location]), @@ -414,3 +445,28 @@ def leave_Module( *statements_after_imports, ) ) + + +def _skip_first(orig_module: libcst.Module) -> bool: + # Is there a __strict__ flag or docstring at the top? + if m.matches( + orig_module, + m.Module( + body=[ + m.SimpleStatementLine( + body=[ + m.Assign(targets=[m.AssignTarget(target=m.Name("__strict__"))]) + ] + ), + m.ZeroOrMore(), + ] + ) + | m.Module( + body=[ + m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]), + m.ZeroOrMore(), + ] + ), + ): + return True + return False diff --git a/libcst/codemod/visitors/_gather_imports.py b/libcst/codemod/visitors/_gather_imports.py index 4847afc1f..6b187c53e 100644 --- a/libcst/codemod/visitors/_gather_imports.py +++ b/libcst/codemod/visitors/_gather_imports.py @@ -12,43 +12,9 @@ from libcst.helpers import get_absolute_module_from_package_for_import -class GatherImportsVisitor(ContextAwareVisitor): +class _GatherImportsMixin(ContextAwareVisitor): """ - Gathers all imports in a module and stores them as attributes on the instance. - Intended to be instantiated and passed to a :class:`~libcst.Module` - :meth:`~libcst.CSTNode.visit` method in order to gather up information about - imports on a module. Note that this is not a substitute for scope analysis or - qualified name support. Please see :ref:`libcst-scope-tutorial` for a more - robust way of determining the qualified name and definition for an arbitrary - node. - - After visiting a module the following attributes will be populated: - - module_imports - A sequence of strings representing modules that were imported directly, such as - in the case of ``import typing``. Each module directly imported but not aliased - will be included here. - object_mapping - A mapping of strings to sequences of strings representing modules where we - imported objects from, such as in the case of ``from typing import Optional``. - Each from import that was not aliased will be included here, where the keys of - the mapping are the module we are importing from, and the value is a - sequence of objects we are importing from the module. - module_aliases - A mapping of strings representing modules that were imported and aliased, - such as in the case of ``import typing as t``. Each module imported this - way will be represented as a key in this mapping, and the value will be - the local alias of the module. - alias_mapping - A mapping of strings to sequences of tuples representing modules where we - imported objects from and aliased using ``as`` syntax, such as in the case - of ``from typing import Optional as opt``. Each from import that was aliased - will be included here, where the keys of the mapping are the module we are - importing from, and the value is a tuple representing the original object - name and the alias. - all_imports - A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom` - statements that were encountered in the module. + A Mixin class for tracking visited imports. """ def __init__(self, context: CodemodContext) -> None: @@ -59,15 +25,10 @@ def __init__(self, context: CodemodContext) -> None: # Track the aliased imports in this transform self.module_aliases: Dict[str, str] = {} self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {} - # Track all of the imports found in this transform - self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] # Track the import for every symbol introduced into the module self.symbol_mapping: Dict[str, ImportItem] = {} - def visit_Import(self, node: libcst.Import) -> None: - # Track this import statement for later analysis. - self.all_imports.append(node) - + def _handle_Import(self, node: libcst.Import) -> None: for name in node.names: alias = name.evaluated_alias imp = ImportItem(name.evaluated_name, alias=alias) @@ -80,10 +41,7 @@ def visit_Import(self, node: libcst.Import) -> None: self.module_imports.add(name.evaluated_name) self.symbol_mapping[name.evaluated_name] = imp - def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: - # Track this import statement for later analysis. - self.all_imports.append(node) - + def _handle_ImportFrom(self, node: libcst.ImportFrom) -> None: # Get the module we're importing as a string. module = get_absolute_module_from_package_for_import( self.context.full_package_name, node @@ -128,3 +86,58 @@ def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: ) key = ia.evaluated_alias or ia.evaluated_name self.symbol_mapping[key] = imp + + +class GatherImportsVisitor(_GatherImportsMixin): + """ + Gathers all imports in a module and stores them as attributes on the instance. + Intended to be instantiated and passed to a :class:`~libcst.Module` + :meth:`~libcst.CSTNode.visit` method in order to gather up information about + imports on a module. Note that this is not a substitute for scope analysis or + qualified name support. Please see :ref:`libcst-scope-tutorial` for a more + robust way of determining the qualified name and definition for an arbitrary + node. + + After visiting a module the following attributes will be populated: + + module_imports + A sequence of strings representing modules that were imported directly, such as + in the case of ``import typing``. Each module directly imported but not aliased + will be included here. + object_mapping + A mapping of strings to sequences of strings representing modules where we + imported objects from, such as in the case of ``from typing import Optional``. + Each from import that was not aliased will be included here, where the keys of + the mapping are the module we are importing from, and the value is a + sequence of objects we are importing from the module. + module_aliases + A mapping of strings representing modules that were imported and aliased, + such as in the case of ``import typing as t``. Each module imported this + way will be represented as a key in this mapping, and the value will be + the local alias of the module. + alias_mapping + A mapping of strings to sequences of tuples representing modules where we + imported objects from and aliased using ``as`` syntax, such as in the case + of ``from typing import Optional as opt``. Each from import that was aliased + will be included here, where the keys of the mapping are the module we are + importing from, and the value is a tuple representing the original object + name and the alias. + all_imports + A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom` + statements that were encountered in the module. + """ + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + # Track all of the imports found in this transform + self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] + + def visit_Import(self, node: libcst.Import) -> None: + # Track this import statement for later analysis. + self.all_imports.append(node) + self._handle_Import(node) + + def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: + # Track this import statement for later analysis. + self.all_imports.append(node) + self._handle_ImportFrom(node) diff --git a/libcst/codemod/visitors/tests/test_add_imports.py b/libcst/codemod/visitors/tests/test_add_imports.py index 0682fa513..613da9071 100644 --- a/libcst/codemod/visitors/tests/test_add_imports.py +++ b/libcst/codemod/visitors/tests/test_add_imports.py @@ -923,3 +923,106 @@ def func(): full_module_name="a.b.foobar", full_package_name="a.b" ), ) + + def test_add_at_first_block(self) -> None: + """ + Should add the import only at the end of the first import block. + """ + + before = """ + import a + import b + + e() + + import c + import d + """ + + after = """ + import a + import b + import e + + e() + + import c + import d + """ + + self.assertCodemod(before, after, [ImportItem("e", None, None)]) + + def test_add_no_import_block_before_statement(self) -> None: + """ + Should add the import before the call. + """ + + before = """ + '''docstring''' + e() + import a + import b + """ + + after = """ + '''docstring''' + import c + + e() + import a + import b + """ + + self.assertCodemod(before, after, [ImportItem("c", None, None)]) + + def test_do_not_add_existing(self) -> None: + """ + Should not add the new object import at existing import since it's not at the top + """ + + before = """ + '''docstring''' + e() + import a + import b + from c import f + """ + + after = """ + '''docstring''' + from c import e + + e() + import a + import b + from c import f + """ + + self.assertCodemod(before, after, [ImportItem("c", "e", None)]) + + def test_add_existing_at_top(self) -> None: + """ + Should add new import at exisitng from import at top + """ + + before = """ + '''docstring''' + from c import d + e() + import a + import b + from c import f + """ + + after = """ + '''docstring''' + from c import e, x, d + e() + import a + import b + from c import f + """ + + self.assertCodemod( + before, after, [ImportItem("c", "x", None), ImportItem("c", "e", None)] + )