Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC and WIP: Add MypyTypeInferenceProvider #831

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libcst/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ExpressionContextProvider,
)
from libcst.metadata.full_repo_manager import FullRepoManager
from libcst.metadata.mypy_type_inference_provider import MypyTypeInferenceProvider
from libcst.metadata.name_provider import (
FullyQualifiedNameProvider,
QualifiedNameProvider,
Expand Down Expand Up @@ -74,6 +75,7 @@
"ClassScope",
"ComprehensionScope",
"ScopeProvider",
"MypyTypeInferenceProvider",
"ParentNodeProvider",
"QualifiedName",
"QualifiedNameSource",
Expand Down
96 changes: 96 additions & 0 deletions libcst/metadata/mypy_type_inference_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Dict, List, Mapping, Optional, TYPE_CHECKING

import libcst as cst
from libcst._position import CodeRange
from libcst.helpers import calculate_module_and_package
from libcst.metadata.base_provider import BatchableMetadataProvider
from libcst.metadata.position_provider import PositionProvider

try:
import mypy

MYPY_INSTALLED = True
except ImportError:
MYPY_INSTALLED = False


if TYPE_CHECKING:
import mypy.nodes

import libcst.metadata.mypy_utils


def raise_on_mypy_non_installed() -> None:
if not MYPY_INSTALLED:
raise RuntimeError("mypy is not installed, please install it")


class MypyTypeInferenceProvider(
BatchableMetadataProvider["libcst.metadata.mypy_utils.MypyType"]
):
"""
Access inferred type annotation through `mypy <http://mypy-lang.org/>`_.
"""

METADATA_DEPENDENCIES = (PositionProvider,)

@classmethod
def gen_cache(
cls, root_path: Path, paths: List[str], timeout: Optional[int] = None
) -> Mapping[
str, Optional["libcst.metadata.mypy_utils.MypyTypeInferenceProviderCache"]
]:
raise_on_mypy_non_installed()

import mypy.build
import mypy.main

from libcst.metadata.mypy_utils import MypyTypeInferenceProviderCache

targets, options = mypy.main.process_options(paths)
options.preserve_asts = True
options.fine_grained_incremental = True
options.use_fine_grained_cache = True
mypy_result = mypy.build.build(targets, options=options)
cache = {}
for path in paths:
module = calculate_module_and_package(str(root_path), path).name
cache[path] = MypyTypeInferenceProviderCache(
module_name=module,
mypy_file=mypy_result.graph[module].tree,
)
return cache

def __init__(
self,
cache: Optional["libcst.metadata.mypy_utils.MypyTypeInferenceProviderCache"],
) -> None:
from libcst.metadata.mypy_utils import CodeRangeToMypyNodesBinder

super().__init__(cache)
self._mypy_node_locations: Dict[CodeRange, "mypy.nodes.Node"] = {}
if cache is None:
return
code_range_to_mypy_nodes_binder = CodeRangeToMypyNodesBinder(cache.module_name)
code_range_to_mypy_nodes_binder.visit_mypy_file(cache.mypy_file)
self._mypy_node_locations = code_range_to_mypy_nodes_binder.locations

def _parse_metadata(self, node: cst.CSTNode) -> None:
range = self.get_metadata(PositionProvider, node)
if range in self._mypy_node_locations:
self.set_metadata(node, self._mypy_node_locations.get(range))

def visit_Name(self, node: cst.Name) -> Optional[bool]:
self._parse_metadata(node)

def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:
self._parse_metadata(node)

def visit_Call(self, node: cst.Call) -> Optional[bool]:
self._parse_metadata(node)
146 changes: 146 additions & 0 deletions libcst/metadata/mypy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Dict, Optional, Union

import mypy.build
import mypy.main
import mypy.modulefinder
import mypy.nodes
import mypy.options
import mypy.patterns
import mypy.traverser
import mypy.types
import mypy.typetraverser

from libcst._add_slots import add_slots
from libcst._position import CodePosition, CodeRange


@add_slots
@dataclass(frozen=True)
class MypyTypeInferenceProviderCache:
module_name: str
mypy_file: mypy.nodes.MypyFile


@add_slots
@dataclass(frozen=True)
class MypyType:
is_type_constructor: bool
mypy_type: Union[mypy.types.Type, mypy.nodes.TypeInfo]
fullname: str = field(init=False)

def __post_init__(self) -> None:
if isinstance(self.mypy_type, mypy.types.Type):
fullname = str(self.mypy_type)
else:
fullname = self.mypy_type.fullname
if self.is_type_constructor:
fullname = f"typing.Type[{fullname}]"
object.__setattr__(self, "fullname", fullname)

def __str__(self) -> str:
return self.fullname


class CodeRangeToMypyNodesBinder(
mypy.traverser.TraverserVisitor, mypy.typetraverser.TypeTraverserVisitor
):
def __init__(self, module_name: str) -> None:
super().__init__()
self.locations: Dict[CodeRange, MypyType] = {}
self.in_type_alias_expr = False
self.module_name = module_name

# Helpers

@staticmethod
def get_code_range(o: mypy.nodes.Context) -> CodeRange:
return CodeRange(
start=CodePosition(o.line, o.column),
end=CodePosition(o.end_line, o.end_column),
)

@staticmethod
def check_bounds(o: mypy.nodes.Context) -> bool:
return (
(o.line is not None)
and (o.line >= 1)
and (o.column is not None)
and (o.column >= 0)
and (o.end_line is not None)
and (o.end_line >= 1)
and (o.end_column is not None)
and (o.end_column >= 0)
)

def record_type_location_using_code_range(
self,
code_range: CodeRange,
t: Optional[Union[mypy.types.Type, mypy.nodes.TypeInfo]],
is_type_constructor: bool,
) -> None:
if t is not None:
self.locations[code_range] = MypyType(
is_type_constructor=is_type_constructor, mypy_type=t
)

def record_type_location(
self,
o: mypy.nodes.Context,
t: Optional[Union[mypy.types.Type, mypy.nodes.TypeInfo]],
is_type_constructor: bool,
) -> None:
if self.check_bounds(o):
self.record_type_location_using_code_range(
code_range=self.get_code_range(o),
t=t,
is_type_constructor=is_type_constructor,
)

def record_location_by_name_expr(
self, code_range: CodeRange, o: mypy.nodes.NameExpr, is_type_constructor: bool
) -> None:
if isinstance(o.node, mypy.nodes.Var):
self.record_type_location_using_code_range(
code_range=code_range, t=o.node.type, is_type_constructor=False
)
elif isinstance(o.node, mypy.nodes.TypeInfo):
self.record_type_location_using_code_range(
code_range=code_range, t=o.node, is_type_constructor=is_type_constructor
)

# Actual visitors

def visit_var(self, o: mypy.nodes.Var) -> None:
super().visit_var(o)
self.record_type_location(o=o, t=o.type, is_type_constructor=False)

def visit_name_expr(self, o: mypy.nodes.NameExpr) -> None:
super().visit_name_expr(o)
# Implementation in base classes is omitted, record it if it is variable or class
self.record_location_by_name_expr(
self.get_code_range(o), o, is_type_constructor=True
)

def visit_member_expr(self, o: mypy.nodes.MemberExpr) -> None:
super().visit_member_expr(o)
# Implementation in base classes is omitted, record it
# o.def_var should not be None after mypy run, checking here just to be sure
if o.def_var is not None:
self.record_type_location(o=o, t=o.def_var.type, is_type_constructor=False)

def visit_call_expr(self, o: mypy.nodes.CallExpr) -> None:
super().visit_call_expr(o)
if isinstance(o.callee, mypy.nodes.NameExpr):
self.record_location_by_name_expr(
code_range=self.get_code_range(o), o=o.callee, is_type_constructor=False
)

def visit_instance(self, o: mypy.types.Instance) -> None:
super().visit_instance(o)
self.record_type_location(o=o, t=o, is_type_constructor=False)
63 changes: 63 additions & 0 deletions libcst/metadata/tests/test_mypy_type_inference_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path

import libcst as cst
from libcst import MetadataWrapper
from libcst.metadata.mypy_type_inference_provider import MypyTypeInferenceProvider
from libcst.testing.utils import data_provider, UnitTest
from libcst.tests.test_pyre_integration import TEST_SUITE_PATH


def _test_simple_class_helper(test: UnitTest, wrapper: MetadataWrapper) -> None:
mypy_nodes = wrapper.resolve(MypyTypeInferenceProvider)
m = wrapper.module
assign = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(
cst.ensure_type(m.body[1].body, cst.IndentedBlock).body[0],
cst.FunctionDef,
).body.body[0],
cst.SimpleStatementLine,
).body[0],
cst.AnnAssign,
)
self_number_attr = cst.ensure_type(assign.target, cst.Attribute)
test.assertEqual(str(mypy_nodes[self_number_attr]), "builtins.int")

