diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 9e6556ce5..0c7df8d85 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -874,7 +874,6 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: def visit_Call(self, node: cst.Call) -> Optional[bool]: self.__top_level_attribute_stack.append(None) self.__in_type_hint_stack.append(False) - self.__in_annotation_stack.append(False) qnames = {qn.name for qn in self.scope.get_qualified_names_for(node)} if "typing.NewType" in qnames or "typing.TypeVar" in qnames: node.func.visit(self) @@ -896,7 +895,6 @@ def visit_Call(self, node: cst.Call) -> Optional[bool]: def leave_Call(self, original_node: cst.Call) -> None: self.__top_level_attribute_stack.pop() self.__in_type_hint_stack.pop() - self.__in_annotation_stack.pop() def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]: self.__in_annotation_stack.append(True) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 85d9266b9..4f84f439b 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -6,7 +6,7 @@ import sys from textwrap import dedent -from typing import Mapping, Tuple, cast +from typing import Mapping, Tuple, cast, Sequence from unittest import mock import libcst as cst @@ -1768,78 +1768,87 @@ def test_walrus_accesses(self) -> None: ), ) - def test_cast(self) -> None: - def assert_parsed(code, *calls): - parse = cst.parse_module - with mock.patch("libcst.parse_module") as parse_mock: - parse_mock.side_effect = parse - get_scope_metadata_provider(dedent(code)) - calls = [mock.call(dedent(code))] + list(calls) - self.assertEqual(parse_mock.call_count, len(calls)) - parse_mock.assert_has_calls(calls) - - assert_parsed( - """ - from typing import TypeVar - TypeVar("Name", "int") - """, - mock.call("int"), - ) - - assert_parsed( - """ - from typing import Dict - Dict["str", "int"] - """, - mock.call("str"), - mock.call("int"), - ) - - assert_parsed( - """ - from typing import Dict, cast - cast(Dict[str, str], {})["3rr0r"] - """ - ) - - assert_parsed( - """ - from typing import cast - cast(str, "foo") - """, - ) - - assert_parsed( - """ - from typing import cast - cast("int", "foo") - """, - mock.call("int"), - ) - - assert_parsed( - """ - from typing import TypeVar - TypeVar("Name", func("int")) - """, - ) - - assert_parsed( - """ - from typing import Literal - Literal[\"G\"] - """, - ) - - assert_parsed( - r""" - from typing import TypeVar, Optional - from a import G - TypeVar("G2", bound="Optional[\"G\"]") - """, - mock.call('Optional["G"]'), - mock.call("G"), - ) + @data_provider( + { + "TypeVar": { + "code": """ + from typing import TypeVar + TypeVar("Name", "int") + """, + "calls": [mock.call("int")], + }, + "Dict": { + "code": """ + from typing import Dict + Dict["str", "int"] + """, + "calls": [mock.call("str"), mock.call("int")], + }, + "cast_no_annotation": { + "code": """ + from typing import Dict, cast + cast(Dict[str, str], {})["3rr0r"] + """, + "calls": [], + }, + "cast_second_arg": { + "code": """ + from typing import cast + cast(str, "foo") + """, + "calls": [], + }, + "cast_first_arg": { + "code": """ + from typing import cast + cast("int", "foo") + """, + "calls": [ + mock.call("int"), + ], + }, + "typevar_func": { + "code": """ + from typing import TypeVar + TypeVar("Name", func("int")) + """, + "calls": [], + }, + "literal": { + "code": """ + from typing import Literal + Literal[\"G\"] + """, + "calls": [], + }, + "nested_str": { + "code": r""" + from typing import TypeVar, Optional + from a import G + TypeVar("G2", bound="Optional[\"G\"]") + """, + "calls": [mock.call('Optional["G"]'), mock.call("G")], + }, + "class_self_ref": { + "code": """ + from typing import TypeVar + class HelperClass: + value: TypeVar("THelperClass", bound="HelperClass") + """, + "calls": [mock.call("HelperClass")], + }, + } + ) + def test_parse_string_annotations( + self, *, code: str, calls: Sequence[mock._Call] + ) -> None: + parse = cst.parse_module + with mock.patch("libcst.parse_module") as parse_mock: + parse_mock.side_effect = parse + get_scope_metadata_provider(dedent(code)) + calls = [mock.call(dedent(code))] + list(calls) + self.assertEqual(parse_mock.call_count, len(calls)) + parse_mock.assert_has_calls(calls) def test_builtin_scope(self) -> None: m, scopes = get_scope_metadata_provider( @@ -1907,3 +1916,21 @@ def foo(): global_pow_accesses = list(global_pow_assignment.references) self.assertEqual(len(global_pow_accesses), 2) + + def test_annotation_access_in_typevar_bound(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + from typing import TypeVar + class Test: + var: TypeVar("T", bound="Test") + """ + ) + imp = ensure_type( + ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom + ) + scope = scopes[imp] + assignment = list(scope["Test"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + references = list(assignment.references) + self.assertTrue(references[0].is_annotation)