Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract node_comments to new module, return nodes instead of strings #450

Open
wants to merge 3 commits into
base: gh/amyreese/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"packaging >= 21",
"tomli >= 2.0; python_version < '3.11'",
"trailrunner >= 1.2",
"typing_extensions >= 4.0",
]

dynamic = ["version"]
Expand Down
113 changes: 113 additions & 0 deletions src/fixit/comments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Generator, Optional, Sequence

from libcst import (
BaseSuite,
Comma,
Comment,
CSTNode,
Decorator,
EmptyLine,
IndentedBlock,
LeftSquareBracket,
Module,
RightSquareBracket,
SimpleStatementSuite,
TrailingWhitespace,
)
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider


def node_comments(
node: CSTNode, metadata: MetadataWrapper
) -> Generator[Comment, None, None]:
"""
Yield all comments associated with the given node.

Includes comments from both leading comments and trailing inline comments.
"""
parent_nodes = metadata.resolve(ParentNodeProvider)
positions = metadata.resolve(PositionProvider)
target_line = positions[node].end.line

def gen(node: CSTNode) -> Generator[Comment, None, None]:
while not isinstance(node, Module):
# trailing_whitespace can either be a property of the node itself, or in
# case of blocks, be part of the block's body element
tw: Optional[TrailingWhitespace] = getattr(
node, "trailing_whitespace", None
)
if tw is None:
body: Optional[BaseSuite] = getattr(node, "body", None)
if isinstance(body, SimpleStatementSuite):
tw = body.trailing_whitespace
elif isinstance(body, IndentedBlock):
tw = body.header

if tw and tw.comment:
yield tw.comment

comma: Optional[Comma] = getattr(node, "comma", None)
if isinstance(comma, Comma):
tw = getattr(comma.whitespace_after, "first_line", None)
if tw and tw.comment:
yield tw.comment

rb: Optional[RightSquareBracket] = getattr(node, "rbracket", None)
if rb is not None:
tw = getattr(rb.whitespace_before, "first_line", None)
if tw and tw.comment:
yield tw.comment

el: Optional[Sequence[EmptyLine]] = None
lb: Optional[LeftSquareBracket] = getattr(node, "lbracket", None)
if lb is not None:
el = getattr(lb.whitespace_after, "empty_lines", None)
if el is not None:
for line in el:
if line.comment:
yield line.comment

el = getattr(node, "lines_after_decorators", None)
if el is not None:
for line in el:
if line.comment:
yield line.comment

ll: Optional[Sequence[EmptyLine]] = getattr(node, "leading_lines", None)
if ll is not None:
for line in ll:
if line.comment:
yield line.comment
if not isinstance(node, Decorator):
# stop looking once we've gone up far enough for leading_lines,
# even if there are no comment lines here at all
break

parent = parent_nodes.get(node)
if parent is None:
break
node = parent

# comments at the start of the file are part of the module header rather than
# part of the first statement's leading_lines, so we need to look there in case
# the reported node is part of the first statement.
if isinstance(node, Module):
for line in node.header:
if line.comment:
yield line.comment
else:
parent = parent_nodes.get(node)
if isinstance(parent, Module) and parent.body and parent.body[0] == node:
for line in parent.header:
if line.comment:
yield line.comment

# wrap this in a pass-through generator so that we can easily filter the results
# to only include comments that are located on or before the line containing
# the original node that we're searching from
yield from (c for c in gen(node) if positions[c].end.line <= target_line)
14 changes: 9 additions & 5 deletions src/fixit/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, path: Path, source: FileContent) -> None:
self.source = source
self.module: Module = parse_module(source)
self.timings: Timings = defaultdict(lambda: 0)
self.wrapper = MetadataWrapper(self.module)
Comment on lines 55 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can always get back to the module with self.wrapper.module


def collect_violations(
self,
Expand All @@ -79,10 +80,14 @@ def visit_hook(name: str) -> Iterator[None]:
self.timings[name] += duration_us

metadata_cache: Mapping[ProviderT, object] = {}
self.wrapper = MetadataWrapper(
self.module, unsafe_skip_copy=True, cache=metadata_cache
)
needs_repo_manager: Set[ProviderT] = set()

for rule in rules:
rule._visit_hook = visit_hook
rule._metadata_wrapper = self.wrapper
for provider in rule.get_inherited_dependencies():
if provider.gen_cache is not None:
# TODO: find a better way to declare this requirement in LibCST
Expand All @@ -95,12 +100,11 @@ def visit_hook(name: str) -> Iterator[None]:
providers=needs_repo_manager,
)
repo_manager.resolve_cache()
metadata_cache = repo_manager.get_cache_for_path(config.path.as_posix())
self.wrapper._cache = repo_manager.get_cache_for_path(
config.path.as_posix()
)

wrapper = MetadataWrapper(
self.module, unsafe_skip_copy=True, cache=metadata_cache
)
wrapper.visit_batched(rules)
self.wrapper.visit_batched(rules)
count = 0
for rule in rules:
for violation in rule._violations:
Expand Down
5 changes: 5 additions & 0 deletions src/fixit/ftypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ class Valid:
code: str


class LintIgnoreStyle(Enum):
fixme = "fixme"
ignore = "ignore"


