diff --git a/libcst/metadata/__init__.py b/libcst/metadata/__init__.py index 66e7e5251..ab33b2bd3 100644 --- a/libcst/metadata/__init__.py +++ b/libcst/metadata/__init__.py @@ -17,6 +17,7 @@ ExpressionContextProvider, ) from libcst.metadata.full_repo_manager import FullRepoManager +from libcst.metadata.mypy_type_inference_provider import MypyTypeInferenceProvider from libcst.metadata.name_provider import ( FullyQualifiedNameProvider, QualifiedNameProvider, @@ -74,6 +75,7 @@ "ClassScope", "ComprehensionScope", "ScopeProvider", + "MypyTypeInferenceProvider", "ParentNodeProvider", "QualifiedName", "QualifiedNameSource", diff --git a/libcst/metadata/mypy_type_inference_provider.py b/libcst/metadata/mypy_type_inference_provider.py new file mode 100644 index 000000000..29e9a6b99 --- /dev/null +++ b/libcst/metadata/mypy_type_inference_provider.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Dict, List, Mapping, Optional, TYPE_CHECKING + +import libcst as cst +from libcst._position import CodeRange +from libcst.helpers import calculate_module_and_package +from libcst.metadata.base_provider import BatchableMetadataProvider +from libcst.metadata.position_provider import PositionProvider + +try: + import mypy + + MYPY_INSTALLED = True +except ImportError: + MYPY_INSTALLED = False + + +if TYPE_CHECKING: + import mypy.nodes + + import libcst.metadata.mypy_utils + + +def raise_on_mypy_non_installed() -> None: + if not MYPY_INSTALLED: + raise RuntimeError("mypy is not installed, please install it") + + +class MypyTypeInferenceProvider( + BatchableMetadataProvider["libcst.metadata.mypy_utils.MypyType"] +): + """ + Access inferred type annotation through `mypy `_. + """ + + METADATA_DEPENDENCIES = (PositionProvider,) + + @classmethod + def gen_cache( + cls, root_path: Path, paths: List[str], timeout: Optional[int] = None + ) -> Mapping[ + str, Optional["libcst.metadata.mypy_utils.MypyTypeInferenceProviderCache"] + ]: + raise_on_mypy_non_installed() + + import mypy.build + import mypy.main + + from libcst.metadata.mypy_utils import MypyTypeInferenceProviderCache + + targets, options = mypy.main.process_options(paths) + options.preserve_asts = True + options.fine_grained_incremental = True + options.use_fine_grained_cache = True + mypy_result = mypy.build.build(targets, options=options) + cache = {} + for path in paths: + module = calculate_module_and_package(str(root_path), path).name + cache[path] = MypyTypeInferenceProviderCache( + module_name=module, + mypy_file=mypy_result.graph[module].tree, + ) + return cache + + def __init__( + self, + cache: Optional["libcst.metadata.mypy_utils.MypyTypeInferenceProviderCache"], + ) -> None: + from libcst.metadata.mypy_utils import CodeRangeToMypyNodesBinder + + super().__init__(cache) + self._mypy_node_locations: Dict[CodeRange, "mypy.nodes.Node"] = {} + if cache is None: + return + code_range_to_mypy_nodes_binder = CodeRangeToMypyNodesBinder(cache.module_name) + code_range_to_mypy_nodes_binder.visit_mypy_file(cache.mypy_file) + self._mypy_node_locations = code_range_to_mypy_nodes_binder.locations + + def _parse_metadata(self, node: cst.CSTNode) -> None: + range = self.get_metadata(PositionProvider, node) + if range in self._mypy_node_locations: + self.set_metadata(node, self._mypy_node_locations.get(range)) + + def visit_Name(self, node: cst.Name) -> Optional[bool]: + self._parse_metadata(node) + + def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: + self._parse_metadata(node) + + def visit_Call(self, node: cst.Call) -> Optional[bool]: + self._parse_metadata(node) diff --git a/libcst/metadata/mypy_utils.py b/libcst/metadata/mypy_utils.py new file mode 100644 index 000000000..84ca4d85d --- /dev/null +++ b/libcst/metadata/mypy_utils.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import Dict, Optional, Union + +import mypy.build +import mypy.main +import mypy.modulefinder +import mypy.nodes +import mypy.options +import mypy.patterns +import mypy.traverser +import mypy.types +import mypy.typetraverser + +from libcst._add_slots import add_slots +from libcst._position import CodePosition, CodeRange + + +@add_slots +@dataclass(frozen=True) +class MypyTypeInferenceProviderCache: + module_name: str + mypy_file: mypy.nodes.MypyFile + + +@add_slots +@dataclass(frozen=True) +class MypyType: + is_type_constructor: bool + mypy_type: Union[mypy.types.Type, mypy.nodes.TypeInfo] + fullname: str = field(init=False) + + def __post_init__(self) -> None: + if isinstance(self.mypy_type, mypy.types.Type): + fullname = str(self.mypy_type) + else: + fullname = self.mypy_type.fullname + if self.is_type_constructor: + fullname = f"typing.Type[{fullname}]" + object.__setattr__(self, "fullname", fullname) + + def __str__(self) -> str: + return self.fullname + + +class CodeRangeToMypyNodesBinder( + mypy.traverser.TraverserVisitor, mypy.typetraverser.TypeTraverserVisitor +): + def __init__(self, module_name: str) -> None: + super().__init__() + self.locations: Dict[CodeRange, MypyType] = {} + self.in_type_alias_expr = False + self.module_name = module_name + + # Helpers + + @staticmethod + def get_code_range(o: mypy.nodes.Context) -> CodeRange: + return CodeRange( + start=CodePosition(o.line, o.column), + end=CodePosition(o.end_line, o.end_column), + ) + + @staticmethod + def check_bounds(o: mypy.nodes.Context) -> bool: + return ( + (o.line is not None) + and (o.line >= 1) + and (o.column is not None) + and (o.column >= 0) + and (o.end_line is not None) + and (o.end_line >= 1) + and (o.end_column is not None) + and (o.end_column >= 0) + ) + + def record_type_location_using_code_range( + self, + code_range: CodeRange, + t: Optional[Union[mypy.types.Type, mypy.nodes.TypeInfo]], + is_type_constructor: bool, + ) -> None: + if t is not None: + self.locations[code_range] = MypyType( + is_type_constructor=is_type_constructor, mypy_type=t + ) + + def record_type_location( + self, + o: mypy.nodes.Context, + t: Optional[Union[mypy.types.Type, mypy.nodes.TypeInfo]], + is_type_constructor: bool, + ) -> None: + if self.check_bounds(o): + self.record_type_location_using_code_range( + code_range=self.get_code_range(o), + t=t, + is_type_constructor=is_type_constructor, + ) + + def record_location_by_name_expr( + self, code_range: CodeRange, o: mypy.nodes.NameExpr, is_type_constructor: bool + ) -> None: + if isinstance(o.node, mypy.nodes.Var): + self.record_type_location_using_code_range( + code_range=code_range, t=o.node.type, is_type_constructor=False + ) + elif isinstance(o.node, mypy.nodes.TypeInfo): + self.record_type_location_using_code_range( + code_range=code_range, t=o.node, is_type_constructor=is_type_constructor + ) + + # Actual visitors + + def visit_var(self, o: mypy.nodes.Var) -> None: + super().visit_var(o) + self.record_type_location(o=o, t=o.type, is_type_constructor=False) + + def visit_name_expr(self, o: mypy.nodes.NameExpr) -> None: + super().visit_name_expr(o) + # Implementation in base classes is omitted, record it if it is variable or class + self.record_location_by_name_expr( + self.get_code_range(o), o, is_type_constructor=True + ) + + def visit_member_expr(self, o: mypy.nodes.MemberExpr) -> None: + super().visit_member_expr(o) + # Implementation in base classes is omitted, record it + # o.def_var should not be None after mypy run, checking here just to be sure + if o.def_var is not None: + self.record_type_location(o=o, t=o.def_var.type, is_type_constructor=False) + + def visit_call_expr(self, o: mypy.nodes.CallExpr) -> None: + super().visit_call_expr(o) + if isinstance(o.callee, mypy.nodes.NameExpr): + self.record_location_by_name_expr( + code_range=self.get_code_range(o), o=o.callee, is_type_constructor=False + ) + + def visit_instance(self, o: mypy.types.Instance) -> None: + super().visit_instance(o) + self.record_type_location(o=o, t=o, is_type_constructor=False) diff --git a/libcst/metadata/tests/test_mypy_type_inference_provider.py b/libcst/metadata/tests/test_mypy_type_inference_provider.py new file mode 100644 index 000000000..57fcdb354 --- /dev/null +++ b/libcst/metadata/tests/test_mypy_type_inference_provider.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +from pathlib import Path +from unittest import skipIf + +import libcst as cst +from libcst import MetadataWrapper +from libcst.metadata.mypy_type_inference_provider import MypyTypeInferenceProvider +from libcst.testing.utils import data_provider, UnitTest +from libcst.tests.test_pyre_integration import TEST_SUITE_PATH + + +def _test_simple_class_helper(test: UnitTest, wrapper: MetadataWrapper) -> None: + mypy_nodes = wrapper.resolve(MypyTypeInferenceProvider) + m = wrapper.module + assign = cst.ensure_type( + cst.ensure_type( + cst.ensure_type( + cst.ensure_type(m.body[1].body, cst.IndentedBlock).body[0], + cst.FunctionDef, + ).body.body[0], + cst.SimpleStatementLine, + ).body[0], + cst.AnnAssign, + ) + self_number_attr = cst.ensure_type(assign.target, cst.Attribute) + test.assertEqual(str(mypy_nodes[self_number_attr]), "builtins.int") + + # self + test.assertEqual( + str(mypy_nodes[self_number_attr.value]), "libcst.tests.pyre.simple_class.Item" + ) + collector_assign = cst.ensure_type( + cst.ensure_type(m.body[3], cst.SimpleStatementLine).body[0], cst.Assign + ) + collector = collector_assign.targets[0].target + test.assertEqual( + str(mypy_nodes[collector]), "libcst.tests.pyre.simple_class.ItemCollector" + ) + items_assign = cst.ensure_type( + cst.ensure_type(m.body[4], cst.SimpleStatementLine).body[0], cst.AnnAssign + ) + items = items_assign.target + test.assertEqual( + str(mypy_nodes[items]), "typing.Sequence[libcst.tests.pyre.simple_class.Item]" + ) + + +class MypyTypeInferenceProviderTest(UnitTest): + @data_provider( + ((TEST_SUITE_PATH / "simple_class.py", TEST_SUITE_PATH / "simple_class.json"),) + ) + def test_simple_class_types(self, source_path: Path, data_path: Path) -> None: + file = str(source_path) + repo_root = Path(__file__).parents[3] + cache = MypyTypeInferenceProvider.gen_cache(repo_root, [file]) + wrapper = MetadataWrapper( + cst.parse_module(source_path.read_text()), + cache={MypyTypeInferenceProvider: cache[file]}, + ) + _test_simple_class_helper(self, wrapper) diff --git a/requirements-dev.txt b/requirements-dev.txt index 09bcd66b3..cc802708e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,6 +7,7 @@ hypothesis>=4.36.0 hypothesmith>=0.0.4 jupyter>=1.0.0 maturin>=0.8.3,<0.14 +mypy>=0.991 nbsphinx>=0.4.2 prompt-toolkit>=2.0.9 pyre-check==0.9.9; platform_system != "Windows"