Skip to content

Commit

Permalink
Extract node_comments to new module, return nodes instead of strings
Browse files Browse the repository at this point in the history
ghstack-source-id: adf9fc5529eb062d6e1bdc3e44424b9ec045b09d
Pull Request resolved: #450
  • Loading branch information
amyreese committed May 1, 2024
1 parent cca9f96 commit 3483c12
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 109 deletions.
113 changes: 113 additions & 0 deletions src/fixit/comments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Generator, Optional, Sequence

from libcst import (
BaseSuite,
Comma,
Comment,
CSTNode,
Decorator,
EmptyLine,
IndentedBlock,
LeftSquareBracket,
Module,
RightSquareBracket,
SimpleStatementSuite,
TrailingWhitespace,
)
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider


def node_comments(
node: CSTNode, metadata: MetadataWrapper
) -> Generator[Comment, None, None]:
"""
Yield all comments associated with the given node.
Includes comments from both leading comments and trailing inline comments.
"""
parent_nodes = metadata.resolve(ParentNodeProvider)
positions = metadata.resolve(PositionProvider)
target_line = positions[node].end.line

def gen(node: CSTNode) -> Generator[Comment, None, None]:
while not isinstance(node, Module):
# trailing_whitespace can either be a property of the node itself, or in
# case of blocks, be part of the block's body element
tw: Optional[TrailingWhitespace] = getattr(
node, "trailing_whitespace", None
)
if tw is None:
body: Optional[BaseSuite] = getattr(node, "body", None)
if isinstance(body, SimpleStatementSuite):
tw = body.trailing_whitespace
elif isinstance(body, IndentedBlock):
tw = body.header

if tw and tw.comment:
yield tw.comment

comma: Optional[Comma] = getattr(node, "comma", None)
if isinstance(comma, Comma):
tw = getattr(comma.whitespace_after, "first_line", None)
if tw and tw.comment:
yield tw.comment

rb: Optional[RightSquareBracket] = getattr(node, "rbracket", None)
if rb is not None:
tw = getattr(rb.whitespace_before, "first_line", None)
if tw and tw.comment:
yield tw.comment

el: Optional[Sequence[EmptyLine]] = None
lb: Optional[LeftSquareBracket] = getattr(node, "lbracket", None)
if lb is not None:
el = getattr(lb.whitespace_after, "empty_lines", None)
if el is not None:
for line in el:
if line.comment:
yield line.comment

el = getattr(node, "lines_after_decorators", None)
if el is not None:
for line in el:
if line.comment:
yield line.comment

ll: Optional[Sequence[EmptyLine]] = getattr(node, "leading_lines", None)
if ll is not None:
for line in ll:
if line.comment:
yield line.comment
if not isinstance(node, Decorator):
# stop looking once we've gone up far enough for leading_lines,
# even if there are no comment lines here at all
break

parent = parent_nodes.get(node)
if parent is None:
break
node = parent

# comments at the start of the file are part of the module header rather than
# part of the first statement's leading_lines, so we need to look there in case
# the reported node is part of the first statement.
if isinstance(node, Module):
for line in node.header:
if line.comment:
yield line.comment
else:
parent = parent_nodes.get(node)
if isinstance(parent, Module) and parent.body and parent.body[0] == node:
for line in parent.header:
if line.comment:
yield line.comment

# wrap this in a pass-through generator so that we can easily filter the results
# to only include comments that are located on or before the line containing
# the original node that we're searching from
yield from (c for c in gen(node) if positions[c].end.line <= target_line)
14 changes: 9 additions & 5 deletions src/fixit/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, path: Path, source: FileContent) -> None:
self.source = source
self.module: Module = parse_module(source)
self.timings: Timings = defaultdict(lambda: 0)
self.wrapper = MetadataWrapper(self.module)

def collect_violations(
self,
Expand All @@ -79,10 +80,14 @@ def visit_hook(name: str) -> Iterator[None]:
self.timings[name] += duration_us

metadata_cache: Mapping[ProviderT, object] = {}
self.wrapper = MetadataWrapper(
self.module, unsafe_skip_copy=True, cache=metadata_cache
)
needs_repo_manager: Set[ProviderT] = set()

for rule in rules:
rule._visit_hook = visit_hook
rule._metadata_wrapper = self.wrapper
for provider in rule.get_inherited_dependencies():
if provider.gen_cache is not None:
# TODO: find a better way to declare this requirement in LibCST
Expand All @@ -95,12 +100,11 @@ def visit_hook(name: str) -> Iterator[None]:
providers=needs_repo_manager,
)
repo_manager.resolve_cache()
metadata_cache = repo_manager.get_cache_for_path(config.path.as_posix())
self.wrapper._cache = repo_manager.get_cache_for_path(
config.path.as_posix()
)

wrapper = MetadataWrapper(
self.module, unsafe_skip_copy=True, cache=metadata_cache
)
wrapper.visit_batched(rules)
self.wrapper.visit_batched(rules)
count = 0
for rule in rules:
for violation in rule._violations:
Expand Down
6 changes: 6 additions & 0 deletions src/fixit/ftypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import platform
import re
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -65,6 +66,11 @@ class Valid:
code: str


class LintIgnoreStyle(Enum):
fixme = "fixme"
ignore = "ignore"


