Skip to content

Commit

Permalink
Refactoring of GatherImportsVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva committed Sep 25, 2023
1 parent 0931988 commit 6d6c59f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 77 deletions.
89 changes: 12 additions & 77 deletions libcst/codemod/visitors/_add_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,20 @@
from libcst import matchers as m, parse_statement
from libcst._nodes.statement import Import, ImportFrom, SimpleStatementLine
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer, ContextAwareVisitor
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_from_package_for_import
from libcst.helpers.common import ensure_type


class _GatherTopImportsBeforeStatements(ContextAwareVisitor):
class _GatherTopImportsBeforeStatements(GatherImportsVisitor):
"""
Works similarly to GatherImportsVisitor, but only considers imports
declared before any other statements of the module with the exception
of docstrings and __strict__ flag.
"""

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
# Track the available imports in this transform
self.module_imports: Set[str] = set()
self.object_mapping: Dict[str, Set[str]] = {}
# Track the aliased imports in this transform
self.module_aliases: Dict[str, str] = {}
self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {}
# Track all of the imports found in this transform
self.all_imports: Set[Union[libcst.Import, libcst.ImportFrom]] = set()
# Track the import for every symbol introduced into the module
self.symbol_mapping: Dict[str, ImportItem] = {}

def leave_Module(self, original_node: libcst.Module) -> None:
start = 1 if _skip_first(original_node) else 0
for stmt in original_node.body[start:]:
Expand All @@ -45,75 +33,22 @@ def leave_Module(self, original_node: libcst.Module) -> None:
):
stmt = ensure_type(stmt, SimpleStatementLine)
imp = ensure_type(stmt.body[0], Union[ImportFrom, Import])
self.all_imports.add(imp)
self.all_imports.append(imp)
else:
break
for imp in self.all_imports:
if m.matches(imp, m.Import()):
imp = ensure_type(imp, Import)
self.handle_Import(imp)
self._handle_Import(imp)
else:
imp = ensure_type(imp, ImportFrom)
self.handle_ImportFrom(imp)

def handle_Import(self, node: libcst.Import) -> None:
for name in node.names:
alias = name.evaluated_alias
imp = ImportItem(name.evaluated_name, alias=alias)
if alias is not None:
# Track this as an aliased module
self.module_aliases[name.evaluated_name] = alias
self.symbol_mapping[alias] = imp
else:
# Get the module we're importing as a string.
self.module_imports.add(name.evaluated_name)
self.symbol_mapping[name.evaluated_name] = imp
self._handle_ImportFrom(imp)

def handle_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Get the module we're importing as a string.
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, node
)
if module is None:
# Can't get the absolute import from relative, so we can't
# support this.
return
nodenames = node.names
if isinstance(nodenames, libcst.ImportStar):
# We cover everything, no need to bother tracking other things
self.object_mapping[module] = set("*")
return
elif isinstance(nodenames, Sequence):
# Get the list of imports we're aliasing in this import
new_aliases = [
(ia.evaluated_name, ia.evaluated_alias)
for ia in nodenames
if ia.asname is not None
]
if new_aliases:
if module not in self.alias_mapping:
self.alias_mapping[module] = []
# pyre-ignore We know that aliases are not None here.
self.alias_mapping[module].extend(new_aliases)

# Get the list of imports we're importing in this import
new_objects = {ia.evaluated_name for ia in nodenames if ia.asname is None}
if new_objects:
if module not in self.object_mapping:
self.object_mapping[module] = set()

# Make sure that we don't add to a '*' module
if "*" in self.object_mapping[module]:
self.object_mapping[module] = set("*")
return

self.object_mapping[module].update(new_objects)
for ia in nodenames:
imp = ImportItem(
module, obj_name=ia.evaluated_name, alias=ia.evaluated_alias
)
key = ia.evaluated_alias or ia.evaluated_name
self.symbol_mapping[key] = imp
def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
pass

def visit_Import(self, node: libcst.Import) -> None:
pass


class AddImportsVisitor(ContextAwareTransformer):
Expand Down Expand Up @@ -277,7 +212,7 @@ def visit_Module(self, node: libcst.Module) -> None:
# Do a preliminary pass to gather the imports we already have at the top
gatherer = _GatherTopImportsBeforeStatements(self.context)
node.visit(gatherer)
self.all_imports = gatherer.all_imports
self.all_imports = set(gatherer.all_imports)

self.module_imports = self.module_imports - gatherer.module_imports
for module, alias in gatherer.module_aliases.items():
Expand Down
4 changes: 4 additions & 0 deletions libcst/codemod/visitors/_gather_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def __init__(self, context: CodemodContext) -> None:
def visit_Import(self, node: libcst.Import) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)
self._handle_Import(node)

def _handle_Import(self, node: libcst.Import) -> None:
for name in node.names:
alias = name.evaluated_alias
imp = ImportItem(name.evaluated_name, alias=alias)
Expand All @@ -83,7 +85,9 @@ def visit_Import(self, node: libcst.Import) -> None:
def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)
self._handle_ImportFrom(node)

def _handle_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Get the module we're importing as a string.
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, node
Expand Down

0 comments on commit 6d6c59f

Please sign in to comment.