Skip to content

Commit

Permalink
Handle string type references in cast() (#418)
Browse files Browse the repository at this point in the history
* Handle string type references in cast()

* Directly visit the first argument of cast()

Co-authored-by: Zsolt Dollenstein <[email protected]>

Co-authored-by: Zsolt Dollenstein <[email protected]>
  • Loading branch information
Kronuz and zsol authored Nov 17, 2020
1 parent 2ef7302 commit 1100951
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
18 changes: 11 additions & 7 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,13 +773,19 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:

def visit_Call(self, node: cst.Call) -> Optional[bool]:
self.__top_level_attribute_stack.append(None)
qnames = self.scope.get_qualified_names_for(node)
if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames):
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node)}
if "typing.NewType" in qnames or "typing.TypeVar" in qnames:
node.func.visit(self)
self.__in_type_hint.add(node)
for arg in node.args[1:]:
arg.visit(self)
return False
if "typing.cast" in qnames:
node.func.visit(self)
self.__in_type_hint.add(node)
if len(node.args) > 0:
node.args[0].visit(self)
return False
return True

def leave_Call(self, original_node: cst.Call) -> None:
Expand Down Expand Up @@ -814,12 +820,10 @@ def _handle_string_annotation(
return False

def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
qnames = self.scope.get_qualified_names_for(node.value)
if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames):
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node.value)}
if any(qn.startswith(("typing.", "typing_extensions.")) for qn in qnames):
self.__in_type_hint.add(node)
if any(
qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames
):
if "typing.Literal" in qnames or "typing_extensions.Literal" in qnames:
self.__in_ignored_subscript.add(node)
return True

Expand Down
33 changes: 28 additions & 5 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,23 +1037,24 @@ def g():
def test_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import Literal, NewType, Optional, TypeVar, Callable
from a import A, B, C, D, E, F, G, H, I, J
from typing import Literal, NewType, Optional, TypeVar, Callable, cast
from a import A, B, C, D, D2, E, E2, F, G, G2, H, I, J, K, K2
def x(a: A):
pass
def y(b: "B"):
pass
def z(c: Literal["C"]):
pass
DType = TypeVar("DType", bound=D)
EType = TypeVar("EType", bound="E")
DType = TypeVar("D2", bound=D)
EType = TypeVar("E2", bound="E")
FType = TypeVar("F")
GType = NewType("GType", "Optional[G]")
GType = NewType("G2", "Optional[G]")
HType = Optional["H"]
IType = Callable[..., I]
class Test(Generic[J]):
pass
casted = cast("K", "K2")
"""
)
imp = ensure_type(
Expand Down Expand Up @@ -1084,13 +1085,21 @@ class Test(Generic[J]):
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["D2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

assignment = list(scope["E"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["E2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

assignment = list(scope["F"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
Expand All @@ -1102,6 +1111,10 @@ class Test(Generic[J]):
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["G2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

assignment = list(scope["H"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
Expand All @@ -1121,6 +1134,16 @@ class Test(Generic[J]):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)

assignment = list(scope["K"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)

assignment = list(scope["K2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
Expand Down

0 comments on commit 1100951

Please sign in to comment.