From 8a187b2459586934fa5ef589e8164ffe696d257c Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Tue, 30 Apr 2024 18:57:23 -0700 Subject: [PATCH] Add helpers to find and add suppressions comments to nodes ghstack-source-id: 92333ed112203d05ca54ca1dd29126ef7fbe63de Pull Request resolved: https://github.com/Instagram/Fixit/pull/451 --- src/fixit/comments.py | 109 ++++++++++++++++++- src/fixit/ftypes.py | 32 +++++- src/fixit/tests/comments.py | 202 +++++++++++++++++++++++++++++++++++- src/fixit/tests/ftypes.py | 52 ++++++++++ 4 files changed, 392 insertions(+), 3 deletions(-) diff --git a/src/fixit/comments.py b/src/fixit/comments.py index 3e261fd8..5741cb72 100644 --- a/src/fixit/comments.py +++ b/src/fixit/comments.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Generator, Optional, Sequence +from typing import Generator, List, Optional, Sequence from libcst import ( BaseSuite, @@ -12,15 +12,21 @@ CSTNode, Decorator, EmptyLine, + ensure_type, IndentedBlock, LeftSquareBracket, + matchers as m, Module, + ParenthesizedWhitespace, RightSquareBracket, SimpleStatementSuite, + SimpleWhitespace, TrailingWhitespace, ) from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider +from .ftypes import LintIgnore, LintIgnoreStyle + def node_comments( node: CSTNode, metadata: MetadataWrapper @@ -111,3 +117,104 @@ def gen(node: CSTNode) -> Generator[Comment, None, None]: # to only include comments that are located on or before the line containing # the original node that we're searching from yield from (c for c in gen(node) if positions[c].end.line <= target_line) + + +def node_nearest_comment(node: CSTNode, metadata: MetadataWrapper) -> CSTNode: + """ + Return the nearest tree node where a suppression comment could be added. + """ + parent_nodes = metadata.resolve(ParentNodeProvider) + positions = metadata.resolve(PositionProvider) + node_line = positions[node].start.line + + while not isinstance(node, Module): + if hasattr(node, "comment"): + return node + + if hasattr(node, "trailing_whitespace"): + tw = ensure_type(node.trailing_whitespace, TrailingWhitespace) + if tw and positions[tw].start.line == node_line: + if tw.comment: + return tw.comment + else: + return tw + + if hasattr(node, "comma"): + if m.matches( + node.comma, + m.Comma( + whitespace_after=m.ParenthesizedWhitespace( + first_line=m.TrailingWhitespace() + ) + ), + ): + return ensure_type( + node.comma.whitespace_after.first_line, TrailingWhitespace + ) + + if hasattr(node, "rbracket"): + tw = ensure_type( + ensure_type( + node.rbracket.whitespace_before, + ParenthesizedWhitespace, + ).first_line, + TrailingWhitespace, + ) + if positions[tw].start.line == node_line: + return tw + + if hasattr(node, "leading_lines"): + return node + + parent = parent_nodes.get(node) + if parent is None: + break + node = parent + + raise RuntimeError("could not find nearest comment node") + + +def add_suppression_comment( + module: Module, + node: CSTNode, + metadata: MetadataWrapper, + name: str, + style: LintIgnoreStyle = LintIgnoreStyle.fixme, +) -> Module: + """ + Return a modified tree that includes a suppression comment for the given rule. + """ + # reuse an existing suppression directive if available rather than making a new one + for comment in node_comments(node, metadata): + lint_ignore = LintIgnore.parse(comment.value) + if lint_ignore and lint_ignore.style == style: + if name in lint_ignore.names: + return module # already suppressed + lint_ignore.names.add(name) + return module.with_deep_changes(comment, value=str(lint_ignore)) + + # no existing directives, find the "nearest" location and add a comment there + target = node_nearest_comment(node, metadata) + lint_ignore = LintIgnore(style, {name}) + + if isinstance(target, Comment): + lint_ignore.prefix = target.value.strip() + return module.with_deep_changes(target, value=str(lint_ignore)) + + if isinstance(target, TrailingWhitespace): + if target.comment: + lint_ignore.prefix = target.comment.value.strip() + return module.with_deep_changes(target.comment, value=str(lint_ignore)) + else: + return module.with_deep_changes( + target, + comment=Comment(str(lint_ignore)), + whitespace=SimpleWhitespace(" "), + ) + + if hasattr(target, "leading_lines"): + ll: List[EmptyLine] = list(target.leading_lines or ()) + ll.append(EmptyLine(comment=Comment(str(lint_ignore)))) + return module.with_deep_changes(target, leading_lines=ll) + + raise RuntimeError("failed to add suppression comment") diff --git a/src/fixit/ftypes.py b/src/fixit/ftypes.py index fd8aa9a0..a6f7f7b7 100644 --- a/src/fixit/ftypes.py +++ b/src/fixit/ftypes.py @@ -19,6 +19,7 @@ List, Optional, Sequence, + Set, Tuple, TypedDict, TypeVar, @@ -29,6 +30,7 @@ from libcst._add_slots import add_slots from libcst.metadata import CodePosition as CodePosition, CodeRange as CodeRange from packaging.version import Version +from typing_extensions import Self __all__ = ("Version",) @@ -74,7 +76,7 @@ class LintIgnoreStyle(Enum): LintIgnoreRegex = re.compile( r""" \#\s* # leading hash and whitespace - (lint-(?:ignore|fixme)) # directive + (?:lint-(ignore|fixme)) # directive (?: (?::\s*|\s+) # separator ( @@ -87,6 +89,34 @@ class LintIgnoreStyle(Enum): ) +@dataclass +class LintIgnore: + style: LintIgnoreStyle + names: Set[str] = field(default_factory=set) + prefix: str = "" + postfix: str = "" + + @classmethod + def parse(cls, value: str) -> Optional[Self]: + value = value.strip() + if match := LintIgnoreRegex.search(value): + style, raw_names = match.groups() + names = {n.strip() for n in raw_names.split(",")} if raw_names else set() + start, end = match.span() + prefix = value[:start].strip() + postfix = value[end:] + return cls(LintIgnoreStyle(style), names, prefix, postfix) + + return None + + def __str__(self) -> str: + if self.names: + directive = f"# lint-{self.style.value}: {', '.join(sorted(self.names))}" + else: + directive = f"# lint-{self.style.value}" + return f"{self.prefix} {directive}{self.postfix}".strip() + + QualifiedRuleRegex = re.compile( r""" ^ diff --git a/src/fixit/tests/comments.py b/src/fixit/tests/comments.py index 627c1119..0c9f9f39 100644 --- a/src/fixit/tests/comments.py +++ b/src/fixit/tests/comments.py @@ -4,12 +4,14 @@ # LICENSE file in the root directory of this source tree. from textwrap import dedent +from typing import Sequence, Tuple from unittest import TestCase import libcst.matchers as m from libcst import MetadataWrapper, parse_module -from ..comments import node_comments +from ..comments import add_suppression_comment, node_comments, node_nearest_comment +from ..ftypes import LintIgnoreStyle class CommentsTest(TestCase): @@ -73,3 +75,201 @@ class Foo: # trailing comment for node in m.findall(module, matcher): comments = [c.value for c in node_comments(node, wrapper)] self.assertEqual(sorted(expected), sorted(comments)) + break + else: + assert expected == (), f"no node matched by {matcher}" + + def test_node_nearest_comment(self) -> None: + test_cases: Sequence[Tuple[str, m.BaseMatcherNode, m.BaseMatcherNode]] = ( + ( + """ + print("hello") + """, + m.Call(func=m.Name("print")), + m.TrailingWhitespace(), + ), + ( + """ + print("hello") # here + """, + m.Call(func=m.Name("print")), + m.Comment("# here"), + ), + ( + """ + import sys + + # here + def foo(): + pass + """, + m.FunctionDef(name=m.Name("foo")), + m.FunctionDef(name=m.Name("foo")), + ), + ( + """ + def foo(): + pass # here + """, + m.Pass(), + m.Comment("# here"), + ), + ( + """ + items = [ + foo, # here + bar, + ] + """, + m.Element(value=m.Name("foo")), + m.TrailingWhitespace(comment=m.Comment("# here")), + ), + ( + """ + items = [ + foo, + bar, # here + ] + """, + m.Element(value=m.Name("bar")), + m.TrailingWhitespace(comment=m.Comment("# here")), + ), + ( + """ + import sys + # here + items = [ + foo, + bar, + ] + """, + m.List(), + m.SimpleStatementLine( + leading_lines=[m.EmptyLine(comment=m.Comment("# here"))] + ), + ), + ) + for idx, (code, target, expected) in enumerate(test_cases, start=1): + with self.subTest(f"nearest node {idx}"): + code = dedent(code) + module = parse_module(code) + wrapper = MetadataWrapper(module, unsafe_skip_copy=True) + + for target_node in m.findall(module, target): # noqa: B007 + break + else: + self.fail(f"no target node matched by {target}") + + comment = node_nearest_comment(target_node, wrapper) + self.assertTrue( + m.matches(comment, expected), + f"nearest comment did not match expected node\n----{code}----\ntarget: {target_node}\n----\nfound: {comment}", + ) + + def test_add_suppression_comment(self) -> None: + test_cases: Sequence[ + Tuple[str, m.BaseMatcherNode, str, LintIgnoreStyle, str] + ] = ( + ( + """ + print("hello") + """, + m.Call(func=m.Name("print")), + "NoPrint", + LintIgnoreStyle.fixme, + """ + print("hello") # lint-fixme: NoPrint + """, + ), + ( + """ + print("hello") # noqa + """, + m.Call(func=m.Name("print")), + "NoPrint", + LintIgnoreStyle.ignore, + """ + print("hello") # noqa # lint-ignore: NoPrint + """, + ), + ( + """ + print("hello") # noqa # lint-fixme: SomethingElse [whatever] + """, + m.Call(func=m.Name("print")), + "NoPrint", + LintIgnoreStyle.fixme, + """ + print("hello") # noqa # lint-fixme: NoPrint, SomethingElse [whatever] + """, + ), + ( + """ + items = [ + foo, + bar, + ] + """, + m.Element(value=m.Name("foo")), + "NoFoo", + LintIgnoreStyle.fixme, + """ + items = [ + foo, # lint-fixme: NoFoo + bar, + ] + """, + ), + ( + """ + items = [ + foo, + bar, + ] + """, + m.Element(value=m.Name("bar")), + "NoFoo", + LintIgnoreStyle.fixme, + """ + items = [ + foo, + bar, # lint-fixme: NoFoo + ] + """, + ), + ( + """ + items = [ + foo, + bar, + ] + """, + m.List(), + "SomethingWrong", + LintIgnoreStyle.fixme, + """ + # lint-fixme: SomethingWrong + items = [ + foo, + bar, + ] + """, + ), + ) + for idx, (code, matcher, name, style, expected) in enumerate( + test_cases, start=1 + ): + with self.subTest(f"add suppression {idx}"): + expected = dedent(expected) + code = dedent(code) + module = parse_module(code) + wrapper = MetadataWrapper(module, unsafe_skip_copy=True) + + for node in m.findall(module, matcher): # noqa: B007 + break + else: + self.fail(f"no node matched by {matcher}") + + new_module = add_suppression_comment(module, node, wrapper, name, style) + result = new_module.code + self.assertEqual(expected, result) diff --git a/src/fixit/tests/ftypes.py b/src/fixit/tests/ftypes.py index 292aac01..0a4579bc 100644 --- a/src/fixit/tests/ftypes.py +++ b/src/fixit/tests/ftypes.py @@ -7,6 +7,7 @@ from unittest import TestCase from .. import ftypes +from ..ftypes import LintIgnore, LintIgnoreStyle class TypesTest(TestCase): @@ -42,6 +43,57 @@ def test_ignore_comment_regex(self) -> None: "value unexpectedly matches lint-ignore regex", ) + def test_lint_ignore_parse(self) -> None: + for value, expected in ( + ("# lint-ignore", LintIgnore(LintIgnoreStyle.ignore)), + ("# lint-fixme: foo", LintIgnore(LintIgnoreStyle.fixme, {"foo"})), + ( + "# lint-ignore: foo, bar", + LintIgnore(LintIgnoreStyle.ignore, {"foo", "bar"}), + ), + ( + "# lint-ignore: foo, bar, foo, bar, baz", + LintIgnore(LintIgnoreStyle.ignore, {"foo", "bar", "baz"}), + ), + ( + "# type: ignore # lint-fixme: foo, bar # noqa", + LintIgnore( + LintIgnoreStyle.fixme, + {"foo", "bar"}, + prefix="# type: ignore", + postfix=" # noqa", + ), + ), + ): + with self.subTest(value): + result = LintIgnore.parse(value) + self.assertEqual(expected, result) + + def test_lint_ignore_roundtrip(self) -> None: + """ensure that well-formed/sorted ignores parse and stringify back exactly""" + for idx, (value, expected) in enumerate( + ( + ("# lint-ignore", LintIgnore(LintIgnoreStyle.ignore)), + ("# lint-fixme", LintIgnore(LintIgnoreStyle.fixme)), + ("# lint-ignore: foo", LintIgnore(LintIgnoreStyle.ignore, {"foo"})), + ("# lint-fixme: foo", LintIgnore(LintIgnoreStyle.fixme, {"foo"})), + ( + "# type: ignore # lint-fixme: bar, foo # noqa", + LintIgnore( + LintIgnoreStyle.fixme, + {"bar", "foo"}, + "# type: ignore", + " # noqa", + ), + ), + ), + start=1, + ): + with self.subTest(f"lint ignore {idx}"): + ignore = LintIgnore.parse(value) + self.assertEqual(expected, ignore) + self.assertEqual(value, str(ignore)) + def test_qualified_rule(self) -> None: valid: Set[ftypes.QualifiedRule] = set()