diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 15e3c90d6f72..1f1dd8b5a57a 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -52,6 +52,11 @@ object desugar { */ val ContextBoundParam: Property.Key[Unit] = Property.StickyKey() + /** An attachment key to indicate that a DefDef is a poly function apply + * method definition. + */ + val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey() + /** What static check should be applied to a Match? */ enum MatchCheck { case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom @@ -337,7 +342,8 @@ object desugar { cpy.DefDef(meth)( name = normalizeName(meth, tpt).asTermName, paramss = paramssNoContextBounds), - evidenceParamBuf.toList) + evidenceParamBuf.toList + ) end elimContextBounds def addDefaultGetters(meth: DefDef)(using Context): Tree = @@ -508,7 +514,19 @@ object desugar { case Nil => params :: Nil - cpy.DefDef(meth)(paramss = recur(meth.paramss)) + if meth.hasAttachment(PolyFunctionApply) then + meth.removeAttachment(PolyFunctionApply) + val paramTpts = params.map(_.tpt) + val paramNames = params.map(_.name) + val paramsErased = params.map(_.mods.flags.is(Erased)) + if ctx.mode.is(Mode.Type) then + val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased) + cpy.DefDef(meth)(tpt = ctxFunction) + else + val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased) + cpy.DefDef(meth)(rhs = ctxFunction) + else + cpy.DefDef(meth)(paramss = recur(meth.paramss)) end addEvidenceParams /** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */ @@ -1209,38 +1227,6 @@ object desugar { case _ => body cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction] - /** Desugar [T_1 : B_1, ..., T_N : B_N] => (P_1, ..., P_M) => R - * Into [T_1, ..., T_N] => (P_1, ..., P_M) => (B_1, ..., B_N) ?=> R - */ - def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction = - val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked - val newTParams = tparams.mapConserve { - case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) => - cpy.TypeDef(td)(name, bounds) - case t => t - } - var idx = 0 - val collectedContextBounds = tparams.collect { - case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty => - name -> ctxBounds - }.flatMap { case (name, ctxBounds) => - ctxBounds.map { ctxBound => - val ContextBoundTypeTree(tycon, paramName, ownName) = ctxBound: @unchecked - if tree.isTerm then - ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) - else - ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType - } - } - val contextFunctionResult = - if collectedContextBounds.isEmpty then fun - else - val mods = EmptyModifiers.withFlags(Given) - val erasedParams = collectedContextBounds.map(_ => false) - Function(vparamTypes, FunctionWithMods(collectedContextBounds, res, mods, erasedParams)).withSpan(fun.span) - if collectedContextBounds.isEmpty then tree - else PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span) - /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } */ @@ -1263,7 +1249,9 @@ object desugar { }.toList RefinedTypeTree(ref(defn.PolyFunctionType), List( - DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic) + DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree) + .withFlags(Synthetic) + .withAttachment(PolyFunctionApply, ()) )).withSpan(tree.span) end makePolyFunctionType diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 6167db62fbe0..0849e57b8c7d 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -877,16 +877,16 @@ class Namer { typer: Typer => protected def addAnnotations(sym: Symbol): Unit = original match { case original: untpd.MemberDef => lazy val annotCtx = annotContext(original, sym) - original.setMods: + original.setMods: original.mods.withAnnotations : - original.mods.annotations.mapConserve: annotTree => + original.mods.annotations.mapConserve: annotTree => val cls = typedAheadAnnotationClass(annotTree)(using annotCtx) if (cls eq sym) report.error(em"An annotation class cannot be annotated with iself", annotTree.srcPos) annotTree else - val ann = - if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass) + val ann = + if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass) else annotTree val ann1 = Annotation.deferred(cls)(typedAheadExpr(ann)(using annotCtx)) sym.addAnnotation(ann1) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 4df9e85044b7..637984fe22ba 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -40,7 +40,7 @@ import annotation.tailrec import Implicits.* import util.Stats.record import config.Printers.{gadts, typr} -import config.Feature, Feature.{migrateTo3, sourceVersion, warnOnMigration} +import config.Feature, Feature.{migrateTo3, modularity, sourceVersion, warnOnMigration} import config.SourceVersion.* import rewrites.Rewrites, Rewrites.patch import staging.StagingLevel @@ -53,6 +53,7 @@ import config.MigrationVersion import transform.CheckUnused.OriginalName import scala.annotation.constructorOnly +import dotty.tools.dotc.ast.desugar.PolyFunctionApply object Typer { @@ -1142,7 +1143,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if templ1.parents.isEmpty && isFullyDefined(pt, ForceDegree.flipBottom) && isSkolemFree(pt) - && isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity))) + && isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(modularity))) then templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil) for case parent: RefTree <- templ1.parents do @@ -1717,11 +1718,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt) else val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure) - val args1 = args.mapConserve { - case cb: untpd.ContextBoundTypeTree => typed(cb) - case t => t - } - val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args1 :+ body), pt) + // val args1 = args.mapConserve { + // case cb: untpd.ContextBoundTypeTree => typed(cb) + // case t => t + // } + val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt) // if there are any erased classes, we need to re-do the typecheck. result match case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) => @@ -1930,10 +1931,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val tree1 = desugar.normalizePolyFunction(tree) - val tree2 = if Feature.enabled(Feature.modularity) then desugar.expandPolyFunctionContextBounds(tree1) - else tree1 - if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt) - else typedPolyFunctionValue(tree2, pt) + if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt) + else typedPolyFunctionValue(tree1, pt) def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked @@ -1958,7 +1957,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val resultTpt = untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) - val desugared = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) + val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span) + defdef.putAttachment(PolyFunctionApply, ()) typed(desugared, pt) else val msg = @@ -1966,7 +1966,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer |Expected type should be a polymorphic function with the same number of type and value parameters.""" errorTree(EmptyTree, msg, tree.srcPos) case _ => - val desugared = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) + val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span) + defdef.putAttachment(PolyFunctionApply, ()) typed(desugared, pt) end typedPolyFunctionValue @@ -2463,12 +2464,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if tycon.tpe.typeParams.nonEmpty then val tycon0 = tycon.withType(tycon.tpe.etaCollapse) typed(untpd.AppliedTypeTree(spliced(tycon0), tparam :: Nil)) - else if Feature.enabled(Feature.modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then + else if Feature.enabled(modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then val tparamSplice = untpd.TypedSplice(typedExpr(tparam)) typed(untpd.RefinedTypeTree(spliced(tycon), List(untpd.TypeDef(tpnme.Self, tparamSplice)))) else def selfNote = - if Feature.enabled(Feature.modularity) then + if Feature.enabled(modularity) then " and\ndoes not have an abstract type member named `Self` either" else "" errorTree(tree, @@ -3602,6 +3603,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked + println(i"make contextual function $tree / $pt") val paramNamesOrNil = pt match case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames case _ => Nil @@ -4697,7 +4699,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName) case _ => errorTree(tree, em"cannot convert from $tree to an instance creation expression") - val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity)) + val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(modularity)) typed( untpd.Select( untpd.New(untpd.TypedSplice(tpt.withType(tycon))), diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 7db41628e57d..a5a035754b08 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -7,10 +7,18 @@ trait Ord[X]: trait Show[X]: def show(x: X): String +val less0: [X: Ord] => (X, X) => Boolean = ??? + val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less1_type_test: [X: Ord] => (X, X) => Boolean = + [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +val less2_type_test: [X: Ord as ord] => (X, X) => Boolean = + [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 + type CtxFunctionRef = Ord[Int] ?=> Boolean type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean type Comparer = [X: Ord] => (x: X, y: X) => Boolean @@ -20,12 +28,31 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Comparer2 = [X: Ord] => Cmp[X] // val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +// type CmpWeak[X] = (x: X, y: X) => Boolean +// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X] +// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less5_type_test: [X: [X] =>> Ord[X]] => (X, X) => Boolean = + [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less6 = [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less6_type_test: [X: {Ord, Show}] => (X, X) => Boolean = + [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less7 = [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0 +val less7_type_test: [X: {Ord as ord, Show}] => (X, X) => Boolean = + [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0 + val less8 = [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +val less8_type_test: [X: {Ord, Show as show}] => (X, X) => Boolean = + [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 + val less9 = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0 + +val less9_type_test: [X: {Ord as ord, Show as show}] => (X, X) => Boolean = + [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0