diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 4bc427ee0687..34ddaacc5378 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -787,6 +787,7 @@ class Definitions { @tu lazy val MirrorClass: ClassSymbol = requiredClass("scala.deriving.Mirror") @tu lazy val Mirror_ProductClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Product") @tu lazy val Mirror_Product_fromProduct: Symbol = Mirror_ProductClass.requiredMethod(nme.fromProduct) + @tu lazy val Mirror_Product_defaultArgument: Symbol = Mirror_ProductClass.requiredMethod(nme.defaultArgument) @tu lazy val Mirror_SumClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Sum") @tu lazy val Mirror_SingletonClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Singleton") @tu lazy val Mirror_SingletonProxyClass: ClassSymbol = requiredClass("scala.deriving.Mirror.SingletonProxy") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 253a45ffd7a8..14ad70d74f5f 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -368,6 +368,7 @@ object StdNames { val LiteralAnnotArg: N = "LiteralAnnotArg" val Matchable: N = "Matchable" val MatchCase: N = "MatchCase" + val MirroredElemHasDefaults: N = "MirroredElemHasDefaults" val MirroredElemTypes: N = "MirroredElemTypes" val MirroredElemLabels: N = "MirroredElemLabels" val MirroredLabel: N = "MirroredLabel" @@ -452,6 +453,7 @@ object StdNames { val create: N = "create" val currentMirror: N = "currentMirror" val curried: N = "curried" + val defaultArgument: N = "defaultArgument" val definitions: N = "definitions" val delayedInit: N = "delayedInit" val delayedInitArg: N = "delayedInit$body" diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index 6d2aedb9b47b..c0bfebdbe98a 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -9,7 +9,7 @@ import Decorators.* import NameOps.* import Annotations.Annotation import typer.ProtoTypes.constrained -import ast.untpd +import ast.{tpd, untpd} import util.Property import util.Spans.Span @@ -547,6 +547,30 @@ class SyntheticMembers(thisPhase: DenotTransformer) { New(classRefApplied, elems) end fromProductBody + def defaultArgumentBody(caseClass: Symbol, index: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree = + val companionTree: Tree = + val companion: Symbol = caseClass.companionModule + val prefix: Type = optInfo.fold(NoPrefix)(_.pre) + ref(TermRef(prefix, companion.asTerm)) + + def defaultArgumentGetter(idx: Int): Tree = + val getterName = NameKinds.DefaultGetterName(nme.CONSTRUCTOR, idx) + val getterDenot = companionTree.tpe.member(getterName) + companionTree.select(TermRef(companionTree.tpe, getterName, getterDenot)) + + val withDefaultCases = for + (acc, idx) <- caseClass.caseAccessors.zipWithIndex if acc.is(HasDefault) + body = Typed(defaultArgumentGetter(idx), TypeTree(defn.AnyType)) // so match tree does try to find union of case types + yield CaseDef(Literal(Constant(idx)), EmptyTree, body) + + val withoutDefaultCase = + val stringIndex = Apply(Select(index, nme.toString_), Nil) + val nsee = tpd.resolveConstructor(defn.NoSuchElementExceptionType, List(stringIndex)) + CaseDef(Underscore(defn.IntType), EmptyTree, Throw(nsee)) + + Match(index, withDefaultCases :+ withoutDefaultCase) + end defaultArgumentBody + /** For an enum T: * * def ordinal(x: MirroredMonoType) = x.ordinal @@ -616,6 +640,12 @@ class SyntheticMembers(thisPhase: DenotTransformer) { synthesizeDef(meth, vrefss => body(cls, vrefss.head.head)) } } + def overrideMethod(name: TermName, info: Type, cls: Symbol, body: (Symbol, Tree) => Context ?=> Tree, isExperimental: Boolean = false): Unit = { + val meth = newSymbol(clazz, name, Synthetic | Method | Override, info, coord = clazz.coord) + if isExperimental then meth.addAnnotation(defn.ExperimentalAnnot) + meth.enteredAfter(thisPhase) + newBody = newBody :+ synthesizeDef(meth, vrefss => body(cls, vrefss.head.head)) + } val linked = clazz.linkedClass lazy val monoType = { val existing = clazz.info.member(tpnme.MirroredMonoType).symbol @@ -633,6 +663,9 @@ class SyntheticMembers(thisPhase: DenotTransformer) { addParent(defn.Mirror_ProductClass.typeRef) addMethod(nme.fromProduct, MethodType(defn.ProductClass.typeRef :: Nil, monoType.typeRef), cls, fromProductBody(_, _, optInfo).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed + if cls.primaryConstructor.hasDefaultParams then + overrideMethod(nme.defaultArgument, MethodType(defn.IntType :: Nil, defn.AnyType), cls, + defaultArgumentBody(_, _, optInfo), isExperimental = true) } def makeSumMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfSum]) = { addParent(defn.Mirror_SumClass.typeRef) diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index c94724faf4d4..977621f959de 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -409,25 +409,30 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): def makeProductMirror(pre: Type, cls: Symbol, tps: Option[List[Type]]): TreeWithErrors = val accessors = cls.caseAccessors - val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString))) - val typeElems = tps.getOrElse(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)) - val nestedPairs = TypeOps.nestedPairs(typeElems) - val (monoType, elemsType) = mirroredType match + val Seq(elemLabels, elemHasDefaults, elemTypes1) = + Seq( + accessors.map(acc => ConstantType(Constant(acc.name.toString))), + accessors.map(acc => ConstantType(Constant(acc.is(HasDefault)))), + tps.getOrElse(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)) + ).map(TypeOps.nestedPairs) + val (monoType, elemTypes) = mirroredType match case mirroredType: HKTypeLambda => - (mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs)) + (mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = elemTypes1)) case _ => - (mirroredType, nestedPairs) - val elemsLabels = TypeOps.nestedPairs(elemLabels) - checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span) - checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span) + (mirroredType, elemTypes1) + + checkRefinement(formal, tpnme.MirroredElemTypes, elemTypes, span) + checkRefinement(formal, tpnme.MirroredElemLabels, elemLabels, span) + checkRefinement(formal, tpnme.MirroredElemHasDefaults, elemHasDefaults, span) val mirrorType = formal.constrained_& { mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name) - .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType)) - .refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels)) + .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemTypes)) + .refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemLabels)) + .refinedWith(tpnme.MirroredElemHasDefaults, TypeAlias(elemHasDefaults)) } val mirrorRef = if cls.useCompanionAsProductMirror then companionPath(mirroredType, span) - else if defn.isTupleClass(cls) then newTupleMirror(typeElems.size) // TODO: cls == defn.PairClass when > 22 + else if defn.isTupleClass(cls) then newTupleMirror(accessors.size) // TODO: cls == defn.PairClass when > 22 else anonymousMirror(monoType, MirrorImpl.OfProduct(pre), span) withNoErrors(mirrorRef.cast(mirrorType).withSpan(span)) end makeProductMirror diff --git a/library/src/scala/deriving/Mirror.scala b/library/src/scala/deriving/Mirror.scala index 57453a516567..b54786ad4208 100644 --- a/library/src/scala/deriving/Mirror.scala +++ b/library/src/scala/deriving/Mirror.scala @@ -1,5 +1,8 @@ package scala.deriving +import java.util.NoSuchElementException +import scala.annotation.experimental + /** Mirrors allows typelevel access to enums, case classes and objects, and their sealed parents. */ sealed trait Mirror { @@ -27,6 +30,14 @@ object Mirror { /** Create a new instance of type `T` with elements taken from product `p`. */ def fromProduct(p: scala.Product): MirroredMonoType + + /** Whether each product element has a default value */ + @experimental type MirroredElemHasDefaults <: Tuple + + /** The default argument of the product argument at given `index` */ + @experimental def defaultArgument(index: Int): Any = + throw NoSuchElementException(String.valueOf(index)) + } trait Singleton extends Product { diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index 696fbeec8a39..f3b3b28e3d3c 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -28,6 +28,7 @@ object MiMaFilters { val LibraryForward: Map[String, Seq[ProblemFilter]] = Map( // Additions that require a new minor version of the library Build.previousDottyVersion -> Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]("scala.compiletime.testing.Error.defaultArgument"), ), // Additions since last LTS @@ -62,6 +63,7 @@ object MiMaFilters { ), ) val TastyCore: Seq[ProblemFilter] = Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]("dotty.tools.tasty.TastyVersion.defaultArgument"), ) val Interfaces: Seq[ProblemFilter] = Seq( ) diff --git a/tests/run-macros/i7987.check b/tests/run-macros/i7987.check index 85a185c1d5c7..80e6e372c833 100644 --- a/tests/run-macros/i7987.check +++ b/tests/run-macros/i7987.check @@ -4,4 +4,5 @@ scala.deriving.Mirror.Product { type MirroredLabel >: "Some" <: "Some" type MirroredElemTypes >: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] <: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] type MirroredElemLabels >: scala.*:["value", scala.Tuple$package.EmptyTuple] <: scala.*:["value", scala.Tuple$package.EmptyTuple] + type MirroredElemHasDefaults >: scala.*:[false, scala.Tuple$package.EmptyTuple] <: scala.*:[false, scala.Tuple$package.EmptyTuple] } diff --git a/tests/run-macros/mirror-defaultArgument/MirrorOps.scala b/tests/run-macros/mirror-defaultArgument/MirrorOps.scala new file mode 100644 index 000000000000..75f00aff8f63 --- /dev/null +++ b/tests/run-macros/mirror-defaultArgument/MirrorOps.scala @@ -0,0 +1,25 @@ +import scala.deriving._ +import scala.annotation.experimental +import scala.quoted._ + +object MirrorOps: + + inline def overridesDefaultArgument[T]: Boolean = ${ overridesDefaultArgumentImpl[T] } + + def overridesDefaultArgumentImpl[T](using Quotes, Type[T]): Expr[Boolean] = + import quotes.reflect.* + val cls = TypeRepr.of[T].classSymbol.get + val companion = cls.companionModule.moduleClass + val methods = companion.declaredMethods + + val experAnnotType = Symbol.requiredClass("scala.annotation.experimental").typeRef + + Expr { + methods.exists { m => + m.name == "defaultArgument" && + m.flags.is(Flags.Synthetic) && + m.annotations.exists(_.tpe <:< experAnnotType) + } + } + +end MirrorOps diff --git a/tests/run-macros/mirror-defaultArgument/test.scala b/tests/run-macros/mirror-defaultArgument/test.scala new file mode 100644 index 000000000000..da2d29b27b20 --- /dev/null +++ b/tests/run-macros/mirror-defaultArgument/test.scala @@ -0,0 +1,13 @@ +import scala.deriving._ +import scala.annotation.experimental +import scala.quoted._ + +import MirrorOps.* + +object Test extends App: + + case class WithDefault(x: Int, y: Int = 1) + assert(overridesDefaultArgument[WithDefault]) + + case class WithoutDefault(x: Int) + assert(!overridesDefaultArgument[WithoutDefault]) diff --git a/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala b/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala index 12ea8eb26c47..e939494108bc 100644 --- a/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala +++ b/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala @@ -96,6 +96,10 @@ val experimentalDefinitionInLibrary = Set( "scala.Tuple$.Reverse", // can be stabilized in 3.5 "scala.Tuple$.ReverseOnto", // can be stabilized in 3.5 "scala.runtime.Tuples$.reverse", // can be stabilized in 3.5 + + // New APIs: Mirror support for default arguments + "scala.deriving.Mirror$.Product.MirroredElemHasDefaults", + "scala.deriving.Mirror$.Product.defaultArgument", ) diff --git a/tests/run/mirror-defaultArgument.scala b/tests/run/mirror-defaultArgument.scala new file mode 100644 index 000000000000..eaff19094128 --- /dev/null +++ b/tests/run/mirror-defaultArgument.scala @@ -0,0 +1,53 @@ +import scala.deriving._ +import scala.annotation.experimental + +object Test extends App: + + case class WithDefault(x: Int, y: Int = 1) + val m = summon[Mirror.Of[WithDefault]] + assert(m.defaultArgument(1) == 1) + try + m.defaultArgument(0) + throw IllegalStateException("There should be no default argument") + catch + case ex: NoSuchElementException => assert(ex.getMessage == "0") // Ok + + + case class WithCompanion(s: String = "hello") + case object WithCompanion // => mirrors must be anonymous + + val m2 = summon[Mirror.Of[WithCompanion]] + assert(m2 ne WithCompanion) + assert(m2.defaultArgument(0) == "hello") + + + class Outer(val i: Int) { + + case class Inner(x: Int, y: Int = i + 1) + case object Inner + + val m3 = summon[Mirror.Of[Inner]] + assert(m3.defaultArgument(1) == i + 1) + + def localTest(d: Double): Unit = { + case class Local(x: Int = i, y: Double = d, z: Double = i + d) + case object Local + + val m4 = summon[Mirror.Of[Local]] + assert(m4.defaultArgument(0) == i) + assert(m4.defaultArgument(1) == d) + assert(m4.defaultArgument(2) == i + d) + } + + } + + val outer = Outer(3) + val m5 = summon[Mirror.Of[outer.Inner]] + assert(m5.defaultArgument(1) == 3 + 1) + outer.localTest(9d) + + + // new defaultArgument match tree should be able to unify different default value types + case class Foo[T](x: Int = 0, y: String = "hi") + +end Test diff --git a/tests/run/typeclass-derivation-defaultArgument.scala b/tests/run/typeclass-derivation-defaultArgument.scala new file mode 100644 index 000000000000..e2648c6cad89 --- /dev/null +++ b/tests/run/typeclass-derivation-defaultArgument.scala @@ -0,0 +1,101 @@ +import scala.deriving.Mirror as M +import scala.deriving.* +import scala.Tuple.* +import scala.compiletime.* +import scala.compiletime.ops.int.S + +trait Migration[-From, +To]: + def apply(x: From): To + +object Migration: + + extension [From](x: From) + def migrateTo[To](using m: Migration[From, To]): To = m(x) + + given[T]: Migration[T, T] with + override def apply(x: T): T = x + + type IndexOf[Elems <: Tuple, X] <: Int = Elems match { + case (X *: elems) => 0 + case (_ *: elems) => S[IndexOf[elems, X]] + case EmptyTuple => Nothing + } + + inline def migrateElem[F,T, ToIdx <: Int](from: M.ProductOf[F], to: M.ProductOf[T])(x: Product): Any = + + type Label = Elem[to.MirroredElemLabels, ToIdx] + type FromIdx = IndexOf[from.MirroredElemLabels, Label] + inline constValueOpt[FromIdx] match + + case Some(fromIdx) => + type FromType = Elem[from.MirroredElemTypes, FromIdx] + type ToType = Elem[to.MirroredElemTypes, ToIdx] + summonFrom { case _: Migration[FromType, ToType] => + x.productElement(fromIdx).asInstanceOf[FromType].migrateTo[ToType] + } + + case None => + type HasDefault = Elem[to.MirroredElemHasDefaults, ToIdx] + inline erasedValue[HasDefault] match + case _: true => to.defaultArgument(constValue[ToIdx]) + case _: false => compiletime.error("An element has no equivalent or default") + + + inline def migrateElems[F,T, ToIdx <: Int](from: M.ProductOf[F], to: M.ProductOf[T])(x: Product): Seq[Any] = + inline erasedValue[ToIdx] match + case _: Tuple.Size[to.MirroredElemLabels] => Seq() + case _ => migrateElem[F,T,ToIdx](from, to)(x) +: migrateElems[F,T,S[ToIdx]](from, to)(x) + + inline def migrateProduct[F,T](from: M.ProductOf[F], to: M.ProductOf[T]) + (x: Product): T = + val elems = migrateElems[F, T, 0](from, to)(x) + to.fromProduct(new Product: + def canEqual(that: Any): Boolean = false + def productArity: Int = elems.length + def productElement(n: Int): Any = elems(n) + ) + + inline def migration[F,T](using from: M.Of[F], to: M.Of[T]): Migration[F,T] = (x: F) => + inline from match + case fromP: M.ProductOf[F] => inline to match + case toP: M.ProductOf[T] => migrateProduct[F, T](fromP, toP)(x.asInstanceOf[Product]) + case _: M.SumOf[T] => compiletime.error("Cannot migrate sums") + case _: M.SumOf[F] => compiletime.error("Cannot migrate sums") + +end Migration + + +import Migration.* +object Test extends App: + + case class A1(x: Int) + case class A2(x: Int) + given Migration[A1, A2] = migration + assert(A1(2).migrateTo[A2] == A2(2)) + + case class B1(x: Int, y: String) + case class B2(y: String, x: Int) + given Migration[B1, B2] = migration + assert(B1(5, "hi").migrateTo[B2] == B2("hi", 5)) + + case class C1(x: A1) + case class C2(x: A2) + given Migration[C1, C2] = migration + assert(C1(A1(0)).migrateTo[C2] == C2(A2(0))) + + case class D1(x: Double) + case class D2(b: Boolean = true, x: Double) + given Migration[D1, D2] = migration + assert(D1(9).migrateTo[D2] == D2(true, 9)) + + case class E1(x: D1, y: D1) + case class E2(y: D2, s: String = "hi", x: D2) + given Migration[E1, E2] = migration + assert(E1(D1(1), D1(2)).migrateTo[E2] == E2(D2(true, 2), "hi", D2(true, 1))) + + // should only use default when needed + case class F1(x: Int) + case class F2(x: Int = 3) + given Migration[F1, F2] = migration + assert(F1(7).migrateTo[F2] == F2(7)) +