Skip to content

Commit

Permalink
Add supports for type cast and filtering type for field and method ow…
Browse files Browse the repository at this point in the history
…ner in global initialization checker (#19612)

This PR adds support for more precise analysis of type casting in global
initialization checker, which filters impossible classes after reference
casting and preserves only possible classes. It also automatically
filter the set of possible classes during field selection or method
invocation and only preserves owners of the field/method
  • Loading branch information
olhotak authored Mar 8, 2024
2 parents 21cfb15 + 229062c commit 8642df1
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 8 deletions.
39 changes: 31 additions & 8 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.collection.immutable.ListSet
import scala.collection.mutable
import scala.annotation.tailrec
import scala.annotation.constructorOnly
import dotty.tools.dotc.core.Flags.AbstractOrTrait

/** Check initialization safety of static objects
*
Expand Down Expand Up @@ -203,6 +204,7 @@ object Objects:

/**
* Represents a lambda expression
* @param klass The enclosing class of the anonymous function's creation site
*/
case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data) extends ValueElement:
def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")"
Expand Down Expand Up @@ -599,6 +601,26 @@ object Objects:

case _ => a

def filterType(tpe: Type)(using Context): Value =
tpe match
case t @ SAMType(_, _) if a.isInstanceOf[Fun] => a // if tpe is SAMType and a is Fun, allow it
case _ =>
val baseClasses = tpe.baseClasses
if baseClasses.isEmpty then a
else filterClass(baseClasses.head) // could have called ClassSymbol, but it does not handle OrType and AndType

def filterClass(sym: Symbol)(using Context): Value =
if !sym.isClass then a
else
val klass = sym.asClass
a match
case Cold => Cold
case ref: Ref => if ref.klass.isSubClass(klass) then ref else Bottom
case ValueSet(values) => values.map(v => v.filterClass(klass)).join
case arr: OfArray => if defn.ArrayClass.isSubClass(klass) then arr else Bottom
case fun: Fun =>
if klass.isOneOf(AbstractOrTrait) && klass.baseClasses.exists(defn.isFunctionClass) then fun else Bottom

extension (value: Ref | Cold.type)
def widenRefOrCold(height : Int)(using Context) : Ref | Cold.type = value.widen(height).asInstanceOf[ThisValue]

