Skip to content

Commit

Permalink
Fixes incorrectly missing annotations (#561)
Browse files Browse the repository at this point in the history
Co-authored-by: Zsolt Dollenstein <[email protected]>
  • Loading branch information
lpetre and zsol authored Nov 23, 2021
1 parent 3895925 commit 58b447d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 75 deletions.
2 changes: 0 additions & 2 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,6 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:
def visit_Call(self, node: cst.Call) -> Optional[bool]:
self.__top_level_attribute_stack.append(None)
self.__in_type_hint_stack.append(False)
self.__in_annotation_stack.append(False)
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node)}
if "typing.NewType" in qnames or "typing.TypeVar" in qnames:
node.func.visit(self)
Expand All @@ -896,7 +895,6 @@ def visit_Call(self, node: cst.Call) -> Optional[bool]:
def leave_Call(self, original_node: cst.Call) -> None:
self.__top_level_attribute_stack.pop()
self.__in_type_hint_stack.pop()
self.__in_annotation_stack.pop()

def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]:
self.__in_annotation_stack.append(True)
Expand Down
173 changes: 100 additions & 73 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import sys
from textwrap import dedent
from typing import Mapping, Tuple, cast
from typing import Mapping, Tuple, cast, Sequence
from unittest import mock

import libcst as cst
Expand Down Expand Up @@ -1768,78 +1768,87 @@ def test_walrus_accesses(self) -> None:
),
)

def test_cast(self) -> None:
def assert_parsed(code, *calls):
parse = cst.parse_module
with mock.patch("libcst.parse_module") as parse_mock:
parse_mock.side_effect = parse
get_scope_metadata_provider(dedent(code))
calls = [mock.call(dedent(code))] + list(calls)
self.assertEqual(parse_mock.call_count, len(calls))
parse_mock.assert_has_calls(calls)

assert_parsed(
"""
from typing import TypeVar
TypeVar("Name", "int")
""",
mock.call("int"),
)

assert_parsed(
"""
from typing import Dict
Dict["str", "int"]
""",
mock.call("str"),
mock.call("int"),
)

assert_parsed(
"""
from typing import Dict, cast
cast(Dict[str, str], {})["3rr0r"]
"""
)

assert_parsed(
"""
from typing import cast
cast(str, "foo")
""",
)

assert_parsed(
"""
from typing import cast
cast("int", "foo")
""",
mock.call("int"),
)

assert_parsed(
"""
from typing import TypeVar
TypeVar("Name", func("int"))
""",
)

assert_parsed(
"""
from typing import Literal
Literal[\"G\"]
""",
)

assert_parsed(
r"""
from typing import TypeVar, Optional
from a import G
TypeVar("G2", bound="Optional[\"G\"]")
""",
mock.call('Optional["G"]'),
mock.call("G"),
)
@data_provider(
{
"TypeVar": {
"code": """
from typing import TypeVar
TypeVar("Name", "int")
""",
"calls": [mock.call("int")],
},
"Dict": {
"code": """
from typing import Dict
Dict["str", "int"]
""",
"calls": [mock.call("str"), mock.call("int")],
},
"cast_no_annotation": {
"code": """
from typing import Dict, cast
cast(Dict[str, str], {})["3rr0r"]
""",
"calls": [],
},
"cast_second_arg": {
"code": """
from typing import cast
cast(str, "foo")
""",
"calls": [],
},
"cast_first_arg": {
"code": """
from typing import cast
cast("int", "foo")
""",
"calls": [
mock.call("int"),
],
},
"typevar_func": {
"code": """
from typing import TypeVar
TypeVar("Name", func("int"))
""",
"calls": [],
},
"literal": {
"code": """
from typing import Literal
Literal[\"G\"]
""",
"calls": [],
},
"nested_str": {
"code": r"""
from typing import TypeVar, Optional
from a import G
TypeVar("G2", bound="Optional[\"G\"]")
""",
"calls": [mock.call('Optional["G"]'), mock.call("G")],
},
"class_self_ref": {
"code": """
from typing import TypeVar
class HelperClass:
value: TypeVar("THelperClass", bound="HelperClass")
""",
"calls": [mock.call("HelperClass")],
},
}
)
def test_parse_string_annotations(
self, *, code: str, calls: Sequence[mock._Call]
) -> None:
parse = cst.parse_module
with mock.patch("libcst.parse_module") as parse_mock:
parse_mock.side_effect = parse
get_scope_metadata_provider(dedent(code))
calls = [mock.call(dedent(code))] + list(calls)
self.assertEqual(parse_mock.call_count, len(calls))
parse_mock.assert_has_calls(calls)

def test_builtin_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
Expand Down Expand Up @@ -1907,3 +1916,21 @@ def foo():

global_pow_accesses = list(global_pow_assignment.references)
self.assertEqual(len(global_pow_accesses), 2)

def test_annotation_access_in_typevar_bound(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import TypeVar
class Test:
var: TypeVar("T", bound="Test")
"""
)
imp = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
scope = scopes[imp]
assignment = list(scope["Test"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)

0 comments on commit 58b447d

Please sign in to comment.