From d0461f9198b5b4793d22575b5bf6b0349720e169 Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Mon, 16 Oct 2023 19:10:21 -0700 Subject: [PATCH] Refactor violation diffing into separate function ghstack-source-id: ea193e1ede6730af3793871cacd06047532199ba Pull Request resolved: https://github.com/Instagram/Fixit/pull/399 --- src/fixit/engine.py | 34 ++++++++++++-------- src/fixit/testing.py | 8 ++--- src/fixit/tests/__init__.py | 1 + src/fixit/tests/engine.py | 62 +++++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 19 deletions(-) create mode 100644 src/fixit/tests/engine.py diff --git a/src/fixit/engine.py b/src/fixit/engine.py index f9858792..6608cef3 100644 --- a/src/fixit/engine.py +++ b/src/fixit/engine.py @@ -28,6 +28,26 @@ LOG = logging.getLogger(__name__) +def diff_violation(path: Path, module: Module, violation: LintViolation) -> str: + """ + Generate string diff representation of a violation. + """ + + orig = module.code + mod = module.deep_replace( # type:ignore # LibCST#906 + violation.node, violation.replacement + ) + assert isinstance(mod, Module) + change = mod.code + + return unified_diff( + orig, + change, + path.name, + n=1, + ) + + class LintRunner: def __init__(self, path: Path, source: FileContent) -> None: self.path = path @@ -87,19 +107,7 @@ def visit_hook(name: str) -> Iterator[None]: count += 1 if violation.replacement: - orig = self.module.code - mod = self.module.deep_replace( # type:ignore # LibCST#906 - violation.node, violation.replacement - ) - assert isinstance(mod, Module) - change = mod.code - - diff = unified_diff( - orig, - change, - self.path.name, - n=1, - ) + diff = diff_violation(self.path, self.module, violation) violation = replace(violation, diff=diff) yield violation diff --git a/src/fixit/testing.py b/src/fixit/testing.py index eb7d55e6..74ff3e3d 100644 --- a/src/fixit/testing.py +++ b/src/fixit/testing.py @@ -10,9 +10,7 @@ from pathlib import Path from typing import Any, Callable, Collection, Dict, List, Mapping, Sequence, Type, Union -from moreorless import unified_diff - -from .engine import LintRunner +from .engine import LintRunner, diff_violation from .ftypes import Config from .rule import Invalid, LintRule, Valid @@ -112,9 +110,7 @@ def _test_method( if len(reports) == 1: # make sure we generated a reasonable diff - expected_diff = unified_diff( - source_code, expected_code, filename=path.name, n=1 - ) + expected_diff = diff_violation(path, runner.module, reports[0]) self.assertEqual(expected_diff, report.diff) diff --git a/src/fixit/tests/__init__.py b/src/fixit/tests/__init__.py index 761dbe0f..4021fdec 100644 --- a/src/fixit/tests/__init__.py +++ b/src/fixit/tests/__init__.py @@ -8,6 +8,7 @@ from fixit.testing import add_lint_rule_tests_to_module from .config import ConfigTest +from .engine import EngineTest from .ftypes import TypesTest from .rule import RuleTest, RunnerTest from .smoke import SmokeTest diff --git a/src/fixit/tests/engine.py b/src/fixit/tests/engine.py new file mode 100644 index 00000000..8e964ed7 --- /dev/null +++ b/src/fixit/tests/engine.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from textwrap import dedent +from typing import Any, cast, Optional, Set +from unittest import TestCase + +from libcst import ( + Call, + ensure_type, + Expr, + parse_module, + SimpleStatementLine, + SimpleString, +) +from libcst.metadata import CodePosition, CodeRange + +from ..engine import diff_violation +from ..ftypes import LintViolation + + +class EngineTest(TestCase): + def test_diff_violation(self): + src = dedent( + """\ + import sys + print("hello world") + """ + ) + path = Path("foo.py") + module = parse_module(src) + node = ensure_type( + ensure_type( + ensure_type(module.body[-1], SimpleStatementLine).body[0], Expr + ).value, + Call, + ).args[0] + repl = node.with_changes(value=SimpleString('"goodnight moon"')) + + violation = LintViolation( + "Fake", + CodeRange(CodePosition(1, 1), CodePosition(2, 2)), + message="some error", + node=node, + replacement=repl, + ) + + expected = dedent( + """\ + --- a/foo.py + +++ b/foo.py + @@ -1,2 +1,2 @@ + import sys + -print("hello world") + +print("goodnight moon") + """ + ) + result = diff_violation(path, module, violation) + self.assertEqual(expected, result)