Skip to content

Commit

Permalink
Require array element types to be sealed
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Oct 31, 2023
1 parent 1065bd1 commit ef0e3ac
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 26 deletions.
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ extension (tp: Type)
case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot)
case _ => false

def isSealed(using Context): Boolean = tp match
case tp: TypeParamRef => tp.underlying.isSealed
case tp: TypeBounds => tp.hi.hasAnnotation(defn.Caps_SealedAnnot)
case tp: TypeRef => tp.symbol.is(Sealed) || tp.info.isSealed // TODO: drop symbol flag?
case _ => false

/** Drop @retains annotations everywhere */
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
val tm = new TypeMap:
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ object CaptureSet:
upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
if variance > 0 || isExact then upper
else if variance < 0 then CaptureSet.empty
else if ctx.mode.is(Mode.Printing) then upper
else assert(false, i"trying to add $upper from $r via ${tm.getClass} in a non-variant setting")

/** Apply `f` to each element in `xs`, and join result sets with `++` */
Expand Down
31 changes: 24 additions & 7 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ object CheckCaptures:
val check = new TypeTraverser:

extension (tparam: Symbol) def isParametricIn(carrier: Symbol): Boolean =
val encl = carrier.owner.enclosingMethodOrClass
val encl = carrier.maybeOwner.enclosingMethodOrClass
if encl.isClass then tparam.isParametricIn(encl)
else
def recur(encl: Symbol): Boolean =
Expand All @@ -160,11 +160,9 @@ object CheckCaptures:
def traverse(t: Type) =
t.dealiasKeepAnnots match
case t: TypeRef =>
capt.println(i"disallow $t, $tp, $what, ${t.symbol.is(Sealed)}")
capt.println(i"disallow $t, $tp, $what, ${t.isSealed}")
t.info match
case TypeBounds(_, hi)
if !t.symbol.is(Sealed) && !hi.hasAnnotation(defn.Caps_SealedAnnot)
&& !t.symbol.isParametricIn(carrier) =>
case TypeBounds(_, hi) if !t.isSealed && !t.symbol.isParametricIn(carrier) =>
if hi.isAny then
report.error(
em"""$what cannot $have $tp since
Expand Down Expand Up @@ -543,8 +541,8 @@ class CheckCaptures extends Recheck, SymTransformer:
val TypeApply(fn, args) = tree
val polyType = atPhase(thisPhase.prev):
fn.tpe.widen.asInstanceOf[TypeLambda]
for case (arg: TypeTree, pinfo, pname) <- args.lazyZip(polyType.paramInfos).lazyZip((polyType.paramNames)) do
if pinfo.bounds.hi.hasAnnotation(defn.Caps_SealedAnnot) then
for case (arg: TypeTree, formal, pname) <- args.lazyZip(polyType.paramRefs).lazyZip((polyType.paramNames)) do
if formal.isSealed then
def where = if fn.symbol.exists then i" in an argument of ${fn.symbol}" else ""
disallowRootCapabilitiesIn(arg.knownType, fn.symbol,
i"Sealed type variable $pname", "be instantiated to",
Expand Down Expand Up @@ -1315,6 +1313,23 @@ class CheckCaptures extends Recheck, SymTransformer:
traverseChildren(tp)
check.traverse(info)

def checkArraysAreSealedIn(tp: Type, pos: SrcPos)(using Context): Unit =
val check = new TypeTraverser:
def traverse(t: Type): Unit =
t match
case AppliedType(tycon, arg :: Nil) if tycon.typeSymbol == defn.ArrayClass =>
if !(pos.span.isSynthetic && ctx.reporter.errorsReported) then
CheckCaptures.disallowRootCapabilitiesIn(arg, NoSymbol,
"Array", "have element type",
"Since arrays are mutable, they have to be treated like variables,\nso their element type must be sealed.",
pos)
traverseChildren(t)
case defn.RefinedFunctionOf(rinfo: MethodType) =>
traverse(rinfo)
case _ =>
traverseChildren(t)
check.traverse(tp)

/** Perform the following kinds of checks
* - Check all explicitly written capturing types for well-formedness using `checkWellFormedPost`.
* - Check that arguments of TypeApplys and AppliedTypes conform to their bounds.
Expand All @@ -1340,6 +1355,8 @@ class CheckCaptures extends Recheck, SymTransformer:
case _ =>
case _: ValOrDefDef | _: TypeDef =>
checkNoLocalRootIn(tree.symbol, tree.symbol.info, tree.symbol.srcPos)
case tree: TypeTree =>
checkArraysAreSealedIn(tree.tpe, tree.srcPos)
case _ =>
end check
end checker
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,9 @@ abstract class Recheck extends Phase, SymTransformer:

/** Show tree with rechecked types instead of the types stored in the `.tpe` field */
override def show(tree: untpd.Tree)(using Context): String =
atPhase(thisPhase) {
super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree]))
}
atPhase(thisPhase):
withMode(Mode.Printing):
super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree]))
end Recheck

/** A class that can be used to test basic rechecking without any customaization */
Expand Down
26 changes: 26 additions & 0 deletions tests/neg-custom-args/captures/buffers.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- Error: tests/neg-custom-args/captures/buffers.scala:11:6 ------------------------------------------------------------
11 | var elems: Array[A] = new Array[A](10) // error // error
| ^
| mutable variable elems cannot have type Array[A] since
| that type refers to the type variable A, which is not sealed.
-- Error: tests/neg-custom-args/captures/buffers.scala:16:38 -----------------------------------------------------------
16 | def make[A: ClassTag](xs: A*) = new ArrayBuffer: // error
| ^^^^^^^^^^^
| Sealed type variable A cannot be instantiated to box A^? since
| that type refers to the type variable A, which is not sealed.
| This is often caused by a local capability in an argument of constructor ArrayBuffer
| leaking as part of its result.
-- Error: tests/neg-custom-args/captures/buffers.scala:11:13 -----------------------------------------------------------
11 | var elems: Array[A] = new Array[A](10) // error // error
| ^^^^^^^^
| Array cannot have element type A since
| that type refers to the type variable A, which is not sealed.
| Since arrays are mutable, they have to be treated like variables,
| so their element type must be sealed.
-- Error: tests/neg-custom-args/captures/buffers.scala:22:9 ------------------------------------------------------------
22 | val x: Array[A] = new Array[A](10) // error
| ^^^^^^^^
| Array cannot have element type A since
| that type refers to the type variable A, which is not sealed.
| Since arrays are mutable, they have to be treated like variables,
| so their element type must be sealed.
30 changes: 30 additions & 0 deletions tests/neg-custom-args/captures/buffers.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import reflect.ClassTag

class Buffer[A]

class ArrayBuffer[sealed A: ClassTag] extends Buffer[A]:
var elems: Array[A] = new Array[A](10)
def add(x: A): this.type = ???
def at(i: Int): A = ???

class ArrayBufferBAD[A: ClassTag] extends Buffer[A]:
var elems: Array[A] = new Array[A](10) // error // error
def add(x: A): this.type = ???
def at(i: Int): A = ???

object ArrayBuffer:
def make[A: ClassTag](xs: A*) = new ArrayBuffer: // error
elems = xs.toArray
def apply[sealed A: ClassTag](xs: A*) = new ArrayBuffer:
elems = xs.toArray // ok

class EncapsArray[A: ClassTag]:
val x: Array[A] = new Array[A](10) // error








19 changes: 10 additions & 9 deletions tests/pos-special/stdlib/collection/IterableOnce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ final class IterableOnceExtensionMethods[A](private val it: IterableOnce[A]) ext
def toBuffer[sealed B >: A]: mutable.Buffer[B] = mutable.ArrayBuffer.from(it)

@deprecated("Use .iterator.toArray", "2.13.0")
def toArray[B >: A: ClassTag]: Array[B] = it match {
def toArray[sealed B >: A: ClassTag]: Array[B] = it match {
case it: Iterable[B] => it.toArray[B]
case _ => it.iterator.toArray[B]
}
Expand Down Expand Up @@ -272,10 +272,11 @@ object IterableOnce {
math.max(math.min(math.min(len, srcLen), destLen - start), 0)

/** Calls `copyToArray` on the given collection, regardless of whether or not it is an `Iterable`. */
@inline private[collection] def copyElemsToArray[A, B >: A](elems: IterableOnce[A]^,
xs: Array[B],
start: Int = 0,
len: Int = Int.MaxValue): Int =
@inline private[collection] def copyElemsToArray[A, sealed B >: A](
elems: IterableOnce[A]^,
xs: Array[B],
start: Int = 0,
len: Int = Int.MaxValue): Int =
elems match {
case src: Iterable[A] => src.copyToArray[B](xs, start, len)
case src => src.iterator.copyToArray[B](xs, start, len)
Expand Down Expand Up @@ -889,7 +890,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
* @note Reuse: $consumesIterator
*/
@deprecatedOverriding("This should always forward to the 3-arg version of this method", since = "2.13.4")
def copyToArray[B >: A](xs: Array[B]): Int = copyToArray(xs, 0, Int.MaxValue)
def copyToArray[sealed B >: A](xs: Array[B]): Int = copyToArray(xs, 0, Int.MaxValue)

/** Copy elements to an array, returning the number of elements written.
*
Expand All @@ -906,7 +907,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
* @note Reuse: $consumesIterator
*/
@deprecatedOverriding("This should always forward to the 3-arg version of this method", since = "2.13.4")
def copyToArray[B >: A](xs: Array[B], start: Int): Int = copyToArray(xs, start, Int.MaxValue)
def copyToArray[sealed B >: A](xs: Array[B], start: Int): Int = copyToArray(xs, start, Int.MaxValue)

/** Copy elements to an array, returning the number of elements written.
*
Expand All @@ -923,7 +924,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
*
* @note Reuse: $consumesIterator
*/
def copyToArray[B >: A](xs: Array[B], start: Int, len: Int): Int = {
def copyToArray[sealed B >: A](xs: Array[B], start: Int, len: Int): Int = {
val it = iterator
var i = start
val end = start + math.min(len, xs.length - start)
Expand Down Expand Up @@ -1318,7 +1319,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
*
* Implementation note: DO NOT call [[Array.from]] from this method.
*/
def toArray[B >: A: ClassTag]: Array[B] =
def toArray[sealed B >: A: ClassTag]: Array[B] =
if (knownSize >= 0) {
val destination = new Array[B](knownSize)
copyToArray(destination, 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/pos-special/stdlib/collection/Iterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
}
// segment must have data, and must be complete unless they allow partial
val ok = index > 0 && (partial || index == size)
if (ok) buffer = builder.result().asInstanceOf[Array[B]]
if (ok) buffer = builder.result().asInstanceOf[Array[B @uncheckedCaptures]]
else prev = null
ok
}
Expand Down
3 changes: 2 additions & 1 deletion tests/pos-special/stdlib/collection/SeqView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package collection
import scala.annotation.nowarn
import language.experimental.captureChecking
import caps.unsafe.unsafeAssumePure
import scala.annotation.unchecked.uncheckedCaptures

/** !!! Scala 2 difference: Need intermediate trait SeqViewOps to collect the
* necessary functionality over which SeqViews are defined, and at the same
Expand Down Expand Up @@ -195,7 +196,7 @@ object SeqView {
// contains items of another type, we'd get a CCE anyway)
// - the cast doesn't actually do anything in the runtime because the
// type of A is not known and Array[_] is Array[AnyRef]
immutable.ArraySeq.unsafeWrapArray(arr.asInstanceOf[Array[A]])
immutable.ArraySeq.unsafeWrapArray(arr.asInstanceOf[Array[A @uncheckedCaptures]])
}
}
evaluated = true
Expand Down
2 changes: 1 addition & 1 deletion tests/pos-special/stdlib/collection/StringOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ final class StringOps(private val s: String) extends AnyVal {
else if (s.equalsIgnoreCase("false")) false
else throw new IllegalArgumentException("For input string: \""+s+"\"")

def toArray[B >: Char](implicit tag: ClassTag[B]): Array[B] =
def toArray[sealed B >: Char](implicit tag: ClassTag[B]): Array[B] =
if (tag == ClassTag.Char) s.toCharArray.asInstanceOf[Array[B]]
else new WrappedString(s).toArray[B]

Expand Down
4 changes: 2 additions & 2 deletions tests/pos-special/stdlib/collection/mutable/ArrayBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class ArrayBuffer[sealed A] private (initialElements: Array[AnyRef], initialSize
@nowarn("""cat=deprecation&origin=scala\.collection\.Iterable\.stringPrefix""")
override protected[this] def stringPrefix = "ArrayBuffer"

override def copyToArray[B >: A](xs: Array[B], start: Int, len: Int): Int = {
override def copyToArray[sealed B >: A](xs: Array[B], start: Int, len: Int): Int = {
val copied = IterableOnce.elemsToCopyToArray(length, xs.length, start, len)
if(copied > 0) {
Array.copy(array, 0, xs, start, copied)
Expand All @@ -258,7 +258,7 @@ class ArrayBuffer[sealed A] private (initialElements: Array[AnyRef], initialSize
override def sortInPlace[B >: A]()(implicit ord: Ordering[B]): this.type = {
if (length > 1) {
mutationCount += 1
scala.util.Sorting.stableSort(array.asInstanceOf[Array[B]], 0, length)
scala.util.Sorting.stableSort(array.asInstanceOf[Array[B @uncheckedCaptures]], 0, length)
}
this
}
Expand Down
3 changes: 2 additions & 1 deletion tests/pos-special/stdlib/collection/mutable/Buffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package mutable

import scala.annotation.nowarn
import language.experimental.captureChecking
import scala.annotation.unchecked.uncheckedCaptures


/** A `Buffer` is a growable and shrinkable `Seq`. */
Expand Down Expand Up @@ -185,7 +186,7 @@ trait IndexedBuffer[A] extends IndexedSeq[A]
// There's scope for a better implementation which copies elements in place.
var i = 0
val s = size
val newElems = new Array[IterableOnce[A]^](s)
val newElems = new Array[(IterableOnce[A]^) @uncheckedCaptures](s)
while (i < s) { newElems(i) = f(this(i)); i += 1 }
clear()
i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ final class StringBuilder(val underlying: java.lang.StringBuilder) extends Abstr

override def toString: String = result()

override def toArray[B >: Char](implicit ct: scala.reflect.ClassTag[B]) =
override def toArray[sealed B >: Char](implicit ct: scala.reflect.ClassTag[B]) =
ct.runtimeClass match {
case java.lang.Character.TYPE => toCharArray.asInstanceOf[Array[B]]
case _ => super.toArray
Expand Down

0 comments on commit ef0e3ac

Please sign in to comment.