From 1abbaa51fe0ffd38f45b0a4754d68bdc979e452c Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Fri, 22 Sep 2023 09:22:24 -0300 Subject: [PATCH] Refactoring of GatherImportsVisitor --- libcst/codemod/visitors/_add_imports.py | 89 +++------------------- libcst/codemod/visitors/_gather_imports.py | 4 + 2 files changed, 16 insertions(+), 77 deletions(-) diff --git a/libcst/codemod/visitors/_add_imports.py b/libcst/codemod/visitors/_add_imports.py index 78349421a..0bcc3cbcb 100644 --- a/libcst/codemod/visitors/_add_imports.py +++ b/libcst/codemod/visitors/_add_imports.py @@ -10,32 +10,20 @@ from libcst import matchers as m, parse_statement from libcst._nodes.statement import Import, ImportFrom, SimpleStatementLine from libcst.codemod._context import CodemodContext -from libcst.codemod._visitor import ContextAwareTransformer, ContextAwareVisitor +from libcst.codemod._visitor import ContextAwareTransformer +from libcst.codemod.visitors._gather_imports import GatherImportsVisitor from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_absolute_module_from_package_for_import from libcst.helpers.common import ensure_type -class _GatherTopImportsBeforeStatements(ContextAwareVisitor): +class _GatherTopImportsBeforeStatements(GatherImportsVisitor): """ Works similarly to GatherImportsVisitor, but only considers imports declared before any other statements of the module with the exception of docstrings and __strict__ flag. """ - def __init__(self, context: CodemodContext) -> None: - super().__init__(context) - # Track the available imports in this transform - self.module_imports: Set[str] = set() - self.object_mapping: Dict[str, Set[str]] = {} - # Track the aliased imports in this transform - self.module_aliases: Dict[str, str] = {} - self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {} - # Track all of the imports found in this transform - self.all_imports: Set[Union[libcst.Import, libcst.ImportFrom]] = set() - # Track the import for every symbol introduced into the module - self.symbol_mapping: Dict[str, ImportItem] = {} - def leave_Module(self, original_node: libcst.Module) -> None: start = 1 if _skip_first(original_node) else 0 for stmt in original_node.body[start:]: @@ -45,75 +33,22 @@ def leave_Module(self, original_node: libcst.Module) -> None: ): stmt = ensure_type(stmt, SimpleStatementLine) imp = ensure_type(stmt.body[0], Union[ImportFrom, Import]) - self.all_imports.add(imp) + self.all_imports.append(imp) else: break for imp in self.all_imports: if m.matches(imp, m.Import()): imp = ensure_type(imp, Import) - self.handle_Import(imp) + self._handle_Import(imp) else: imp = ensure_type(imp, ImportFrom) - self.handle_ImportFrom(imp) - - def handle_Import(self, node: libcst.Import) -> None: - for name in node.names: - alias = name.evaluated_alias - imp = ImportItem(name.evaluated_name, alias=alias) - if alias is not None: - # Track this as an aliased module - self.module_aliases[name.evaluated_name] = alias - self.symbol_mapping[alias] = imp - else: - # Get the module we're importing as a string. - self.module_imports.add(name.evaluated_name) - self.symbol_mapping[name.evaluated_name] = imp + self._handle_ImportFrom(imp) - def handle_ImportFrom(self, node: libcst.ImportFrom) -> None: - # Get the module we're importing as a string. - module = get_absolute_module_from_package_for_import( - self.context.full_package_name, node - ) - if module is None: - # Can't get the absolute import from relative, so we can't - # support this. - return - nodenames = node.names - if isinstance(nodenames, libcst.ImportStar): - # We cover everything, no need to bother tracking other things - self.object_mapping[module] = set("*") - return - elif isinstance(nodenames, Sequence): - # Get the list of imports we're aliasing in this import - new_aliases = [ - (ia.evaluated_name, ia.evaluated_alias) - for ia in nodenames - if ia.asname is not None - ] - if new_aliases: - if module not in self.alias_mapping: - self.alias_mapping[module] = [] - # pyre-ignore We know that aliases are not None here. - self.alias_mapping[module].extend(new_aliases) - - # Get the list of imports we're importing in this import - new_objects = {ia.evaluated_name for ia in nodenames if ia.asname is None} - if new_objects: - if module not in self.object_mapping: - self.object_mapping[module] = set() - - # Make sure that we don't add to a '*' module - if "*" in self.object_mapping[module]: - self.object_mapping[module] = set("*") - return - - self.object_mapping[module].update(new_objects) - for ia in nodenames: - imp = ImportItem( - module, obj_name=ia.evaluated_name, alias=ia.evaluated_alias - ) - key = ia.evaluated_alias or ia.evaluated_name - self.symbol_mapping[key] = imp + def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: + pass + + def visit_Import(self, node: libcst.Import) -> None: + pass class AddImportsVisitor(ContextAwareTransformer): @@ -277,7 +212,7 @@ def visit_Module(self, node: libcst.Module) -> None: # Do a preliminary pass to gather the imports we already have at the top gatherer = _GatherTopImportsBeforeStatements(self.context) node.visit(gatherer) - self.all_imports = gatherer.all_imports + self.all_imports = set(gatherer.all_imports) self.module_imports = self.module_imports - gatherer.module_imports for module, alias in gatherer.module_aliases.items(): diff --git a/libcst/codemod/visitors/_gather_imports.py b/libcst/codemod/visitors/_gather_imports.py index 4847afc1f..a4a43d37a 100644 --- a/libcst/codemod/visitors/_gather_imports.py +++ b/libcst/codemod/visitors/_gather_imports.py @@ -67,7 +67,9 @@ def __init__(self, context: CodemodContext) -> None: def visit_Import(self, node: libcst.Import) -> None: # Track this import statement for later analysis. self.all_imports.append(node) + self._handle_Import(node) + def _handle_Import(self, node: libcst.Import) -> None: for name in node.names: alias = name.evaluated_alias imp = ImportItem(name.evaluated_name, alias=alias) @@ -83,7 +85,9 @@ def visit_Import(self, node: libcst.Import) -> None: def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: # Track this import statement for later analysis. self.all_imports.append(node) + self._handle_ImportFrom(node) + def _handle_ImportFrom(self, node: libcst.ImportFrom) -> None: # Get the module we're importing as a string. module = get_absolute_module_from_package_for_import( self.context.full_package_name, node