LintIgnoreRegex = re.compile(
r"""
\#\s* # leading hash and whitespace
Expand Down
108 changes: 6 additions & 102 deletions src/fixit/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,9 @@

import functools
from dataclasses import replace
from typing import (
ClassVar,
Collection,
Generator,
List,
Mapping,
Optional,
Sequence,
Set,
Union,
)
from typing import ClassVar, Collection, List, Mapping, Optional, Set, Union

from libcst import (
BaseSuite,
BatchableCSTVisitor,
Comma,
CSTNode,
Decorator,
EmptyLine,
IndentedBlock,
LeftSquareBracket,
Module,
RightSquareBracket,
SimpleStatementSuite,
TrailingWhitespace,
)
from libcst import BatchableCSTVisitor, CSTNode, MetadataWrapper, Module
from libcst.metadata import (
CodePosition,
CodeRange,
Expand All @@ -41,6 +18,7 @@
ProviderT,
)

from .comments import node_comments
from .ftypes import (
Invalid,
LintIgnoreRegex,
Expand Down Expand Up @@ -115,81 +93,7 @@ def __str__(self) -> str:
return f"{self.__class__.__module__}:{self.__class__.__name__}"

_visit_hook: Optional[VisitHook] = None

def node_comments(self, node: CSTNode) -> Generator[str, None, None]:
"""
Yield all comments associated with the given node.

Includes comments from both leading comments and trailing inline comments.
"""
while not isinstance(node, Module):
# trailing_whitespace can either be a property of the node itself, or in
# case of blocks, be part of the block's body element
tw: Optional[TrailingWhitespace] = getattr(
node, "trailing_whitespace", None
)
if tw is None:
body: Optional[BaseSuite] = getattr(node, "body", None)
if isinstance(body, SimpleStatementSuite):
tw = body.trailing_whitespace
elif isinstance(body, IndentedBlock):
tw = body.header

if tw and tw.comment:
yield tw.comment.value

comma: Optional[Comma] = getattr(node, "comma", None)
if isinstance(comma, Comma):
tw = getattr(comma.whitespace_after, "first_line", None)
if tw and tw.comment:
yield tw.comment.value

rb: Optional[RightSquareBracket] = getattr(node, "rbracket", None)
if rb is not None:
tw = getattr(rb.whitespace_before, "first_line", None)
if tw and tw.comment:
yield tw.comment.value

el: Optional[Sequence[EmptyLine]] = None
lb: Optional[LeftSquareBracket] = getattr(node, "lbracket", None)
if lb is not None:
el = getattr(lb.whitespace_after, "empty_lines", None)
if el is not None:
for line in el:
if line.comment:
yield line.comment.value

el = getattr(node, "lines_after_decorators", None)
if el is not None:
for line in el:
if line.comment:
yield line.comment.value

ll: Optional[Sequence[EmptyLine]] = getattr(node, "leading_lines", None)
if ll is not None:
for line in ll:
if line.comment:
yield line.comment.value
if not isinstance(node, Decorator):
# stop looking once we've gone up far enough for leading_lines,
# even if there are no comment lines here at all
break

node = self.get_metadata(ParentNodeProvider, node)

# comments at the start of the file are part of the module header rather than
# part of the first statement's leading_lines, so we need to look there in case
# the reported node is part of the first statement.
if isinstance(node, Module):
for line in node.header:
if line.comment:
yield line.comment.value
else:
parent = self.get_metadata(ParentNodeProvider, node)
if isinstance(parent, Module) and parent.body and parent.body[0] == node:
for line in parent.header:
if line.comment:
yield line.comment.value
_metadata_wrapper: MetadataWrapper = MetadataWrapper(Module([]))

def ignore_lint(self, node: CSTNode) -> bool:
"""
Expand All @@ -199,8 +103,8 @@ def ignore_lint(self, node: CSTNode) -> bool:
current rule by name, or if the directives have no rule names listed.
"""
rule_names = (self.name, self.name.lower())
for comment in self.node_comments(node):
if match := LintIgnoreRegex.search(comment):
for comment in node_comments(node, self._metadata_wrapper):
if match := LintIgnoreRegex.search(comment.value):
_style, names = match.groups()

# directive
Expand Down
1 change: 1 addition & 0 deletions src/fixit/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fixit.ftypes import Config, QualifiedRule

from fixit.testing import add_lint_rule_tests_to_module
from .comments import CommentsTest
from .config import ConfigTest
from .engine import EngineTest
from .ftypes import TypesTest
Expand Down
75 changes: 75 additions & 0 deletions src/fixit/tests/comments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from textwrap import dedent
from unittest import TestCase

import libcst.matchers as m
from libcst import MetadataWrapper, parse_module

from ..comments import node_comments


class CommentsTest(TestCase):
def test_node_comments(self) -> None:
for idx, (code, test_cases) in enumerate(
(
(
"""
# module-level comment
print("hello") # trailing comment
""",
(
(m.Call(func=m.Name("something")), ()),
(m.Call(), ["# module-level comment", "# trailing comment"]),
),
),
(
"""
import sys

# leading comment
print("hello") # trailing comment
""",
((m.Call(), ["# leading comment", "# trailing comment"]),),
),
(
"""
import sys

# leading comment
@alpha # first decorator comment
# between-decorator comment
@beta # second decorator comment
# after-decorator comment
class Foo: # trailing comment
pass
""",
(
(
m.ClassDef(),
[
"# leading comment",
"# after-decorator comment",
"# trailing comment",
],
),
(
m.Decorator(decorator=m.Name("alpha")),
["# leading comment", "# first decorator comment"],
),
),
),
),
start=1,
):
code = dedent(code)
module = parse_module(code)
wrapper = MetadataWrapper(module, unsafe_skip_copy=True)
for idx2, (matcher, expected) in enumerate(test_cases):
with self.subTest(f"node comments {idx}-{chr(ord('a')+idx2)}"):
for node in m.findall(module, matcher):
comments = [c.value for c in node_comments(node, wrapper)]
self.assertEqual(sorted(expected), sorted(comments))
Loading
Loading