Skip to content

Commit

Permalink
Approximate MatchTypes with lub of case bodies, if non-recursive (#19761
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dwijnand authored Mar 4, 2024
2 parents 469c980 + d687dee commit de6a090
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 9 deletions.
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2857,6 +2857,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
tp
case tp: HKTypeLambda =>
tp
case tp: ParamRef =>
val st = tp.superTypeNormalized
if st.exists then
disjointnessBoundary(st)
else
// workaround for when ParamRef#underlying returns NoType
defn.AnyType
case tp: TypeProxy =>
disjointnessBoundary(tp.superTypeNormalized)
case tp: WildcardType =>
Expand Down
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2375,7 +2375,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
report.error(MatchTypeScrutineeCannotBeHigherKinded(sel1Tpe), sel1.srcPos)
val pt1 = if (bound1.isEmpty) pt else bound1.tpe
val cases1 = tree.cases.mapconserve(typedTypeCase(_, sel1Tpe, pt1))
assignType(cpy.MatchTypeTree(tree)(bound1, sel1, cases1), bound1, sel1, cases1)
val bound2 = if tree.bound.isEmpty then
val lub = cases1.foldLeft(defn.NothingType: Type): (acc, case1) =>
if !acc.exists then NoType
else if case1.body.tpe.isProvisional then NoType
else acc | case1.body.tpe
if lub.exists then TypeTree(lub, inferred = true)
else bound1
else bound1
assignType(cpy.MatchTypeTree(tree)(bound2, sel1, cases1), bound2, sel1, cases1)
}

def typedByNameTypeTree(tree: untpd.ByNameTypeTree)(using Context): ByNameTypeTree = tree.result match
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/13633.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ object Sums extends App:

type Reverse[A] = ReverseLoop[A, EmptyTuple]

type PlusTri[A, B, C] = (A, B, C) match
type PlusTri[A, B, C] <: Tuple = (A, B, C) match
case (false, false, false) => (false, false)
case (true, false, false) | (false, true, false) | (false, false, true) => (false, true)
case (true, true, false) | (true, false, true) | (false, true, true) => (true, false)
Expand Down
7 changes: 7 additions & 0 deletions tests/pos/Tuple.Drop.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import compiletime.ops.int.*

type Drop[T <: Tuple, N <: Int] <: Tuple = N match
case 0 => T
case S[n1] => T match
case EmptyTuple => EmptyTuple
case x *: xs => Drop[xs, n1]
7 changes: 7 additions & 0 deletions tests/pos/Tuple.Elem.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import compiletime.ops.int.*

type Elem[T <: Tuple, I <: Int] = T match
case h *: tail =>
I match
case 0 => h
case S[j] => Elem[tail, j]
11 changes: 11 additions & 0 deletions tests/pos/i19710.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import scala.util.NotGiven

type HasName1 = [n] =>> [x] =>> x match {
case n => true
case _ => false
}
@main def Test = {
summon[HasName1["foo"]["foo"] =:= true]
summon[NotGiven[HasName1["foo"]["bar"] =:= true]]
summon[Tuple.Filter[(1, "foo", 2, "bar"), HasName1["foo"]] =:= Tuple1["foo"]] // error
}
30 changes: 23 additions & 7 deletions tests/run-macros/type-show/Test_2.scala
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@

object Test {
import TypeToolbox.*

def assertEql[A](obt: A, exp: A): Unit =
assert(obt == exp, s"\nexpected: $exp\nobtained: $obt")

def main(args: Array[String]): Unit = {
val x = 5
assert(show[x.type] == "x.type")
assert(show[Nil.type] == "scala.Nil.type")
assert(show[Int] == "scala.Int")
assert(show[Int => Int] == "scala.Function1[scala.Int, scala.Int]")
assert(show[(Int, String)] == "scala.Tuple2[scala.Int, scala.Predef.String]")
assert(show[[X] =>> X match { case Int => Int }] ==
assertEql(show[x.type], "x.type")
assertEql(show[Nil.type], "scala.Nil.type")
assertEql(show[Int], "scala.Int")
assertEql(show[Int => Int], "scala.Function1[scala.Int, scala.Int]")
assertEql(show[(Int, String)], "scala.Tuple2[scala.Int, scala.Predef.String]")
assertEql(show[[X] =>> X match { case Int => Int }],
"""[X >: scala.Nothing <: scala.Any] =>> X match {
| case scala.Int => scala.Int
|}""".stripMargin)
assert(showStructure[[X] =>> X match { case Int => Int }] == """TypeLambda(List(X), List(TypeBounds(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Nothing"), TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Any"))), MatchType(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Any"), ParamRef(binder, 0), List(MatchCase(TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int"), TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int")))))""")
assertEql(showStructure[[X] =>> X match { case Int => Int }],
"""TypeLambda("""+
"""List(X), """+
"""List(TypeBounds("""+
"""TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Nothing"), """+
"""TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Any"))), """+
"""MatchType("""+
"""TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int"), """+ // match type bound
"""ParamRef(binder, 0), """+
"""List("""+
"""MatchCase("""+
"""TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int"), """+
"""TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int")))))""")

// TODO: more complex types:
// - implicit function types
Expand Down

0 comments on commit de6a090

Please sign in to comment.