Skip to content

Commit

Permalink
Move find_qualified_names_for in the Assignment class. (#557)
Browse files Browse the repository at this point in the history
Move _NameUtil.find_qualified_name_for ... method inside Assignment classes.
  • Loading branch information
giomeg authored Nov 19, 2021
1 parent 9732f5e commit c48cc21
Showing 1 changed file with 100 additions and 116 deletions.
216 changes: 100 additions & 116 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,23 @@ def record_assignments(self, name: str) -> None:
self.__assignments |= previous_assignments


class QualifiedNameSource(Enum):
IMPORT = auto()
BUILTIN = auto()
LOCAL = auto()


@add_slots
@dataclass(frozen=True)
class QualifiedName:
#: Qualified name, e.g. ``a.b.c`` or ``fn.<locals>.var``.
name: str

#: Source of the name, either :attr:`QualifiedNameSource.IMPORT`, :attr:`QualifiedNameSource.BUILTIN`
#: or :attr:`QualifiedNameSource.LOCAL`.
source: QualifiedNameSource


class BaseAssignment(abc.ABC):
"""Abstract base class of :class:`Assignment` and :class:`BuitinAssignment`."""

Expand Down Expand Up @@ -175,6 +192,10 @@ def _index(self) -> int:
"""Return an integer that represents the order of assignments in `scope`"""
return -1

@abc.abstractmethod
def get_qualified_names_for(self, full_name: str) -> Set[QualifiedName]:
...


class Assignment(BaseAssignment):
"""An assignment records the name, CSTNode and its accesses."""
Expand All @@ -195,6 +216,26 @@ def __init__(
def _index(self) -> int:
return self.__index

def get_qualified_names_for(self, full_name: str) -> Set[QualifiedName]:
scope = self.scope
name_prefixes = []
while scope:
if isinstance(scope, ClassScope):
name_prefixes.append(scope.name)
elif isinstance(scope, FunctionScope):
name_prefixes.append(f"{scope.name}.<locals>")
elif isinstance(scope, ComprehensionScope):
name_prefixes.append("<comprehension>")
elif not isinstance(scope, (GlobalScope, BuiltinScope)):
raise Exception(f"Unexpected Scope: {scope}")

scope = scope.parent if scope.parent != scope else None

parts = [*reversed(name_prefixes)]
if full_name:
parts.append(full_name)
return {QualifiedName(".".join(parts), QualifiedNameSource.LOCAL)}


# even though we don't override the constructor.
class BuiltinAssignment(BaseAssignment):
Expand All @@ -205,7 +246,8 @@ class BuiltinAssignment(BaseAssignment):
`types <https://docs.python.org/3/library/stdtypes.html>`_.
"""

pass
def get_qualified_names_for(self, full_name: str) -> Set[QualifiedName]:
return {QualifiedName(f"builtins.{self.name}", QualifiedNameSource.BUILTIN)}


class ImportAssignment(Assignment):
Expand All @@ -224,6 +266,55 @@ def __init__(
super().__init__(name, scope, node, index)
self.as_name = as_name

def get_module_name_for_import(self) -> str:
module = ""
if isinstance(self.node, cst.ImportFrom):
module_attr = self.node.module
relative = self.node.relative
if module_attr:
module = get_full_name_for_node(module_attr) or ""
if relative:
module = "." * len(relative) + module
return module

def get_qualified_names_for(self, full_name: str) -> Set[QualifiedName]:
module = self.get_module_name_for_import()
results = set()
import_names = self.node.names
if not isinstance(import_names, cst.ImportStar):
for name in import_names:
real_name = get_full_name_for_node(name.name)
if not real_name:
continue
# real_name can contain `.` for dotted imports
# for these we want to find the longest prefix that matches full_name
parts = real_name.split(".")
real_names = [".".join(parts[:i]) for i in range(len(parts), 0, -1)]
for real_name in real_names:
as_name = real_name
if module and module.endswith("."):
# from . import a
# real_name should be ".a"
real_name = f"{module}{real_name}"
elif module:
real_name = f"{module}.{real_name}"
if name and name.asname:
eval_alias = name.evaluated_alias
if eval_alias is not None:
as_name = eval_alias
if full_name.startswith(as_name):
remaining_name = full_name.split(as_name, 1)[1].lstrip(".")
results.add(
QualifiedName(
f"{real_name}.{remaining_name}"
if remaining_name
else real_name,
QualifiedNameSource.IMPORT,
)
)
break
return results


class Assignments:
"""A container to provide all assignments in a scope."""
Expand Down Expand Up @@ -269,23 +360,6 @@ def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
return len(self[node]) > 0


class QualifiedNameSource(Enum):
IMPORT = auto()
BUILTIN = auto()
LOCAL = auto()


@add_slots
@dataclass(frozen=True)
class QualifiedName:
#: Qualified name, e.g. ``a.b.c`` or ``fn.<locals>.var``.
name: str

#: Source of the name, either :attr:`QualifiedNameSource.IMPORT`, :attr:`QualifiedNameSource.BUILTIN`
#: or :attr:`QualifiedNameSource.LOCAL`.
source: QualifiedNameSource


class _NameUtil:
@staticmethod
def get_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]:
Expand All @@ -302,84 +376,6 @@ def get_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]:
return _NameUtil.get_name_for(node.name)
return None

@staticmethod
def get_module_name_for_import_alike(
assignment_node: Union[cst.Import, cst.ImportFrom]
) -> str:
module = ""
if isinstance(assignment_node, cst.ImportFrom):
module_attr = assignment_node.module
relative = assignment_node.relative
if module_attr:
module = get_full_name_for_node(module_attr) or ""
if relative:
module = "." * len(relative) + module
return module

@staticmethod
def find_qualified_name_for_import_alike(
assignment_node: Union[cst.Import, cst.ImportFrom], full_name: str
) -> Set[QualifiedName]:
module = _NameUtil.get_module_name_for_import_alike(assignment_node)
results = set()
import_names = assignment_node.names
if not isinstance(import_names, cst.ImportStar):
for name in import_names:
real_name = get_full_name_for_node(name.name)
if not real_name:
continue
# real_name can contain `.` for dotted imports
# for these we want to find the longest prefix that matches full_name
parts = real_name.split(".")
real_names = [".".join(parts[:i]) for i in range(len(parts), 0, -1)]
for real_name in real_names:
as_name = real_name
if module and module.endswith("."):
# from . import a
# real_name should be ".a"
real_name = f"{module}{real_name}"
elif module:
real_name = f"{module}.{real_name}"
if name and name.asname:
eval_alias = name.evaluated_alias
if eval_alias is not None:
as_name = eval_alias
if full_name.startswith(as_name):
remaining_name = full_name.split(as_name, 1)[1].lstrip(".")
results.add(
QualifiedName(
f"{real_name}.{remaining_name}"
if remaining_name
else real_name,
QualifiedNameSource.IMPORT,
)
)
break
return results

@staticmethod
def find_qualified_name_for_non_import(
assignment: Assignment, remaining_name: str
) -> Set[QualifiedName]:
scope = assignment.scope
name_prefixes = []
while scope:
if isinstance(scope, ClassScope):
name_prefixes.append(scope.name)
elif isinstance(scope, FunctionScope):
name_prefixes.append(f"{scope.name}.<locals>")
elif isinstance(scope, ComprehensionScope):
name_prefixes.append("<comprehension>")
elif not isinstance(scope, (GlobalScope, BuiltinScope)):
raise Exception(f"Unexpected Scope: {scope}")

scope = scope.parent if scope.parent != scope else None

parts = [*reversed(name_prefixes)]
if remaining_name:
parts.append(remaining_name)
return {QualifiedName(".".join(parts), QualifiedNameSource.LOCAL)}


class Scope(abc.ABC):
"""
Expand Down Expand Up @@ -555,26 +551,14 @@ def f(self) -> "c":
assignments = self[prefix]
break
for assignment in assignments:
if isinstance(assignment, Assignment):
assignment_node = assignment.node
if isinstance(assignment_node, (cst.Import, cst.ImportFrom)):
names = _NameUtil.find_qualified_name_for_import_alike(
assignment_node, full_name
)
else:
names = _NameUtil.find_qualified_name_for_non_import(
assignment, full_name
)
if not isinstance(node, str) and _is_assignment(node, assignment_node):
return names
else:
results |= names
elif isinstance(assignment, BuiltinAssignment):
results.add(
QualifiedName(
f"builtins.{assignment.name}", QualifiedNameSource.BUILTIN
)
)
names = assignment.get_qualified_names_for(full_name)
if (
isinstance(assignment, Assignment)
and not isinstance(node, str)
and _is_assignment(node, assignment.node)
):
return names
results |= names
return results

@property
Expand Down

0 comments on commit c48cc21

Please sign in to comment.