Skip to content

Commit

Permalink
Change the implementation of context bound expansion for poly functio…
Browse files Browse the repository at this point in the history
…ns to reuse some of the existing context bound expansion
  • Loading branch information
KacperFKorban committed Oct 2, 2024
1 parent 8b72b1e commit 458fd29
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 55 deletions.
58 changes: 23 additions & 35 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ object desugar {
*/
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()

/** An attachment key to indicate that a DefDef is a poly function apply
* method definition.
*/
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
Expand Down Expand Up @@ -337,7 +342,8 @@ object desugar {
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = paramssNoContextBounds),
evidenceParamBuf.toList)
evidenceParamBuf.toList
)
end elimContextBounds

def addDefaultGetters(meth: DefDef)(using Context): Tree =
Expand Down Expand Up @@ -508,7 +514,19 @@ object desugar {
case Nil =>
params :: Nil

cpy.DefDef(meth)(paramss = recur(meth.paramss))
if meth.hasAttachment(PolyFunctionApply) then
meth.removeAttachment(PolyFunctionApply)
val paramTpts = params.map(_.tpt)
val paramNames = params.map(_.name)
val paramsErased = params.map(_.mods.flags.is(Erased))
if ctx.mode.is(Mode.Type) then
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased)
cpy.DefDef(meth)(tpt = ctxFunction)
else
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased)
cpy.DefDef(meth)(rhs = ctxFunction)
else
cpy.DefDef(meth)(paramss = recur(meth.paramss))
end addEvidenceParams

/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
Expand Down Expand Up @@ -1209,38 +1227,6 @@ object desugar {
case _ => body
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]

/** Desugar [T_1 : B_1, ..., T_N : B_N] => (P_1, ..., P_M) => R
* Into [T_1, ..., T_N] => (P_1, ..., P_M) => (B_1, ..., B_N) ?=> R
*/
def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked
val newTParams = tparams.mapConserve {
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) =>
cpy.TypeDef(td)(name, bounds)
case t => t
}
var idx = 0
val collectedContextBounds = tparams.collect {
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty =>
name -> ctxBounds
}.flatMap { case (name, ctxBounds) =>
ctxBounds.map { ctxBound =>
val ContextBoundTypeTree(tycon, paramName, ownName) = ctxBound: @unchecked
if tree.isTerm then
ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given)
else
ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType
}
}
val contextFunctionResult =
if collectedContextBounds.isEmpty then fun
else
val mods = EmptyModifiers.withFlags(Given)
val erasedParams = collectedContextBounds.map(_ => false)
Function(vparamTypes, FunctionWithMods(collectedContextBounds, res, mods, erasedParams)).withSpan(fun.span)
if collectedContextBounds.isEmpty then tree
else PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)

/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
Expand All @@ -1263,7 +1249,9 @@ object desugar {
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, ())
)).withSpan(tree.span)
end makePolyFunctionType

Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -877,16 +877,16 @@ class Namer { typer: Typer =>
protected def addAnnotations(sym: Symbol): Unit = original match {
case original: untpd.MemberDef =>
lazy val annotCtx = annotContext(original, sym)
original.setMods:
original.setMods:
original.mods.withAnnotations :
original.mods.annotations.mapConserve: annotTree =>
original.mods.annotations.mapConserve: annotTree =>
val cls = typedAheadAnnotationClass(annotTree)(using annotCtx)
if (cls eq sym)
report.error(em"An annotation class cannot be annotated with iself", annotTree.srcPos)
annotTree
else
val ann =
if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass)
val ann =
if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass)
else annotTree
val ann1 = Annotation.deferred(cls)(typedAheadExpr(ann)(using annotCtx))
sym.addAnnotation(ann1)
Expand Down
34 changes: 18 additions & 16 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import annotation.tailrec
import Implicits.*
import util.Stats.record
import config.Printers.{gadts, typr}
import config.Feature, Feature.{migrateTo3, sourceVersion, warnOnMigration}
import config.Feature, Feature.{migrateTo3, modularity, sourceVersion, warnOnMigration}
import config.SourceVersion.*
import rewrites.Rewrites, Rewrites.patch
import staging.StagingLevel
Expand All @@ -53,6 +53,7 @@ import config.MigrationVersion
import transform.CheckUnused.OriginalName

import scala.annotation.constructorOnly
import dotty.tools.dotc.ast.desugar.PolyFunctionApply

