From 767dc68237080d84346ccf21f06d3ac393b5c14b Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Tue, 26 Sep 2023 18:23:40 +0200 Subject: [PATCH 01/13] Add callback object to Context capable of recording progress [Cherry-picked 721888d95f599dbc6aa38a9fe4c72ae545ebc194] --- .../src/dotty/tools/dotc/core/Contexts.scala | 19 +++++++++- .../dotc/sbt/interfaces/ProgressCallback.java | 21 +++++++++++ .../src/dotty/tools/xsbt/CompilerBridge.java | 2 +- .../tools/xsbt/CompilerBridgeDriver.java | 9 ++++- .../tools/xsbt/ProgressCallbackImpl.java | 37 +++++++++++++++++++ .../xsbt/ScalaCompilerForUnitTesting.scala | 3 +- .../test/xsbt/TestCompileProgress.scala | 5 +++ 7 files changed, 90 insertions(+), 6 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java create mode 100644 sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java create mode 100644 sbt-bridge/test/xsbt/TestCompileProgress.scala diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 8a7f2ff4e051..3404efebf215 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -34,7 +34,7 @@ import scala.annotation.internal.sharable import DenotTransformers.DenotTransformer import dotty.tools.dotc.profile.Profiler -import dotty.tools.dotc.sbt.interfaces.IncrementalCallback +import dotty.tools.dotc.sbt.interfaces.{IncrementalCallback, ProgressCallback} import util.Property.Key import util.Store import plugins._ @@ -53,8 +53,9 @@ object Contexts { private val (notNullInfosLoc, store8) = store7.newLocation[List[NotNullInfo]]() private val (importInfoLoc, store9) = store8.newLocation[ImportInfo | Null]() private val (typeAssignerLoc, store10) = store9.newLocation[TypeAssigner](TypeAssigner) + private val (progressCallbackLoc, store11) = store10.newLocation[ProgressCallback | Null]() - private val initialStore = store10 + private val initialStore = store11 /** The current context */ inline def ctx(using ctx: Context): Context = ctx @@ -177,6 +178,19 @@ object Contexts { val local = incCallback local != null && local.enabled || forceRun + /** The Zinc compile progress callback implementation if we are run from Zinc, null otherwise */ + def progressCallback: ProgressCallback | Null = store(progressCallbackLoc) + + /** Run `op` if there exists a Zinc progress callback */ + inline def withProgressCallback(inline op: ProgressCallback => Unit): Unit = + val local = progressCallback + if local != null then op(local) + + def cancelSignalRecorded: Boolean = + val local = progressCallback + val noSignalRecieved = local == null || !local.isCancelled + !noSignalRecieved // if true then cancel request was recorded + /** The current plain printer */ def printerFn: Context => Printer = store(printerFnLoc) @@ -675,6 +689,7 @@ object Contexts { def setCompilerCallback(callback: CompilerCallback): this.type = updateStore(compilerCallbackLoc, callback) def setIncCallback(callback: IncrementalCallback): this.type = updateStore(incCallbackLoc, callback) + def setProgressCallback(callback: ProgressCallback): this.type = updateStore(progressCallbackLoc, callback) def setPrinterFn(printer: Context => Printer): this.type = updateStore(printerFnLoc, printer) def setSettings(settingsState: SettingsState): this.type = updateStore(settingsStateLoc, settingsState) def setRun(run: Run | Null): this.type = updateStore(runLoc, run) diff --git a/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java new file mode 100644 index 000000000000..8f81ea5f99a2 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java @@ -0,0 +1,21 @@ +package dotty.tools.dotc.sbt.interfaces; + +import dotty.tools.dotc.CompilationUnit; + +public interface ProgressCallback { + /** Record that the cancellation signal has been recieved during the Zinc run. */ + default void cancel() {} + + /** Report on if there was a cancellation signal for the current Zinc run. */ + default boolean isCancelled() { return false; } + + /** Record that a unit has started compiling in the given phase. */ + default void informUnitStarting(String phase, CompilationUnit unit) {} + + /** Record the current compilation progress. + * @param current `completedPhaseCount * totalUnits + completedUnitsInCurrPhase + completedLate` + * @param total `totalPhases * totalUnits + totalLate` + * @return true if the compilation should continue (if false, then subsequent calls to `isCancelled()` will return true) + */ + default boolean progress(int current, int total, String currPhase, String nextPhase) { return true; } +} diff --git a/sbt-bridge/src/dotty/tools/xsbt/CompilerBridge.java b/sbt-bridge/src/dotty/tools/xsbt/CompilerBridge.java index 92b8062700c4..6e2095a9df1e 100644 --- a/sbt-bridge/src/dotty/tools/xsbt/CompilerBridge.java +++ b/sbt-bridge/src/dotty/tools/xsbt/CompilerBridge.java @@ -19,6 +19,6 @@ public final class CompilerBridge implements CompilerInterface2 { public void run(VirtualFile[] sources, DependencyChanges changes, String[] options, Output output, AnalysisCallback callback, Reporter delegate, CompileProgress progress, Logger log) { CompilerBridgeDriver driver = new CompilerBridgeDriver(options, output); - driver.run(sources, callback, log, delegate); + driver.run(sources, callback, log, delegate, progress); } } diff --git a/sbt-bridge/src/dotty/tools/xsbt/CompilerBridgeDriver.java b/sbt-bridge/src/dotty/tools/xsbt/CompilerBridgeDriver.java index c5c2e0adaef4..2d54d4e83404 100644 --- a/sbt-bridge/src/dotty/tools/xsbt/CompilerBridgeDriver.java +++ b/sbt-bridge/src/dotty/tools/xsbt/CompilerBridgeDriver.java @@ -21,6 +21,7 @@ import xsbti.Problem; import xsbti.*; import xsbti.compile.Output; +import xsbti.compile.CompileProgress; import java.io.IOException; import java.io.InputStream; @@ -82,7 +83,8 @@ private static void reportMissingFile(DelegatingReporter reporter, SourceFile so reporter.reportBasicWarning(message); } - synchronized public void run(VirtualFile[] sources, AnalysisCallback callback, Logger log, Reporter delegate) { + synchronized public void run( + VirtualFile[] sources, AnalysisCallback callback, Logger log, Reporter delegate, CompileProgress progress) { VirtualFile[] sortedSources = new VirtualFile[sources.length]; System.arraycopy(sources, 0, sortedSources, 0, sources.length); Arrays.sort(sortedSources, (x0, x1) -> x0.id().compareTo(x1.id())); @@ -111,6 +113,8 @@ synchronized public void run(VirtualFile[] sources, AnalysisCallback callback, L return sourceFile.path(); }); + ProgressCallbackImpl progressCallback = new ProgressCallbackImpl(progress); + IncrementalCallback incCallback = new IncrementalCallback(callback, sourceFile -> asVirtualFile(sourceFile, reporter, lookup) ); @@ -121,7 +125,8 @@ synchronized public void run(VirtualFile[] sources, AnalysisCallback callback, L Contexts.Context initialCtx = initCtx() .fresh() .setReporter(reporter) - .setIncCallback(incCallback); + .setIncCallback(incCallback) + .setProgressCallback(progressCallback); Contexts.Context context = setup(args, initialCtx).map(t -> t._2).getOrElse(() -> initialCtx); diff --git a/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java new file mode 100644 index 000000000000..420f4c02a03b --- /dev/null +++ b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java @@ -0,0 +1,37 @@ +package dotty.tools.xsbt; + +import dotty.tools.dotc.sbt.interfaces.ProgressCallback; +import dotty.tools.dotc.CompilationUnit; + +import xsbti.compile.CompileProgress; + +public final class ProgressCallbackImpl implements ProgressCallback { + private boolean _cancelled = false; + private final CompileProgress _progress; + + public ProgressCallbackImpl(CompileProgress progress) { + _progress = progress; + } + + @Override + public void cancel() { + _cancelled = true; + } + + @Override + public boolean isCancelled() { + return _cancelled; + } + + @Override + public void informUnitStarting(String phase, CompilationUnit unit) { + _progress.startUnit(phase, unit.source().file().path()); + } + + @Override + public boolean progress(int current, int total, String currPhase, String nextPhase) { + boolean shouldAdvance = _progress.advance(current, total, currPhase, nextPhase); + if (!shouldAdvance) cancel(); + return shouldAdvance; + } +} diff --git a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala index 51f10e90f932..5726bf1b7c44 100644 --- a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala +++ b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala @@ -124,6 +124,7 @@ class ScalaCompilerForUnitTesting { def compileSrcs(groupedSrcs: List[List[String]]): (Seq[VirtualFile], TestCallback) = { val temp = IO.createTemporaryDirectory val analysisCallback = new TestCallback + val testProgress = new TestCompileProgress val classesDir = new File(temp, "classes") classesDir.mkdir() @@ -148,7 +149,7 @@ class ScalaCompilerForUnitTesting { output, analysisCallback, new TestReporter, - new CompileProgress {}, + testProgress, new TestLogger ) diff --git a/sbt-bridge/test/xsbt/TestCompileProgress.scala b/sbt-bridge/test/xsbt/TestCompileProgress.scala new file mode 100644 index 000000000000..6e385a2ed4bd --- /dev/null +++ b/sbt-bridge/test/xsbt/TestCompileProgress.scala @@ -0,0 +1,5 @@ +package xsbt + +import xsbti.compile.CompileProgress + +class TestCompileProgress extends CompileProgress From df2e033bfdf79ee234015a797ac0d413f515cc7c Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 27 Sep 2023 17:40:17 +0200 Subject: [PATCH 02/13] Record progress in the current run Test that the callbacks are called with expected values [Cherry-picked 2b7a09e90501299537a439945d749b8e6ccabf70] --- compiler/src/dotty/tools/dotc/Run.scala | 131 ++++++++++- .../src/dotty/tools/dotc/core/Phases.scala | 2 + .../tools/dotc/fromtasty/ReadTasty.scala | 7 +- .../tools/dotc/parsing/ParserPhase.scala | 5 +- .../tools/dotc/transform/init/Checker.scala | 10 +- .../src/dotty/tools/dotc/typer/Namer.scala | 24 +- .../dotty/tools/dotc/typer/TyperPhase.scala | 12 +- compiler/test/dotty/tools/DottyTest.scala | 16 +- .../tools/dotc/sbt/ProgressCallbackTest.scala | 206 ++++++++++++++++++ .../tools/xsbt/ProgressCallbackImpl.java | 2 +- .../tasty/inspector/TastyInspector.scala | 1 + .../scala/quoted/staging/QuoteCompiler.scala | 1 + .../tasty/inspector/TastyInspector.scala | 1 + 13 files changed, 388 insertions(+), 30 deletions(-) create mode 100644 compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 3e7bba86dcf4..632baadcc8cc 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -12,7 +12,9 @@ import typer.Typer import typer.ImportInfo.withRootImports import Decorators._ import io.AbstractFile -import Phases.unfusedPhases +import Phases.{unfusedPhases, Phase} + +import sbt.interfaces.ProgressCallback import util._ import reporting.{Suppression, Action, Profile, ActiveProfile, NoProfile} @@ -32,6 +34,9 @@ import scala.collection.mutable import scala.util.control.NonFatal import scala.io.Codec +import Run.Progress +import scala.compiletime.uninitialized + /** A compiler run. Exports various methods to compile source files */ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with ConstraintRunInfo { @@ -155,7 +160,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint } /** The source files of all late entered symbols, as a set */ - private var lateFiles = mutable.Set[AbstractFile]() + private val lateFiles = mutable.Set[AbstractFile]() /** A cache for static references to packages and classes */ val staticRefs = util.EqHashMap[Name, Denotation](initialCapacity = 1024) @@ -163,6 +168,43 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint /** Actions that need to be performed at the end of the current compilation run */ private var finalizeActions = mutable.ListBuffer[() => Unit]() + private var _progress: Progress | Null = null // Set if progress reporting is enabled + + /** Only safe to call if progress is being tracked. */ + private inline def trackProgress(using Context)(inline op: Context ?=> Progress => Unit): Unit = + val local = _progress + if local != null then + op(using ctx)(local) + + def doBeginUnit(unit: CompilationUnit)(using Context): Unit = + trackProgress: progress => + progress.informUnitStarting(unit) + + def doAdvanceUnit()(using Context): Unit = + trackProgress: progress => + progress.unitc += 1 // trace that we completed a unit in the current phase + progress.refreshProgress() + + def doAdvanceLate()(using Context): Unit = + trackProgress: progress => + progress.latec += 1 // trace that we completed a late compilation + progress.refreshProgress() + + private def doEnterPhase(currentPhase: Phase)(using Context): Unit = + trackProgress: progress => + progress.enterPhase(currentPhase) + + private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit = + trackProgress: progress => + progress.unitc = 0 // reset unit count in current phase + progress.seen += 1 // trace that we've seen a phase + if wasRan then + // add an extra traversal now that we completed a phase + progress.traversalc += 1 + else + // no phase was ran, remove a traversal from expected total + progress.runnablePhases -= 1 + /** Will be set to true if any of the compiled compilation units contains * a pureFunctions language import. */ @@ -233,13 +275,15 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if ctx.settings.YnoDoubleBindings.value then ctx.base.checkNoDoubleBindings = true - def runPhases(using Context) = { + def runPhases(allPhases: Array[Phase])(using Context) = { var lastPrintedTree: PrintedTree = NoPrintedTree val profiler = ctx.profiler var phasesWereAdjusted = false - for (phase <- ctx.base.allPhases) - if (phase.isRunnable) + for phase <- allPhases do + doEnterPhase(phase) + val phaseWillRun = phase.isRunnable + if phaseWillRun then Stats.trackTime(s"phase time ms/$phase") { val start = System.currentTimeMillis val profileBefore = profiler.beforePhase(phase) @@ -260,14 +304,21 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if !Feature.ccEnabledSomewhere then ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase.prev) ctx.base.unlinkPhaseAsDenotTransformer(Phases.checkCapturesPhase) - + end if + end if + end if + doAdvancePhase(phase, wasRan = phaseWillRun) + end for profiler.finished() } val runCtx = ctx.fresh runCtx.setProfiler(Profiler()) unfusedPhases.foreach(_.initContext(runCtx)) - runPhases(using runCtx) + val fusedPhases = runCtx.base.allPhases + runCtx.withProgressCallback: cb => + _progress = Progress(cb, this, fusedPhases.length) + runPhases(allPhases = fusedPhases)(using runCtx) if (!ctx.reporter.hasErrors) Rewrites.writeBack() suppressions.runFinished(hasErrors = ctx.reporter.hasErrors) @@ -293,10 +344,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint .withRootImports def process()(using Context) = - ctx.typer.lateEnterUnit(doTypeCheck => - if typeCheck then - if compiling then finalizeActions += doTypeCheck - else doTypeCheck() + ctx.typer.lateEnterUnit(typeCheck)(doTypeCheck => + if compiling then finalizeActions += doTypeCheck + else doTypeCheck() ) process()(using unitCtx) @@ -399,7 +449,66 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint } object Run { + + /**Computes the next MegaPhase for the given phase.*/ + def nextMegaPhase(phase: Phase)(using Context): Phase = phase.megaPhase.next.megaPhase + + private class Progress(cb: ProgressCallback, private val run: Run, val initialPhases: Int): + private[Run] var runnablePhases: Int = initialPhases // track how many phases we expect to run + private[Run] var unitc: Int = 0 // current unit count in the current phase + private[Run] var latec: Int = 0 // current late unit count + private[Run] var traversalc: Int = 0 // completed traversals over all files + private[Run] var seen: Int = 0 // how many phases we've seen so far + + private var currPhase: Phase = uninitialized // initialized by enterPhase + private var currPhaseName: String = uninitialized // initialized by enterPhase + private var nextPhaseName: String = uninitialized // initialized by enterPhase + + private def phaseNameFor(phase: Phase): String = + if phase.exists then phase.phaseName + else "" + + private[Run] def enterPhase(newPhase: Phase)(using Context): Unit = + if newPhase ne currPhase then + currPhase = newPhase + currPhaseName = phaseNameFor(newPhase) + nextPhaseName = phaseNameFor(Run.nextMegaPhase(newPhase)) + if seen > 0 then + refreshProgress() + + + /** Counts the number of completed full traversals over files, plus the number of units in the current phase */ + private def currentProgress()(using Context): Int = + traversalc * run.files.size + unitc + latec + + /**Total progress is computed as the sum of + * - the number of traversals we expect to make over all files + * - the number of late compilations + */ + private def totalProgress()(using Context): Int = + runnablePhases * run.files.size + run.lateFiles.size + + private def requireInitialized(): Unit = + require((currPhase: Phase | Null) != null, "enterPhase was not called") + + private[Run] def informUnitStarting(unit: CompilationUnit)(using Context): Unit = + requireInitialized() + cb.informUnitStarting(currPhaseName, unit) + + private[Run] def refreshProgress()(using Context): Unit = + requireInitialized() + cb.progress(currentProgress(), totalProgress(), currPhaseName, nextPhaseName) + extension (run: Run | Null) + def beginUnit(unit: CompilationUnit)(using Context): Unit = + if run != null then run.doBeginUnit(unit) + + def advanceUnit()(using Context): Unit = + if run != null then run.doAdvanceUnit() + + def advanceLate()(using Context): Unit = + if run != null then run.doAdvanceLate() + def enrichedErrorMessage: Boolean = if run == null then false else run.myEnrichedErrorMessage def enrichErrorMessage(errorMessage: String)(using Context): String = if run == null then diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 2a3828004525..e1a99f84debc 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -324,10 +324,12 @@ object Phases { def runOn(units: List[CompilationUnit])(using runCtx: Context): List[CompilationUnit] = units.map { unit => given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports + ctx.run.beginUnit(unit) try run catch case ex: Throwable if !ctx.run.enrichedErrorMessage => println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) throw ex + finally ctx.run.advanceUnit() unitCtx.compilationUnit } diff --git a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala index 86ae99b3e0f9..8a0d5f68410a 100644 --- a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala +++ b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala @@ -22,7 +22,12 @@ class ReadTasty extends Phase { ctx.settings.fromTasty.value override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = - withMode(Mode.ReadPositions)(units.flatMap(readTASTY(_))) + withMode(Mode.ReadPositions)(units.flatMap(applyPhase(_))) + + private def applyPhase(unit: CompilationUnit)(using Context): Option[CompilationUnit] = + ctx.run.beginUnit(unit) + try readTASTY(unit) + finally ctx.run.advanceUnit() def readTASTY(unit: CompilationUnit)(using Context): Option[CompilationUnit] = unit match { case unit: TASTYCompilationUnit => diff --git a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala index 7caff4996b85..a7360da6ada9 100644 --- a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala +++ b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala @@ -43,10 +43,13 @@ class Parser extends Phase { override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = { val unitContexts = for unit <- units yield + ctx.run.beginUnit(unit) report.inform(s"parsing ${unit.source}") ctx.fresh.setCompilationUnit(unit).withRootImports - unitContexts.foreach(parse(using _)) + for given Context <- unitContexts do + try parse + finally ctx.run.advanceUnit() record("parsedTrees", ast.Trees.ntrees) unitContexts.map(_.compilationUnit) diff --git a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala index 1efb3c88149e..a6966ee84a2d 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala @@ -18,6 +18,7 @@ import Phases._ import scala.collection.mutable import Semantic._ +import dotty.tools.unsupported class Checker extends Phase: @@ -33,16 +34,17 @@ class Checker extends Phase: override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = val checkCtx = ctx.fresh.setPhase(this.start) val traverser = new InitTreeTraverser() - units.foreach { unit => traverser.traverse(unit.tpdTree) } + for unit <- units do + checkCtx.run.beginUnit(unit) + try traverser.traverse(unit.tpdTree) + finally ctx.run.advanceUnit() val classes = traverser.getClasses() Semantic.checkClasses(classes)(using checkCtx) units - def run(using Context): Unit = - // ignore, we already called `Semantic.check()` in `runOn` - () + def run(using Context): Unit = unsupported("run") class InitTreeTraverser extends TreeTraverser: private val classes: mutable.ArrayBuffer[ClassSymbol] = new mutable.ArrayBuffer diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 36ffbd2e64a4..cbc2796c3895 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -722,20 +722,27 @@ class Namer { typer: Typer => * Will call the callback with an implementation of type checking * That will set the tpdTree and root tree for the compilation unit. */ - def lateEnterUnit(typeCheckCB: (() => Unit) => Unit)(using Context) = + def lateEnterUnit(typeCheck: Boolean)(typeCheckCB: (() => Unit) => Unit)(using Context) = val unit = ctx.compilationUnit /** Index symbols in unit.untpdTree with lateCompile flag = true */ def lateEnter()(using Context): Context = val saved = lateCompile lateCompile = true - try index(unit.untpdTree :: Nil) finally lateCompile = saved + try + index(unit.untpdTree :: Nil) + finally + lateCompile = saved + if !typeCheck then ctx.run.advanceLate() /** Set the tpdTree and root tree of the compilation unit */ def lateTypeCheck()(using Context) = - unit.tpdTree = typer.typedExpr(unit.untpdTree) - val phase = new transform.SetRootTree() - phase.run + try + unit.tpdTree = typer.typedExpr(unit.untpdTree) + val phase = new transform.SetRootTree() + phase.run + finally + if typeCheck then ctx.run.advanceLate() unit.untpdTree = if (unit.isJava) new JavaParser(unit.source).parse() @@ -746,9 +753,10 @@ class Namer { typer: Typer => // inline body annotations are set in namer, capturing the current context // we need to prepare the context for inlining. lateEnter() - typeCheckCB { () => - lateTypeCheck() - } + if typeCheck then + typeCheckCB { () => + lateTypeCheck() + } } } end lateEnterUnit diff --git a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala index f0218413d6ab..f85ebcb3b1e2 100644 --- a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala +++ b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala @@ -63,13 +63,15 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { for unit <- units yield val newCtx0 = ctx.fresh.setPhase(this.start).setCompilationUnit(unit) val newCtx = PrepareInlineable.initContext(newCtx0) + newCtx.run.beginUnit(unit) report.inform(s"typing ${unit.source}") if (addRootImports) newCtx.withRootImports else newCtx - unitContexts.foreach(enterSyms(using _)) + for given Context <- unitContexts do + enterSyms ctx.base.parserPhase match { case p: ParserPhase => @@ -81,9 +83,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { case _ => } - unitContexts.foreach(typeCheck(using _)) + for given Context <- unitContexts do + typeCheck + record("total trees after typer", ast.Trees.ntrees) - unitContexts.foreach(javaCheck(using _)) // after typechecking to avoid cycles + for given Context <- unitContexts do + try javaCheck // after typechecking to avoid cycles + finally ctx.run.advanceUnit() val newUnits = unitContexts.map(_.compilationUnit).filterNot(discardAfterTyper) ctx.run.nn.checkSuspendedUnits(newUnits) diff --git a/compiler/test/dotty/tools/DottyTest.scala b/compiler/test/dotty/tools/DottyTest.scala index 54cf0e0c177c..229891806a76 100644 --- a/compiler/test/dotty/tools/DottyTest.scala +++ b/compiler/test/dotty/tools/DottyTest.scala @@ -44,9 +44,14 @@ trait DottyTest extends ContextEscapeDetection { fc.setProperty(ContextDoc, new ContextDocstrings) } + protected def defaultCompiler: Compiler = Compiler() + private def compilerWithChecker(phase: String)(assertion: (tpd.Tree, Context) => Unit) = new Compiler { + + private val baseCompiler = defaultCompiler + override def phases = { - val allPhases = super.phases + val allPhases = baseCompiler.phases val targetPhase = allPhases.flatten.find(p => p.phaseName == phase).get val groupsBefore = allPhases.takeWhile(x => !x.contains(targetPhase)) val lastGroup = allPhases.find(x => x.contains(targetPhase)).get.takeWhile(x => !(x eq targetPhase)) @@ -67,6 +72,15 @@ trait DottyTest extends ContextEscapeDetection { run.runContext } + def checkAfterCompile(checkAfterPhase: String, sources: List[String])(assertion: Context => Unit): Context = { + val c = defaultCompiler + val run = c.newRun + run.compileFromStrings(sources) + val rctx = run.runContext + assertion(rctx) + rctx + } + def checkTypes(source: String, typeStrings: String*)(assertion: (List[Type], Context) => Unit): Unit = checkTypes(source, List(typeStrings.toList)) { (tpess, ctx) => (tpess: @unchecked) match { case List(tpes) => assertion(tpes, ctx) diff --git a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala new file mode 100644 index 000000000000..b7845d47b6f9 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala @@ -0,0 +1,206 @@ +package dotty.tools.dotc.sbt + +import dotty.tools.DottyTest +import dotty.tools.dotc.core.Contexts.FreshContext +import dotty.tools.dotc.sbt.ProgressCallbackTest.* + +import org.junit.Assert.* +import org.junit.Test + +import dotty.tools.toOption +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Contexts.ctx +import dotty.tools.dotc.CompilationUnit +import dotty.tools.dotc.Compiler +import dotty.tools.dotc.Run +import dotty.tools.dotc.core.Phases.Phase +import dotty.tools.io.VirtualDirectory + +final class ProgressCallbackTest extends DottyTest: + + @Test + def testCallback: Unit = + val source1 = """class Foo""" + val source2 = """class Bar""" + + inspectProgress(List(source1, source2), terminalPhase = None): progressCallback => + // (1) assert that the way we compute next phase in `Run.doAdvancePhase` is correct + assertNextPhaseIsNext() + + // (1) given correct computation, check that the recorded progression is monotonic + assertMonotonicProgression(progressCallback) + + // (1) given monotonic progression, check that the recorded progression has full coverage + assertFullCoverage(progressCallback) + + // (2) next check that for each unit, we record the expected phases that it should progress through + assertExpectedPhases(progressCallback) + + // (2) therefore we can now cross-reference the recorded progression with the recorded phases per unit + assertTotalUnits(progressCallback) + + // (3) finally, check that the callback was not cancelled + assertFalse(progressCallback.isCancelled) + end testCallback + + // TODO: test lateCompile, test cancellation + + /** Assert that the computed `next` phase matches the real next phase */ + def assertNextPhaseIsNext()(using Context): Unit = + val allPhases = ctx.base.allPhases + for case Array(p1, p2) <- allPhases.sliding(2) do + val p1Next = Run.nextMegaPhase(p1) // used to compute the next phase in `Run.doAdvancePhase` + assertEquals(p1Next.phaseName, p2.phaseName) + + /** Assert that the recorded progression of phases are all in the real progression, and that order is preserved */ + def assertMonotonicProgression(progressCallback: TestProgressCallback)(using Context): Unit = + val allPhasePlan = ctx.base.allPhases + for case List( + PhaseTransition(curr1, next1), + PhaseTransition(curr2, next2) + ) <- progressCallback.progressPhasesFinal.sliding(2) do + val curr1Index = indexFor(allPhasePlan, curr1) + val curr2Index = indexFor(allPhasePlan, curr2) + val next1Index = indexFor(allPhasePlan, next1) + val next2Index = indexFor(allPhasePlan, next2) + assertTrue(s"Phase $curr1 comes before $curr2", curr1Index < curr2Index) + assertTrue(s"Phase $next1 comes before $next2", next1Index < next2Index) + assertTrue(s"Phase $curr1 comes before $next1", curr1Index < next1Index) + assertTrue(s"Phase $curr2 comes before $next2", curr2Index < next2Index) + assertTrue(s"Predicted next phase $next1 was next current $curr2", next1Index == curr2Index) + + /** Assert that the recorded progression of phases contains every phase in the plan */ + def assertFullCoverage(progressCallback: TestProgressCallback)(using Context): Unit = + val (allPhasePlan, expectedCurrPhases, expectedNextPhases) = + val allPhases = ctx.base.allPhases.map(_.phaseName) + val firstPhase = allPhases.head + val expectedCurrPhases = allPhases.toSet + val expectedNextPhases = expectedCurrPhases - firstPhase ++ syntheticNextPhases + (allPhases.toList, expectedCurrPhases, expectedNextPhases) + + for (expectedCurr, recordedCurr) <- allPhasePlan.zip(progressCallback.progressPhasesFinal.map(_.curr)) do + assertEquals(s"Phase $recordedCurr was not expected", expectedCurr, recordedCurr) + + val (seenCurrPhases, seenNextPhases) = + val (currs0, nexts0) = progressCallback.progressPhasesFinal.unzip(Tuple.fromProductTyped) + (currs0.toSet, nexts0.toSet) + + val missingCurrPhases = expectedCurrPhases.diff(seenCurrPhases) + val extraCurrPhases = seenCurrPhases.diff(expectedCurrPhases) + assertTrue(s"these phases were not visited ${missingCurrPhases}", missingCurrPhases.isEmpty) + assertTrue(s"these phases were visited, but not in the real plan ${extraCurrPhases}", extraCurrPhases.isEmpty) + + val missingNextPhases = expectedNextPhases.diff(seenNextPhases) + val extraNextPhases = seenNextPhases.diff(expectedNextPhases) + assertTrue(s"these phases were not planned to visit, but were expected ${missingNextPhases}", missingNextPhases.isEmpty) + assertTrue(s"these phases were planned to visit, but were not in the real plan ${extraNextPhases}", extraNextPhases.isEmpty) + + + /** Assert that the phases recorded per unit match the actual phases ran on them */ + def assertExpectedPhases(progressCallback: TestProgressCallback)(using Context): Unit = + val expectedPhases = runnablePhases() + for (_, visitedPhases) <- progressCallback.unitPhases do + val uniquePhases = visitedPhases.toSet + assertEquals("some phases were visited twice!", visitedPhases.size, uniquePhases.size) + val unvisitedPhases = expectedPhases.filterNot(visitedPhases.contains) + val extraPhases = visitedPhases.filterNot(expectedPhases.contains) + assertTrue(s"these phases were not visited ${unvisitedPhases}", unvisitedPhases.isEmpty) + assertTrue(s"these phases were visited, but not expected ${extraPhases}", extraPhases.isEmpty) + + /** Assert that the number of total units of work matches the number of files * the runnable phases */ + def assertTotalUnits(progressCallback: TestProgressCallback)(using Context): Unit = + val expectedPhases = runnablePhases() + var fileTraversals = 0 // files * phases + for (_, phases) <- progressCallback.unitPhases do + fileTraversals += phases.size + val expectedTotal = fileTraversals + progressCallback.totalEvents match + case Nil => fail("No total events recorded") + case TotalEvent(total, _) :: _ => + assertEquals(expectedTotal, total) + + def inspectProgress(sources: List[String], terminalPhase: Option[String] = Some("typer"))(op: Context ?=> TestProgressCallback => Unit) = + // given Context = getCtx + val sources0 = sources.map(_.linesIterator.map(_.trim.nn).filterNot(_.isEmpty).mkString("\n|").stripMargin) + val terminalPhase0 = terminalPhase.getOrElse(defaultCompiler.phases.last.last.phaseName) + checkAfterCompile(terminalPhase0, sources0) { case given Context => + ctx.progressCallback match + case cb: TestProgressCallback => op(cb) + case _ => + fail(s"Expected TestProgressCallback but got ${ctx.progressCallback}") + ??? + } + + override protected def initializeCtx(fc: FreshContext): Unit = + super.initializeCtx( + fc.setProgressCallback(TestProgressCallback()) + .setSetting(fc.settings.outputDir, new VirtualDirectory("")) + ) + +object ProgressCallbackTest: + + case class TotalEvent(total: Int, atPhase: String) + case class ProgressEvent(curr: Int, total: Int, currPhase: String, nextPhase: String) + case class PhaseTransition(curr: String, next: String) + + def runnablePhases()(using Context) = + ctx.base.allPhases.filter(_.isRunnable).map(_.phaseName).toList + + private val syntheticNextPhases = List("") + + /** Flatten the terminal phases into linear order */ + private val terminalIndices = + syntheticNextPhases.zipWithIndex.toMap + + /** Asserts that the computed phase name exists in the real phase plan */ + def indexFor(allPhasePlan: Array[Phase], phaseName: String): Int = + val i = allPhasePlan.indexWhere(_.phaseName == phaseName) + if i < 0 then // not found in real phase plan + terminalIndices.get(phaseName) match + case Some(index) => allPhasePlan.size + index // append to end of phase plan + case None => + fail(s"Phase $phaseName not found") + -1 + else + i + + final class TestProgressCallback extends interfaces.ProgressCallback: + private var _cancelled: Boolean = false + private var _unitPhases: Map[CompilationUnit, List[String]] = Map.empty + private var _totalEvents: List[TotalEvent] = List.empty + private var _progressPhases: List[PhaseTransition] = List.empty + private var _shouldCancelNow: TestProgressCallback => Boolean = _ => false + + def totalEvents = _totalEvents + def unitPhases = _unitPhases + def progressPhasesFinal = _progressPhases.reverse + + def withCancelNow(f: TestProgressCallback => Boolean): this.type = + _shouldCancelNow = f + this + + override def cancel(): Unit = _cancelled = true + override def isCancelled(): Boolean = _cancelled + + override def informUnitStarting(phase: String, unit: CompilationUnit): Unit = + _unitPhases += (unit -> (unitPhases.getOrElse(unit, Nil) :+ phase)) + + override def progress(current: Int, total: Int, currPhase: String, nextPhase: String): Boolean = + // record the total and current phase whenever the total changes + _totalEvents = _totalEvents match + case Nil => TotalEvent(total, currPhase) :: Nil + case events @ (head :: _) if head.total != total => TotalEvent(total, currPhase) :: events + case events => events + + // record the current and next phase whenever the current phase changes + _progressPhases = _progressPhases match + case all @ PhaseTransition(head, _) :: rest => + if head != currPhase then + PhaseTransition(currPhase, nextPhase) :: all + else + all + case Nil => PhaseTransition(currPhase, nextPhase) :: Nil + + !_shouldCancelNow(this) + +end ProgressCallbackTest diff --git a/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java index 420f4c02a03b..ce9f7debbfa8 100644 --- a/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java +++ b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java @@ -6,7 +6,7 @@ import xsbti.compile.CompileProgress; public final class ProgressCallbackImpl implements ProgressCallback { - private boolean _cancelled = false; + private boolean _cancelled = false; // TODO: atomic boolean? private final CompileProgress _progress; public ProgressCallbackImpl(CompileProgress progress) { diff --git a/scaladoc/src/scala/tasty/inspector/TastyInspector.scala b/scaladoc/src/scala/tasty/inspector/TastyInspector.scala index 00aa6c5e3771..14e5f019b433 100644 --- a/scaladoc/src/scala/tasty/inspector/TastyInspector.scala +++ b/scaladoc/src/scala/tasty/inspector/TastyInspector.scala @@ -69,6 +69,7 @@ object TastyInspector: override def phaseName: String = "tastyInspector" override def runOn(units: List[CompilationUnit])(using ctx0: Context): List[CompilationUnit] = + // NOTE: although this is a phase, do not expect this to be ran with an xsbti.CompileProgress val ctx = QuotesCache.init(ctx0.fresh) runOnImpl(units)(using ctx) diff --git a/staging/src/scala/quoted/staging/QuoteCompiler.scala b/staging/src/scala/quoted/staging/QuoteCompiler.scala index eee2dacdc5f5..9fee0e41efd1 100644 --- a/staging/src/scala/quoted/staging/QuoteCompiler.scala +++ b/staging/src/scala/quoted/staging/QuoteCompiler.scala @@ -62,6 +62,7 @@ private class QuoteCompiler extends Compiler: def phaseName: String = "quotedFrontend" override def runOn(units: List[CompilationUnit])(implicit ctx: Context): List[CompilationUnit] = + // NOTE: although this is a phase, there is no need to track xsbti.CompileProgress here. units.flatMap { case exprUnit: ExprCompilationUnit => val ctx1 = ctx.fresh.setPhase(this.start).setCompilationUnit(exprUnit) diff --git a/tasty-inspector/src/scala/tasty/inspector/TastyInspector.scala b/tasty-inspector/src/scala/tasty/inspector/TastyInspector.scala index 4c6440530ba2..e70d2d4f6dc5 100644 --- a/tasty-inspector/src/scala/tasty/inspector/TastyInspector.scala +++ b/tasty-inspector/src/scala/tasty/inspector/TastyInspector.scala @@ -66,6 +66,7 @@ object TastyInspector: override def phaseName: String = "tastyInspector" override def runOn(units: List[CompilationUnit])(using ctx0: Context): List[CompilationUnit] = + // NOTE: although this is a phase, do not expect this to be ran with an xsbti.CompileProgress val ctx = QuotesCache.init(ctx0.fresh) runOnImpl(units)(using ctx) From 8a79d07d76ed2e91476924576e5fb880ac6623f3 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Thu, 19 Oct 2023 17:50:09 +0200 Subject: [PATCH 03/13] trace subphases [Cherry-picked ef9fabc8425417427094ce81fe7f3c4f29c4e75e] --- compiler/src/dotty/tools/dotc/Run.scala | 89 +++++++++++++------ .../src/dotty/tools/dotc/core/Phases.scala | 11 ++- .../tools/dotc/fromtasty/ReadTasty.scala | 2 +- .../tools/dotc/parsing/ParserPhase.scala | 5 +- .../tools/dotc/transform/init/Checker.scala | 2 +- .../dotty/tools/dotc/typer/TyperPhase.scala | 22 +++-- .../tools/dotc/sbt/ProgressCallbackTest.scala | 70 +++++++-------- 7 files changed, 126 insertions(+), 75 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 632baadcc8cc..b3f059ab2735 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -176,13 +176,13 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if local != null then op(using ctx)(local) - def doBeginUnit(unit: CompilationUnit)(using Context): Unit = + def doBeginUnit()(using Context): Unit = trackProgress: progress => - progress.informUnitStarting(unit) + progress.informUnitStarting(ctx.compilationUnit) def doAdvanceUnit()(using Context): Unit = trackProgress: progress => - progress.unitc += 1 // trace that we completed a unit in the current phase + progress.unitc += 1 // trace that we completed a unit in the current (sub)phase progress.refreshProgress() def doAdvanceLate()(using Context): Unit = @@ -196,14 +196,23 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit = trackProgress: progress => - progress.unitc = 0 // reset unit count in current phase - progress.seen += 1 // trace that we've seen a phase + progress.unitc = 0 // reset unit count in current (sub)phase + progress.subtraversalc = 0 // reset subphase index to initial + progress.seen += 1 // trace that we've seen a (sub)phase if wasRan then - // add an extra traversal now that we completed a phase + // add an extra traversal now that we completed a (sub)phase progress.traversalc += 1 else - // no phase was ran, remove a traversal from expected total - progress.runnablePhases -= 1 + // no subphases were ran, remove traversals from expected total + progress.totalTraversals -= currentPhase.traversals + + private def doAdvanceSubPhase()(using Context): Unit = + trackProgress: progress => + progress.unitc = 0 // reset unit count in current (sub)phase + progress.seen += 1 // trace that we've seen a (sub)phase + progress.traversalc += 1 // add an extra traversal now that we completed a (sub)phase + progress.subtraversalc += 1 // record that we've seen a subphase + progress.tickSubphase() /** Will be set to true if any of the compiled compilation units contains * a pureFunctions language import. @@ -317,7 +326,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint unfusedPhases.foreach(_.initContext(runCtx)) val fusedPhases = runCtx.base.allPhases runCtx.withProgressCallback: cb => - _progress = Progress(cb, this, fusedPhases.length) + _progress = Progress(cb, this, fusedPhases.map(_.traversals).sum) runPhases(allPhases = fusedPhases)(using runCtx) if (!ctx.reporter.hasErrors) Rewrites.writeBack() @@ -450,31 +459,52 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint object Run { - /**Computes the next MegaPhase for the given phase.*/ - def nextMegaPhase(phase: Phase)(using Context): Phase = phase.megaPhase.next.megaPhase + class SubPhases(val phase: Phase): + require(phase.exists) + + val all = IArray.from(phase.subPhases.map(sub => s"${phase.phaseName} ($sub)")) + + def next(using Context): Option[SubPhases] = + val next0 = phase.megaPhase.next.megaPhase + if next0.exists then Some(SubPhases(next0)) + else None - private class Progress(cb: ProgressCallback, private val run: Run, val initialPhases: Int): - private[Run] var runnablePhases: Int = initialPhases // track how many phases we expect to run - private[Run] var unitc: Int = 0 // current unit count in the current phase + def subPhase(index: Int) = + if index < all.size then all(index) + else phase.phaseName + + private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int): + private[Run] var totalTraversals: Int = initialTraversals // track how many phases we expect to run + private[Run] var unitc: Int = 0 // current unit count in the current (sub)phase private[Run] var latec: Int = 0 // current late unit count private[Run] var traversalc: Int = 0 // completed traversals over all files + private[Run] var subtraversalc: Int = 0 // completed subphases in the current phase private[Run] var seen: Int = 0 // how many phases we've seen so far private var currPhase: Phase = uninitialized // initialized by enterPhase + private var subPhases: SubPhases = uninitialized // initialized by enterPhase private var currPhaseName: String = uninitialized // initialized by enterPhase private var nextPhaseName: String = uninitialized // initialized by enterPhase - private def phaseNameFor(phase: Phase): String = - if phase.exists then phase.phaseName - else "" - + /** Enter into a new real phase, setting the current and next (sub)phases */ private[Run] def enterPhase(newPhase: Phase)(using Context): Unit = if newPhase ne currPhase then currPhase = newPhase - currPhaseName = phaseNameFor(newPhase) - nextPhaseName = phaseNameFor(Run.nextMegaPhase(newPhase)) - if seen > 0 then - refreshProgress() + subPhases = SubPhases(newPhase) + tickSubphase() + + /** Compute the current (sub)phase name and next (sub)phase name */ + private[Run] def tickSubphase()(using Context): Unit = + val index = subtraversalc + val s = subPhases + currPhaseName = s.subPhase(index) + nextPhaseName = + if index + 1 < s.all.size then s.subPhase(index + 1) + else s.next match + case None => "" + case Some(next0) => next0.subPhase(0) + if seen > 0 then + refreshProgress() /** Counts the number of completed full traversals over files, plus the number of units in the current phase */ @@ -486,26 +516,35 @@ object Run { * - the number of late compilations */ private def totalProgress()(using Context): Int = - runnablePhases * run.files.size + run.lateFiles.size + totalTraversals * run.files.size + run.lateFiles.size private def requireInitialized(): Unit = require((currPhase: Phase | Null) != null, "enterPhase was not called") + /** trace that we are beginning a unit in the current (sub)phase */ private[Run] def informUnitStarting(unit: CompilationUnit)(using Context): Unit = requireInitialized() cb.informUnitStarting(currPhaseName, unit) + /** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */ private[Run] def refreshProgress()(using Context): Unit = requireInitialized() cb.progress(currentProgress(), totalProgress(), currPhaseName, nextPhaseName) extension (run: Run | Null) - def beginUnit(unit: CompilationUnit)(using Context): Unit = - if run != null then run.doBeginUnit(unit) + /** record that the current phase has begun for the compilation unit of the current Context */ + def beginUnit()(using Context): Unit = + if run != null then run.doBeginUnit() + + /** advance the unit count and record progress in the current phase */ def advanceUnit()(using Context): Unit = if run != null then run.doAdvanceUnit() + def advanceSubPhase()(using Context): Unit = + if run != null then run.doAdvanceSubPhase() + + /** advance the late count and record progress in the current phase */ def advanceLate()(using Context): Unit = if run != null then run.doAdvanceLate() diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index e1a99f84debc..4743033ee0e5 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -317,6 +317,10 @@ object Phases { /** List of names of phases that should precede this phase */ def runsAfter: Set[String] = Set.empty + /** for purposes of progress tracking, overridden in TyperPhase */ + def subPhases: List[String] = Nil + final def traversals: Int = if subPhases.isEmpty then 1 else subPhases.length + /** @pre `isRunnable` returns true */ def run(using Context): Unit @@ -324,7 +328,7 @@ object Phases { def runOn(units: List[CompilationUnit])(using runCtx: Context): List[CompilationUnit] = units.map { unit => given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports - ctx.run.beginUnit(unit) + ctx.run.beginUnit() try run catch case ex: Throwable if !ctx.run.enrichedErrorMessage => println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) @@ -438,12 +442,15 @@ object Phases { final def iterator: Iterator[Phase] = Iterator.iterate(this)(_.next) takeWhile (_.hasNext) - final def monitor(doing: String)(body: => Unit)(using Context): Unit = + /** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */ + final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Unit = + ctx.run.beginUnit() try body catch case NonFatal(ex) if !ctx.run.enrichedErrorMessage => report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) throw ex + finally ctx.run.advanceUnit() override def toString: String = phaseName } diff --git a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala index 8a0d5f68410a..82fe1ef13c10 100644 --- a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala +++ b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala @@ -25,7 +25,7 @@ class ReadTasty extends Phase { withMode(Mode.ReadPositions)(units.flatMap(applyPhase(_))) private def applyPhase(unit: CompilationUnit)(using Context): Option[CompilationUnit] = - ctx.run.beginUnit(unit) + ctx.run.beginUnit() try readTASTY(unit) finally ctx.run.advanceUnit() diff --git a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala index a7360da6ada9..3b23847db7f5 100644 --- a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala +++ b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala @@ -43,13 +43,12 @@ class Parser extends Phase { override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = { val unitContexts = for unit <- units yield - ctx.run.beginUnit(unit) report.inform(s"parsing ${unit.source}") ctx.fresh.setCompilationUnit(unit).withRootImports for given Context <- unitContexts do - try parse - finally ctx.run.advanceUnit() + parse + record("parsedTrees", ast.Trees.ntrees) unitContexts.map(_.compilationUnit) diff --git a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala index a6966ee84a2d..92069f834cff 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala @@ -35,7 +35,7 @@ class Checker extends Phase: val checkCtx = ctx.fresh.setPhase(this.start) val traverser = new InitTreeTraverser() for unit <- units do - checkCtx.run.beginUnit(unit) + checkCtx.run.beginUnit() try traverser.traverse(unit.tpdTree) finally ctx.run.advanceUnit() val classes = traverser.getClasses() diff --git a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala index f85ebcb3b1e2..a15ab8afee39 100644 --- a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala +++ b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala @@ -58,20 +58,25 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { protected def discardAfterTyper(unit: CompilationUnit)(using Context): Boolean = unit.isJava || unit.suspended + /** Keep synchronised with `monitor` subcalls */ + override def subPhases: List[String] = List("indexing", "typechecking", "checking java") + override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = val unitContexts = for unit <- units yield val newCtx0 = ctx.fresh.setPhase(this.start).setCompilationUnit(unit) val newCtx = PrepareInlineable.initContext(newCtx0) - newCtx.run.beginUnit(unit) report.inform(s"typing ${unit.source}") if (addRootImports) newCtx.withRootImports else newCtx - for given Context <- unitContexts do - enterSyms + try + for given Context <- unitContexts do + enterSyms + finally + ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)" ctx.base.parserPhase match { case p: ParserPhase => @@ -83,13 +88,16 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { case _ => } - for given Context <- unitContexts do - typeCheck + try + for given Context <- unitContexts do + typeCheck + finally + ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)" record("total trees after typer", ast.Trees.ntrees) + for given Context <- unitContexts do - try javaCheck // after typechecking to avoid cycles - finally ctx.run.advanceUnit() + javaCheck // after typechecking to avoid cycles val newUnits = unitContexts.map(_.compilationUnit).filterNot(discardAfterTyper) ctx.run.nn.checkSuspendedUnits(newUnits) diff --git a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala index b7845d47b6f9..c2ab4c41318c 100644 --- a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala +++ b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala @@ -15,6 +15,7 @@ import dotty.tools.dotc.Compiler import dotty.tools.dotc.Run import dotty.tools.dotc.core.Phases.Phase import dotty.tools.io.VirtualDirectory +import dotty.tools.dotc.NoCompilationUnit final class ProgressCallbackTest extends DottyTest: @@ -49,30 +50,30 @@ final class ProgressCallbackTest extends DottyTest: def assertNextPhaseIsNext()(using Context): Unit = val allPhases = ctx.base.allPhases for case Array(p1, p2) <- allPhases.sliding(2) do - val p1Next = Run.nextMegaPhase(p1) // used to compute the next phase in `Run.doAdvancePhase` - assertEquals(p1Next.phaseName, p2.phaseName) + val p1Next = Run.SubPhases(p1).next.get.phase // used to compute the next phase in `Run.doAdvancePhase` + assertEquals(p1Next, p2) /** Assert that the recorded progression of phases are all in the real progression, and that order is preserved */ def assertMonotonicProgression(progressCallback: TestProgressCallback)(using Context): Unit = - val allPhasePlan = ctx.base.allPhases + val allPhasePlan = ctx.base.allPhases.flatMap(asSubphases) ++ syntheticNextPhases for case List( PhaseTransition(curr1, next1), PhaseTransition(curr2, next2) ) <- progressCallback.progressPhasesFinal.sliding(2) do - val curr1Index = indexFor(allPhasePlan, curr1) - val curr2Index = indexFor(allPhasePlan, curr2) - val next1Index = indexFor(allPhasePlan, next1) - val next2Index = indexFor(allPhasePlan, next2) - assertTrue(s"Phase $curr1 comes before $curr2", curr1Index < curr2Index) - assertTrue(s"Phase $next1 comes before $next2", next1Index < next2Index) - assertTrue(s"Phase $curr1 comes before $next1", curr1Index < next1Index) - assertTrue(s"Phase $curr2 comes before $next2", curr2Index < next2Index) - assertTrue(s"Predicted next phase $next1 was next current $curr2", next1Index == curr2Index) + val curr1Index = indexOrFail(allPhasePlan, curr1) + val curr2Index = indexOrFail(allPhasePlan, curr2) + val next1Index = indexOrFail(allPhasePlan, next1) + val next2Index = indexOrFail(allPhasePlan, next2) + assertTrue(s"Phase `$curr1` did not come before `$curr2`", curr1Index < curr2Index) + assertTrue(s"Phase `$next1` did not come before `$next2`", next1Index < next2Index) + assertTrue(s"Phase `$curr1` did not come before `$next1`", curr1Index < next1Index) + assertTrue(s"Phase `$curr2` did not come before `$next2`", curr2Index < next2Index) + assertTrue(s"Predicted next phase `$next1` didn't match the following current `$curr2`", next1Index == curr2Index) /** Assert that the recorded progression of phases contains every phase in the plan */ def assertFullCoverage(progressCallback: TestProgressCallback)(using Context): Unit = val (allPhasePlan, expectedCurrPhases, expectedNextPhases) = - val allPhases = ctx.base.allPhases.map(_.phaseName) + val allPhases = ctx.base.allPhases.flatMap(asSubphases) val firstPhase = allPhases.head val expectedCurrPhases = allPhases.toSet val expectedNextPhases = expectedCurrPhases - firstPhase ++ syntheticNextPhases @@ -98,22 +99,23 @@ final class ProgressCallbackTest extends DottyTest: /** Assert that the phases recorded per unit match the actual phases ran on them */ def assertExpectedPhases(progressCallback: TestProgressCallback)(using Context): Unit = - val expectedPhases = runnablePhases() - for (_, visitedPhases) <- progressCallback.unitPhases do + val expectedPhases = runnablePhases().flatMap(asSubphases) + for (unit, visitedPhases) <- progressCallback.unitPhases do val uniquePhases = visitedPhases.toSet - assertEquals("some phases were visited twice!", visitedPhases.size, uniquePhases.size) + assert(unit != NoCompilationUnit, s"unexpected NoCompilationUnit for phases $uniquePhases") + val duplicatePhases = visitedPhases.view.groupBy(identity).values.filter(_.size > 1).map(_.head) + assertEquals(s"some phases were visited twice for $unit! ${duplicatePhases.toList}", visitedPhases.size, uniquePhases.size) val unvisitedPhases = expectedPhases.filterNot(visitedPhases.contains) val extraPhases = visitedPhases.filterNot(expectedPhases.contains) - assertTrue(s"these phases were not visited ${unvisitedPhases}", unvisitedPhases.isEmpty) - assertTrue(s"these phases were visited, but not expected ${extraPhases}", extraPhases.isEmpty) + assertTrue(s"these phases were not visited for $unit ${unvisitedPhases}", unvisitedPhases.isEmpty) + assertTrue(s"these phases were visited for $unit, but not expected ${extraPhases}", extraPhases.isEmpty) /** Assert that the number of total units of work matches the number of files * the runnable phases */ def assertTotalUnits(progressCallback: TestProgressCallback)(using Context): Unit = - val expectedPhases = runnablePhases() var fileTraversals = 0 // files * phases for (_, phases) <- progressCallback.unitPhases do fileTraversals += phases.size - val expectedTotal = fileTraversals + val expectedTotal = fileTraversals // assume that no late enters occur progressCallback.totalEvents match case Nil => fail("No total events recorded") case TotalEvent(total, _) :: _ => @@ -143,26 +145,22 @@ object ProgressCallbackTest: case class ProgressEvent(curr: Int, total: Int, currPhase: String, nextPhase: String) case class PhaseTransition(curr: String, next: String) - def runnablePhases()(using Context) = - ctx.base.allPhases.filter(_.isRunnable).map(_.phaseName).toList + def asSubphases(phase: Phase): IArray[String] = + val subPhases = Run.SubPhases(phase).all + if subPhases.isEmpty then IArray(phase.phaseName) + else subPhases - private val syntheticNextPhases = List("") + def runnablePhases()(using Context): IArray[Phase] = + IArray.from(ctx.base.allPhases.filter(_.isRunnable)) - /** Flatten the terminal phases into linear order */ - private val terminalIndices = - syntheticNextPhases.zipWithIndex.toMap + private val syntheticNextPhases = List("") /** Asserts that the computed phase name exists in the real phase plan */ - def indexFor(allPhasePlan: Array[Phase], phaseName: String): Int = - val i = allPhasePlan.indexWhere(_.phaseName == phaseName) - if i < 0 then // not found in real phase plan - terminalIndices.get(phaseName) match - case Some(index) => allPhasePlan.size + index // append to end of phase plan - case None => - fail(s"Phase $phaseName not found") - -1 - else - i + def indexOrFail(allPhasePlan: Array[String], phaseName: String): Int = + val i = allPhasePlan.indexOf(phaseName) + if i < 0 then + fail(s"Phase $phaseName not found") + i final class TestProgressCallback extends interfaces.ProgressCallback: private var _cancelled: Boolean = false From 4acd64699a56efaac748b08d5889902e9eaaba05 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Fri, 20 Oct 2023 16:44:15 +0200 Subject: [PATCH 04/13] change encoding MegaPhase name in progress tracking also add sbt-bridge test for CompileProgress [Cherry-picked c0190c2585b18f4a8df6fa4711d0f753e873f7a2] --- compiler/src/dotty/tools/dotc/Run.scala | 10 +++- .../tools/dotc/transform/MegaPhase.scala | 6 ++ .../xsbt/CompileProgressSpecification.scala | 56 +++++++++++++++++++ .../xsbt/ExtractUsedNamesSpecification.scala | 3 +- .../xsbt/ScalaCompilerForUnitTesting.scala | 37 ++++++++---- .../test/xsbt/TestCompileProgress.scala | 5 -- .../test/xsbti/TestCompileProgress.scala | 30 ++++++++++ 7 files changed, 129 insertions(+), 18 deletions(-) create mode 100644 sbt-bridge/test/xsbt/CompileProgressSpecification.scala delete mode 100644 sbt-bridge/test/xsbt/TestCompileProgress.scala create mode 100644 sbt-bridge/test/xsbti/TestCompileProgress.scala diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index b3f059ab2735..5a033b664652 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -36,6 +36,7 @@ import scala.io.Codec import Run.Progress import scala.compiletime.uninitialized +import dotty.tools.dotc.transform.MegaPhase /** A compiler run. Exports various methods to compile source files */ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with ConstraintRunInfo { @@ -462,7 +463,11 @@ object Run { class SubPhases(val phase: Phase): require(phase.exists) - val all = IArray.from(phase.subPhases.map(sub => s"${phase.phaseName} ($sub)")) + private def baseName: String = phase match + case phase: MegaPhase => phase.shortPhaseName + case phase => phase.phaseName + + val all = IArray.from(phase.subPhases.map(sub => s"$baseName ($sub)")) def next(using Context): Option[SubPhases] = val next0 = phase.megaPhase.next.megaPhase @@ -471,7 +476,8 @@ object Run { def subPhase(index: Int) = if index < all.size then all(index) - else phase.phaseName + else baseName + private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int): private[Run] var totalTraversals: Int = initialTraversals // track how many phases we expect to run diff --git a/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala b/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala index 58c3cd7c65ed..fe70a1659036 100644 --- a/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala +++ b/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala @@ -145,6 +145,12 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase { if (miniPhases.length == 1) miniPhases(0).phaseName else miniPhases.map(_.phaseName).mkString("MegaPhase{", ", ", "}") + /** Used in progress reporting to avoid super long phase names, also the precision is not so important here */ + lazy val shortPhaseName: String = + if (miniPhases.length == 1) miniPhases(0).phaseName + else + s"MegaPhase{${miniPhases.head.phaseName},...,${miniPhases.last.phaseName}}" + private var relaxedTypingCache: Boolean = _ private var relaxedTypingKnown = false diff --git a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala new file mode 100644 index 000000000000..2297f8bb441e --- /dev/null +++ b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala @@ -0,0 +1,56 @@ +package xsbt + +import org.junit.{ Test, Ignore } +import org.junit.Assert._ + +/**Only does some rudimentary checks to assert compat with sbt. + * More thorough tests are found in compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala + */ +class CompileProgressSpecification { + + @Test + def multipleFilesVisitSamePhases = { + val srcA = """class A""" + val srcB = """class B""" + val compilerForTesting = new ScalaCompilerForUnitTesting + val Seq(phasesA, phasesB) = compilerForTesting.extractEnteredPhases(srcA, srcB) + assertTrue("expected some phases, was empty", phasesA.nonEmpty) + assertEquals(phasesA, phasesB) + } + + @Test + def multipleFiles = { + val srcA = """class A""" + val srcB = """class B""" + val compilerForTesting = new ScalaCompilerForUnitTesting + val allPhases = compilerForTesting.extractProgressPhases(srcA, srcB) + assertTrue("expected some phases, was empty", allPhases.nonEmpty) + val someExpectedPhases = // just check some "fundamental" phases, don't put all phases to avoid brittleness + Set( + "parser", + "typer (indexing)", "typer (typechecking)", "typer (checking java)", + "sbt-deps", + "extractSemanticDB", + "posttyper", + "sbt-api", + "SetRootTree", + "pickler", + "inlining", + "postInlining", + "staging", + "splicing", + "pickleQuotes", + "MegaPhase{pruneErasedDefs,...,arrayConstructors}", + "erasure", + "constructors", + "genSJSIR", + "genBCode" + ) + val missingExpectedPhases = someExpectedPhases -- allPhases.toSet + val msgIfMissing = + s"missing expected phases: $missingExpectedPhases. " + + s"Either the compiler phases changed, or the encoding of Run.SubPhases.subPhase" + assertTrue(msgIfMissing, missingExpectedPhases.isEmpty) + } + +} diff --git a/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala b/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala index 819bedec3cbc..2b2b7d26c716 100644 --- a/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala +++ b/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala @@ -1,6 +1,7 @@ package xsbt import xsbti.UseScope +import ScalaCompilerForUnitTesting.Callbacks import org.junit.{ Test, Ignore } import org.junit.Assert._ @@ -226,7 +227,7 @@ class ExtractUsedNamesSpecification { def findPatMatUsages(in: String): Set[String] = { val compilerForTesting = new ScalaCompilerForUnitTesting - val (_, callback) = + val (_, Callbacks(callback, _)) = compilerForTesting.compileSrcs(List(List(sealedClass, in))) val clientNames = callback.usedNamesAndScopes.view.filterKeys(!_.startsWith("base.")) diff --git a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala index 5726bf1b7c44..520b7f7053da 100644 --- a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala +++ b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala @@ -13,6 +13,10 @@ import dotty.tools.io.PlainFile.toPlainFile import dotty.tools.xsbt.CompilerBridge import TestCallback.ExtractedClassDependencies +import ScalaCompilerForUnitTesting.Callbacks + +object ScalaCompilerForUnitTesting: + case class Callbacks(analysis: TestCallback, progress: TestCompileProgress) /** * Provides common functionality needed for unit tests that require compiling @@ -20,12 +24,23 @@ import TestCallback.ExtractedClassDependencies */ class ScalaCompilerForUnitTesting { + def extractEnteredPhases(srcs: String*): Seq[List[String]] = { + val (tempSrcFiles, Callbacks(_, testProgress)) = compileSrcs(srcs: _*) + val run = testProgress.runs.head + tempSrcFiles.map(src => run.unitPhases(src.id)) + } + + def extractProgressPhases(srcs: String*): List[String] = { + val (_, Callbacks(_, testProgress)) = compileSrcs(srcs: _*) + testProgress.runs.head.phases + } + /** * Compiles given source code using Scala compiler and returns API representation * extracted by ExtractAPI class. */ def extractApiFromSrc(src: String): Seq[ClassLike] = { - val (Seq(tempSrcFile), analysisCallback) = compileSrcs(src) + val (Seq(tempSrcFile), Callbacks(analysisCallback, _)) = compileSrcs(src) analysisCallback.apis(tempSrcFile) } @@ -34,7 +49,7 @@ class ScalaCompilerForUnitTesting { * extracted by ExtractAPI class. */ def extractApisFromSrcs(srcs: List[String]*): Seq[Seq[ClassLike]] = { - val (tempSrcFiles, analysisCallback) = compileSrcs(srcs.toList) + val (tempSrcFiles, Callbacks(analysisCallback, _)) = compileSrcs(srcs.toList) tempSrcFiles.map(analysisCallback.apis) } @@ -52,7 +67,7 @@ class ScalaCompilerForUnitTesting { assertDefaultScope: Boolean = true ): Map[String, Set[String]] = { // we drop temp src file corresponding to the definition src file - val (Seq(_, tempSrcFile), analysisCallback) = compileSrcs(definitionSrc, actualSrc) + val (Seq(_, tempSrcFile), Callbacks(analysisCallback, _)) = compileSrcs(definitionSrc, actualSrc) if (assertDefaultScope) for { (className, used) <- analysisCallback.usedNamesAndScopes @@ -70,7 +85,7 @@ class ScalaCompilerForUnitTesting { * Only the names used in the last src file are returned. */ def extractUsedNamesFromSrc(sources: String*): Map[String, Set[String]] = { - val (srcFiles, analysisCallback) = compileSrcs(sources: _*) + val (srcFiles, Callbacks(analysisCallback, _)) = compileSrcs(sources: _*) srcFiles .map { srcFile => val classesInSrc = analysisCallback.classNames(srcFile).map(_._1) @@ -92,7 +107,7 @@ class ScalaCompilerForUnitTesting { * file system-independent way of testing dependencies between source code "files". */ def extractDependenciesFromSrcs(srcs: List[List[String]]): ExtractedClassDependencies = { - val (_, testCallback) = compileSrcs(srcs) + val (_, Callbacks(testCallback, _)) = compileSrcs(srcs) val memberRefDeps = testCallback.classDependencies collect { case (target, src, DependencyByMemberRef) => (src, target) @@ -121,7 +136,7 @@ class ScalaCompilerForUnitTesting { * The sequence of temporary files corresponding to passed snippets and analysis * callback is returned as a result. */ - def compileSrcs(groupedSrcs: List[List[String]]): (Seq[VirtualFile], TestCallback) = { + def compileSrcs(groupedSrcs: List[List[String]]): (Seq[VirtualFile], Callbacks) = { val temp = IO.createTemporaryDirectory val analysisCallback = new TestCallback val testProgress = new TestCompileProgress @@ -130,8 +145,8 @@ class ScalaCompilerForUnitTesting { val bridge = new CompilerBridge - val files = for ((compilationUnit, unitId) <- groupedSrcs.zipWithIndex) yield { - val srcFiles = compilationUnit.toSeq.zipWithIndex.map { + val files = for ((compilationUnits, unitId) <- groupedSrcs.zipWithIndex) yield { + val srcFiles = compilationUnits.toSeq.zipWithIndex.map { (src, i) => val fileName = s"Test-$unitId-$i.scala" prepareSrcFile(temp, fileName, src) @@ -153,12 +168,14 @@ class ScalaCompilerForUnitTesting { new TestLogger ) + testProgress.completeRun() + srcFiles } - (files.flatten.toSeq, analysisCallback) + (files.flatten.toSeq, Callbacks(analysisCallback, testProgress)) } - def compileSrcs(srcs: String*): (Seq[VirtualFile], TestCallback) = { + def compileSrcs(srcs: String*): (Seq[VirtualFile], Callbacks) = { compileSrcs(List(srcs.toList)) } diff --git a/sbt-bridge/test/xsbt/TestCompileProgress.scala b/sbt-bridge/test/xsbt/TestCompileProgress.scala deleted file mode 100644 index 6e385a2ed4bd..000000000000 --- a/sbt-bridge/test/xsbt/TestCompileProgress.scala +++ /dev/null @@ -1,5 +0,0 @@ -package xsbt - -import xsbti.compile.CompileProgress - -class TestCompileProgress extends CompileProgress diff --git a/sbt-bridge/test/xsbti/TestCompileProgress.scala b/sbt-bridge/test/xsbti/TestCompileProgress.scala new file mode 100644 index 000000000000..9753a6e15b4c --- /dev/null +++ b/sbt-bridge/test/xsbti/TestCompileProgress.scala @@ -0,0 +1,30 @@ +package xsbti + +import xsbti.compile.CompileProgress + +import scala.collection.mutable + +class TestCompileProgress extends CompileProgress: + class Run: + private[TestCompileProgress] val _phases: mutable.Set[String] = mutable.LinkedHashSet.empty + private[TestCompileProgress] val _unitPhases: mutable.Map[String, mutable.Set[String]] = mutable.LinkedHashMap.empty + + def phases: List[String] = _phases.toList + def unitPhases: collection.MapView[String, List[String]] = _unitPhases.view.mapValues(_.toList) + + private val _runs: mutable.ListBuffer[Run] = mutable.ListBuffer.empty + private var _currentRun: Run = new Run + + def runs: List[Run] = _runs.toList + + def completeRun(): Unit = + _runs += _currentRun + _currentRun = new Run + + override def startUnit(phase: String, unitPath: String): Unit = + _currentRun._unitPhases.getOrElseUpdate(unitPath, mutable.LinkedHashSet.empty) += phase + + override def advance(current: Int, total: Int, prevPhase: String, nextPhase: String): Boolean = + _currentRun._phases += prevPhase + _currentRun._phases += nextPhase + true From ba3f0b0cf479a88bd67bfd1b7570db28b8f83851 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Fri, 20 Oct 2023 17:28:54 +0200 Subject: [PATCH 05/13] test progress with late compilation [Cherry-picked 84f5cdfcd0603ae11bb80d31319e6b780fb3fdab] --- .../xsbt/CompileProgressSpecification.scala | 24 +++++++++++++++++++ .../xsbt/ScalaCompilerForUnitTesting.scala | 17 +++++++++++-- .../test/xsbti/TestCompileProgress.scala | 3 +++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala index 2297f8bb441e..32b4f58effdb 100644 --- a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala +++ b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala @@ -8,6 +8,30 @@ import org.junit.Assert._ */ class CompileProgressSpecification { + @Test + def totalIsMoreWhenSourcePath = { + val srcA = """class A""" + val srcB = """class B""" + val extraC = """trait C""" // will only exist in the `-sourcepath`, causing a late compile + val extraD = """trait D""" // will only exist in the `-sourcepath`, causing a late compile + val srcE = """class E extends C""" // depends on class in the sourcepath + val srcF = """class F extends C, D""" // depends on classes in the sourcepath + + val compilerForTesting = new ScalaCompilerForUnitTesting + + val totalA = compilerForTesting.extractTotal(srcA)() + assertTrue("expected more than 1 unit of work for a single file", totalA > 1) + + val totalB = compilerForTesting.extractTotal(srcA, srcB)() + assertEquals("expected twice the work for two sources", totalA * 2, totalB) + + val totalC = compilerForTesting.extractTotal(srcA, srcE)(extraC) + assertEquals("expected 2x+1 the work for two sources, and 1 late compile", totalA * 2 + 1, totalC) + + val totalD = compilerForTesting.extractTotal(srcA, srcF)(extraC, extraD) + assertEquals("expected 2x+2 the work for two sources, and 2 late compiles", totalA * 2 + 2, totalD) + } + @Test def multipleFilesVisitSamePhases = { val srcA = """class A""" diff --git a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala index 520b7f7053da..87bc45744e21 100644 --- a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala +++ b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala @@ -30,6 +30,12 @@ class ScalaCompilerForUnitTesting { tempSrcFiles.map(src => run.unitPhases(src.id)) } + def extractTotal(srcs: String*)(extraSourcePath: String*): Int = { + val (tempSrcFiles, Callbacks(_, testProgress)) = compileSrcs(List(srcs.toList), extraSourcePath.toList) + val run = testProgress.runs.head + run.total + } + def extractProgressPhases(srcs: String*): List[String] = { val (_, Callbacks(_, testProgress)) = compileSrcs(srcs: _*) testProgress.runs.head.phases @@ -136,7 +142,7 @@ class ScalaCompilerForUnitTesting { * The sequence of temporary files corresponding to passed snippets and analysis * callback is returned as a result. */ - def compileSrcs(groupedSrcs: List[List[String]]): (Seq[VirtualFile], Callbacks) = { + def compileSrcs(groupedSrcs: List[List[String]], sourcePath: List[String] = Nil): (Seq[VirtualFile], Callbacks) = { val temp = IO.createTemporaryDirectory val analysisCallback = new TestCallback val testProgress = new TestCompileProgress @@ -146,6 +152,11 @@ class ScalaCompilerForUnitTesting { val bridge = new CompilerBridge val files = for ((compilationUnits, unitId) <- groupedSrcs.zipWithIndex) yield { + val extraFiles = sourcePath.toSeq.zipWithIndex.map { + case (src, i) => + val fileName = s"Extra-$unitId-$i.scala" + prepareSrcFile(temp, fileName, src) + } val srcFiles = compilationUnits.toSeq.zipWithIndex.map { (src, i) => val fileName = s"Test-$unitId-$i.scala" @@ -157,10 +168,12 @@ class ScalaCompilerForUnitTesting { val output = new SingleOutput: def getOutputDirectory() = classesDir + val maybeSourcePath = if extraFiles.isEmpty then Nil else List("-sourcepath", temp.getAbsolutePath.toString) + bridge.run( virtualSrcFiles, new TestDependencyChanges, - Array("-Yforce-sbt-phases", "-classpath", classesDirPath, "-usejavacp", "-d", classesDirPath), + Array("-Yforce-sbt-phases", "-classpath", classesDirPath, "-usejavacp", "-d", classesDirPath) ++ maybeSourcePath, output, analysisCallback, new TestReporter, diff --git a/sbt-bridge/test/xsbti/TestCompileProgress.scala b/sbt-bridge/test/xsbti/TestCompileProgress.scala index 9753a6e15b4c..d5dc81dfda24 100644 --- a/sbt-bridge/test/xsbti/TestCompileProgress.scala +++ b/sbt-bridge/test/xsbti/TestCompileProgress.scala @@ -8,9 +8,11 @@ class TestCompileProgress extends CompileProgress: class Run: private[TestCompileProgress] val _phases: mutable.Set[String] = mutable.LinkedHashSet.empty private[TestCompileProgress] val _unitPhases: mutable.Map[String, mutable.Set[String]] = mutable.LinkedHashMap.empty + private[TestCompileProgress] var _latestTotal: Int = 0 def phases: List[String] = _phases.toList def unitPhases: collection.MapView[String, List[String]] = _unitPhases.view.mapValues(_.toList) + def total: Int = _latestTotal private val _runs: mutable.ListBuffer[Run] = mutable.ListBuffer.empty private var _currentRun: Run = new Run @@ -27,4 +29,5 @@ class TestCompileProgress extends CompileProgress: override def advance(current: Int, total: Int, prevPhase: String, nextPhase: String): Boolean = _currentRun._phases += prevPhase _currentRun._phases += nextPhase + _currentRun._latestTotal = total true From a8a2aed4825ea566338e71f4cb3301db98e63aef Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Mon, 23 Oct 2023 10:30:51 +0200 Subject: [PATCH 06/13] fix assertions in tests, fix compile error [Cherry-picked 6f6539bb4f212d01402e1f41e2b62c71eb873f63] --- compiler/test/dotty/tools/DottyTest.scala | 2 +- .../test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala | 8 ++++---- sbt-bridge/test/xsbt/CompileProgressSpecification.scala | 1 - 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/compiler/test/dotty/tools/DottyTest.scala b/compiler/test/dotty/tools/DottyTest.scala index 229891806a76..7ccbc09a4c92 100644 --- a/compiler/test/dotty/tools/DottyTest.scala +++ b/compiler/test/dotty/tools/DottyTest.scala @@ -44,7 +44,7 @@ trait DottyTest extends ContextEscapeDetection { fc.setProperty(ContextDoc, new ContextDocstrings) } - protected def defaultCompiler: Compiler = Compiler() + protected def defaultCompiler: Compiler = new Compiler() private def compilerWithChecker(phase: String)(assertion: (tpd.Tree, Context) => Unit) = new Compiler { diff --git a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala index c2ab4c41318c..82cee9928271 100644 --- a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala +++ b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala @@ -145,10 +145,10 @@ object ProgressCallbackTest: case class ProgressEvent(curr: Int, total: Int, currPhase: String, nextPhase: String) case class PhaseTransition(curr: String, next: String) - def asSubphases(phase: Phase): IArray[String] = - val subPhases = Run.SubPhases(phase).all - if subPhases.isEmpty then IArray(phase.phaseName) - else subPhases + def asSubphases(phase: Phase): IndexedSeq[String] = + val subPhases = Run.SubPhases(phase) + val indices = 0 until phase.traversals + indices.map(subPhases.subPhase) def runnablePhases()(using Context): IArray[Phase] = IArray.from(ctx.base.allPhases.filter(_.isRunnable)) diff --git a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala index 32b4f58effdb..45f9daa70e05 100644 --- a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala +++ b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala @@ -54,7 +54,6 @@ class CompileProgressSpecification { "parser", "typer (indexing)", "typer (typechecking)", "typer (checking java)", "sbt-deps", - "extractSemanticDB", "posttyper", "sbt-api", "SetRootTree", From cf60f88530817b35c1c9ac974a50f0afd4adc2c8 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Mon, 23 Oct 2023 18:07:02 +0200 Subject: [PATCH 07/13] add in cooperative cancellation, test that it works [Cherry-picked 7fc4341dffdb32fb573b13c53c2611789a1ba640] --- compiler/src/dotty/tools/dotc/Run.scala | 66 ++++++--- .../src/dotty/tools/dotc/core/Phases.scala | 52 ++++--- .../tools/dotc/fromtasty/ReadTasty.scala | 14 +- .../tools/dotc/parsing/ParserPhase.scala | 11 +- .../dotc/sbt/interfaces/ProgressCallback.java | 2 +- .../tools/dotc/transform/init/Checker.scala | 18 ++- .../dotty/tools/dotc/typer/TyperPhase.scala | 41 ++++-- .../tools/dotc/sbt/ProgressCallbackTest.scala | 137 ++++++++++++++---- .../tools/xsbt/ProgressCallbackImpl.java | 4 +- 9 files changed, 247 insertions(+), 98 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 5a033b664652..89ca7dec64ce 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -177,9 +177,18 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if local != null then op(using ctx)(local) - def doBeginUnit()(using Context): Unit = - trackProgress: progress => - progress.informUnitStarting(ctx.compilationUnit) + private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T = + val local = _progress + if local != null then + op(using ctx)(local) + else + default + + def didEnterUnit()(using Context): Boolean = + foldProgress(true /* should progress by default */)(_.tryEnterUnit(ctx.compilationUnit)) + + def didEnterFinal()(using Context): Boolean = + foldProgress(true /* should progress by default */)(p => !p.checkCancellation()) def doAdvanceUnit()(using Context): Unit = trackProgress: progress => @@ -195,6 +204,13 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint trackProgress: progress => progress.enterPhase(currentPhase) + /** interrupt the thread and set cancellation state */ + private def cancelInterrupted(): Unit = + try + trackProgress(_.cancel()) + finally + Thread.currentThread().nn.interrupt() + private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit = trackProgress: progress => progress.unitc = 0 // reset unit count in current (sub)phase @@ -213,7 +229,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint progress.seen += 1 // trace that we've seen a (sub)phase progress.traversalc += 1 // add an extra traversal now that we completed a (sub)phase progress.subtraversalc += 1 // record that we've seen a subphase - progress.tickSubphase() + if !progress.isCancelled() then + progress.tickSubphase() /** Will be set to true if any of the compiled compilation units contains * a pureFunctions language import. @@ -297,7 +314,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint Stats.trackTime(s"phase time ms/$phase") { val start = System.currentTimeMillis val profileBefore = profiler.beforePhase(phase) - units = phase.runOn(units) + try units = phase.runOn(units) + catch case _: InterruptedException => cancelInterrupted() profiler.afterPhase(phase, profileBefore) if (ctx.settings.Xprint.value.containsPhase(phase)) for (unit <- units) @@ -332,7 +350,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if (!ctx.reporter.hasErrors) Rewrites.writeBack() suppressions.runFinished(hasErrors = ctx.reporter.hasErrors) - while (finalizeActions.nonEmpty) { + while (finalizeActions.nonEmpty && didEnterFinal()) { val action = finalizeActions.remove(0) action() } @@ -480,6 +498,8 @@ object Run { private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int): + export cb.{cancel, isCancelled} + private[Run] var totalTraversals: Int = initialTraversals // track how many phases we expect to run private[Run] var unitc: Int = 0 // current unit count in the current (sub)phase private[Run] var latec: Int = 0 // current late unit count @@ -514,34 +534,46 @@ object Run { /** Counts the number of completed full traversals over files, plus the number of units in the current phase */ - private def currentProgress()(using Context): Int = - traversalc * run.files.size + unitc + latec + private def currentProgress(): Int = + traversalc * work() + unitc + latec /**Total progress is computed as the sum of * - the number of traversals we expect to make over all files * - the number of late compilations */ - private def totalProgress()(using Context): Int = - totalTraversals * run.files.size + run.lateFiles.size + private def totalProgress(): Int = + totalTraversals * work() + run.lateFiles.size + + private def work(): Int = run.files.size private def requireInitialized(): Unit = require((currPhase: Phase | Null) != null, "enterPhase was not called") - /** trace that we are beginning a unit in the current (sub)phase */ - private[Run] def informUnitStarting(unit: CompilationUnit)(using Context): Unit = - requireInitialized() - cb.informUnitStarting(currPhaseName, unit) + private[Run] def checkCancellation(): Boolean = + if Thread.interrupted() then cancel() + isCancelled() + + /** trace that we are beginning a unit in the current (sub)phase, unless cancelled */ + private[Run] def tryEnterUnit(unit: CompilationUnit): Boolean = + if checkCancellation() then false + else + requireInitialized() + cb.informUnitStarting(currPhaseName, unit) + true /** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */ private[Run] def refreshProgress()(using Context): Unit = requireInitialized() - cb.progress(currentProgress(), totalProgress(), currPhaseName, nextPhaseName) + val total = totalProgress() + if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then + cancel() extension (run: Run | Null) /** record that the current phase has begun for the compilation unit of the current Context */ - def beginUnit()(using Context): Unit = - if run != null then run.doBeginUnit() + def enterUnit()(using Context): Boolean = + if run != null then run.didEnterUnit() + else true // don't check cancellation if we're not tracking progress /** advance the unit count and record progress in the current phase */ def advanceUnit()(using Context): Unit = diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 4743033ee0e5..1483e1497c63 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -326,16 +326,20 @@ object Phases { /** @pre `isRunnable` returns true */ def runOn(units: List[CompilationUnit])(using runCtx: Context): List[CompilationUnit] = - units.map { unit => + val buf = List.newBuilder[CompilationUnit] + for unit <- units do given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports - ctx.run.beginUnit() - try run - catch case ex: Throwable if !ctx.run.enrichedErrorMessage => - println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) - throw ex - finally ctx.run.advanceUnit() - unitCtx.compilationUnit - } + if ctx.run.enterUnit() then + try run + catch case ex: Throwable if !ctx.run.enrichedErrorMessage => + println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) + throw ex + finally ctx.run.advanceUnit() + buf += unitCtx.compilationUnit + end if + end for + buf.result() + end runOn /** Convert a compilation unit's tree to a string; can be overridden */ def show(tree: untpd.Tree)(using Context): String = @@ -443,14 +447,28 @@ object Phases { Iterator.iterate(this)(_.next) takeWhile (_.hasNext) /** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */ - final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Unit = - ctx.run.beginUnit() - try body - catch - case NonFatal(ex) if !ctx.run.enrichedErrorMessage => - report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) - throw ex - finally ctx.run.advanceUnit() + final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Boolean = + if ctx.run.enterUnit() then + try {body; true} + catch + case NonFatal(ex) if !ctx.run.enrichedErrorMessage => + report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) + throw ex + finally ctx.run.advanceUnit() + else + false + + /** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */ + final def monitorOpt[T](doing: String)(body: Context ?=> Option[T])(using Context): Option[T] = + if ctx.run.enterUnit() then + try body + catch + case NonFatal(ex) if !ctx.run.enrichedErrorMessage => + report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) + throw ex + finally ctx.run.advanceUnit() + else + None override def toString: String = phaseName } diff --git a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala index 82fe1ef13c10..e64d02a31b00 100644 --- a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala +++ b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala @@ -22,12 +22,14 @@ class ReadTasty extends Phase { ctx.settings.fromTasty.value override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = - withMode(Mode.ReadPositions)(units.flatMap(applyPhase(_))) + withMode(Mode.ReadPositions) { + val unitContexts = units.map(unit => ctx.fresh.setCompilationUnit(unit)) + unitContexts.flatMap(applyPhase()(using _)) + } - private def applyPhase(unit: CompilationUnit)(using Context): Option[CompilationUnit] = - ctx.run.beginUnit() - try readTASTY(unit) - finally ctx.run.advanceUnit() + private def applyPhase()(using Context): Option[CompilationUnit] = monitorOpt(phaseName): + val unit = ctx.compilationUnit + readTASTY(unit) def readTASTY(unit: CompilationUnit)(using Context): Option[CompilationUnit] = unit match { case unit: TASTYCompilationUnit => @@ -82,7 +84,7 @@ class ReadTasty extends Phase { } } case unit => - Some(unit) + Some(unit) } def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala index 3b23847db7f5..d8c1f5f17adf 100644 --- a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala +++ b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala @@ -22,7 +22,7 @@ class Parser extends Phase { */ private[dotc] var firstXmlPos: SourcePosition = NoSourcePosition - def parse(using Context) = monitor("parser") { + def parse(using Context): Boolean = monitor("parser") { val unit = ctx.compilationUnit unit.untpdTree = if (unit.isJava) new JavaParsers.JavaParser(unit.source).parse() @@ -46,12 +46,15 @@ class Parser extends Phase { report.inform(s"parsing ${unit.source}") ctx.fresh.setCompilationUnit(unit).withRootImports - for given Context <- unitContexts do - parse + val unitContexts0 = + for + given Context <- unitContexts + if parse + yield ctx record("parsedTrees", ast.Trees.ntrees) - unitContexts.map(_.compilationUnit) + unitContexts0.map(_.compilationUnit) } def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java index 8f81ea5f99a2..d1e076c75bfa 100644 --- a/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java +++ b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java @@ -15,7 +15,7 @@ default void informUnitStarting(String phase, CompilationUnit unit) {} /** Record the current compilation progress. * @param current `completedPhaseCount * totalUnits + completedUnitsInCurrPhase + completedLate` * @param total `totalPhases * totalUnits + totalLate` - * @return true if the compilation should continue (if false, then subsequent calls to `isCancelled()` will return true) + * @return true if the compilation should continue (callers are expected to cancel if this returns false) */ default boolean progress(int current, int total, String currPhase, String nextPhase) { return true; } } diff --git a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala index 92069f834cff..ea4bc7a15f24 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala @@ -31,18 +31,26 @@ class Checker extends Phase: override def isEnabled(using Context): Boolean = super.isEnabled && ctx.settings.YcheckInit.value + def traverse(traverser: InitTreeTraverser)(using Context): Boolean = monitor(phaseName): + val unit = ctx.compilationUnit + traverser.traverse(unit.tpdTree) + override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = val checkCtx = ctx.fresh.setPhase(this.start) val traverser = new InitTreeTraverser() - for unit <- units do - checkCtx.run.beginUnit() - try traverser.traverse(unit.tpdTree) - finally ctx.run.advanceUnit() + val unitContexts = units.map(unit => checkCtx.fresh.setCompilationUnit(unit)) + + val unitContexts0 = + for + given Context <- unitContexts + if traverse(traverser) + yield ctx + val classes = traverser.getClasses() Semantic.checkClasses(classes)(using checkCtx) - units + unitContexts0.map(_.compilationUnit) def run(using Context): Unit = unsupported("run") diff --git a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala index a15ab8afee39..210e457a7764 100644 --- a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala +++ b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala @@ -31,13 +31,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { // Run regardless of parsing errors override def isRunnable(implicit ctx: Context): Boolean = true - def enterSyms(using Context): Unit = monitor("indexing") { + def enterSyms(using Context): Boolean = monitor("indexing") { val unit = ctx.compilationUnit ctx.typer.index(unit.untpdTree) typr.println("entered: " + unit.source) } - def typeCheck(using Context): Unit = monitor("typechecking") { + def typeCheck(using Context): Boolean = monitor("typechecking") { val unit = ctx.compilationUnit try if !unit.suspended then @@ -49,7 +49,7 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { catch case _: CompilationUnit.SuspendException => () } - def javaCheck(using Context): Unit = monitor("checking java") { + def javaCheck(using Context): Boolean = monitor("checking java") { val unit = ctx.compilationUnit if unit.isJava then JavaChecks.check(unit.tpdTree) @@ -72,11 +72,14 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { else newCtx - try - for given Context <- unitContexts do - enterSyms - finally - ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)" + val unitContexts0 = + try + for + given Context <- unitContexts + if enterSyms + yield ctx + finally + ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)" ctx.base.parserPhase match { case p: ParserPhase => @@ -88,18 +91,24 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { case _ => } - try - for given Context <- unitContexts do - typeCheck - finally - ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)" + val unitContexts1 = + try + for + given Context <- unitContexts0 + if typeCheck + yield ctx + finally + ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)" record("total trees after typer", ast.Trees.ntrees) - for given Context <- unitContexts do - javaCheck // after typechecking to avoid cycles + val unitContexts2 = + for + given Context <- unitContexts1 + if javaCheck // after typechecking to avoid cycles + yield ctx - val newUnits = unitContexts.map(_.compilationUnit).filterNot(discardAfterTyper) + val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper) ctx.run.nn.checkSuspendedUnits(newUnits) newUnits diff --git a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala index 82cee9928271..e6e67b997aae 100644 --- a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala +++ b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala @@ -16,6 +16,7 @@ import dotty.tools.dotc.Run import dotty.tools.dotc.core.Phases.Phase import dotty.tools.io.VirtualDirectory import dotty.tools.dotc.NoCompilationUnit +import dotty.tools.dotc.interactive.Interactive.Include.all final class ProgressCallbackTest extends DottyTest: @@ -25,26 +26,86 @@ final class ProgressCallbackTest extends DottyTest: val source2 = """class Bar""" inspectProgress(List(source1, source2), terminalPhase = None): progressCallback => - // (1) assert that the way we compute next phase in `Run.doAdvancePhase` is correct - assertNextPhaseIsNext() + locally: + // (1) assert that the way we compute next phase in `Run.doAdvancePhase` is correct + assertNextPhaseIsNext() + + locally: + // (1) given correct computation, check that the recorded progression of phases is monotonic + assertMonotonicProgression(progressCallback) + + locally: + // (1) given monotonic progression, check that the recorded progression of phases is complete + val expectedCurr = allSubPhases + val expectedNext = expectedCurr.tail ++ syntheticNextPhases + assertProgressPhases(progressCallback, expectedCurr, expectedNext) + + locally: + // (2) next check that for each unit, we record all the "runnable" phases that could go through + assertExpectedPhasesForUnits(progressCallback, expectedPhases = runnableSubPhases) + + locally: + // (2) therefore we can now cross-reference the recorded progression with the recorded phases per unit + assertTotalUnits(progressCallback) + + locally: + // (3) finally, check that the callback was not cancelled + assertFalse(progressCallback.isCancelled) + end testCallback - // (1) given correct computation, check that the recorded progression is monotonic - assertMonotonicProgression(progressCallback) + // TODO: test cancellation - // (1) given monotonic progression, check that the recorded progression has full coverage - assertFullCoverage(progressCallback) + @Test + def cancelMidTyper: Unit = + inspectCancellationAtPhase("typer (typechecking)") - // (2) next check that for each unit, we record the expected phases that it should progress through - assertExpectedPhases(progressCallback) + @Test + def cancelErasure: Unit = + inspectCancellationAtPhase("erasure") - // (2) therefore we can now cross-reference the recorded progression with the recorded phases per unit - assertTotalUnits(progressCallback) + @Test + def cancelPickler: Unit = + inspectCancellationAtPhase("pickler") - // (3) finally, check that the callback was not cancelled - assertFalse(progressCallback.isCancelled) - end testCallback + def cancelOnEnter(targetPhase: String)(testCallback: TestProgressCallback): Boolean = + testCallback.latestProgress.exists(_.currPhase == targetPhase) - // TODO: test lateCompile, test cancellation + def inspectCancellationAtPhase(targetPhase: String): Unit = + val source1 = """class Foo""" + + inspectProgress(List(source1), cancellation = Some(cancelOnEnter(targetPhase))): progressCallback => + locally: + // (1) assert that the compiler was cancelled + assertTrue("should have cancelled", progressCallback.isCancelled) + + locally: + // (2) assert that compiler visited all the subphases before cancellation, + // and does not visit any after. + // (2.2) first extract the surrounding phases of the target + val (befores, target +: next +: _) = allSubPhases.span(_ != targetPhase): @unchecked + // (2.3) we expect to see the subphases before&including target reported as a "current" phase, so extract here + val expectedCurr = befores :+ target + // (2.4) we expect to see next after target reported as a "next" phase, so extract here + val expectedNext = expectedCurr.tail :+ next + assertProgressPhases(progressCallback, expectedCurr, expectedNext) + + locally: + // (3) assert that the compilation units were only entered in the phases before cancellation + val (befores, target +: next +: _) = runnableSubPhases.span(_ != targetPhase): @unchecked + assertExpectedPhasesForUnits(progressCallback, expectedPhases = befores) + + locally: + // (4) assert that the final progress recorded is at the target phase, + // and progress is equal to the number of phases before the target. + val (befores, target +: next +: _) = runnableSubPhases.span(_ != targetPhase): @unchecked + // (4.1) we expect cancellation to occur *as we enter* the target phase, + // so no units should be visited in this phase. Therefore progress + // should be equal to the number of phases before the target. (as we have 1 unit) + val expectedProgress = befores.size + progressCallback.latestProgress match + case Some(ProgressEvent(`expectedProgress`, _, `target`, `next`)) => + case other => fail(s"did not match expected progress, found $other") + end inspectCancellationAtPhase /** Assert that the computed `next` phase matches the real next phase */ def assertNextPhaseIsNext()(using Context): Unit = @@ -71,12 +132,13 @@ final class ProgressCallbackTest extends DottyTest: assertTrue(s"Predicted next phase `$next1` didn't match the following current `$curr2`", next1Index == curr2Index) /** Assert that the recorded progression of phases contains every phase in the plan */ - def assertFullCoverage(progressCallback: TestProgressCallback)(using Context): Unit = + def assertProgressPhases(progressCallback: TestProgressCallback, + currExpected: Seq[String], nextExpected: Seq[String])(using Context): Unit = val (allPhasePlan, expectedCurrPhases, expectedNextPhases) = - val allPhases = ctx.base.allPhases.flatMap(asSubphases) + val allPhases = currExpected val firstPhase = allPhases.head val expectedCurrPhases = allPhases.toSet - val expectedNextPhases = expectedCurrPhases - firstPhase ++ syntheticNextPhases + val expectedNextPhases = nextExpected.toSet //expectedCurrPhases - firstPhase ++ syntheticNextPhases (allPhases.toList, expectedCurrPhases, expectedNextPhases) for (expectedCurr, recordedCurr) <- allPhasePlan.zip(progressCallback.progressPhasesFinal.map(_.curr)) do @@ -98,8 +160,7 @@ final class ProgressCallbackTest extends DottyTest: /** Assert that the phases recorded per unit match the actual phases ran on them */ - def assertExpectedPhases(progressCallback: TestProgressCallback)(using Context): Unit = - val expectedPhases = runnablePhases().flatMap(asSubphases) + def assertExpectedPhasesForUnits(progressCallback: TestProgressCallback, expectedPhases: Seq[String])(using Context): Unit = for (unit, visitedPhases) <- progressCallback.unitPhases do val uniquePhases = visitedPhases.toSet assert(unit != NoCompilationUnit, s"unexpected NoCompilationUnit for phases $uniquePhases") @@ -121,18 +182,26 @@ final class ProgressCallbackTest extends DottyTest: case TotalEvent(total, _) :: _ => assertEquals(expectedTotal, total) - def inspectProgress(sources: List[String], terminalPhase: Option[String] = Some("typer"))(op: Context ?=> TestProgressCallback => Unit) = - // given Context = getCtx + def inspectProgress( + sources: List[String], + terminalPhase: Option[String] = Some("typer"), + cancellation: Option[TestProgressCallback => Boolean] = None)( + op: Context ?=> TestProgressCallback => Unit)(using Context) = + for cancelNow <- cancellation do + testProgressCallback.withCancelNow(cancelNow) val sources0 = sources.map(_.linesIterator.map(_.trim.nn).filterNot(_.isEmpty).mkString("\n|").stripMargin) val terminalPhase0 = terminalPhase.getOrElse(defaultCompiler.phases.last.last.phaseName) checkAfterCompile(terminalPhase0, sources0) { case given Context => - ctx.progressCallback match - case cb: TestProgressCallback => op(cb) - case _ => - fail(s"Expected TestProgressCallback but got ${ctx.progressCallback}") - ??? + op(testProgressCallback) } + private def testProgressCallback(using Context): TestProgressCallback = + ctx.progressCallback match + case cb: TestProgressCallback => cb + case _ => + fail(s"Expected TestProgressCallback but got ${ctx.progressCallback}") + ??? + override protected def initializeCtx(fc: FreshContext): Unit = super.initializeCtx( fc.setProgressCallback(TestProgressCallback()) @@ -150,8 +219,11 @@ object ProgressCallbackTest: val indices = 0 until phase.traversals indices.map(subPhases.subPhase) - def runnablePhases()(using Context): IArray[Phase] = - IArray.from(ctx.base.allPhases.filter(_.isRunnable)) + def runnableSubPhases(using Context): IndexedSeq[String] = + ctx.base.allPhases.filter(_.isRunnable).flatMap(asSubphases).toIndexedSeq + + def allSubPhases(using Context): IndexedSeq[String] = + ctx.base.allPhases.flatMap(asSubphases).toIndexedSeq private val syntheticNextPhases = List("") @@ -163,15 +235,20 @@ object ProgressCallbackTest: i final class TestProgressCallback extends interfaces.ProgressCallback: + import collection.immutable, immutable.SeqMap + private var _cancelled: Boolean = false - private var _unitPhases: Map[CompilationUnit, List[String]] = Map.empty + private var _unitPhases: SeqMap[CompilationUnit, List[String]] = immutable.SeqMap.empty // preserve order private var _totalEvents: List[TotalEvent] = List.empty + private var _latestProgress: Option[ProgressEvent] = None private var _progressPhases: List[PhaseTransition] = List.empty private var _shouldCancelNow: TestProgressCallback => Boolean = _ => false def totalEvents = _totalEvents + def latestProgress = _latestProgress def unitPhases = _unitPhases def progressPhasesFinal = _progressPhases.reverse + def currentPhase = _progressPhases.headOption.map(_.curr) def withCancelNow(f: TestProgressCallback => Boolean): this.type = _shouldCancelNow = f @@ -190,6 +267,8 @@ object ProgressCallbackTest: case events @ (head :: _) if head.total != total => TotalEvent(total, currPhase) :: events case events => events + _latestProgress = Some(ProgressEvent(current, total, currPhase, nextPhase)) + // record the current and next phase whenever the current phase changes _progressPhases = _progressPhases match case all @ PhaseTransition(head, _) :: rest => diff --git a/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java index ce9f7debbfa8..f5fb78f12bb1 100644 --- a/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java +++ b/sbt-bridge/src/dotty/tools/xsbt/ProgressCallbackImpl.java @@ -30,8 +30,6 @@ public void informUnitStarting(String phase, CompilationUnit unit) { @Override public boolean progress(int current, int total, String currPhase, String nextPhase) { - boolean shouldAdvance = _progress.advance(current, total, currPhase, nextPhase); - if (!shouldAdvance) cancel(); - return shouldAdvance; + return _progress.advance(current, total, currPhase, nextPhase); } } From fef5807c8e20e7c7dd58689fb51dc0f890b1369c Mon Sep 17 00:00:00 2001 From: Wojciech Mazur Date: Sun, 23 Jun 2024 13:19:50 +0200 Subject: [PATCH 08/13] simplify monitor [Cherry-picked 7ccdd40c2ce61ecf078b29f04ca8832faf66bc18][modified] --- compiler/src/dotty/tools/dotc/Run.scala | 24 +++++++------- .../src/dotty/tools/dotc/core/Phases.scala | 31 +++++++++---------- .../tools/dotc/fromtasty/ReadTasty.scala | 12 +++---- .../tools/dotc/transform/init/Checker.scala | 16 +++++----- 4 files changed, 41 insertions(+), 42 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 89ca7dec64ce..0be2b530ab56 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -171,11 +171,8 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint private var _progress: Progress | Null = null // Set if progress reporting is enabled - /** Only safe to call if progress is being tracked. */ private inline def trackProgress(using Context)(inline op: Context ?=> Progress => Unit): Unit = - val local = _progress - if local != null then - op(using ctx)(local) + foldProgress(())(op) private inline def foldProgress[T](using Context)(inline default: T)(inline op: Context ?=> Progress => T): T = val local = _progress @@ -184,11 +181,11 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint else default - def didEnterUnit()(using Context): Boolean = - foldProgress(true /* should progress by default */)(_.tryEnterUnit(ctx.compilationUnit)) + def didEnterUnit(unit: CompilationUnit)(using Context): Boolean = + foldProgress(true /* should progress by default */)(_.tryEnterUnit(unit)) - def didEnterFinal()(using Context): Boolean = - foldProgress(true /* should progress by default */)(p => !p.checkCancellation()) + def canProgress()(using Context): Boolean = + foldProgress(true /* not cancelled by default */)(p => !p.checkCancellation()) def doAdvanceUnit()(using Context): Unit = trackProgress: progress => @@ -350,7 +347,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint if (!ctx.reporter.hasErrors) Rewrites.writeBack() suppressions.runFinished(hasErrors = ctx.reporter.hasErrors) - while (finalizeActions.nonEmpty && didEnterFinal()) { + while (finalizeActions.nonEmpty && canProgress()) { val action = finalizeActions.remove(0) action() } @@ -571,8 +568,13 @@ object Run { extension (run: Run | Null) /** record that the current phase has begun for the compilation unit of the current Context */ - def enterUnit()(using Context): Boolean = - if run != null then run.didEnterUnit() + def enterUnit(unit: CompilationUnit)(using Context): Boolean = + if run != null then run.didEnterUnit(unit) + else true // don't check cancellation if we're not tracking progress + + /** check progress cancellation, true if not cancelled */ + def enterRegion()(using Context): Boolean = + if run != null then run.canProgress() else true // don't check cancellation if we're not tracking progress /** advance the unit count and record progress in the current phase */ diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 1483e1497c63..31e07001a4a2 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -329,7 +329,7 @@ object Phases { val buf = List.newBuilder[CompilationUnit] for unit <- units do given unitCtx: Context = runCtx.fresh.setPhase(this.start).setCompilationUnit(unit).withRootImports - if ctx.run.enterUnit() then + if ctx.run.enterUnit(unit) then try run catch case ex: Throwable if !ctx.run.enrichedErrorMessage => println(ctx.run.enrichErrorMessage(s"unhandled exception while running $phaseName on $unit")) @@ -446,29 +446,26 @@ object Phases { final def iterator: Iterator[Phase] = Iterator.iterate(this)(_.next) takeWhile (_.hasNext) - /** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */ + /** Cancellable region, if not cancelled, run the body in the context of the current compilation unit. + * Enrich crash messages. + */ final def monitor(doing: String)(body: Context ?=> Unit)(using Context): Boolean = - if ctx.run.enterUnit() then + val unit = ctx.compilationUnit + if ctx.run.enterUnit(unit) then try {body; true} - catch - case NonFatal(ex) if !ctx.run.enrichedErrorMessage => - report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) - throw ex + catch case NonFatal(ex) if !ctx.run.enrichedErrorMessage => + report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing $unit")) + throw ex finally ctx.run.advanceUnit() else false - /** run the body as one iteration of a (sub)phase (see Run.Progress), Enrich crash messages */ - final def monitorOpt[T](doing: String)(body: Context ?=> Option[T])(using Context): Option[T] = - if ctx.run.enterUnit() then - try body - catch - case NonFatal(ex) if !ctx.run.enrichedErrorMessage => - report.echo(ctx.run.enrichErrorMessage(s"exception occurred while $doing ${ctx.compilationUnit}")) - throw ex - finally ctx.run.advanceUnit() + /** Do not run if compile progress has been cancelled */ + final def cancellable(body: Context ?=> Unit)(using Context): Boolean = + if ctx.run.enterRegion() then + {body; true} else - None + false override def toString: String = phaseName } diff --git a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala index e64d02a31b00..d3ff1776e621 100644 --- a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala +++ b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala @@ -12,7 +12,6 @@ import NameOps._ import ast.Trees.Tree import Phases.Phase - /** Load trees from TASTY files */ class ReadTasty extends Phase { @@ -23,13 +22,14 @@ class ReadTasty extends Phase { override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = withMode(Mode.ReadPositions) { - val unitContexts = units.map(unit => ctx.fresh.setCompilationUnit(unit)) - unitContexts.flatMap(applyPhase()(using _)) + val nextUnits = collection.mutable.ListBuffer.empty[CompilationUnit] + val unitContexts = units.view.map(ctx.fresh.setCompilationUnit) + for given Context <- unitContexts if addTasty(nextUnits += _) do () + nextUnits.toList } - private def applyPhase()(using Context): Option[CompilationUnit] = monitorOpt(phaseName): - val unit = ctx.compilationUnit - readTASTY(unit) + def addTasty(fn: CompilationUnit => Unit)(using Context): Boolean = monitor(phaseName): + readTASTY(ctx.compilationUnit).foreach(fn) def readTASTY(unit: CompilationUnit)(using Context): Option[CompilationUnit] = unit match { case unit: TASTYCompilationUnit => diff --git a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala index ea4bc7a15f24..23ccee675220 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala @@ -40,17 +40,17 @@ class Checker extends Phase: val traverser = new InitTreeTraverser() val unitContexts = units.map(unit => checkCtx.fresh.setCompilationUnit(unit)) - val unitContexts0 = - for - given Context <- unitContexts - if traverse(traverser) - yield ctx + val units0 = + for given Context <- unitContexts if traverse(traverser) yield ctx.compilationUnit - val classes = traverser.getClasses() + cancellable { + val classes = traverser.getClasses() - Semantic.checkClasses(classes)(using checkCtx) + Semantic.checkClasses(classes)(using checkCtx) + } - unitContexts0.map(_.compilationUnit) + units0 + end runOn def run(using Context): Unit = unsupported("run") From ddc1756673fa9aaf0fac9a0bf1136cb8e6bb485d Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 25 Oct 2023 18:33:59 +0200 Subject: [PATCH 09/13] add progress tracking/cancellation to extractSemanticDB [Cherry-picked 51abd428d252d42e24928c5feb495c87ef601fcd] --- .../dotc/semanticdb/ExtractSemanticDB.scala | 65 +++++++++++-------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala b/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala index f1b4e4637eb8..b2f16eb741cb 100644 --- a/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala +++ b/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala @@ -29,6 +29,7 @@ import dotty.tools.dotc.{semanticdb => s} import dotty.tools.io.{AbstractFile, JarArchive} import dotty.tools.dotc.semanticdb.DiagnosticOps.* import scala.util.{Using, Failure, Success} +import java.nio.file.Path /** Extract symbol references and uses to semanticdb files. @@ -65,41 +66,49 @@ class ExtractSemanticDB private (phaseMode: ExtractSemanticDB.PhaseMode) extends val appendDiagnostics = phaseMode == ExtractSemanticDB.PhaseMode.AppendDiagnostics if (appendDiagnostics) val warnings = ctx.reporter.allWarnings.groupBy(w => w.pos.source) - units.flatMap { unit => - warnings.get(unit.source).map { ws => - val unitCtx = ctx.fresh.setCompilationUnit(unit).withRootImports - val outputDir = - ExtractSemanticDB.semanticdbPath( - unit.source, - ExtractSemanticDB.semanticdbOutDir(using unitCtx), - sourceRoot - ) - (outputDir, ws.map(_.toSemanticDiagnostic)) + val buf = mutable.ListBuffer.empty[(Path, Seq[Diagnostic])] + units.foreach { unit => + val unitCtx = ctx.fresh.setCompilationUnit(unit).withRootImports + monitor(phaseName) { + warnings.get(unit.source).foreach { ws => + val outputDir = + ExtractSemanticDB.semanticdbPath( + unit.source, + ExtractSemanticDB.semanticdbOutDir(using unitCtx), + sourceRoot + ) + buf += ((outputDir, ws.map(_.toSemanticDiagnostic))) + } + }(using unitCtx) + } + cancellable { + buf.toList.asJava.parallelStream().forEach { case (out, warnings) => + ExtractSemanticDB.appendDiagnostics(warnings, out) } - }.asJava.parallelStream().forEach { case (out, warnings) => - ExtractSemanticDB.appendDiagnostics(warnings, out) } else val writeSemanticdbText = ctx.settings.semanticdbText.value units.foreach { unit => val unitCtx = ctx.fresh.setCompilationUnit(unit).withRootImports - val outputDir = - ExtractSemanticDB.semanticdbPath( + monitor(phaseName) { + val outputDir = + ExtractSemanticDB.semanticdbPath( + unit.source, + ExtractSemanticDB.semanticdbOutDir(using unitCtx), + sourceRoot + ) + val extractor = ExtractSemanticDB.Extractor() + extractor.extract(unit.tpdTree)(using unitCtx) + ExtractSemanticDB.write( unit.source, - ExtractSemanticDB.semanticdbOutDir(using unitCtx), - sourceRoot + extractor.occurrences.toList, + extractor.symbolInfos.toList, + extractor.synthetics.toList, + outputDir, + sourceRoot, + writeSemanticdbText ) - val extractor = ExtractSemanticDB.Extractor() - extractor.extract(unit.tpdTree)(using unitCtx) - ExtractSemanticDB.write( - unit.source, - extractor.occurrences.toList, - extractor.symbolInfos.toList, - extractor.synthetics.toList, - outputDir, - sourceRoot, - writeSemanticdbText - ) + }(using unitCtx) } units } @@ -611,4 +620,4 @@ object ExtractSemanticDB: traverse(vparam.tpt) tparams.foreach(tp => traverse(tp.rhs)) end Extractor -end ExtractSemanticDB \ No newline at end of file +end ExtractSemanticDB From 2f275cab2f5a3ec7780779294509bd04569de9ad Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 25 Oct 2023 19:15:57 +0200 Subject: [PATCH 10/13] use more descriptive names [Cherry-picked e92a834ae155b53326f97a60949ace50f1d85a73] --- compiler/src/dotty/tools/dotc/Run.scala | 38 +++++++++---------- .../src/dotty/tools/dotc/core/Contexts.scala | 5 --- .../dotc/sbt/interfaces/ProgressCallback.java | 2 +- 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 0be2b530ab56..098c8342cfca 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -189,12 +189,12 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint def doAdvanceUnit()(using Context): Unit = trackProgress: progress => - progress.unitc += 1 // trace that we completed a unit in the current (sub)phase + progress.currentUnitCount += 1 // trace that we completed a unit in the current (sub)phase progress.refreshProgress() def doAdvanceLate()(using Context): Unit = trackProgress: progress => - progress.latec += 1 // trace that we completed a late compilation + progress.currentLateUnitCount += 1 // trace that we completed a late compilation progress.refreshProgress() private def doEnterPhase(currentPhase: Phase)(using Context): Unit = @@ -210,22 +210,22 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint private def doAdvancePhase(currentPhase: Phase, wasRan: Boolean)(using Context): Unit = trackProgress: progress => - progress.unitc = 0 // reset unit count in current (sub)phase - progress.subtraversalc = 0 // reset subphase index to initial - progress.seen += 1 // trace that we've seen a (sub)phase + progress.currentUnitCount = 0 // reset unit count in current (sub)phase + progress.currentCompletedSubtraversalCount = 0 // reset subphase index to initial + progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase if wasRan then // add an extra traversal now that we completed a (sub)phase - progress.traversalc += 1 + progress.completedTraversalCount += 1 else // no subphases were ran, remove traversals from expected total progress.totalTraversals -= currentPhase.traversals private def doAdvanceSubPhase()(using Context): Unit = trackProgress: progress => - progress.unitc = 0 // reset unit count in current (sub)phase - progress.seen += 1 // trace that we've seen a (sub)phase - progress.traversalc += 1 // add an extra traversal now that we completed a (sub)phase - progress.subtraversalc += 1 // record that we've seen a subphase + progress.currentUnitCount = 0 // reset unit count in current (sub)phase + progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase + progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase + progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase if !progress.isCancelled() then progress.tickSubphase() @@ -497,12 +497,12 @@ object Run { private class Progress(cb: ProgressCallback, private val run: Run, val initialTraversals: Int): export cb.{cancel, isCancelled} - private[Run] var totalTraversals: Int = initialTraversals // track how many phases we expect to run - private[Run] var unitc: Int = 0 // current unit count in the current (sub)phase - private[Run] var latec: Int = 0 // current late unit count - private[Run] var traversalc: Int = 0 // completed traversals over all files - private[Run] var subtraversalc: Int = 0 // completed subphases in the current phase - private[Run] var seen: Int = 0 // how many phases we've seen so far + var totalTraversals: Int = initialTraversals // track how many phases we expect to run + var currentUnitCount: Int = 0 // current unit count in the current (sub)phase + var currentLateUnitCount: Int = 0 // current late unit count + var completedTraversalCount: Int = 0 // completed traversals over all files + var currentCompletedSubtraversalCount: Int = 0 // completed subphases in the current phase + var seenPhaseCount: Int = 0 // how many phases we've seen so far private var currPhase: Phase = uninitialized // initialized by enterPhase private var subPhases: SubPhases = uninitialized // initialized by enterPhase @@ -518,7 +518,7 @@ object Run { /** Compute the current (sub)phase name and next (sub)phase name */ private[Run] def tickSubphase()(using Context): Unit = - val index = subtraversalc + val index = currentCompletedSubtraversalCount val s = subPhases currPhaseName = s.subPhase(index) nextPhaseName = @@ -526,13 +526,13 @@ object Run { else s.next match case None => "" case Some(next0) => next0.subPhase(0) - if seen > 0 then + if seenPhaseCount > 0 then refreshProgress() /** Counts the number of completed full traversals over files, plus the number of units in the current phase */ private def currentProgress(): Int = - traversalc * work() + unitc + latec + completedTraversalCount * work() + currentUnitCount + currentLateUnitCount /**Total progress is computed as the sum of * - the number of traversals we expect to make over all files diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 3404efebf215..20b553149edb 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -186,11 +186,6 @@ object Contexts { val local = progressCallback if local != null then op(local) - def cancelSignalRecorded: Boolean = - val local = progressCallback - val noSignalRecieved = local == null || !local.isCancelled - !noSignalRecieved // if true then cancel request was recorded - /** The current plain printer */ def printerFn: Context => Printer = store(printerFnLoc) diff --git a/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java index d1e076c75bfa..39f5ca39962b 100644 --- a/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java +++ b/compiler/src/dotty/tools/dotc/sbt/interfaces/ProgressCallback.java @@ -3,7 +3,7 @@ import dotty.tools.dotc.CompilationUnit; public interface ProgressCallback { - /** Record that the cancellation signal has been recieved during the Zinc run. */ + /** Record that the cancellation signal has been received during the Zinc run. */ default void cancel() {} /** Report on if there was a cancellation signal for the current Zinc run. */ From 917965b4208964994c2cb81ed3ab5859ca52b4d8 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 25 Oct 2023 19:52:18 +0200 Subject: [PATCH 11/13] use explicit unitContext [Cherry-picked 550c316608629b9096161deaf8494a69e1e9a09e] --- .../dotty/tools/dotc/fromtasty/ReadTasty.scala | 2 +- .../dotty/tools/dotc/parsing/ParserPhase.scala | 6 +++--- .../tools/dotc/transform/init/Checker.scala | 2 +- .../dotty/tools/dotc/typer/TyperPhase.scala | 18 +++++++++--------- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala index d3ff1776e621..455b6c89a0ba 100644 --- a/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala +++ b/compiler/src/dotty/tools/dotc/fromtasty/ReadTasty.scala @@ -24,7 +24,7 @@ class ReadTasty extends Phase { withMode(Mode.ReadPositions) { val nextUnits = collection.mutable.ListBuffer.empty[CompilationUnit] val unitContexts = units.view.map(ctx.fresh.setCompilationUnit) - for given Context <- unitContexts if addTasty(nextUnits += _) do () + for unitContext <- unitContexts if addTasty(nextUnits += _)(using unitContext) do () nextUnits.toList } diff --git a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala index d8c1f5f17adf..bcabfbd03a1d 100644 --- a/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala +++ b/compiler/src/dotty/tools/dotc/parsing/ParserPhase.scala @@ -48,9 +48,9 @@ class Parser extends Phase { val unitContexts0 = for - given Context <- unitContexts - if parse - yield ctx + unitContext <- unitContexts + if parse(using unitContext) + yield unitContext record("parsedTrees", ast.Trees.ntrees) diff --git a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala index 23ccee675220..523a82dcd947 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala @@ -41,7 +41,7 @@ class Checker extends Phase: val unitContexts = units.map(unit => checkCtx.fresh.setCompilationUnit(unit)) val units0 = - for given Context <- unitContexts if traverse(traverser) yield ctx.compilationUnit + for unitContext <- unitContexts if traverse(traverser)(using unitContext) yield unitContext.compilationUnit cancellable { val classes = traverser.getClasses() diff --git a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala index 210e457a7764..10796dce2e7c 100644 --- a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala +++ b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala @@ -75,9 +75,9 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { val unitContexts0 = try for - given Context <- unitContexts - if enterSyms - yield ctx + unitContext <- unitContexts + if enterSyms(using unitContext) + yield unitContext finally ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)" @@ -94,9 +94,9 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { val unitContexts1 = try for - given Context <- unitContexts0 - if typeCheck - yield ctx + unitContext <- unitContexts0 + if typeCheck(using unitContext) + yield unitContext finally ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)" @@ -104,9 +104,9 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { val unitContexts2 = for - given Context <- unitContexts1 - if javaCheck // after typechecking to avoid cycles - yield ctx + unitContext <- unitContexts1 + if javaCheck(using unitContext) // after typechecking to avoid cycles + yield unitContext val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper) ctx.run.nn.checkSuspendedUnits(newUnits) From 0404c006d7e5f2a7aec5104758e98ac624042454 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 25 Oct 2023 20:20:35 +0200 Subject: [PATCH 12/13] refactor semanticdb [Cherry-picked b072662b867026f65c739f4ff7a284908cda8734] --- .../dotc/semanticdb/ExtractSemanticDB.scala | 83 ++++++++++--------- 1 file changed, 46 insertions(+), 37 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala b/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala index b2f16eb741cb..07f3fcea2e88 100644 --- a/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala +++ b/compiler/src/dotty/tools/dotc/semanticdb/ExtractSemanticDB.scala @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._ import scala.PartialFunction.condOpt import typer.ImportInfo.withRootImports +import dotty.tools.dotc.reporting.Diagnostic.Warning import dotty.tools.dotc.{semanticdb => s} import dotty.tools.io.{AbstractFile, JarArchive} import dotty.tools.dotc.semanticdb.DiagnosticOps.* @@ -61,56 +62,64 @@ class ExtractSemanticDB private (phaseMode: ExtractSemanticDB.PhaseMode) extends // Check not needed since it does not transform trees override def isCheckable: Boolean = false + private def computeDiagnostics( + sourceRoot: String, + warnings: Map[SourceFile, List[Warning]], + append: ((Path, List[Diagnostic])) => Unit)(using Context): Boolean = monitor(phaseName) { + val unit = ctx.compilationUnit + warnings.get(unit.source).foreach { ws => + val outputDir = + ExtractSemanticDB.semanticdbPath( + unit.source, + ExtractSemanticDB.semanticdbOutDir, + sourceRoot + ) + append((outputDir, ws.map(_.toSemanticDiagnostic))) + } + } + + private def extractSemanticDB(sourceRoot: String, writeSemanticdbText: Boolean)(using Context): Boolean = + monitor(phaseName) { + val unit = ctx.compilationUnit + val outputDir = + ExtractSemanticDB.semanticdbPath( + unit.source, + ExtractSemanticDB.semanticdbOutDir, + sourceRoot + ) + val extractor = ExtractSemanticDB.Extractor() + extractor.extract(unit.tpdTree) + ExtractSemanticDB.write( + unit.source, + extractor.occurrences.toList, + extractor.symbolInfos.toList, + extractor.synthetics.toList, + outputDir, + sourceRoot, + writeSemanticdbText + ) + } + override def runOn(units: List[CompilationUnit])(using ctx: Context): List[CompilationUnit] = { val sourceRoot = ctx.settings.sourceroot.value val appendDiagnostics = phaseMode == ExtractSemanticDB.PhaseMode.AppendDiagnostics + val unitContexts = units.map(ctx.fresh.setCompilationUnit(_).withRootImports) if (appendDiagnostics) val warnings = ctx.reporter.allWarnings.groupBy(w => w.pos.source) val buf = mutable.ListBuffer.empty[(Path, Seq[Diagnostic])] - units.foreach { unit => - val unitCtx = ctx.fresh.setCompilationUnit(unit).withRootImports - monitor(phaseName) { - warnings.get(unit.source).foreach { ws => - val outputDir = - ExtractSemanticDB.semanticdbPath( - unit.source, - ExtractSemanticDB.semanticdbOutDir(using unitCtx), - sourceRoot - ) - buf += ((outputDir, ws.map(_.toSemanticDiagnostic))) - } - }(using unitCtx) - } + val units0 = + for unitCtx <- unitContexts if computeDiagnostics(sourceRoot, warnings, buf += _)(using unitCtx) + yield unitCtx.compilationUnit cancellable { buf.toList.asJava.parallelStream().forEach { case (out, warnings) => ExtractSemanticDB.appendDiagnostics(warnings, out) } } + units0 else val writeSemanticdbText = ctx.settings.semanticdbText.value - units.foreach { unit => - val unitCtx = ctx.fresh.setCompilationUnit(unit).withRootImports - monitor(phaseName) { - val outputDir = - ExtractSemanticDB.semanticdbPath( - unit.source, - ExtractSemanticDB.semanticdbOutDir(using unitCtx), - sourceRoot - ) - val extractor = ExtractSemanticDB.Extractor() - extractor.extract(unit.tpdTree)(using unitCtx) - ExtractSemanticDB.write( - unit.source, - extractor.occurrences.toList, - extractor.symbolInfos.toList, - extractor.synthetics.toList, - outputDir, - sourceRoot, - writeSemanticdbText - ) - }(using unitCtx) - } - units + for unitCtx <- unitContexts if extractSemanticDB(sourceRoot, writeSemanticdbText)(using unitCtx) + yield unitCtx.compilationUnit } def run(using Context): Unit = unsupported("run") From 0de90f3efc7697f7513ccd805a060ea101ce3c1a Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 25 Oct 2023 21:19:20 +0200 Subject: [PATCH 13/13] simplify subphase traversal [Cherry-picked b51077277085319da721b8aa2aaa0b4fafcbafa5] --- compiler/src/dotty/tools/dotc/Run.scala | 40 ++++++++++------- .../src/dotty/tools/dotc/core/Phases.scala | 9 +++- .../dotty/tools/dotc/typer/TyperPhase.scala | 44 +++++++++---------- .../tools/dotc/sbt/ProgressCallbackTest.scala | 2 +- .../xsbt/CompileProgressSpecification.scala | 2 +- 5 files changed, 56 insertions(+), 41 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 098c8342cfca..40a343fb1267 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -220,14 +220,15 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint // no subphases were ran, remove traversals from expected total progress.totalTraversals -= currentPhase.traversals - private def doAdvanceSubPhase()(using Context): Unit = + private def tryAdvanceSubPhase()(using Context): Unit = trackProgress: progress => - progress.currentUnitCount = 0 // reset unit count in current (sub)phase - progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase - progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase - progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase - if !progress.isCancelled() then - progress.tickSubphase() + if progress.canAdvanceSubPhase then + progress.currentUnitCount = 0 // reset unit count in current (sub)phase + progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase + progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase + progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase + if !progress.isCancelled() then + progress.tickSubphase() /** Will be set to true if any of the compiled compilation units contains * a pureFunctions language import. @@ -475,6 +476,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint object Run { + case class SubPhase(val name: String): + override def toString: String = name + class SubPhases(val phase: Phase): require(phase.exists) @@ -482,13 +486,15 @@ object Run { case phase: MegaPhase => phase.shortPhaseName case phase => phase.phaseName - val all = IArray.from(phase.subPhases.map(sub => s"$baseName ($sub)")) + val all = IArray.from(phase.subPhases.map(sub => s"$baseName[$sub]")) def next(using Context): Option[SubPhases] = val next0 = phase.megaPhase.next.megaPhase if next0.exists then Some(SubPhases(next0)) else None + def size: Int = all.size + def subPhase(index: Int) = if index < all.size then all(index) else baseName @@ -510,14 +516,17 @@ object Run { private var nextPhaseName: String = uninitialized // initialized by enterPhase /** Enter into a new real phase, setting the current and next (sub)phases */ - private[Run] def enterPhase(newPhase: Phase)(using Context): Unit = + def enterPhase(newPhase: Phase)(using Context): Unit = if newPhase ne currPhase then currPhase = newPhase subPhases = SubPhases(newPhase) tickSubphase() + def canAdvanceSubPhase: Boolean = + currentCompletedSubtraversalCount + 1 < subPhases.size + /** Compute the current (sub)phase name and next (sub)phase name */ - private[Run] def tickSubphase()(using Context): Unit = + def tickSubphase()(using Context): Unit = val index = currentCompletedSubtraversalCount val s = subPhases currPhaseName = s.subPhase(index) @@ -546,12 +555,12 @@ object Run { private def requireInitialized(): Unit = require((currPhase: Phase | Null) != null, "enterPhase was not called") - private[Run] def checkCancellation(): Boolean = + def checkCancellation(): Boolean = if Thread.interrupted() then cancel() isCancelled() /** trace that we are beginning a unit in the current (sub)phase, unless cancelled */ - private[Run] def tryEnterUnit(unit: CompilationUnit): Boolean = + def tryEnterUnit(unit: CompilationUnit): Boolean = if checkCancellation() then false else requireInitialized() @@ -559,7 +568,7 @@ object Run { true /** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */ - private[Run] def refreshProgress()(using Context): Unit = + def refreshProgress()(using Context): Unit = requireInitialized() val total = totalProgress() if total > 0 && !cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then @@ -581,8 +590,9 @@ object Run { def advanceUnit()(using Context): Unit = if run != null then run.doAdvanceUnit() - def advanceSubPhase()(using Context): Unit = - if run != null then run.doAdvanceSubPhase() + /** if there exists another subphase, switch to it and record progress */ + def enterNextSubphase()(using Context): Unit = + if run != null then run.tryAdvanceSubPhase() /** advance the late count and record progress in the current phase */ def advanceLate()(using Context): Unit = diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 31e07001a4a2..d6a49186b539 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -318,7 +318,7 @@ object Phases { def runsAfter: Set[String] = Set.empty /** for purposes of progress tracking, overridden in TyperPhase */ - def subPhases: List[String] = Nil + def subPhases: List[Run.SubPhase] = Nil final def traversals: Int = if subPhases.isEmpty then 1 else subPhases.length /** @pre `isRunnable` returns true */ @@ -460,6 +460,13 @@ object Phases { else false + inline def runSubPhase[T](id: Run.SubPhase)(inline body: (Run.SubPhase, Context) ?=> T)(using Context): T = + given Run.SubPhase = id + try + body + finally + ctx.run.enterNextSubphase() + /** Do not run if compile progress has been cancelled */ final def cancellable(body: Context ?=> Unit)(using Context): Boolean = if ctx.run.enterRegion() then diff --git a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala index 10796dce2e7c..857ed1bad4d9 100644 --- a/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala +++ b/compiler/src/dotty/tools/dotc/typer/TyperPhase.scala @@ -3,6 +3,7 @@ package dotc package typer import core._ +import Run.SubPhase import Phases._ import Contexts._ import Symbols._ @@ -31,13 +32,13 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { // Run regardless of parsing errors override def isRunnable(implicit ctx: Context): Boolean = true - def enterSyms(using Context): Boolean = monitor("indexing") { + def enterSyms(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit ctx.typer.index(unit.untpdTree) typr.println("entered: " + unit.source) } - def typeCheck(using Context): Boolean = monitor("typechecking") { + def typeCheck(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit try if !unit.suspended then @@ -49,7 +50,7 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { catch case _: CompilationUnit.SuspendException => () } - def javaCheck(using Context): Boolean = monitor("checking java") { + def javaCheck(using Context)(using subphase: SubPhase): Boolean = monitor(subphase.name) { val unit = ctx.compilationUnit if unit.isJava then JavaChecks.check(unit.tpdTree) @@ -58,10 +59,11 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { protected def discardAfterTyper(unit: CompilationUnit)(using Context): Boolean = unit.isJava || unit.suspended - /** Keep synchronised with `monitor` subcalls */ - override def subPhases: List[String] = List("indexing", "typechecking", "checking java") + override val subPhases: List[SubPhase] = List( + SubPhase("indexing"), SubPhase("typechecking"), SubPhase("checkingJava")) override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = + val List(Indexing @ _, Typechecking @ _, CheckingJava @ _) = subPhases: @unchecked val unitContexts = for unit <- units yield val newCtx0 = ctx.fresh.setPhase(this.start).setCompilationUnit(unit) @@ -72,14 +74,12 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { else newCtx - val unitContexts0 = - try - for - unitContext <- unitContexts - if enterSyms(using unitContext) - yield unitContext - finally - ctx.run.advanceSubPhase() // tick from "typer (indexing)" to "typer (typechecking)" + val unitContexts0 = runSubPhase(Indexing) { + for + unitContext <- unitContexts + if enterSyms(using unitContext) + yield unitContext + } ctx.base.parserPhase match { case p: ParserPhase => @@ -91,23 +91,21 @@ class TyperPhase(addRootImports: Boolean = true) extends Phase { case _ => } - val unitContexts1 = - try - for - unitContext <- unitContexts0 - if typeCheck(using unitContext) - yield unitContext - finally - ctx.run.advanceSubPhase() // tick from "typer (typechecking)" to "typer (java checking)" + val unitContexts1 = runSubPhase(Typechecking) { + for + unitContext <- unitContexts0 + if typeCheck(using unitContext) + yield unitContext + } record("total trees after typer", ast.Trees.ntrees) - val unitContexts2 = + val unitContexts2 = runSubPhase(CheckingJava) { for unitContext <- unitContexts1 if javaCheck(using unitContext) // after typechecking to avoid cycles yield unitContext - + } val newUnits = unitContexts2.map(_.compilationUnit).filterNot(discardAfterTyper) ctx.run.nn.checkSuspendedUnits(newUnits) newUnits diff --git a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala index e6e67b997aae..489dc0f1759c 100644 --- a/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala +++ b/compiler/test/dotty/tools/dotc/sbt/ProgressCallbackTest.scala @@ -57,7 +57,7 @@ final class ProgressCallbackTest extends DottyTest: @Test def cancelMidTyper: Unit = - inspectCancellationAtPhase("typer (typechecking)") + inspectCancellationAtPhase("typer[typechecking]") @Test def cancelErasure: Unit = diff --git a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala index 45f9daa70e05..bcdac0547e75 100644 --- a/sbt-bridge/test/xsbt/CompileProgressSpecification.scala +++ b/sbt-bridge/test/xsbt/CompileProgressSpecification.scala @@ -52,7 +52,7 @@ class CompileProgressSpecification { val someExpectedPhases = // just check some "fundamental" phases, don't put all phases to avoid brittleness Set( "parser", - "typer (indexing)", "typer (typechecking)", "typer (checking java)", + "typer[indexing]", "typer[typechecking]", "typer[checkingJava]", "sbt-deps", "posttyper", "sbt-api",