From 2394596d666a962ffa4931a94e7c9f53cfeb9e92 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Sun, 31 Mar 2024 20:02:28 -1000 Subject: [PATCH] Optimize if-matches case (#1193) --- .../org/bykn/bosatsu/SourceConverter.scala | 25 +++++++++++++++++++ .../org/bykn/bosatsu/EvaluationTest.scala | 16 ++++++++++++ 2 files changed, 41 insertions(+) diff --git a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala index 58e80cf7..e6e1a82e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala @@ -348,6 +348,31 @@ final class SourceConverter( else RecursionKind.NonRecursive Expr.Let(boundName, lam, in, recursive = rec, decl) } + case IfElse(NonEmptyList((Matches(a, p), res), tail), elseCase) if p.names.isEmpty => + // if x matches p: res + // else: elseCase + // same as: match x: + // case p: res + // case _: elseCase + // + // we filter on p.names.isEmpty to ensure this is valid, if it isn't valid + // we want to give the most localized version of Matches to the unusued + // let checker to give the best error message. + val restDecl: OptIndent[Declaration] = + NonEmptyList.fromList(tail) match { + case None => elseCase + case Some(nel) => + val restRegion = nel.map(_._2.get.region).reduce[Region](_ + _) + // keep the OptIndent from the first item + nel.head._2.map(_ => IfElse(nel, elseCase)(restRegion)) + } + loop(Match( + RecursionKind.NonRecursive, + a, + OptIndent.same(NonEmptyList( + (p, res), + (Pattern.WildCard, restDecl) :: Nil + )))(decl.region)) case IfElse(ifCases, elseCase) => def loop0( ifs: NonEmptyList[(Expr[Declaration], Expr[Declaration])], diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index 7ce60b41..17517b4b 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -4004,4 +4004,20 @@ external def foo[b](lst: List[a]) -> a () } } + + test("test nested if matches") { + runBosatsuTest( + List(""" +package Foo + +export fn + +def fn(x): + if x matches True: False + elif x matches False: True + else: False + +test = Assertion(fn(False), "") +"""), "Foo", 1) + } }