Skip to content

Commit

Permalink
AddImportsVisitor: add imports before the first non-import statement (I…
Browse files Browse the repository at this point in the history
…nstagram#1024)

* AddImportsVisitor will now only add at the top of module

- Also added new tests to cover these cases

* Fixed an issue with from imports

* Added a couple tests for AddImportsVisitor

* Refactoring of GatherImportsVisitor

* Refactors, typos and typing changes
  • Loading branch information
andrecsilva authored and manmartgarc committed Oct 3, 2023
1 parent aaeebe2 commit bee920d
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 78 deletions.
120 changes: 88 additions & 32 deletions libcst/codemod/visitors/_add_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,51 @@

import libcst
from libcst import matchers as m, parse_statement
from libcst._nodes.statement import Import, ImportFrom, SimpleStatementLine
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._gather_imports import _GatherImportsMixin
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_from_package_for_import
from libcst.helpers.common import ensure_type


class _GatherTopImportsBeforeStatements(_GatherImportsMixin):
"""
Works similarly to GatherImportsVisitor, but only considers imports
declared before any other statements of the module with the exception
of docstrings and __strict__ flag.
"""

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
# Track all of the imports found in this transform
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []

def leave_Module(self, original_node: libcst.Module) -> None:
start = 1 if _skip_first(original_node) else 0
for stmt in original_node.body[start:]:
if m.matches(
stmt,
m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]),
):
stmt = ensure_type(stmt, SimpleStatementLine)
# Workaround for python 3.8 and 3.9, won't accept Union for isinstance
if m.matches(stmt.body[0], m.ImportFrom()):
imp = ensure_type(stmt.body[0], ImportFrom)
self.all_imports.append(imp)
if m.matches(stmt.body[0], m.Import()):
imp = ensure_type(stmt.body[0], Import)
self.all_imports.append(imp)
else:
break
for imp in self.all_imports:
if m.matches(imp, m.Import()):
imp = ensure_type(imp, Import)
self._handle_Import(imp)
else:
imp = ensure_type(imp, ImportFrom)
self._handle_ImportFrom(imp)


class AddImportsVisitor(ContextAwareTransformer):
Expand Down Expand Up @@ -169,12 +209,12 @@ def __init__(
for module in sorted(from_imports_aliases)
}

# Track the list of imports found in the file
# Track the list of imports found at the top of the file
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []

def visit_Module(self, node: libcst.Module) -> None:
# Do a preliminary pass to gather the imports we already have
gatherer = GatherImportsVisitor(self.context)
# Do a preliminary pass to gather the imports we already have at the top
gatherer = _GatherTopImportsBeforeStatements(self.context)
node.visit(gatherer)
self.all_imports = gatherer.all_imports

Expand Down Expand Up @@ -213,6 +253,10 @@ def leave_ImportFrom(
# There's nothing to do here!
return updated_node

# Ensure this is one of the imports at the top
if original_node not in self.all_imports:
return updated_node

# Get the module we're importing as a string, see if we have work to do.
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, updated_node
Expand Down Expand Up @@ -260,39 +304,26 @@ def _split_module(
statement_before_import_location = 0
import_add_location = 0

# never insert an import before initial __strict__ flag
if m.matches(
orig_module,
m.Module(
body=[
m.SimpleStatementLine(
body=[
m.Assign(
targets=[m.AssignTarget(target=m.Name("__strict__"))]
)
]
),
m.ZeroOrMore(),
]
),
):
statement_before_import_location = import_add_location = 1

# This works under the principle that while we might modify node contents,
# we have yet to modify the number of statements. So we can match on the
# original tree but break up the statements of the modified tree. If we
# change this assumption in this visitor, we will have to change this code.
for i, statement in enumerate(orig_module.body):
if i == 0 and m.matches(
statement, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())])

# Finds the location to add imports. It is the end of the first import block that occurs before any other statement (save for docstrings)

# Never insert an import before initial __strict__ flag or docstring
if _skip_first(orig_module):
statement_before_import_location = import_add_location = 1

for i, statement in enumerate(
orig_module.body[statement_before_import_location:]
):
if m.matches(
statement, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])
):
statement_before_import_location = import_add_location = 1
elif isinstance(statement, libcst.SimpleStatementLine):
for possible_import in statement.body:
for last_import in self.all_imports:
if possible_import is last_import:
import_add_location = i + 1
break
import_add_location = i + statement_before_import_location + 1
else:
break

return (
list(updated_module.body[:statement_before_import_location]),
Expand Down Expand Up @@ -414,3 +445,28 @@ def leave_Module(
*statements_after_imports,
)
)


