diff --git a/libcst/codemod/visitors/_add_imports.py b/libcst/codemod/visitors/_add_imports.py index 0bcc3cbcb..f734af5cd 100644 --- a/libcst/codemod/visitors/_add_imports.py +++ b/libcst/codemod/visitors/_add_imports.py @@ -11,19 +11,24 @@ from libcst._nodes.statement import Import, ImportFrom, SimpleStatementLine from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareTransformer -from libcst.codemod.visitors._gather_imports import GatherImportsVisitor +from libcst.codemod.visitors._gather_imports import _GatherImportsMixin from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_absolute_module_from_package_for_import from libcst.helpers.common import ensure_type -class _GatherTopImportsBeforeStatements(GatherImportsVisitor): +class _GatherTopImportsBeforeStatements(_GatherImportsMixin): """ Works similarly to GatherImportsVisitor, but only considers imports declared before any other statements of the module with the exception of docstrings and __strict__ flag. """ + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + # Track all of the imports found in this transform + self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] + def leave_Module(self, original_node: libcst.Module) -> None: start = 1 if _skip_first(original_node) else 0 for stmt in original_node.body[start:]: @@ -32,8 +37,13 @@ def leave_Module(self, original_node: libcst.Module) -> None: m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]), ): stmt = ensure_type(stmt, SimpleStatementLine) - imp = ensure_type(stmt.body[0], Union[ImportFrom, Import]) - self.all_imports.append(imp) + # Workaround for python 3.8 and 3.9, won't accept Union for isinstance + if m.matches(stmt.body[0], m.ImportFrom()): + imp = ensure_type(stmt.body[0], ImportFrom) + self.all_imports.append(imp) + if m.matches(stmt.body[0], m.Import()): + imp = ensure_type(stmt.body[0], Import) + self.all_imports.append(imp) else: break for imp in self.all_imports: @@ -44,12 +54,6 @@ def leave_Module(self, original_node: libcst.Module) -> None: imp = ensure_type(imp, ImportFrom) self._handle_ImportFrom(imp) - def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: - pass - - def visit_Import(self, node: libcst.Import) -> None: - pass - class AddImportsVisitor(ContextAwareTransformer): """ @@ -206,13 +210,13 @@ def __init__( } # Track the list of imports found at the top of the file - self.all_imports: Set[Union[libcst.Import, libcst.ImportFrom]] = set() + self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] def visit_Module(self, node: libcst.Module) -> None: # Do a preliminary pass to gather the imports we already have at the top gatherer = _GatherTopImportsBeforeStatements(self.context) node.visit(gatherer) - self.all_imports = set(gatherer.all_imports) + self.all_imports = gatherer.all_imports self.module_imports = self.module_imports - gatherer.module_imports for module, alias in gatherer.module_aliases.items(): @@ -249,7 +253,7 @@ def leave_ImportFrom( # There's nothing to do here! return updated_node - # Ensure this is on of the imports at the top + # Ensure this is one of the imports at the top if original_node not in self.all_imports: return updated_node diff --git a/libcst/codemod/visitors/_gather_imports.py b/libcst/codemod/visitors/_gather_imports.py index a4a43d37a..6b187c53e 100644 --- a/libcst/codemod/visitors/_gather_imports.py +++ b/libcst/codemod/visitors/_gather_imports.py @@ -12,43 +12,9 @@ from libcst.helpers import get_absolute_module_from_package_for_import -class GatherImportsVisitor(ContextAwareVisitor): +class _GatherImportsMixin(ContextAwareVisitor): """ - Gathers all imports in a module and stores them as attributes on the instance. - Intended to be instantiated and passed to a :class:`~libcst.Module` - :meth:`~libcst.CSTNode.visit` method in order to gather up information about - imports on a module. Note that this is not a substitute for scope analysis or - qualified name support. Please see :ref:`libcst-scope-tutorial` for a more - robust way of determining the qualified name and definition for an arbitrary - node. - - After visiting a module the following attributes will be populated: - - module_imports - A sequence of strings representing modules that were imported directly, such as - in the case of ``import typing``. Each module directly imported but not aliased - will be included here. - object_mapping - A mapping of strings to sequences of strings representing modules where we - imported objects from, such as in the case of ``from typing import Optional``. - Each from import that was not aliased will be included here, where the keys of - the mapping are the module we are importing from, and the value is a - sequence of objects we are importing from the module. - module_aliases - A mapping of strings representing modules that were imported and aliased, - such as in the case of ``import typing as t``. Each module imported this - way will be represented as a key in this mapping, and the value will be - the local alias of the module. - alias_mapping - A mapping of strings to sequences of tuples representing modules where we - imported objects from and aliased using ``as`` syntax, such as in the case - of ``from typing import Optional as opt``. Each from import that was aliased - will be included here, where the keys of the mapping are the module we are - importing from, and the value is a tuple representing the original object - name and the alias. - all_imports - A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom` - statements that were encountered in the module. + A Mixin class for tracking visited imports. """ def __init__(self, context: CodemodContext) -> None: @@ -59,16 +25,9 @@ def __init__(self, context: CodemodContext) -> None: # Track the aliased imports in this transform self.module_aliases: Dict[str, str] = {} self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {} - # Track all of the imports found in this transform - self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] # Track the import for every symbol introduced into the module self.symbol_mapping: Dict[str, ImportItem] = {} - def visit_Import(self, node: libcst.Import) -> None: - # Track this import statement for later analysis. - self.all_imports.append(node) - self._handle_Import(node) - def _handle_Import(self, node: libcst.Import) -> None: for name in node.names: alias = name.evaluated_alias @@ -82,11 +41,6 @@ def _handle_Import(self, node: libcst.Import) -> None: self.module_imports.add(name.evaluated_name) self.symbol_mapping[name.evaluated_name] = imp - def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: - # Track this import statement for later analysis. - self.all_imports.append(node) - self._handle_ImportFrom(node) - def _handle_ImportFrom(self, node: libcst.ImportFrom) -> None: # Get the module we're importing as a string. module = get_absolute_module_from_package_for_import( @@ -132,3 +86,58 @@ def _handle_ImportFrom(self, node: libcst.ImportFrom) -> None: ) key = ia.evaluated_alias or ia.evaluated_name self.symbol_mapping[key] = imp + + +class GatherImportsVisitor(_GatherImportsMixin): + """ + Gathers all imports in a module and stores them as attributes on the instance. + Intended to be instantiated and passed to a :class:`~libcst.Module` + :meth:`~libcst.CSTNode.visit` method in order to gather up information about + imports on a module. Note that this is not a substitute for scope analysis or + qualified name support. Please see :ref:`libcst-scope-tutorial` for a more + robust way of determining the qualified name and definition for an arbitrary + node. + + After visiting a module the following attributes will be populated: + + module_imports + A sequence of strings representing modules that were imported directly, such as + in the case of ``import typing``. Each module directly imported but not aliased + will be included here. + object_mapping + A mapping of strings to sequences of strings representing modules where we + imported objects from, such as in the case of ``from typing import Optional``. + Each from import that was not aliased will be included here, where the keys of + the mapping are the module we are importing from, and the value is a + sequence of objects we are importing from the module. + module_aliases + A mapping of strings representing modules that were imported and aliased, + such as in the case of ``import typing as t``. Each module imported this + way will be represented as a key in this mapping, and the value will be + the local alias of the module. + alias_mapping + A mapping of strings to sequences of tuples representing modules where we + imported objects from and aliased using ``as`` syntax, such as in the case + of ``from typing import Optional as opt``. Each from import that was aliased + will be included here, where the keys of the mapping are the module we are + importing from, and the value is a tuple representing the original object + name and the alias. + all_imports + A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom` + statements that were encountered in the module. + """ + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + # Track all of the imports found in this transform + self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] + + def visit_Import(self, node: libcst.Import) -> None: + # Track this import statement for later analysis. + self.all_imports.append(node) + self._handle_Import(node) + + def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: + # Track this import statement for later analysis. + self.all_imports.append(node) + self._handle_ImportFrom(node)