Skip to content

Commit

Permalink
fix various Match statement visitation errors (#1161)
Browse files Browse the repository at this point in the history
Fixes #1160.

This PR also

- fixes `whitespace_before_colon` being swallowed during visitation on `MatchCase`s
- adds a new type of roundtrip test that catches issues of this class: the test applies a noop transformer to exercise the visitation API and compares the result with the original source.
- adds a few more cases to the match fixture
  • Loading branch information
zsol authored Jun 12, 2024
1 parent 9f6e276 commit 8b97600
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
12 changes: 6 additions & 6 deletions libcst/_nodes/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2854,17 +2854,16 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "CSTNode":
self, "whitespace_after_case", self.whitespace_after_case, visitor
),
pattern=visit_required(self, "pattern", self.pattern, visitor),
# pyre-fixme[6]: Expected `SimpleWhitespace` for 4th param but got
# `Optional[SimpleWhitespace]`.
whitespace_before_if=visit_optional(
whitespace_before_if=visit_required(
self, "whitespace_before_if", self.whitespace_before_if, visitor
),
# pyre-fixme[6]: Expected `SimpleWhitespace` for 5th param but got
# `Optional[SimpleWhitespace]`.
whitespace_after_if=visit_optional(
whitespace_after_if=visit_required(
self, "whitespace_after_if", self.whitespace_after_if, visitor
),
guard=visit_optional(self, "guard", self.guard, visitor),
whitespace_before_colon=visit_required(
self, "whitespace_before_colon", self.whitespace_before_colon, visitor
),
body=visit_required(self, "body", self.body, visitor),
)

Expand Down Expand Up @@ -3382,6 +3381,7 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "MatchClass":
whitespace_after_kwds=visit_required(
self, "whitespace_after_kwds", self.whitespace_after_kwds, visitor
),
rpar=visit_sequence(self, "rpar", self.rpar, visitor),
)

def _codegen_impl(self, state: CodegenState) -> None:
Expand Down
24 changes: 21 additions & 3 deletions libcst/tests/test_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,43 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from pathlib import Path
from unittest import TestCase

from libcst import parse_module
from libcst import CSTTransformer, parse_module
from libcst._parser.entrypoints import is_native

fixtures: Path = Path(__file__).parent.parent.parent / "native/libcst/tests/fixtures"


class NOOPTransformer(CSTTransformer):
pass


class RoundTripTests(TestCase):
def test_clean_roundtrip(self) -> None:
def _get_fixtures(self) -> list[Path]:
if not is_native():
self.skipTest("pure python parser doesn't work with this")
self.assertTrue(fixtures.exists(), f"{fixtures} should exist")
files = list(fixtures.iterdir())
self.assertGreater(len(files), 0)
for file in files:
return files

def test_clean_roundtrip(self) -> None:
for file in self._get_fixtures():
with self.subTest(file=str(file)):
src = file.read_text(encoding="utf-8")
mod = parse_module(src)
self.maxDiff = None
self.assertEqual(mod.code, src)

def test_transform_roundtrip(self) -> None:
transformer = NOOPTransformer()
self.maxDiff = None
for file in self._get_fixtures():
with self.subTest(file=str(file)):
src = file.read_text(encoding="utf-8")
mod = parse_module(src)
new_mod = mod.visit(transformer)
self.assertEqual(src, new_mod.code)
2 changes: 2 additions & 0 deletions native/libcst/tests/fixtures/malicious_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@
case x,y , * more :pass
case y.z: pass
case 1, 2: pass
case ( Foo ( ) ) : pass
case (lol) if ( True , ) :pass

0 comments on commit 8b97600

Please sign in to comment.