diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 1f1dd8b5a57a..6edf3846dfb3 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -55,7 +55,7 @@ object desugar { /** An attachment key to indicate that a DefDef is a poly function apply * method definition. */ - val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey() + val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey() /** What static check should be applied to a Match? */ enum MatchCheck { @@ -514,17 +514,25 @@ object desugar { case Nil => params :: Nil + // TODO(kπ) is this enough? SHould this be a TreeTraverse-thing? + def pushDownEvidenceParams(tree: Tree): Tree = tree match + case Function(params, body) => + cpy.Function(tree)(params, pushDownEvidenceParams(body)) + case Block(stats, expr) => + cpy.Block(tree)(stats, pushDownEvidenceParams(expr)) + case tree => + val paramTpts = params.map(_.tpt) + val paramNames = params.map(_.name) + val paramsErased = params.map(_.mods.flags.is(Erased)) + makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span) + if meth.hasAttachment(PolyFunctionApply) then meth.removeAttachment(PolyFunctionApply) - val paramTpts = params.map(_.tpt) - val paramNames = params.map(_.name) - val paramsErased = params.map(_.mods.flags.is(Erased)) + // (kπ): deffer this until we can type the result? if ctx.mode.is(Mode.Type) then - val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased) - cpy.DefDef(meth)(tpt = ctxFunction) + cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params)) else - val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased) - cpy.DefDef(meth)(rhs = ctxFunction) + cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs)) else cpy.DefDef(meth)(paramss = recur(meth.paramss)) end addEvidenceParams @@ -1251,7 +1259,7 @@ object desugar { RefinedTypeTree(ref(defn.PolyFunctionType), List( DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) .withFlags(Synthetic) - .withAttachment(PolyFunctionApply, ()) + .withAttachment(PolyFunctionApply, List.empty) )).withSpan(tree.span) end makePolyFunctionType diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 637984fe22ba..24d9c7d591e1 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -53,7 +53,6 @@ import config.MigrationVersion import transform.CheckUnused.OriginalName import scala.annotation.constructorOnly -import dotty.tools.dotc.ast.desugar.PolyFunctionApply object Typer { @@ -1958,7 +1957,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) - defdef.putAttachment(PolyFunctionApply, ()) + defdef.putAttachment(desugar.PolyFunctionApply, List.empty) typed(desugared, pt) else val msg = @@ -1967,7 +1966,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer errorTree(EmptyTree, msg, tree.srcPos) case _ => val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) - defdef.putAttachment(PolyFunctionApply, ()) + defdef.putAttachment(desugar.PolyFunctionApply, List.empty) typed(desugared, pt) end typedPolyFunctionValue @@ -3580,30 +3579,57 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case xtree => typedUnnamed(xtree) val unsimplifiedType = result.tpe - simplify(result, pt, locked) - result.tpe.stripTypeVar match + val result1 = simplify(result, pt, locked) + result1.tpe.stripTypeVar match case e: ErrorType if !unsimplifiedType.isErroneous => errorTree(xtree, e.msg, xtree.srcPos) - case _ => result + case _ => result1 catch case ex: TypeError => handleTypeError(ex) } } + private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match { + case tpe: MethodType => + MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + case tpe: PolyType => + PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span)) + case tpe: RefinedType => + // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement + RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)) + case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 => + AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span)) + case tpe => + val paramNames = params.map(_.name) + val paramTpts = params.map(_.tpt) + val paramsErased = params.map(_.mods.flags.is(Erased)) + val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span) + typed(ctxFunction).tpe + } + + private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = { + tree.getAttachment(desugar.PolyFunctionApply) match + case Some(params) if params.nonEmpty => + tree.removeAttachment(desugar.PolyFunctionApply) + val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span) + TypeTree(tpe).withSpan(tree.span) -> tpe + case _ => tree -> pt + } + /** Interpolate and simplify the type of the given tree. */ - protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type = - if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying - if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied - || tree.isDef // ... unless tree is a definition + protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = + val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt) + if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying + if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied + || tree1.isDef // ... unless tree is a definition then - interpolateTypeVars(tree, pt, locked) - val simplified = tree.tpe.simplified - if !MatchType.thatReducesUsingGadt(tree.tpe) then // needs a GADT cast. i15743 + interpolateTypeVars(tree1, pt1, locked) + val simplified = tree1.tpe.simplified + if !MatchType.thatReducesUsingGadt(tree1.tpe) then // needs a GADT cast. i15743 tree.overwriteType(simplified) - tree + tree1 protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked - println(i"make contextual function $tree / $pt") val paramNamesOrNil = pt match case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames case _ => Nil diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index a5a035754b08..adaf6c035406 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -28,9 +28,12 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 -// type CmpWeak[X] = (x: X, y: X) => Boolean -// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X] -// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +type CmpWeak[X] = X => Boolean +type Comparer2Weak = [X: Ord] => X => CmpWeak[X] +val less4_0: [X: Ord] => X => X => Boolean = + [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 +val less4: Comparer2Weak = + [X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0 val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0