Skip to content

Commit

Permalink
Add support for default arguments in product mirrors
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneFlesselle committed Feb 1, 2024
1 parent 95266f2 commit 3e1f445
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 13 deletions.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ class Definitions {
@tu lazy val MirrorClass: ClassSymbol = requiredClass("scala.deriving.Mirror")
@tu lazy val Mirror_ProductClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Product")
@tu lazy val Mirror_Product_fromProduct: Symbol = Mirror_ProductClass.requiredMethod(nme.fromProduct)
@tu lazy val Mirror_Product_defaultArgument: Symbol = Mirror_ProductClass.requiredMethod(nme.defaultArgument)
@tu lazy val Mirror_SumClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Sum")
@tu lazy val Mirror_SingletonClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Singleton")
@tu lazy val Mirror_SingletonProxyClass: ClassSymbol = requiredClass("scala.deriving.Mirror.SingletonProxy")
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ object StdNames {
val LiteralAnnotArg: N = "LiteralAnnotArg"
val Matchable: N = "Matchable"
val MatchCase: N = "MatchCase"
val MirroredElemHasDefaults: N = "MirroredElemHasDefaults"
val MirroredElemTypes: N = "MirroredElemTypes"
val MirroredElemLabels: N = "MirroredElemLabels"
val MirroredLabel: N = "MirroredLabel"
Expand Down Expand Up @@ -452,6 +453,7 @@ object StdNames {
val create: N = "create"
val currentMirror: N = "currentMirror"
val curried: N = "curried"
val defaultArgument: N = "defaultArgument"
val definitions: N = "definitions"
val delayedInit: N = "delayedInit"
val delayedInitArg: N = "delayedInit$body"
Expand Down
35 changes: 34 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Decorators.*
import NameOps.*
import Annotations.Annotation
import typer.ProtoTypes.constrained
import ast.untpd
import ast.{tpd, untpd}

import util.Property
import util.Spans.Span
Expand Down Expand Up @@ -547,6 +547,30 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
New(classRefApplied, elems)
end fromProductBody

def defaultArgumentBody(caseClass: Symbol, index: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
val companionTree: Tree =
val companion: Symbol = caseClass.companionModule
val prefix: Type = optInfo.fold(NoPrefix)(_.pre)
ref(TermRef(prefix, companion.asTerm))

def defaultArgumentGetter(idx: Int): Tree =
val getterName = NameKinds.DefaultGetterName(nme.CONSTRUCTOR, idx)
val getterDenot = companionTree.tpe.member(getterName)
companionTree.select(TermRef(companionTree.tpe, getterName, getterDenot))

val withDefaultCases = for
(acc, idx) <- caseClass.caseAccessors.zipWithIndex if acc.is(HasDefault)
body = Typed(defaultArgumentGetter(idx), TypeTree(defn.AnyType)) // so match tree does try to find union of case types
yield CaseDef(Literal(Constant(idx)), EmptyTree, body)

val withoutDefaultCase =
val stringIndex = Apply(Select(index, nme.toString_), Nil)
val nsee = tpd.resolveConstructor(defn.NoSuchElementExceptionType, List(stringIndex))
CaseDef(Underscore(defn.IntType), EmptyTree, Throw(nsee))

Match(index, withDefaultCases :+ withoutDefaultCase)
end defaultArgumentBody

/** For an enum T:
*
* def ordinal(x: MirroredMonoType) = x.ordinal
Expand Down Expand Up @@ -616,6 +640,12 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
synthesizeDef(meth, vrefss => body(cls, vrefss.head.head))
}
}
def overrideMethod(name: TermName, info: Type, cls: Symbol, body: (Symbol, Tree) => Context ?=> Tree, isExperimental: Boolean = false): Unit = {
val meth = newSymbol(clazz, name, Synthetic | Method | Override, info, coord = clazz.coord)
if isExperimental then meth.addAnnotation(defn.ExperimentalAnnot)
meth.enteredAfter(thisPhase)
newBody = newBody :+ synthesizeDef(meth, vrefss => body(cls, vrefss.head.head))
}
val linked = clazz.linkedClass
lazy val monoType = {
val existing = clazz.info.member(tpnme.MirroredMonoType).symbol
Expand All @@ -633,6 +663,9 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
addParent(defn.Mirror_ProductClass.typeRef)
addMethod(nme.fromProduct, MethodType(defn.ProductClass.typeRef :: Nil, monoType.typeRef), cls,
fromProductBody(_, _, optInfo).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
if cls.primaryConstructor.hasDefaultParams then
overrideMethod(nme.defaultArgument, MethodType(defn.IntType :: Nil, defn.AnyType), cls,
defaultArgumentBody(_, _, optInfo), isExperimental = true)
}
def makeSumMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfSum]) = {
addParent(defn.Mirror_SumClass.typeRef)
Expand Down
29 changes: 17 additions & 12 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -409,25 +409,30 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):

def makeProductMirror(pre: Type, cls: Symbol, tps: Option[List[Type]]): TreeWithErrors =
val accessors = cls.caseAccessors
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
val typeElems = tps.getOrElse(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
val nestedPairs = TypeOps.nestedPairs(typeElems)
val (monoType, elemsType) = mirroredType match
val Seq(elemLabels, elemHasDefaults, elemTypes1) =
Seq(
accessors.map(acc => ConstantType(Constant(acc.name.toString))),
accessors.map(acc => ConstantType(Constant(acc.is(HasDefault)))),
tps.getOrElse(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
).map(TypeOps.nestedPairs)
val (monoType, elemTypes) = mirroredType match
case mirroredType: HKTypeLambda =>
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = elemTypes1))
case _ =>
(mirroredType, nestedPairs)
val elemsLabels = TypeOps.nestedPairs(elemLabels)
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
(mirroredType, elemTypes1)

checkRefinement(formal, tpnme.MirroredElemTypes, elemTypes, span)
checkRefinement(formal, tpnme.MirroredElemLabels, elemLabels, span)
checkRefinement(formal, tpnme.MirroredElemHasDefaults, elemHasDefaults, span)
val mirrorType = formal.constrained_& {
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemTypes))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemLabels))
.refinedWith(tpnme.MirroredElemHasDefaults, TypeAlias(elemHasDefaults))
}
val mirrorRef =
if cls.useCompanionAsProductMirror then companionPath(mirroredType, span)
else if defn.isTupleClass(cls) then newTupleMirror(typeElems.size) // TODO: cls == defn.PairClass when > 22
else if defn.isTupleClass(cls) then newTupleMirror(accessors.size) // TODO: cls == defn.PairClass when > 22
else anonymousMirror(monoType, MirrorImpl.OfProduct(pre), span)
withNoErrors(mirrorRef.cast(mirrorType).withSpan(span))
end makeProductMirror
Expand Down
11 changes: 11 additions & 0 deletions library/src/scala/deriving/Mirror.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package scala.deriving

import java.util.NoSuchElementException
import scala.annotation.experimental

/** Mirrors allows typelevel access to enums, case classes and objects, and their sealed parents.
*/
sealed trait Mirror {
Expand Down Expand Up @@ -27,6 +30,14 @@ object Mirror {

/** Create a new instance of type `T` with elements taken from product `p`. */
def fromProduct(p: scala.Product): MirroredMonoType

/** Whether each product element has a default value */
@experimental type MirroredElemHasDefaults <: Tuple

/** The default argument of the product argument at given `index` */
@experimental def defaultArgument(index: Int): Any =
throw NoSuchElementException(String.valueOf(index))

}

trait Singleton extends Product {
Expand Down
2 changes: 2 additions & 0 deletions project/MiMaFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object MiMaFilters {
val LibraryForward: Map[String, Seq[ProblemFilter]] = Map(
// Additions that require a new minor version of the library
Build.previousDottyVersion -> Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.compiletime.testing.Error.defaultArgument"),
),

// Additions since last LTS
Expand Down Expand Up @@ -62,6 +63,7 @@ object MiMaFilters {
),
)
val TastyCore: Seq[ProblemFilter] = Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("dotty.tools.tasty.TastyVersion.defaultArgument"),
)
val Interfaces: Seq[ProblemFilter] = Seq(
)
Expand Down
1 change: 1 addition & 0 deletions tests/run-macros/i7987.check
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ scala.deriving.Mirror.Product {
type MirroredLabel >: "Some" <: "Some"
type MirroredElemTypes >: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] <: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple]
type MirroredElemLabels >: scala.*:["value", scala.Tuple$package.EmptyTuple] <: scala.*:["value", scala.Tuple$package.EmptyTuple]
type MirroredElemHasDefaults >: scala.*:[false, scala.Tuple$package.EmptyTuple] <: scala.*:[false, scala.Tuple$package.EmptyTuple]
}
25 changes: 25 additions & 0 deletions tests/run-macros/mirror-defaultArgument/MirrorOps.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import scala.deriving._
import scala.annotation.experimental
import scala.quoted._

object MirrorOps:

inline def overridesDefaultArgument[T]: Boolean = ${ overridesDefaultArgumentImpl[T] }

def overridesDefaultArgumentImpl[T](using Quotes, Type[T]): Expr[Boolean] =
import quotes.reflect.*
val cls = TypeRepr.of[T].classSymbol.get
val companion = cls.companionModule.moduleClass
val methods = companion.declaredMethods

val experAnnotType = Symbol.requiredClass("scala.annotation.experimental").typeRef

Expr {
methods.exists { m =>
m.name == "defaultArgument" &&
m.flags.is(Flags.Synthetic) &&
m.annotations.exists(_.tpe <:< experAnnotType)
}
}

end MirrorOps
13 changes: 13 additions & 0 deletions tests/run-macros/mirror-defaultArgument/test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import scala.deriving._
import scala.annotation.experimental
import scala.quoted._

import MirrorOps.*

object Test extends App:

case class WithDefault(x: Int, y: Int = 1)
assert(overridesDefaultArgument[WithDefault])

case class WithoutDefault(x: Int)
assert(!overridesDefaultArgument[WithoutDefault])
4 changes: 4 additions & 0 deletions tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ val experimentalDefinitionInLibrary = Set(
"scala.Tuple$.Reverse", // can be stabilized in 3.5
"scala.Tuple$.ReverseOnto", // can be stabilized in 3.5
"scala.runtime.Tuples$.reverse", // can be stabilized in 3.5

// New APIs: Mirror support for default arguments
"scala.deriving.Mirror$.Product.MirroredElemHasDefaults",
"scala.deriving.Mirror$.Product.defaultArgument",
)


Expand Down
53 changes: 53 additions & 0 deletions tests/run/mirror-defaultArgument.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import scala.deriving._
import scala.annotation.experimental

object Test extends App:

case class WithDefault(x: Int, y: Int = 1)
val m = summon[Mirror.Of[WithDefault]]
assert(m.defaultArgument(1) == 1)
try
m.defaultArgument(0)
throw IllegalStateException("There should be no default argument")
catch
case ex: NoSuchElementException => assert(ex.getMessage == "0") // Ok


case class WithCompanion(s: String = "hello")
case object WithCompanion // => mirrors must be anonymous

val m2 = summon[Mirror.Of[WithCompanion]]
assert(m2 ne WithCompanion)
assert(m2.defaultArgument(0) == "hello")


class Outer(val i: Int) {

case class Inner(x: Int, y: Int = i + 1)
case object Inner

val m3 = summon[Mirror.Of[Inner]]
assert(m3.defaultArgument(1) == i + 1)

def localTest(d: Double): Unit = {
case class Local(x: Int = i, y: Double = d, z: Double = i + d)
case object Local

val m4 = summon[Mirror.Of[Local]]
assert(m4.defaultArgument(0) == i)
assert(m4.defaultArgument(1) == d)
assert(m4.defaultArgument(2) == i + d)
}

}

val outer = Outer(3)
val m5 = summon[Mirror.Of[outer.Inner]]
assert(m5.defaultArgument(1) == 3 + 1)
outer.localTest(9d)


// new defaultArgument match tree should be able to unify different default value types
case class Foo[T](x: Int = 0, y: String = "hi")

end Test
101 changes: 101 additions & 0 deletions tests/run/typeclass-derivation-defaultArgument.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import scala.deriving.Mirror as M
import scala.deriving.*
import scala.Tuple.*
import scala.compiletime.*
import scala.compiletime.ops.int.S

trait Migration[-From, +To]:
def apply(x: From): To

object Migration:

extension [From](x: From)
def migrateTo[To](using m: Migration[From, To]): To = m(x)

given[T]: Migration[T, T] with
override def apply(x: T): T = x

type IndexOf[Elems <: Tuple, X] <: Int = Elems match {
case (X *: elems) => 0
case (_ *: elems) => S[IndexOf[elems, X]]
case EmptyTuple => Nothing
}

inline def migrateElem[F,T, ToIdx <: Int](from: M.ProductOf[F], to: M.ProductOf[T])(x: Product): Any =

type Label = Elem[to.MirroredElemLabels, ToIdx]
type FromIdx = IndexOf[from.MirroredElemLabels, Label]
inline constValueOpt[FromIdx] match

case Some(fromIdx) =>
type FromType = Elem[from.MirroredElemTypes, FromIdx]
type ToType = Elem[to.MirroredElemTypes, ToIdx]
summonFrom { case _: Migration[FromType, ToType] =>
x.productElement(fromIdx).asInstanceOf[FromType].migrateTo[ToType]
}

case None =>
type HasDefault = Elem[to.MirroredElemHasDefaults, ToIdx]
inline erasedValue[HasDefault] match
case _: true => to.defaultArgument(constValue[ToIdx])
case _: false => compiletime.error("An element has no equivalent or default")


inline def migrateElems[F,T, ToIdx <: Int](from: M.ProductOf[F], to: M.ProductOf[T])(x: Product): Seq[Any] =
inline erasedValue[ToIdx] match
case _: Tuple.Size[to.MirroredElemLabels] => Seq()
case _ => migrateElem[F,T,ToIdx](from, to)(x) +: migrateElems[F,T,S[ToIdx]](from, to)(x)

inline def migrateProduct[F,T](from: M.ProductOf[F], to: M.ProductOf[T])
(x: Product): T =
val elems = migrateElems[F, T, 0](from, to)(x)
to.fromProduct(new Product:
def canEqual(that: Any): Boolean = false
def productArity: Int = elems.length
def productElement(n: Int): Any = elems(n)
)

inline def migration[F,T](using from: M.Of[F], to: M.Of[T]): Migration[F,T] = (x: F) =>
inline from match
case fromP: M.ProductOf[F] => inline to match
case toP: M.ProductOf[T] => migrateProduct[F, T](fromP, toP)(x.asInstanceOf[Product])
case _: M.SumOf[T] => compiletime.error("Cannot migrate sums")
case _: M.SumOf[F] => compiletime.error("Cannot migrate sums")

end Migration


import Migration.*
object Test extends App:

case class A1(x: Int)
case class A2(x: Int)
given Migration[A1, A2] = migration
assert(A1(2).migrateTo[A2] == A2(2))

case class B1(x: Int, y: String)
case class B2(y: String, x: Int)
given Migration[B1, B2] = migration
assert(B1(5, "hi").migrateTo[B2] == B2("hi", 5))

case class C1(x: A1)
case class C2(x: A2)
given Migration[C1, C2] = migration
assert(C1(A1(0)).migrateTo[C2] == C2(A2(0)))

case class D1(x: Double)
case class D2(b: Boolean = true, x: Double)
given Migration[D1, D2] = migration
assert(D1(9).migrateTo[D2] == D2(true, 9))

case class E1(x: D1, y: D1)
case class E2(y: D2, s: String = "hi", x: D2)
given Migration[E1, E2] = migration
assert(E1(D1(1), D1(2)).migrateTo[E2] == E2(D2(true, 2), "hi", D2(true, 1)))

// should only use default when needed
case class F1(x: Int)
case class F2(x: Int = 3)
given Migration[F1, F2] = migration
assert(F1(7).migrateTo[F2] == F2(7))

0 comments on commit 3e1f445

Please sign in to comment.