object Typer {

Expand Down Expand Up @@ -1142,7 +1143,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if templ1.parents.isEmpty
&& isFullyDefined(pt, ForceDegree.flipBottom)
&& isSkolemFree(pt)
&& isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity)))
&& isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(modularity)))
then
templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil)
for case parent: RefTree <- templ1.parents do
Expand Down Expand Up @@ -1717,11 +1718,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt)
else
val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure)
val args1 = args.mapConserve {
case cb: untpd.ContextBoundTypeTree => typed(cb)
case t => t
}
val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args1 :+ body), pt)
// val args1 = args.mapConserve {
// case cb: untpd.ContextBoundTypeTree => typed(cb)
// case t => t
// }
val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt)
// if there are any erased classes, we need to re-do the typecheck.
result match
case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) =>
Expand Down Expand Up @@ -1930,10 +1931,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val tree1 = desugar.normalizePolyFunction(tree)
val tree2 = if Feature.enabled(Feature.modularity) then desugar.expandPolyFunctionContextBounds(tree1)
else tree1
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt)
else typedPolyFunctionValue(tree2, pt)
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
else typedPolyFunctionValue(tree1, pt)

def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
Expand All @@ -1958,15 +1957,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val resultTpt =
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
val desugared = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
defdef.putAttachment(PolyFunctionApply, ())
typed(desugared, pt)
else
val msg =
em"""|Provided polymorphic function value doesn't match the expected type $dpt.
|Expected type should be a polymorphic function with the same number of type and value parameters."""
errorTree(EmptyTree, msg, tree.srcPos)
case _ =>
val desugared = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
defdef.putAttachment(PolyFunctionApply, ())
typed(desugared, pt)
end typedPolyFunctionValue

Expand Down Expand Up @@ -2463,12 +2464,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if tycon.tpe.typeParams.nonEmpty then
val tycon0 = tycon.withType(tycon.tpe.etaCollapse)
typed(untpd.AppliedTypeTree(spliced(tycon0), tparam :: Nil))
else if Feature.enabled(Feature.modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then
else if Feature.enabled(modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then
val tparamSplice = untpd.TypedSplice(typedExpr(tparam))
typed(untpd.RefinedTypeTree(spliced(tycon), List(untpd.TypeDef(tpnme.Self, tparamSplice))))
else
def selfNote =
if Feature.enabled(Feature.modularity) then
if Feature.enabled(modularity) then
" and\ndoes not have an abstract type member named `Self` either"
else ""
errorTree(tree,
Expand Down Expand Up @@ -3602,6 +3603,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked
println(i"make contextual function $tree / $pt")
val paramNamesOrNil = pt match
case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames
case _ => Nil
Expand Down Expand Up @@ -4697,7 +4699,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName)
case _ =>
errorTree(tree, em"cannot convert from $tree to an instance creation expression")
val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity))
val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(modularity))
typed(
untpd.Select(
untpd.New(untpd.TypedSplice(tpt.withType(tycon))),
Expand Down
27 changes: 27 additions & 0 deletions tests/pos/contextbounds-for-poly-functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@ trait Ord[X]:
trait Show[X]:
def show(x: X): String

val less0: [X: Ord] => (X, X) => Boolean = ???

val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less1_type_test: [X: Ord] => (X, X) => Boolean =
[X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0

val less2_type_test: [X: Ord as ord] => (X, X) => Boolean =
[X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0

type CtxFunctionRef = Ord[Int] ?=> Boolean
type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean
type Comparer = [X: Ord] => (x: X, y: X) => Boolean
Expand All @@ -20,12 +28,31 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
// type Comparer2 = [X: Ord] => Cmp[X]
// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

// type CmpWeak[X] = (x: X, y: X) => Boolean
// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X]
// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less5_type_test: [X: [X] =>> Ord[X]] => (X, X) => Boolean =
[X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less6 = [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less6_type_test: [X: {Ord, Show}] => (X, X) => Boolean =
[X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less7 = [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0

val less7_type_test: [X: {Ord as ord, Show}] => (X, X) => Boolean =
[X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0

val less8 = [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less8_type_test: [X: {Ord, Show as show}] => (X, X) => Boolean =
[X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less9 = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0

val less9_type_test: [X: {Ord as ord, Show as show}] => (X, X) => Boolean =
[X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0

0 comments on commit 458fd29

Please sign in to comment.