Expand All @@ -617,7 +639,7 @@ object Objects:
* @param needResolve Whether the target of the call needs resolution?
*/
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
value match
value.filterClass(meth.owner) match
case Cold =>
report.warning("Using cold alias. " + Trace.show, Trace.position)
Bottom
Expand Down Expand Up @@ -733,7 +755,6 @@ object Objects:
* @param args Arguments of the constructor call (all parameter blocks flatten to a list).
*/
def callConstructor(value: Value, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("call " + ctor.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) {

value match
case ref: Ref =>
if ctor.hasSource then
Expand Down Expand Up @@ -768,7 +789,7 @@ object Objects:
* @param needResolve Whether the target of the selection needs resolution?
*/
def select(value: Value, field: Symbol, receiver: Type, needResolve: Boolean = true): Contextual[Value] = log("select " + field.show + ", this = " + value.show, printer, (_: Value).show) {
value match
value.filterClass(field.owner) match
case Cold =>
report.warning("Using cold alias", Trace.position)
Bottom
Expand Down Expand Up @@ -839,12 +860,12 @@ object Objects:
* @param rhsTyp The type of the right-hand side.
*/
def assign(lhs: Value, field: Symbol, rhs: Value, rhsTyp: Type): Contextual[Value] = log("Assign" + field.show + " of " + lhs.show + ", rhs = " + rhs.show, printer, (_: Value).show) {
lhs match
lhs.filterClass(field.owner) match
case fun: Fun =>
report.warning("[Internal error] unexpected tree in assignment, fun = " + fun.code.show + Trace.show, Trace.position)

case arr: OfArray =>
report.warning("[Internal error] unexpected tree in assignment, array = " + arr.show + Trace.show, Trace.position)
report.warning("[Internal error] unexpected tree in assignment, array = " + arr.show + " field = " + field + Trace.show, Trace.position)

case Cold =>
report.warning("Assigning to cold aliases is forbidden. " + Trace.show, Trace.position)
Expand Down Expand Up @@ -876,8 +897,7 @@ object Objects:
* @param args The arguments passsed to the constructor.
*/
def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
outer match

outer.filterClass(klass.owner) match
case _ : Fun | _: OfArray =>
report.warning("[Internal error] unexpected outer in instantiating a class, outer = " + outer.show + ", class = " + klass.show + ", " + Trace.show, Trace.position)
Bottom
Expand Down Expand Up @@ -1091,6 +1111,9 @@ object Objects:
instantiate(outer, cls, ctor, args)
}

case TypeCast(elem, tpe) =>
eval(elem, thisV, klass).filterType(tpe)

case Apply(ref, arg :: Nil) if ref.symbol == defn.InitRegionMethod =>
val regions2 = Regions.extend(expr.sourcePos)
if Regions.exists(expr.sourcePos) then
Expand Down Expand Up @@ -1549,7 +1572,7 @@ object Objects:
report.warning("The argument should be a constant integer value", arg)
res.widen(1)
case _ =>
res.widen(1)
if res.isInstanceOf[Fun] then res.widen(2) else res.widen(1)

argInfos += ArgInfo(widened, trace.add(arg.tree), arg.tree)
}
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/init/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ object Util:
case _ =>
None

object TypeCast:
def unapply(tree: Tree)(using Context): Option[(Tree, Type)] =
tree match
case TypeApply(Select(qual, _), typeArgs) if tree.symbol.isTypeCast =>
Some(qual, typeArgs.head.tpe)
case _ => None

def resolve(cls: ClassSymbol, sym: Symbol)(using Context): Symbol = log("resove " + cls + ", " + sym, printer, (_: Symbol).show):
if sym.isEffectivelyFinal then sym
else sym.matchingMember(cls.appliedRef)
Expand Down
18 changes: 18 additions & 0 deletions tests/init-global/neg/TypeCast.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
object A {
val f: Int = 10
def m() = f
}
object B {
val f: Int = g()
def g(): Int = f // error
}
object C {
val a: A.type | B.type = if ??? then A else B
def cast[T](a: Any): T = a.asInstanceOf[T]
val c: A.type = cast[A.type](a) // abstraction for c is {A, B}
val d = c.f // treat as c.asInstanceOf[owner of f].f
val e = c.m() // treat as c.asInstanceOf[owner of f].m()
val c2: B.type = cast[B.type](a)
val g = c2.f // no error here
}

9 changes: 9 additions & 0 deletions tests/init-global/pos/TypeCast1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class A:
class B(val b: Int)

object O:
val o: A | Array[Int] = new Array[Int](10)
o match
case a: A => new a.B(10)
case arr: Array[Int] => arr(5)

9 changes: 9 additions & 0 deletions tests/init-global/pos/TypeCast2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class A:
class B(val b: Int)

object O:
val o: A | (Int => Int) = (x: Int) => x + 1
o match
case a: A => new a.B(10)
case f: (_ => _) => f.asInstanceOf[Int => Int](5)

8 changes: 8 additions & 0 deletions tests/init-global/pos/TypeCast3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class A:
var x: Int = 10

object O:
val o: A | (Int => Int) = (x: Int) => x + 1
o match
case a: A => a.x = 20
case f: (_ => _) => f.asInstanceOf[Int => Int](5)
9 changes: 9 additions & 0 deletions tests/init-global/pos/TypeCast4.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class A:
var x: Int = 10

object O:
val o: A | Array[Int] = new Array[Int](10)
o match
case a: A => a.x = 20
case arr: Array[Int] => arr(5)

15 changes: 15 additions & 0 deletions tests/init-global/pos/i18882.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class A:
var a = 20

class B:
var b = 20

object O:
val o: A | B = new A
if o.isInstanceOf[A] then
o.asInstanceOf[A].a += 1
else
o.asInstanceOf[B].b += 1 // o.asInstanceOf[B] is treated as bottom
o match
case o: A => o.a += 1
case o: B => o.b += 1

0 comments on commit 8642df1

Please sign in to comment.