Skip to content

Commit

Permalink
Support string annotations for type aliases (#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kronuz authored Oct 12, 2020
1 parent 6731aa5 commit 21d37b9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 13 deletions.
2 changes: 1 addition & 1 deletion libcst/_parser/types/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class TestConfig(UnitTest):
@data_provider(
{
"empty": (lambda: PartialParserConfig(),),
"empty": (PartialParserConfig,),
"python_version_a": (lambda: PartialParserConfig(python_version="3.7"),),
"python_version_b": (lambda: PartialParserConfig(python_version="3.7.1"),),
"encoding": (lambda: PartialParserConfig(encoding="latin-1"),),
Expand Down
26 changes: 16 additions & 10 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,10 @@ def __init__(self, provider: "ScopeProvider") -> None:
self.scope: Scope = GlobalScope()
self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = []
self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None]
self.__in_annotation: Set[Union[cst.Call, cst.Annotation]] = set()
self.__in_annotation: Set[
Union[cst.Call, cst.Annotation, cst.Subscript]
] = set()
self.__in_ignored_subscript: Set[cst.Subscript] = set()

@contextmanager
def _new_scope(
Expand Down Expand Up @@ -699,10 +702,8 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:

def visit_Call(self, node: cst.Call) -> Optional[bool]:
self.__top_level_attribute_stack.append(None)
if any(
qn.name == "typing.TypeVar"
for qn in self.scope.get_qualified_names_for(node)
):
qnames = self.scope.get_qualified_names_for(node)
if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames):
node.func.visit(self)
self.__in_annotation.add(node)
for arg in node.args[1:]:
Expand Down Expand Up @@ -731,21 +732,26 @@ def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> Optional[boo
def _handle_string_annotation(
self, node: Union[cst.SimpleString, cst.ConcatenatedString]
) -> None:
if self.__in_annotation:
if self.__in_annotation and not self.__in_ignored_subscript:
value = node.evaluated_value
if value:
mod = cst.parse_module(value)
mod.visit(self)

def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
qnames = self.scope.get_qualified_names_for(node.value)
if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames):
self.__in_annotation.add(node)
if any(
qn.name in ("typing.Literal", "typing_extensions.Literal")
for qn in self.scope.get_qualified_names_for(node.value)
qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames
):
node.value.visit(self)
return False
self.__in_ignored_subscript.add(node)
return True

def leave_Subscript(self, original_node: cst.Subscript) -> None:
self.__in_annotation.discard(original_node)
self.__in_ignored_subscript.discard(original_node)

def visit_Name(self, node: cst.Name) -> Optional[bool]:
# not all Name have ExpressionContext
context = self.provider.get_metadata(ExpressionContextProvider, node, None)
Expand Down
18 changes: 16 additions & 2 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,8 @@ def g():
def test_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import Literal, TypeVar
from a import A, B, C, D, E, F
from typing import Literal, NewType, Optional, TypeVar
from a import A, B, C, D, E, F, G, H
def x(a: A):
pass
def y(b: "B"):
Expand All @@ -1029,6 +1029,8 @@ def z(c: Literal["C"]):
DType = TypeVar("DType", bound=D)
EType = TypeVar("EType", bound="E")
FType = TypeVar("F")
GType = NewType("GType", "Optional[G]")
HType = Optional["H"]
"""
)
imp = ensure_type(
Expand Down Expand Up @@ -1068,6 +1070,18 @@ def z(c: Literal["C"]):
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

assignment = list(scope["G"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)

assignment = list(scope["H"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)

def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
Expand Down
1 change: 1 addition & 0 deletions libcst/tests/test_type_enforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class MyExampleClassWithMetaclass(metaclass=MyExampleMetaclass):
pass


# lint-ignore: NoNamedTupleRule
class NamedTupleSubclass(NamedTuple):
a: str
b: int
Expand Down

0 comments on commit 21d37b9

Please sign in to comment.