Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type inference for functions like fold #18780

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Flags.*
import config.Config
import config.Printers.typr
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import transform.TypeUtils.isTransparent
import UnificationDirection.*
import NameKinds.AvoidNameKind
import util.SimpleIdentitySet
Expand Down Expand Up @@ -566,13 +567,6 @@ trait ConstraintHandling {
inst
end approximation

private def isTransparent(tp: Type, traitOnly: Boolean)(using Context): Boolean = tp match
case AndType(tp1, tp2) =>
isTransparent(tp1, traitOnly) && isTransparent(tp2, traitOnly)
case _ =>
val cls = tp.underlyingClassRef(refinementOK = false).typeSymbol
cls.isTransparentClass && (!traitOnly || cls.is(Trait))

/** If `tp` is an intersection such that some operands are transparent trait instances
* and others are not, replace as many transparent trait instances as possible with Any
* as long as the result is still a subtype of `bound`. But fall back to the
Expand All @@ -585,7 +579,7 @@ trait ConstraintHandling {
var dropped: List[Type] = List() // the types dropped so far, last one on top

def dropOneTransparentTrait(tp: Type): Type =
if isTransparent(tp, traitOnly = true) && !kept.contains(tp) then
if tp.isTransparent(traitOnly = true) && !kept.contains(tp) then
dropped = tp :: dropped
defn.AnyType
else tp match
Expand Down Expand Up @@ -658,7 +652,7 @@ trait ConstraintHandling {
def widenOr(tp: Type) =
if widenUnions then
val tpw = tp.widenUnion
if (tpw ne tp) && !isTransparent(tpw, traitOnly = false) && (tpw <:< bound) then tpw else tp
if (tpw ne tp) && !tpw.isTransparent() && (tpw <:< bound) then tpw else tp
else tp.hardenUnions

def widenSingle(tp: Type) =
Expand Down
14 changes: 10 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4908,6 +4908,9 @@ object Types {
tp
}

def typeToInstantiateWith(fromBelow: Boolean)(using Context): Type =
TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)

/** Instantiate variable from the constraints over its `origin`.
* If `fromBelow` is true, the variable is instantiated to the lub
* of its lower bounds in the current constraint; otherwise it is
Expand All @@ -4916,8 +4919,9 @@ object Types {
* is also a singleton type.
*/
def instantiate(fromBelow: Boolean)(using Context): Type =
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
val tp = typeToInstantiateWith(fromBelow)
if myInst.exists then // The line above might have triggered instantiation of the current type variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check used to be necessary so I would prefer to keep it now that the assert in instantiateWith has been added back, unless we have some strong reason to believe it no longer is.

Member
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stray line?

odersky marked this conversation as resolved.
Show resolved Hide resolved
myInst
else
instantiateWith(tp)
Expand Down Expand Up @@ -5812,11 +5816,13 @@ object Types {
protected def derivedLambdaType(tp: LambdaType)(formals: List[tp.PInfo], restpe: Type): Type =
tp.derivedLambdaType(tp.paramNames, formals, restpe)

protected def mapArg(arg: Type, tparam: ParamInfo): Type = arg match
case arg: TypeBounds => this(arg)
case arg => atVariance(variance * tparam.paramVarianceSign)(this(arg))

protected def mapArgs(args: List[Type], tparams: List[ParamInfo]): List[Type] = args match
case arg :: otherArgs if tparams.nonEmpty =>
val arg1 = arg match
case arg: TypeBounds => this(arg)
case arg => atVariance(variance * tparams.head.paramVarianceSign)(this(arg))
val arg1 = mapArg(arg, tparams.head)
val otherArgs1 = mapArgs(otherArgs, tparams.tail)
if ((arg1 eq arg) && (otherArgs1 eq otherArgs)) args
else arg1 :: otherArgs1
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/dotty/tools/dotc/transform/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@ package transform

import core.*
import TypeErasure.ErasedValueType
import Types.*
import Contexts.*
import Symbols.*
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
import Names.Name

import dotty.tools.dotc.core.Decorators.*

object TypeUtils {
/** A decorator that provides methods on types
* that are needed in the transformer pipeline.
Expand Down Expand Up @@ -98,5 +94,15 @@ object TypeUtils {
def takesImplicitParams(using Context): Boolean = self.stripPoly match
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
case _ => false

/** Is this a type deriving only from transparent classes?
* @param traitOnly if true, all class symbols must be transparent traits
*/
def isTransparent(traitOnly: Boolean = false)(using Context): Boolean = self match
case AndType(tp1, tp2) =>
tp1.isTransparent(traitOnly) && tp2.isTransparent(traitOnly)
case _ =>
val cls = self.underlyingClassRef(refinementOK = false).typeSymbol
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
}
}
111 changes: 91 additions & 20 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import ProtoTypes.*
import NameKinds.UniqueName
import util.Spans.*
import util.{Stats, SimpleIdentityMap, SimpleIdentitySet, SrcPos}
import Decorators.*
import transform.TypeUtils.isTransparent
import Decorators._
import config.Printers.{gadts, typr}
import annotation.tailrec
import reporting.*
Expand Down Expand Up @@ -60,7 +61,9 @@ object Inferencing {
def instantiateSelected(tp: Type, tvars: List[Type])(using Context): Unit =
if (tvars.nonEmpty)
IsFullyDefinedAccumulator(
ForceDegree.Value(tvars.contains, IfBottom.flip), minimizeSelected = true
new ForceDegree.Value(IfBottom.flip):
override def appliesTo(tvar: TypeVar) = tvars.contains(tvar),
minimizeSelected = true
).process(tp)

/** Instantiate any type variables in `tp` whose bounds contain a reference to
Expand Down Expand Up @@ -154,15 +157,66 @@ object Inferencing {
* their lower bound. Record whether successful.
* 2nd Phase: If first phase was successful, instantiate all remaining type variables
* to their upper bound.
*
* Instance types can be improved by replacing covariant occurrences of Nothing
* with fresh type variables, if `force` allows this in its `canImprove` implementation.
*/
private class IsFullyDefinedAccumulator(force: ForceDegree.Value, minimizeSelected: Boolean = false)
(using Context) extends TypeAccumulator[Boolean] {

private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
/** Replace toplevel-covariant occurrences (i.e. covariant without double flips)
smarter marked this conversation as resolved.
Show resolved Hide resolved
* of Nothing by fresh type variables. Double-flips are not covered to be
* conservative and save a bit of time on traversals; we could probably
* generalize that if we see use cases.
* For singleton types and references to module classes: try to
* improve the widened type. For module classes, the widened type
* is the intersection of all its non-transparent parent types.
*/
private def improve(tvar: TypeVar) = new TypeMap:
def apply(t: Type) = trace(i"improve $t", show = true):
def tryWidened(widened: Type): Type =
val improved = apply(widened)
if improved ne widened then improved else mapOver(t)
if variance > 0 then
t match
case t: TypeRef =>
if t.symbol == defn.NothingClass then
newTypeVar(TypeBounds.empty, nestingLevel = tvar.nestingLevel)
smarter marked this conversation as resolved.
Show resolved Hide resolved
else if t.symbol.is(ModuleClass) then
tryWidened(t.parents.filter(!_.isTransparent())
.foldLeft(defn.AnyType: Type)(TypeComparer.andType(_, _)))
else
mapOver(t)
case t: TermRef =>
tryWidened(t.widen)
case _ =>
mapOver(t)
else t

// Don't map Nothing arguments for higher-kinded types; we'd get the wrong kind */
override def mapArg(arg: Type, tparam: ParamInfo): Type =
if tparam.paramInfo.isLambdaSub then arg
else super.mapArg(arg, tparam)
end improve

/** Instantiate type variable with possibly improved computed instance type.
* @return true if variable was instantiated with improved type, which
* in this case should not be instantiated further, false otherwise.
*/
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Boolean =
if fromBelow && force.canImprove(tvar) then
val inst = tvar.typeToInstantiateWith(fromBelow = true)
if apply(true, inst) then
// need to recursively check before improving, since improving adds type vars
// which should not be instantiated at this point
val better = improve(tvar)(inst)
if better <:< TypeComparer.fullUpperBound(tvar.origin) then
typr.println(i"forced instantiation of invariant ${tvar.origin} = $inst, improved to $better")
tvar.instantiateWith(better)
return true
val inst = tvar.instantiate(fromBelow)
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
inst
}
false

private var toMaximize: List[TypeVar] = Nil

Expand All @@ -178,26 +232,27 @@ object Inferencing {
&& ctx.typerState.constraint.contains(tvar)
&& {
var fail = false
var skip = false
val direction = instDirection(tvar.origin)
if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if direction >= 0 && tvar.hasUpperBound then
instantiate(tvar, fromBelow = false)
skip = instantiate(tvar, fromBelow = false)
// else hold off instantiating unbounded unconstrained variable
else if direction != 0 then
instantiate(tvar, fromBelow = direction < 0)
skip = instantiate(tvar, fromBelow = direction < 0)
else if variance >= 0 && tvar.hasLowerBound then
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
&& force.ifBottom == IfBottom.ok
then // if variance == 0, prefer upper bound if one is given
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if variance >= 0 && force.ifBottom == IfBottom.fail then
fail = true
else
toMaximize = tvar :: toMaximize
!fail && foldOver(x, tvar)
!fail && (skip || foldOver(x, tvar))
}
case tp => foldOver(x, tp)
}
Expand Down Expand Up @@ -467,7 +522,7 @@ object Inferencing {
*
* we want to instantiate U to x.type right away. No need to wait further.
*/
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
Stats.record("variances")
val constraint = ctx.typerState.constraint

Expand Down Expand Up @@ -769,14 +824,30 @@ trait Inferencing { this: Typer =>
}

/** An enumeration controlling the degree of forcing in "is-fully-defined" checks. */
@sharable object ForceDegree {
class Value(val appliesTo: TypeVar => Boolean, val ifBottom: IfBottom):
override def toString = s"ForceDegree.Value(.., $ifBottom)"
val none: Value = new Value(_ => false, IfBottom.ok) { override def toString = "ForceDegree.none" }
val all: Value = new Value(_ => true, IfBottom.ok) { override def toString = "ForceDegree.all" }
val failBottom: Value = new Value(_ => true, IfBottom.fail) { override def toString = "ForceDegree.failBottom" }
val flipBottom: Value = new Value(_ => true, IfBottom.flip) { override def toString = "ForceDegree.flipBottom" }
}
@sharable object ForceDegree:
class Value(val ifBottom: IfBottom):

/** Does `tv` need to be instantiated? */
def appliesTo(tv: TypeVar): Boolean = true

/** Should we try to improve the computed instance type by replacing bottom types
* with fresh type variables?
*/
def canImprove(tv: TypeVar): Boolean = false

override def toString = s"ForceDegree.Value($ifBottom)"
end Value

val none: Value = new Value(IfBottom.ok):
override def appliesTo(tv: TypeVar) = false
override def toString = "ForceDegree.none"
val all: Value = new Value(IfBottom.ok):
override def toString = "ForceDegree.all"
val failBottom: Value = new Value(IfBottom.fail):
override def toString = "ForceDegree.failBottom"
val flipBottom: Value = new Value(IfBottom.flip):
override def toString = "ForceDegree.flipBottom"
end ForceDegree

enum IfBottom:
case ok, fail, flip
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1622,14 +1622,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case _ =>

if desugared.isEmpty then
val forceDegree =
if pt.isValueType then
// Allow variables that appear invariantly in `pt` to be improved by mapping
// bottom types in their instance types to fresh type variables
new ForceDegree.Value(IfBottom.fail):
val tvmap = variances(pt)
override def canImprove(tvar: TypeVar) =
tvmap.computedVariance(tvar) == (0: Integer)
else
ForceDegree.failBottom

val inferredParams: List[untpd.ValDef] =
for ((param, i) <- params.zipWithIndex) yield
if (!param.tpt.isEmpty) param
else
val (formalBounds, isErased) = protoFormal(i)
val formal = formalBounds.loBound
val isBottomFromWildcard = (formalBounds ne formal) && formal.isExactlyNothing
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
val knownFormal = isFullyDefined(formal, forceDegree)
// If the expected formal is a TypeBounds wildcard argument with Nothing as lower bound,
// try to prioritize inferring from target. See issue 16405 (tests/run/16405.scala)
val paramType =
Expand Down
7 changes: 7 additions & 0 deletions tests/neg/foldinf-ill-kinded.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg/foldinf-ill-kinded.scala:9:16 -------------------------------------------------
9 | ys.combine(x) // error
| ^^^^^^^^^^^^^
| Found: Foo[List]
| Required: Foo[Nothing]
|
| longer explanation available when compiling with `-explain`
10 changes: 10 additions & 0 deletions tests/neg/foldinf-ill-kinded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Foo[+T[_]]:
def combine[T1[x] >: T[x]](x: T1[Int]): Foo[T1] = new Foo
object Foo:
def empty: Foo[Nothing] = new Foo

object X:
def test(xs: List[List[Int]]): Unit =
xs.foldLeft(Foo.empty)((ys, x) =>
ys.combine(x) // error
)
34 changes: 34 additions & 0 deletions tests/pos/folds.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

object Test:
extension [A](xs: List[A])
def foldl[B](acc: B)(f: (A, B) => B): B = ???

val xs = List(1, 2, 3)

val _ = xs.foldl(List())((y, ys) => y :: ys)

val _ = xs.foldl(Nil)((y, ys) => y :: ys)

def partition[a](xs: List[a], pred: a => Boolean): Tuple2[List[a], List[a]] = {
xs.foldRight/*[Tuple2[List[a], List[a]]]*/((List(), List())) {
(x, p) => if (pred (x)) (x :: p._1, p._2) else (p._1, x :: p._2)
}
}

def snoc[A](xs: List[A], x: A) = x :: xs

def reverse[A](xs: List[A]) =
xs.foldLeft(Nil)(snoc)

def reverse2[A](xs: List[A]) =
xs.foldLeft(List())(snoc)

val ys: Seq[Int] = xs
ys.foldLeft(Seq())((ys, y) => y +: ys)
ys.foldLeft(Nil)((ys, y) => y +: ys)

def dup[A](xs: List[A]) =
xs.foldRight(Nil)((x, xs) => x :: x :: xs)

def toSet[A](xs: Seq[A]) =
xs.foldLeft(Set.empty)(_ + _)
Loading