LintIgnoreRegex = re.compile(
r"""
\#\s* # leading hash and whitespace
Expand Down
108 changes: 6 additions & 102 deletions src/fixit/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,9 @@

import functools
from dataclasses import replace
from typing import (
ClassVar,
Collection,
Generator,
List,
Mapping,
Optional,
Sequence,
Set,
Union,
)
from typing import ClassVar, Collection, List, Mapping, Optional, Set, Union

from libcst import (
BaseSuite,
BatchableCSTVisitor,
Comma,
CSTNode,
Decorator,
EmptyLine,
IndentedBlock,
LeftSquareBracket,
Module,
RightSquareBracket,
SimpleStatementSuite,
TrailingWhitespace,
)
from libcst import BatchableCSTVisitor, CSTNode, MetadataWrapper, Module
from libcst.metadata import (
CodePosition,
CodeRange,
Expand All @@ -41,6 +18,7 @@
ProviderT,
)

from .comments import node_comments
from .ftypes import (
Invalid,
LintIgnoreRegex,
Expand Down Expand Up @@ -115,81 +93,7 @@ def __str__(self) -> str:
return f"{self.__class__.__module__}:{self.__class__.__name__}"

_visit_hook: Optional[VisitHook] = None

def node_comments(self, node: CSTNode) -> Generator[str, None, None]:
"""
Yield all comments associated with the given node.
Includes comments from both leading comments and trailing inline comments.
"""
while not isinstance(node, Module):
# trailing_whitespace can either be a property of the node itself, or in
# case of blocks, be part of the block's body element
tw: Optional[TrailingWhitespace] = getattr(
node, "trailing_whitespace", None
)
if tw is None:
body: Optional[BaseSuite] = getattr(node, "body", None)
if isinstance(body, SimpleStatementSuite):
tw = body.trailing_whitespace
elif isinstance(body, IndentedBlock):
tw = body.header

if tw and tw.comment:
yield tw.comment.value

comma: Optional[Comma] = getattr(node, "comma", None)
if isinstance(comma, Comma):
tw = getattr(comma.whitespace_after, "first_line", None)
if tw and tw.comment:
yield tw.comment.value

rb: Optional[RightSquareBracket] = getattr(node, "rbracket", None)
if rb is not None:
tw = getattr(rb.whitespace_before, "first_line", None)
if tw and tw.comment:
yield tw.comment.value

el: Optional[Sequence[EmptyLine]] = None
lb: Optional[LeftSquareBracket] = getattr(node, "lbracket", None)
if lb is not None:
el = getattr(lb.whitespace_after, "empty_lines", None)
if el is not None:
for line in el:
if line.comment:
yield line.comment.value

el = getattr(node, "lines_after_decorators", None)
if el is not None:
for line in el:
if line.comment:
yield line.comment.value

ll: Optional[Sequence[EmptyLine]] = getattr(node, "leading_lines", None)
if ll is not None:
for line in ll:
if line.comment:
yield line.comment.value
if not isinstance(node, Decorator):
# stop looking once we've gone up far enough for leading_lines,
# even if there are no comment lines here at all
break

node = self.get_metadata(ParentNodeProvider, node)

# comments at the start of the file are part of the module header rather than
# part of the first statement's leading_lines, so we need to look there in case
# the reported node is part of the first statement.
if isinstance(node, Module):
for line in node.header:
if line.comment:
yield line.comment.value
else:
parent = self.get_metadata(ParentNodeProvider, node)
if isinstance(parent, Module) and parent.body and parent.body[0] == node:
for line in parent.header:
if line.comment:
yield line.comment.value
_metadata_wrapper: MetadataWrapper = MetadataWrapper(Module([]))

def ignore_lint(self, node: CSTNode) -> bool:
"""
Expand All @@ -199,8 +103,8 @@ def ignore_lint(self, node: CSTNode) -> bool:
current rule by name, or if the directives have no rule names listed.
"""
rule_names = (self.name, self.name.lower())
for comment in self.node_comments(node):
if match := LintIgnoreRegex.search(comment):
for comment in node_comments(node, self._metadata_wrapper):
if match := LintIgnoreRegex.search(comment.value):
_style, names = match.groups()

# directive
Expand Down
1 change: 1 addition & 0 deletions src/fixit/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fixit.ftypes import Config, QualifiedRule

from fixit.testing import add_lint_rule_tests_to_module
from .comments import CommentsTest
from .config import ConfigTest
from .engine import EngineTest
from .ftypes import TypesTest
Expand Down
75 changes: 75 additions & 0 deletions src/fixit/tests/comments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from textwrap import dedent
from unittest import TestCase

import libcst.matchers as m
from libcst import MetadataWrapper, parse_module

from ..comments import node_comments


class CommentsTest(TestCase):
def test_node_comments(self) -> None:
for idx, (code, test_cases) in enumerate(
(
(
"""
# module-level comment
print("hello") # trailing comment
""",
(
(m.Call(func=m.Name("something")), ()),
(m.Call(), ["# module-level comment", "# trailing comment"]),
),
),
(
"""
import sys
# leading comment
print("hello") # trailing comment
""",
((m.Call(), ["# leading comment", "# trailing comment"]),),
),
(
"""
import sys
# leading comment
@alpha # first decorator comment
# between-decorator comment
@beta # second decorator comment
# after-decorator comment
class Foo: # trailing comment
pass
""",
(
(
m.ClassDef(),
[
"# leading comment",
"# after-decorator comment",
"# trailing comment",
],
),
(
m.Decorator(decorator=m.Name("alpha")),
["# leading comment", "# first decorator comment"],
),
),
),
),
start=1,
):
code = dedent(code)
module = parse_module(code)
wrapper = MetadataWrapper(module, unsafe_skip_copy=True)
for idx2, (matcher, expected) in enumerate(test_cases):
with self.subTest(f"node comments {idx}-{chr(ord('a')+idx2)}"):
for node in m.findall(module, matcher):
comments = [c.value for c in node_comments(node, wrapper)]
self.assertEqual(sorted(expected), sorted(comments))
Loading

0 comments on commit 3483c12

Please sign in to comment.