diff --git a/src/fixit/rules/no_string_type_annotation.py b/src/fixit/rules/no_string_type_annotation.py index 2b931be5..721c4c00 100644 --- a/src/fixit/rules/no_string_type_annotation.py +++ b/src/fixit/rules/no_string_type_annotation.py @@ -84,6 +84,14 @@ def foo() -> typing.Optional[typing.Literal["a", "b"]]: return "a" """ ), + Valid( + """ + import typing + + def foo() -> typing.Optional[typing.Literal["class", "function"]]: + return "class" + """ + ), ] INVALID = [ @@ -273,7 +281,12 @@ def visit_Subscript(self, node: cst.Subscript) -> None: metadata=m.MatchMetadataIfTrue( QualifiedNameProvider, lambda qualnames: any( - n.name == "typing_extensions.Literal" for n in qualnames + n.name + in ( + "typing.Literal", + "typing_extensions.Literal", + ) + for n in qualnames ), ) ), @@ -295,4 +308,8 @@ def visit_SimpleString(self, node: cst.SimpleString) -> None: value = node.evaluated_value if isinstance(value, bytes): value = value.decode("utf-8") - self.report(node, replacement=cst.parse_expression(value)) + try: + repl = cst.parse_expression(value) + self.report(node, replacement=repl) + except cst.ParserSyntaxError: + self.report(node)