Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Three fixes to SAM type handling #21596

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -315,24 +315,41 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def TypeDef(sym: TypeSymbol)(using Context): TypeDef =
ta.assignType(untpd.TypeDef(sym.name, TypeTree(sym.info)), sym)

def ClassDef(cls: ClassSymbol, constr: DefDef, body: List[Tree], superArgs: List[Tree] = Nil)(using Context): TypeDef = {
/** Create a class definition
* @param cls the class symbol of the created class
* @param constr its primary constructor
* @param body the statements in its template
* @param superArgs the arguments to pass to the superclass constructor
* @param adaptVarargs if true, allow matching a vararg superclass constructor
* with a missing argument in superArgs, and synthesize an
* empty repeated parameter in the supercall in this case
*/
def ClassDef(cls: ClassSymbol, constr: DefDef, body: List[Tree],
superArgs: List[Tree] = Nil, adaptVarargs: Boolean = false)(using Context): TypeDef =
val firstParent :: otherParents = cls.info.parents: @unchecked

def adaptedSuperArgs(ctpe: Type): List[Tree] = ctpe match
case ctpe: PolyType =>
adaptedSuperArgs(ctpe.instantiate(firstParent.argTypes))
case ctpe: MethodType
if ctpe.paramInfos.length == superArgs.length + 1 =>
// last argument must be a vararg, otherwise isApplicable would have failed
superArgs :+
repeated(Nil, TypeTree(ctpe.paramInfos.last.argInfos.head, inferred = true))
case _ =>
superArgs

val superRef =
if (cls.is(Trait)) TypeTree(firstParent)
else {
def isApplicable(ctpe: Type): Boolean = ctpe match {
case ctpe: PolyType =>
isApplicable(ctpe.instantiate(firstParent.argTypes))
case ctpe: MethodType =>
(superArgs corresponds ctpe.paramInfos)(_.tpe <:< _)
case _ =>
false
}
val constr = firstParent.decl(nme.CONSTRUCTOR).suchThat(constr => isApplicable(constr.info))
New(firstParent, constr.symbol.asTerm, superArgs)
}
if cls.is(Trait) then TypeTree(firstParent)
else
val parentConstr = firstParent.applicableConstructors(superArgs.tpes, adaptVarargs) match
case Nil => assert(false, i"no applicable parent constructor of $firstParent for supercall arguments $superArgs")
case constr :: Nil => constr
case _ => assert(false, i"multiple applicable parent constructors of $firstParent for supercall arguments $superArgs")
New(firstParent, parentConstr.asTerm, adaptedSuperArgs(parentConstr.info))

ClassDefWithParents(cls, constr, superRef :: otherParents.map(TypeTree(_)), body)
}
end ClassDef