def _skip_first(orig_module: libcst.Module) -> bool:
# Is there a __strict__ flag or docstring at the top?
if m.matches(
orig_module,
m.Module(
body=[
m.SimpleStatementLine(
body=[
m.Assign(targets=[m.AssignTarget(target=m.Name("__strict__"))])
]
),
m.ZeroOrMore(),
]
)
| m.Module(
body=[
m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
m.ZeroOrMore(),
]
),
):
return True
return False
105 changes: 59 additions & 46 deletions libcst/codemod/visitors/_gather_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,9 @@
from libcst.helpers import get_absolute_module_from_package_for_import


class GatherImportsVisitor(ContextAwareVisitor):
class _GatherImportsMixin(ContextAwareVisitor):
"""
Gathers all imports in a module and stores them as attributes on the instance.
Intended to be instantiated and passed to a :class:`~libcst.Module`
:meth:`~libcst.CSTNode.visit` method in order to gather up information about
imports on a module. Note that this is not a substitute for scope analysis or
qualified name support. Please see :ref:`libcst-scope-tutorial` for a more
robust way of determining the qualified name and definition for an arbitrary
node.
After visiting a module the following attributes will be populated:
module_imports
A sequence of strings representing modules that were imported directly, such as
in the case of ``import typing``. Each module directly imported but not aliased
will be included here.
object_mapping
A mapping of strings to sequences of strings representing modules where we
imported objects from, such as in the case of ``from typing import Optional``.
Each from import that was not aliased will be included here, where the keys of
the mapping are the module we are importing from, and the value is a
sequence of objects we are importing from the module.
module_aliases
A mapping of strings representing modules that were imported and aliased,
such as in the case of ``import typing as t``. Each module imported this
way will be represented as a key in this mapping, and the value will be
the local alias of the module.
alias_mapping
A mapping of strings to sequences of tuples representing modules where we
imported objects from and aliased using ``as`` syntax, such as in the case
of ``from typing import Optional as opt``. Each from import that was aliased
will be included here, where the keys of the mapping are the module we are
importing from, and the value is a tuple representing the original object
name and the alias.
all_imports
A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom`
statements that were encountered in the module.
A Mixin class for tracking visited imports.
"""

def __init__(self, context: CodemodContext) -> None:
Expand All @@ -59,15 +25,10 @@ def __init__(self, context: CodemodContext) -> None:
# Track the aliased imports in this transform
self.module_aliases: Dict[str, str] = {}
self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {}
# Track all of the imports found in this transform
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
# Track the import for every symbol introduced into the module
self.symbol_mapping: Dict[str, ImportItem] = {}

def visit_Import(self, node: libcst.Import) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)

def _handle_Import(self, node: libcst.Import) -> None:
for name in node.names:
alias = name.evaluated_alias
imp = ImportItem(name.evaluated_name, alias=alias)
Expand All @@ -80,10 +41,7 @@ def visit_Import(self, node: libcst.Import) -> None:
self.module_imports.add(name.evaluated_name)
self.symbol_mapping[name.evaluated_name] = imp

def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)

def _handle_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Get the module we're importing as a string.
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, node
Expand Down Expand Up @@ -128,3 +86,58 @@ def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
)
key = ia.evaluated_alias or ia.evaluated_name
self.symbol_mapping[key] = imp


class GatherImportsVisitor(_GatherImportsMixin):
"""
Gathers all imports in a module and stores them as attributes on the instance.
Intended to be instantiated and passed to a :class:`~libcst.Module`
:meth:`~libcst.CSTNode.visit` method in order to gather up information about
imports on a module. Note that this is not a substitute for scope analysis or
qualified name support. Please see :ref:`libcst-scope-tutorial` for a more
robust way of determining the qualified name and definition for an arbitrary
node.
After visiting a module the following attributes will be populated:
module_imports
A sequence of strings representing modules that were imported directly, such as
in the case of ``import typing``. Each module directly imported but not aliased
will be included here.
object_mapping
A mapping of strings to sequences of strings representing modules where we
imported objects from, such as in the case of ``from typing import Optional``.
Each from import that was not aliased will be included here, where the keys of
the mapping are the module we are importing from, and the value is a
sequence of objects we are importing from the module.
module_aliases
A mapping of strings representing modules that were imported and aliased,
such as in the case of ``import typing as t``. Each module imported this
way will be represented as a key in this mapping, and the value will be
the local alias of the module.
alias_mapping
A mapping of strings to sequences of tuples representing modules where we
imported objects from and aliased using ``as`` syntax, such as in the case
of ``from typing import Optional as opt``. Each from import that was aliased
will be included here, where the keys of the mapping are the module we are
importing from, and the value is a tuple representing the original object
name and the alias.
all_imports
A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom`
statements that were encountered in the module.
"""

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
# Track all of the imports found in this transform
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []

def visit_Import(self, node: libcst.Import) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)
self._handle_Import(node)

def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)
self._handle_ImportFrom(node)
Loading

0 comments on commit bee920d

Please sign in to comment.