# self
test.assertEqual(
str(mypy_nodes[self_number_attr.value]), "libcst.tests.pyre.simple_class.Item"
)
collector_assign = cst.ensure_type(
cst.ensure_type(m.body[3], cst.SimpleStatementLine).body[0], cst.Assign
)
collector = collector_assign.targets[0].target
test.assertEqual(
str(mypy_nodes[collector]), "libcst.tests.pyre.simple_class.ItemCollector"
)
items_assign = cst.ensure_type(
cst.ensure_type(m.body[4], cst.SimpleStatementLine).body[0], cst.AnnAssign
)
items = items_assign.target
test.assertEqual(
str(mypy_nodes[items]), "typing.Sequence[libcst.tests.pyre.simple_class.Item]"
)


class MypyTypeInferenceProviderTest(UnitTest):
@data_provider(
((TEST_SUITE_PATH / "simple_class.py", TEST_SUITE_PATH / "simple_class.json"),)
)
def test_simple_class_types(self, source_path: Path, data_path: Path) -> None:
file = str(source_path)
repo_root = Path(__file__).parents[3]
cache = MypyTypeInferenceProvider.gen_cache(repo_root, [file])
wrapper = MetadataWrapper(
cst.parse_module(source_path.read_text()),
cache={MypyTypeInferenceProvider: cache[file]},
)
_test_simple_class_helper(self, wrapper)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ hypothesis>=4.36.0
hypothesmith>=0.0.4
jupyter>=1.0.0
maturin>=0.8.3,<0.14
mypy>=0.991
nbsphinx>=0.4.2
prompt-toolkit>=2.0.9
pyre-check==0.9.9; platform_system != "Windows"
Expand Down