def ClassDefWithParents(cls: ClassSymbol, constr: DefDef, parents: List[Tree], body: List[Tree])(using Context): TypeDef = {
val selfType =
Expand All @@ -359,13 +376,18 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* @param parents a non-empty list of class types
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
* @param adaptVarargs if true, allow matching a vararg superclass constructor
* with a missing argument in superArgs, and synthesize an
* empty repeated parameter in the supercall in this case
*
* The class has the same owner as the first function in `termForwarders`.
* Its position is the union of all symbols in `termForwarders`.
*/
def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)],
typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = {
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls =>
def AnonClass(parents: List[Type],
termForwarders: List[(TermName, TermSymbol)],
typeMembers: List[(TypeName, TypeBounds)],
adaptVarargs: Boolean)(using Context): Block = {
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _), adaptVarargs) { cls =>
def forwarder(name: TermName, fn: TermSymbol) = {
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
for overridden <- fwdMeth.allOverriddenSymbols do
Expand All @@ -385,6 +407,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* with the specified owner and position.
*/
def AnonClass(owner: Symbol, parents: List[Type], coord: Coord)(body: ClassSymbol => List[Tree])(using Context): Block =
AnonClass(owner, parents, coord, adaptVarargs = false)(body)

private def AnonClass(owner: Symbol, parents: List[Type], coord: Coord, adaptVarargs: Boolean)(body: ClassSymbol => List[Tree])(using Context): Block =
val parents1 =
if (parents.head.classSymbol.is(Trait)) {
val head = parents.head.parents.head
Expand All @@ -393,7 +418,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else parents
val cls = newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1, coord = coord)
val constr = newConstructor(cls, Synthetic, Nil, Nil).entered
val cdef = ClassDef(cls, DefDef(constr), body(cls))
val cdef = ClassDef(cls, DefDef(constr), body(cls), Nil, adaptVarargs)
Block(cdef :: Nil, New(cls.typeRef, Nil))

def Import(expr: Tree, selectors: List[untpd.ImportSelector])(using Context): Import =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ object Phases {
def sbtExtractAPIPhase(using Context): Phase = ctx.base.sbtExtractAPIPhase
def picklerPhase(using Context): Phase = ctx.base.picklerPhase
def inliningPhase(using Context): Phase = ctx.base.inliningPhase
def stagingPhase(using Context): Phase = ctx.base.stagingPhase
def stagingPhase(using Context): Phase = ctx.base.stagingPhase
def splicingPhase(using Context): Phase = ctx.base.splicingPhase
def firstTransformPhase(using Context): Phase = ctx.base.firstTransformPhase
def refchecksPhase(using Context): Phase = ctx.base.refchecksPhase
Expand Down
26 changes: 26 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Names.{Name, TermName}
import Constants.Constant

import Names.Name
import StdNames.nme
import config.Feature

class TypeUtils:
Expand Down Expand Up @@ -189,5 +190,30 @@ class TypeUtils:
def stripRefinement: Type = self match
case self: RefinedOrRecType => self.parent.stripRefinement
case seld => self

/** The constructors of this tyoe that that are applicable to `argTypes`, without needing
KacperFKorban marked this conversation as resolved.
Show resolved Hide resolved
* an implicit conversion.
* @param adaptVarargs if true, allow a constructor with just a varargs argument to
* match an empty argument list.
*/
def applicableConstructors(argTypes: List[Type], adaptVarargs: Boolean)(using Context): List[Symbol] =
def isApplicable(constr: Symbol): Boolean =
def recur(ctpe: Type): Boolean = ctpe match
case ctpe: PolyType =>
if argTypes.isEmpty then recur(ctpe.resultType) // no need to know instances
else recur(ctpe.instantiate(self.argTypes))
case ctpe: MethodType =>
KacperFKorban marked this conversation as resolved.
Show resolved Hide resolved
var paramInfos = ctpe.paramInfos
if adaptVarargs && paramInfos.length == argTypes.length + 1
&& atPhaseNoLater(Phases.elimRepeatedPhase)(constr.info.isVarArgsMethod)
then // accept missing argument for varargs parameter
paramInfos = paramInfos.init
argTypes.corresponds(paramInfos)(_ <:< _)
case _ =>
false
recur(constr.info)

self.decl(nme.CONSTRUCTOR).altsWith(isApplicable).map(_.symbol)

end TypeUtils

21 changes: 11 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5944,17 +5944,18 @@ object Types extends TypeUtils {

def samClass(tp: Type)(using Context): Symbol = tp match
case tp: ClassInfo =>
def zeroParams(tp: Type): Boolean = tp.stripPoly match
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
case et: ExprType => true
case _ => false
val cls = tp.cls
val validCtor =
val ctor = cls.primaryConstructor
// `ContextFunctionN` does not have constructors
!ctor.exists || zeroParams(ctor.info)
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if validCtor && isInstantiable then tp.cls
def takesNoArgs(tp: Type) =
!tp.classSymbol.primaryConstructor.exists
// e.g. `ContextFunctionN` does not have constructors
|| tp.applicableConstructors(Nil, adaptVarargs = true).lengthCompare(1) == 0
// we require a unique constructor so that SAM expansion is deterministic
val noArgsNeeded: Boolean =
takesNoArgs(tp)
&& (!tp.cls.is(Trait) || takesNoArgs(tp.parents.head))
def isInstantiable =
!tp.cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if noArgsNeeded && isInstantiable then tp.cls
else NoSymbol
case tp: AppliedType =>
samClass(tp.superType)
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ class ExpandSAMs extends MiniPhase:
val tpe1 = collectAndStripRefinements(tpe)
val Seq(samDenot) = tpe1.possibleSamMethods
cpy.Block(tree)(stats,
AnonClass(List(tpe1),
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
refinements.toList
)
transformFollowingDeep:
AnonClass(List(tpe1),
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
refinements.toList,
adaptVarargs = true
)
)
}
case _ =>
Expand Down
10 changes: 10 additions & 0 deletions tests/neg/i15855.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// crash.scala
import scala.language.implicitConversions
KacperFKorban marked this conversation as resolved.
Show resolved Hide resolved

class MyFunction(args: String)

trait MyFunction0[+R] extends MyFunction {
def apply(): R
}

def fromFunction0[R](f: Function0[R]): MyFunction0[R] = () => f() // error
20 changes: 20 additions & 0 deletions tests/run/i15855.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// crash.scala
import scala.language.implicitConversions

class MyFunction(args: String*)

trait MyFunction0[+R] extends MyFunction {
def apply(): R
}

abstract class MyFunction1[R](args: R*):
def apply(): R

def fromFunction0[R](f: Function0[R]): MyFunction0[R] = () => f()
def fromFunction1[R](f: Function0[R]): MyFunction1[R] = () => f()

@main def Test =
val m0: MyFunction0[Int] = fromFunction0(() => 1)
val m1: MyFunction1[Int] = fromFunction1(() => 2)
assert(m0() == 1)
assert(m1() == 2)
Loading