Skip to content

Commit

Permalink
Refactor some tuple methods (#19032)
Browse files Browse the repository at this point in the history
 - Move from Definitions to TypeUtils
- Unify TypeUtils.tupleElementTypes and Definitions.tupleTypes. They do
the same thing.

I'd like to move more things out of Definitions and into TypeUtils and
SymUtils. Then I'd like to move these files to the core package, and
make their operations accessible automatically by having the companion
objects of Types and Symbols inherit from them. This is a first step
into that direction.
  • Loading branch information
nicolasstucki authored Nov 24, 2023
2 parents 00e9e6b + 125321e commit 55c2002
Show file tree
Hide file tree
Showing 105 changed files with 144 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.StdNames.{nme, str}
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.transform.Erasure
import dotty.tools.dotc.transform.SymUtils.*
import dotty.tools.dotc.util.Spans.*
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Phases.*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.util.Spans.*
import dotty.tools.dotc.report
import dotty.tools.dotc.transform.SymUtils.*


/*
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Phases.*
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Phases.Phase
import dotty.tools.dotc.transform.SymUtils.*

import dotty.tools.dotc.core.StdNames
import dotty.tools.dotc.core.Phases

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/backend/jvm/CodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import dotty.tools.dotc.core.Phases.Phase

import scala.collection.mutable
import scala.jdk.CollectionConverters.*
import dotty.tools.dotc.transform.SymUtils.*

import dotty.tools.dotc.interfaces
import dotty.tools.dotc.report

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import scala.language.unsafeNulls

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Flags.*
import dotty.tools.dotc.transform.SymUtils.*

import java.io.{File => _}

import scala.reflect.ClassTag
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import StdNames.*
import TypeErasure.ErasedValueType

import dotty.tools.dotc.transform.{Erasure, ValueClasses}
import dotty.tools.dotc.transform.SymUtils.*

import dotty.tools.dotc.util.SourcePosition
import dotty.tools.dotc.report

Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/CompilationUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import ast.{tpd, untpd}
import tpd.{Tree, TreeTraverser}
import ast.Trees.{Import, Ident}
import typer.Nullables
import transform.SymUtils.*
import core.Decorators.*
import config.{SourceVersion, Feature}
import StdNames.nme
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package ast
import core.*
import util.Spans.*, Types.*, Contexts.*, Constants.*, Names.*, NameOps.*, Flags.*
import Symbols.*, StdNames.*, Trees.*, ContextOps.*
import Decorators.*, transform.SymUtils.*
import Decorators.*
import Annotations.Annotation
import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, DefaultGetterName, WildcardParamName}
import typer.{Namer, Checking}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Flags.*, Trees.*, Types.*, Contexts.*
import Names.*, StdNames.*, NameOps.*, Symbols.*
import typer.ConstFold
import reporting.trace
import dotty.tools.dotc.transform.SymUtils.*

import Decorators.*
import Constants.Constant
import scala.collection.mutable
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import core.*
import Types.*, Contexts.*, Flags.*
import Symbols.*, Annotations.*, Trees.*, Symbols.*, Constants.Constant
import Decorators.*
import dotty.tools.dotc.transform.SymUtils.*


/** A map that applies three functions and a substitution together to a tree and
* makes sure they are coordinated so that the result is well-typed. The functions are
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ package ast

import dotty.tools.dotc.transform.{ExplicitOuter, Erasure}
import typer.ProtoTypes
import transform.SymUtils.*
import transform.TypeUtils.*
import core.*
import Scopes.newScope
import util.Spans.*, Types.*, Contexts.*, Constants.*, Names.*, Flags.*, NameOps.*
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import typer.Checking.{checkBounds, checkAppliedTypesIn}
import typer.ErrorReporting.{Addenda, err}
import typer.ProtoTypes.{AnySelectionProto, LhsProto}
import util.{SimpleIdentitySet, EqHashMap, EqHashSet, SrcPos, Property}
import transform.SymUtils.*
import transform.{Recheck, PreRecheck, CapturedVars}
import Recheck.*
import scala.collection.mutable
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/config/JavaPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import classpath.AggregateClassPath
import core.*
import Symbols.*, Types.*, Contexts.*, StdNames.*
import Flags.*
import transform.ExplicitOuter, transform.SymUtils.*
import transform.ExplicitOuter

class JavaPlatform extends Platform {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import Flags.*
import config.Config
import config.Printers.typr
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import transform.TypeUtils.isTransparent
import UnificationDirection.*
import NameKinds.AvoidNameKind
import util.SimpleIdentitySet
Expand Down
20 changes: 0 additions & 20 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1749,26 +1749,6 @@ class Definitions {
else TypeOps.nestedPairs(elems)
}

def tupleTypes(tp: Type, bound: Int = Int.MaxValue)(using Context): Option[List[Type]] = {
@tailrec def rec(tp: Type, acc: List[Type], bound: Int): Option[List[Type]] = tp.normalized.dealias match {
case _ if bound < 0 => Some(acc.reverse)
case tp: AppliedType if PairClass == tp.classSymbol => rec(tp.args(1), tp.args.head :: acc, bound - 1)
case tp: AppliedType if isTupleNType(tp) => Some(acc.reverse ::: tp.args)
case tp: TermRef if tp.symbol == defn.EmptyTupleModule => Some(acc.reverse)
case _ => None
}
rec(tp.stripTypeVar, Nil, bound)
}

def isSmallGenericTuple(tp: Type)(using Context): Boolean =
if tp.derivesFrom(defn.PairClass) && !defn.isTupleNType(tp.widenDealias) then
// If this is a generic tuple we need to cast it to make the TupleN/ members accessible.
// This works only for generic tuples of known size up to 22.
defn.tupleTypes(tp.widenTermRefExpr) match
case Some(elems) if elems.length <= Definitions.MaxTupleArity => true
case _ => false
else false

def isProductSubType(tp: Type)(using Context): Boolean = tp.derivesFrom(ProductClass)

/** Is `tp` (an alias) of either a scala.FunctionN or a scala.ContextFunctionN
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import scala.util.control.NonFatal
import config.Config
import reporting.*
import collection.mutable
import transform.TypeUtils.*
import cc.{CapturingType, derivedCapturingType}

import scala.annotation.internal.sharable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package dotty.tools.dotc
package transform
package core

import core.*
import Types.*
Expand All @@ -11,18 +11,18 @@ import NameOps.*
import StdNames.*
import NameKinds.*
import Flags.*
import ValueClasses.isDerivedValueClass
import Decorators.*
import Constants.Constant
import Annotations.Annotation
import Phases.*
import ast.tpd.Literal
import transform.Mixin

import dotty.tools.dotc.transform.sjs.JSSymUtils.sjsNeedsField

import scala.annotation.tailrec

object SymUtils:
class SymUtils:

extension (self: Symbol)

Expand Down Expand Up @@ -79,6 +79,14 @@ object SymUtils:
self.is(Enum, butNot = Case) &&
self.info.parents.exists(p => p.typeSymbol == defn.JavaEnumClass)

def isDerivedValueClass(using Context): Boolean = self.isClass && {
val d = self.denot
!d.isRefinementClass &&
d.isValueClass &&
(d.initial.symbol ne defn.AnyValClass) && // Compare the initial symbol because AnyVal does not exist after erasure
!d.isPrimitiveValueClass
}

/** Is this a case class for which a product mirror is generated?
* Excluded are value classes, abstract classes and case classes with more than one
* parameter section.
Expand All @@ -100,7 +108,7 @@ object SymUtils:
if (!self.is(CaseClass)) "it is not a case class"
else if (self.is(Abstract)) "it is an abstract class"
else if (self.primaryConstructor.info.paramInfoss.length != 1) "it takes more than one parameter list"
else if (isDerivedValueClass(self)) "it is a value class"
else if self.isDerivedValueClass then "it is a value class"
else if (!(companionMirror || canAccessCtor)) s"the constructor of $self is inaccessible from the calling scope."
else ""
end whyNotGenericProduct
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import util.Spans.*
import DenotTransformers.*
import StdNames.*
import NameOps.*
import transform.SymUtils.*
import NameKinds.LazyImplicitName
import ast.tpd
import tpd.{Tree, TreeProvider, TreeOps}
Expand All @@ -36,7 +35,7 @@ import dotty.tools.dotc.classpath.FileUtils.isScalaBinary
import scala.compiletime.uninitialized
import dotty.tools.tasty.TastyVersion

object Symbols {
object Symbols extends SymUtils {

implicit def eqSymbol: CanEqual[Symbol, Symbol] = CanEqual.derived

Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import Names.*
import StdNames.nme
import Flags.{Module, Provisional}
import dotty.tools.dotc.config.Config
import dotty.tools.dotc.transform.TypeUtils.isErasedValueType

object TypeApplications {

Expand Down
2 changes: 0 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import TypeErasure.{erasedLub, erasedGlb}
import TypeApplications.*
import Variances.{Variance, variancesConform}
import Constants.Constant
import transform.TypeUtils.*
import transform.SymUtils.*
import scala.util.control.NonFatal
import typer.ProtoTypes.constrained
import typer.Applications.productSelectorTypes
Expand Down
12 changes: 5 additions & 7 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import TypeOps.makePackageObjPrefixExplicit
import backend.sjs.JSDefinitions
import transform.ExplicitOuter.*
import transform.ValueClasses.*
import transform.TypeUtils.*
import transform.ContextFunctionResults.*
import unpickleScala2.Scala2Erasure
import Decorators.*
Expand Down Expand Up @@ -72,7 +71,7 @@ end SourceLanguage
object TypeErasure {

private def erasureDependsOnArgs(sym: Symbol)(using Context) =
sym == defn.ArrayClass || sym == defn.PairClass || isDerivedValueClass(sym)
sym == defn.ArrayClass || sym == defn.PairClass || sym.isDerivedValueClass

/** The arity of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs.
*
Expand Down Expand Up @@ -126,7 +125,7 @@ object TypeErasure {
case tp: TypeRef =>
val sym = tp.symbol
sym.isClass &&
(!erasureDependsOnArgs(sym) || isDerivedValueClass(sym)) &&
(!erasureDependsOnArgs(sym) || sym.isDerivedValueClass) &&
!defn.specialErasure.contains(sym) &&
!defn.isSyntheticFunctionClass(sym)
case _: TermRef =>
Expand Down Expand Up @@ -404,7 +403,6 @@ object TypeErasure {
tp1 // After erasure, T | Nothing is just T and C | Null is just C, if C is a reference type.
else tp1 match {
case JavaArrayType(elem1) =>
import dotty.tools.dotc.transform.TypeUtils.*
tp2 match {
case JavaArrayType(elem2) =>
if (elem1.isPrimitiveValueType || elem2.isPrimitiveValueType)
Expand Down Expand Up @@ -632,15 +630,15 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
case tp: TypeRef =>
val sym = tp.symbol
if !sym.isClass then this(checkedSuperType(tp))
else if semiEraseVCs && isDerivedValueClass(sym) then eraseDerivedValueClass(tp)
else if semiEraseVCs && sym.isDerivedValueClass then eraseDerivedValueClass(tp)
else if defn.isSyntheticFunctionClass(sym) then defn.functionTypeErasure(sym)
else eraseNormalClassRef(tp)
case tp: AppliedType =>
val tycon = tp.tycon
if (tycon.isRef(defn.ArrayClass)) eraseArray(tp)
else if (tycon.isRef(defn.PairClass)) erasePair(tp)
else if (tp.isRepeatedParam) apply(tp.translateFromRepeated(toArray = sourceLanguage.isJava))
else if (semiEraseVCs && isDerivedValueClass(tycon.classSymbol)) eraseDerivedValueClass(tp)
else if (semiEraseVCs && tycon.classSymbol.isDerivedValueClass) eraseDerivedValueClass(tp)
else this(checkedSuperType(tp))
case tp: TermRef =>
this(underlyingOfTermRef(tp))
Expand Down Expand Up @@ -900,7 +898,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
if (!info.exists) assert(false, i"undefined: $tp with symbol $sym")
return sigName(info)
}
if (semiEraseVCs && isDerivedValueClass(sym)) {
if (semiEraseVCs && sym.isDerivedValueClass) {
val erasedVCRef = eraseDerivedValueClass(tp)
if (erasedVCRef.exists) return sigName(erasedVCRef)
}
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import ast.tpd.*
import reporting.trace
import config.Printers.typr
import config.Feature
import transform.SymUtils.*
import typer.ProtoTypes.*
import typer.ForceDegree
import typer.Inferencing.*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package dotty.tools
package dotc
package transform
package core

import core.*
import TypeErasure.ErasedValueType
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
import Names.Name

object TypeUtils {
class TypeUtils {
/** A decorator that provides methods on types
* that are needed in the transformer pipeline.
*/
Expand Down Expand Up @@ -45,22 +44,45 @@ object TypeUtils {
case ps => ps.reduceLeft(AndType(_, _))
}

/** The element types of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs */
def tupleElementTypes(using Context): Option[List[Type]] = self.dealias match {
case AppliedType(tycon, hd :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
tl.tupleElementTypes.map(hd :: _)
case self: SingletonType =>
if self.termSymbol == defn.EmptyTupleModule then Some(Nil) else None
case AndType(tp1, tp2) =>
// We assume that we have the following property:
// (T1, T2, ..., Tn) & (U1, U2, ..., Un) = (T1 & U1, T2 & U2, ..., Tn & Un)
tp1.tupleElementTypes.zip(tp2.tupleElementTypes).map { case (t1, t2) => t1.intersect(t2) }
case OrType(tp1, tp2) =>
None // We can't combine the type of two tuples
case _ =>
if defn.isTupleClass(self.typeSymbol) then Some(self.dealias.argInfos)
else None
}
/** The element types of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs
*/
def tupleElementTypes(using Context): Option[List[Type]] =
tupleElementTypesUpTo(Int.MaxValue)

/** The element types of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs
* @param bound The maximum number of elements that needs generating minus 1
* The generation will stop once more than bound elems have been generated
* @param normalize If true, normalize and dealias at each step.
* If false, never normalize and dealias only to find *:
* and EmptyTuple types. This is useful for printing.
*/
def tupleElementTypesUpTo(bound: Int, normalize: Boolean = true)(using Context): Option[List[Type]] =
def recur(tp: Type, bound: Int): Option[List[Type]] =
if bound < 0 then Some(Nil)
else (if normalize then tp.normalized else tp).dealias match
case AppliedType(tycon, hd :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
recur(tl, bound - 1).map(hd :: _)
case tp: AppliedType if defn.isTupleNType(tp) && normalize =>
Some(tp.args) // if normalize is set, use the dealiased tuple
// otherwise rely on the default case below to print unaliased tuples.
case tp: SingletonType =>
if tp.termSymbol == defn.EmptyTupleModule then Some(Nil) else None
case _ =>
if defn.isTupleClass(tp.typeSymbol) && !normalize then Some(tp.dealias.argInfos)
else None
recur(self.stripTypeVar, bound)

/** Is this a generic tuple that would fit into the range 1..22,
* but is not already an instance of one of Tuple1..22?
* In this case we need to cast it to make the TupleN/ members accessible.
* This works only for generic tuples of known size up to 22.
*/
def isSmallGenericTuple(using Context): Boolean =
self.derivesFrom(defn.PairClass)
&& !defn.isTupleNType(self.widenDealias)
&& self.widenTermRefExpr.tupleElementTypesUpTo(Definitions.MaxTupleArity).match
case Some(elems) if elems.length <= Definitions.MaxTupleArity => true
case _ => false

/** The `*:` equivalent of an instance of a Tuple class */
def toNestedPairs(using Context): Type =
Expand Down
Loading

0 comments on commit 55c2002

Please sign in to comment.