diff --git a/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/ConsistentFileAnalysisStore.scala b/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/ConsistentFileAnalysisStore.scala index 5cb69fcab..418eae932 100644 --- a/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/ConsistentFileAnalysisStore.scala +++ b/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/ConsistentFileAnalysisStore.scala @@ -12,29 +12,25 @@ package sbt.internal.inc.consistent * additional information regarding copyright ownership. */ -import java.io.{ File, FileInputStream, FileOutputStream } -import java.util.Optional import sbt.io.{ IO, Using } +import xsbti.compile.analysis.ReadWriteMappers import xsbti.compile.{ AnalysisContents, AnalysisStore => XAnalysisStore } +import java.io.{ File, FileInputStream, FileOutputStream } +import java.util.Optional import scala.util.control.Exception.allCatch -import xsbti.compile.analysis.ReadWriteMappers - -import scala.concurrent.ExecutionContext object ConsistentFileAnalysisStore { def text( file: File, mappers: ReadWriteMappers, sort: Boolean = true, - ec: ExecutionContext = ExecutionContext.global, parallelism: Int = Runtime.getRuntime.availableProcessors() ): XAnalysisStore = new AStore( file, new ConsistentAnalysisFormat(mappers, sort), SerializerFactory.text, - ec, parallelism ) @@ -59,14 +55,12 @@ object ConsistentFileAnalysisStore { file: File, mappers: ReadWriteMappers, sort: Boolean, - ec: ExecutionContext = ExecutionContext.global, parallelism: Int = Runtime.getRuntime.availableProcessors() ): XAnalysisStore = new AStore( file, new ConsistentAnalysisFormat(mappers, sort), SerializerFactory.binary, - ec, parallelism ) @@ -74,7 +68,6 @@ object ConsistentFileAnalysisStore { file: File, format: ConsistentAnalysisFormat, sf: SerializerFactory[S, D], - ec: ExecutionContext = ExecutionContext.global, parallelism: Int = Runtime.getRuntime.availableProcessors() ) extends XAnalysisStore { @@ -85,7 +78,7 @@ object ConsistentFileAnalysisStore { if (!file.getParentFile.exists()) file.getParentFile.mkdirs() val fout = new FileOutputStream(tmpAnalysisFile) try { - val gout = new ParallelGzipOutputStream(fout, ec, parallelism) + val gout = new ParallelGzipOutputStream(fout, parallelism) val ser = sf.serializerFor(gout) format.write(ser, analysis, setup) gout.close() diff --git a/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/ParallelGzipOutputStream.scala b/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/ParallelGzipOutputStream.scala index bdce23d10..0419abf82 100644 --- a/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/ParallelGzipOutputStream.scala +++ b/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/ParallelGzipOutputStream.scala @@ -1,123 +1,219 @@ +// Original code by Stefan Zeiger (see https://github.com/szeiger/zinc/blob/1d296b2fbeaae1cf14e4c00db0bbc2203f9783a4/internal/zinc-persist/src/main/scala/sbt/internal/inc/consistent/NewParallelGzipOutputStream.scala) +// Modified by Rex Kerr to use Java threads directly rather than Future package sbt.internal.inc.consistent - import java.io.{ ByteArrayOutputStream, FilterOutputStream, OutputStream } import java.util.zip.{ CRC32, Deflater, DeflaterOutputStream } +import java.util.concurrent.{ SynchronousQueue, ArrayBlockingQueue, LinkedTransferQueue } + import scala.annotation.tailrec -import scala.concurrent.duration.Duration -import scala.concurrent.{ Await, ExecutionContext, Future } -import scala.collection.mutable /** * Parallel gzip compression. Algorithm based on https://github.com/shevek/parallelgzip * with additional optimization and simplification. This is essentially a block-buffered * stream but instead of writing a full block to the underlying output, it is passed to a - * thread pool for compression and the Futures of compressed blocks are collected when - * flushing. + * thread pool for compression and the compressed blocks are collected when flushing. */ object ParallelGzipOutputStream { private val blockSize = 64 * 1024 private val compression = Deflater.DEFAULT_COMPRESSION - private class BufOut(size: Int) extends ByteArrayOutputStream(size) { - def writeTo(buf: Array[Byte]): Unit = System.arraycopy(this.buf, 0, buf, 0, count) + // Holds an input buffer to load data and an output buffer to write + // the compressed data into. Compressing clears the input buffer. + // Compressed data can be retrieved with `output.writeTo(OutputStream)`. + private final class Block(var index: Int) { + val input = new Array[Byte](blockSize) + var inputN = 0 + val output = new ByteArrayOutputStream(blockSize + (blockSize >> 3)) + val deflater = new Deflater(compression, true) + val dos = new DeflaterOutputStream(output, deflater, true) + + def compress(): Unit = { + deflater.reset() + output.reset() + dos.write(input, 0, inputN) + dos.flush() + inputN = 0 + } } - private class Worker { - private[this] val defl = new Deflater(compression, true) - private[this] val buf = new BufOut(blockSize + (blockSize >> 3)) - private[this] val out = new DeflaterOutputStream(buf, defl, true) - def compress(b: Block): Unit = { - defl.reset() - buf.reset() - out.write(b.data, 0, b.length) - out.flush() - b.length = buf.size - if (b.length > b.data.length) b.data = new Array[Byte](b.length) - buf.writeTo(b.data) + // Waits for data to appear in a SynchronousQueue. + // When it does, compress it and pass it along. Also put self in a pool of workers awaiting more work. + // If data does not appear but a `None` appears instead, cease running (and do not add self to work queue). + private final class Worker( + val workers: ArrayBlockingQueue[Worker], + val compressed: LinkedTransferQueue[Either[Int, Block]] + ) extends Thread { + val work = new SynchronousQueue[Option[Block]] + + @tailrec + def loop(): Unit = { + work.take() match { + case Some(block) => + block.compress() + compressed.put(Right(block)) + workers.put(this) + loop() + case _ => + } } - } - private val localWorker = new ThreadLocal[Worker] { - override def initialValue = new Worker + override def run(): Unit = { + loop() + } } - private class Block { - var data = new Array[Byte](blockSize + (blockSize >> 3)) - var length = 0 + // Waits for data to appear in a LinkedTransferQueue. + // When it does, place it into a sorted tree and, if the data is in order, write it out. + // Once the data has been written, place it into a cache for completed buffers. + // If data does not appear but an integer appears instead, set a mark to quit once + // that many blocks have been written. + private final class Scribe(out: OutputStream, val completed: LinkedTransferQueue[Block]) + extends Thread { + val work = new LinkedTransferQueue[Either[Int, Block]] + private val tree = new collection.mutable.TreeMap[Int, Block] + private var next = 0 + private var stopAt = Int.MaxValue + + @tailrec + def loop(): Unit = { + work.take() match { + case Right(block) => + tree(block.index) = block + case Left(limit) => + stopAt = limit + } + while (tree.nonEmpty && tree.head._2.index == next) { + val block = tree.remove(next).get + block.output.writeTo(out) + completed.put(block) + next += 1 + } + if (next < stopAt) loop() + } + + override def run(): Unit = { + loop() + } } private val header = Array[Byte](0x1f.toByte, 0x8b.toByte, Deflater.DEFLATED, 0, 0, 0, 0, 0, 0, 0) } -final class ParallelGzipOutputStream(out: OutputStream, ec: ExecutionContext, parallelism: Int) +/** + * Implements a parallel chunked compression algorithm (using minimum of two extra threads). + * Note that the methods in this class are not themselves threadsafe; this class + * has "interior concurrency" (c.f. interior mutability). In particular, writing + * concurrent with or after a close operation is not defined. + */ +final class ParallelGzipOutputStream(out: OutputStream, parallelism: Int) extends FilterOutputStream(out) { import ParallelGzipOutputStream._ - private final val crc = new CRC32 - private final val queueLimit = parallelism * 3 - // preferred on 2.13: new mutable.ArrayDeque[Future[Block]](queueLimit) - private final val pending = mutable.Queue.empty[Future[Block]] - private var current: Block = new Block - private var free: Block = _ - private var total = 0L + private val crc = new CRC32 + private var totalBlocks = 0 + private var totalCount = 0L + + private val bufferLimit = parallelism * 3 + private var bufferCount = 1 + private var current = new Block(0) + + private val workerCount = math.max(1, parallelism - 1) + private val workers = new ArrayBlockingQueue[Worker](workerCount) + private val buffers = new LinkedTransferQueue[Block]() out.write(header) + private val scribe = new Scribe(out, buffers) + scribe.start() + + while (workers.remainingCapacity() > 0) { + val w = new Worker(workers, scribe.work) + workers.put(w) + w.start() + } override def write(b: Int): Unit = write(Array[Byte]((b & 0xff).toByte)) override def write(b: Array[Byte]): Unit = write(b, 0, b.length) - @tailrec override def write(b: Array[Byte], off: Int, len: Int): Unit = { - val copy = math.min(len, blockSize - current.length) + @tailrec + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + val copy = math.min(len, blockSize - current.inputN) crc.update(b, off, copy) - total += copy - System.arraycopy(b, off, current.data, current.length, copy) - current.length += copy + totalCount += copy + System.arraycopy(b, off, current.input, current.inputN, copy) + current.inputN += copy if (copy < len) { submit() write(b, off + copy, len - copy) } } - private[this] def submit(): Unit = { - flushUntil(queueLimit - 1) - val finalBlock = current - pending += Future { localWorker.get.compress(finalBlock); finalBlock }(ec) - if (free != null) { - current = free - free = null - } else current = new Block() + private def submit(): Unit = { + val w = workers.take() + w.work.put(Some(current)) + totalBlocks += 1 + current = buffers.poll() + if (current eq null) { + if (bufferCount < bufferLimit) { + current = new Block(totalBlocks) + bufferCount += 1 + } else { + current = buffers.take() + } + } + current.index = totalBlocks } - private def flushUntil(remaining: Int): Unit = - while (pending.length > remaining || pending.headOption.exists(_.isCompleted)) { - val b = Await.result(pending.dequeue(), Duration.Inf) - out.write(b.data, 0, b.length) - b.length = 0 - free = b + private def flushImpl(shutdown: Boolean): Unit = { + val fetched = new Array[Block](bufferCount - 1) + var n = 0 + // If we have all the buffers, all pending work is done. + while (n < fetched.length) { + fetched(n) = buffers.take() + n += 1 } + if (shutdown) { + // Send stop signal to workers and scribe + n = workerCount + while (n > 0) { + workers.take().work.put(None) + n -= 1 + } + scribe.work.put(Left(totalBlocks)) + } else { + // Put all the buffers back so we can keep accepting data. + n = 0 + while (n < fetched.length) { + buffers.put(fetched(n)) + n += 1 + } + } + } + /** + * Blocks until all pending data is written. Note that this is a poor use of a parallel data writing class. + * It is preferable to write all data and then close the stream. Note also that a flushed stream will not + * have the trailing CRC checksum and therefore will not be a valid compressed file, so there is little point + * flushing early. + */ override def flush(): Unit = { - if (current.length > 0) submit() - flushUntil(0) + if (current.inputN > 0) submit() + flushImpl(false) super.flush() } override def close(): Unit = { - flush() + if (current.inputN > 0) submit() + flushImpl(true) + val buf = new Array[Byte](10) - def int(i: Int, off: Int): Unit = { - buf(off) = ((i & 0xff).toByte) - buf(off + 1) = (((i >>> 8) & 0xff).toByte) - buf(off + 2) = (((i >>> 16) & 0xff).toByte) - buf(off + 3) = (((i >>> 24) & 0xff).toByte) - } - buf(0) = 3 - int(crc.getValue.toInt, 2) - int((total & 0xffffffffL).toInt, 6) + val bb = java.nio.ByteBuffer.wrap(buf) + bb.order(java.nio.ByteOrder.LITTLE_ENDIAN) + bb.putShort(3) + bb.putInt(crc.getValue.toInt) + bb.putInt((totalCount & 0xffffffffL).toInt) out.write(buf) + out.close() - total = Integer.MIN_VALUE - free = null } } diff --git a/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ConsistentAnalysisFormatIntegrationSuite.scala b/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ConsistentAnalysisFormatIntegrationSuite.scala index 1a8b7299e..4b953fa29 100644 --- a/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ConsistentAnalysisFormatIntegrationSuite.scala +++ b/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ConsistentAnalysisFormatIntegrationSuite.scala @@ -1,11 +1,11 @@ package sbt.inc.consistent -import java.io.File +import java.io.{ File, FileInputStream } import java.util.Arrays import org.scalatest.funsuite.AnyFunSuite import sbt.internal.inc.consistent.ConsistentFileAnalysisStore import sbt.internal.inc.{ Analysis, FileAnalysisStore } -import sbt.io.IO +import sbt.io.{ IO, Using } import xsbti.compile.{ AnalysisContents, AnalysisStore } import xsbti.compile.analysis.ReadWriteMappers @@ -50,6 +50,22 @@ class ConsistentAnalysisFormatIntegrationSuite extends AnyFunSuite { } } + test("compression ratio") { + for (d <- data) { + assert(d.exists()) + val api = read(FileAnalysisStore.text(d)) + val file = write("cbin1.zip", api) + val uncompressedSize = Using.gzipInputStream(new FileInputStream(file)) { in => + val content = IO.readBytes(in) + content.length + } + val compressedSize = d.length() + val compressionRatio = compressedSize.toDouble / uncompressedSize.toDouble + assert(compressionRatio < 0.85) + // compression rate for each data: 0.8185090254676337, 0.7247774786370688, 0.8346021341469837 + } + } + def read(store: AnalysisStore): AnalysisContents = { val api = store.unsafeGet() // Force loading of companion file and check that the companion data is present: diff --git a/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ConsistentAnalysisFormatSuite.scala b/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ConsistentAnalysisFormatSuite.scala index d9085323f..97f9c9a12 100644 --- a/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ConsistentAnalysisFormatSuite.scala +++ b/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ConsistentAnalysisFormatSuite.scala @@ -1,21 +1,13 @@ package sbt.inc.consistent -import java.io.{ - BufferedInputStream, - BufferedReader, - ByteArrayInputStream, - ByteArrayOutputStream, - StringReader, - StringWriter -} -import java.util.zip.GZIPInputStream -import java.util.Arrays -import scala.util.Random import org.scalatest.funsuite.AnyFunSuite -import sbt.internal.inc.consistent._ +import sbt.internal.inc.consistent.* import sbt.io.IO -import scala.concurrent.ExecutionContext +import java.io.* +import java.util.Arrays +import java.util.zip.GZIPInputStream +import scala.util.Random class ConsistentAnalysisFormatSuite extends AnyFunSuite { @@ -88,24 +80,4 @@ class ConsistentAnalysisFormatSuite extends AnyFunSuite { writeTo(SerializerFactory.binary.serializerFor(out)) readFrom(SerializerFactory.binary.deserializerFor(new ByteArrayInputStream(out.toByteArray))) } - - test("ParallelGzip") { - val bs = 64 * 1024 - val rnd = new Random(0L) - for { - threads <- Seq(1, 8) - size <- Seq(0, bs - 1, bs, bs + 1, bs * 8 - 1, bs * 8, bs * 8 + 1) - } { - val a = new Array[Byte](size) - rnd.nextBytes(a) - val bout = new ByteArrayOutputStream() - val gout = new ParallelGzipOutputStream(bout, ExecutionContext.global, parallelism = threads) - gout.write(a) - gout.close() - val gin = - new BufferedInputStream(new GZIPInputStream(new ByteArrayInputStream(bout.toByteArray))) - val a2 = IO.readBytes(gin) - assert(Arrays.equals(a, a2), s"threads = $threads, size = $size") - } - } } diff --git a/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ParallelGzipOutputStreamSpecification.scala b/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ParallelGzipOutputStreamSpecification.scala new file mode 100644 index 000000000..cf8b43a0f --- /dev/null +++ b/internal/zinc-persist/src/test/scala/sbt/inc/consistent/ParallelGzipOutputStreamSpecification.scala @@ -0,0 +1,186 @@ +package sbt.inc.consistent + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.io.{ BufferedInputStream, ByteArrayInputStream, ByteArrayOutputStream } +import java.util.zip.GZIPInputStream +import java.nio.file.{ Files, Paths, StandardOpenOption } +import sbt.internal.inc.consistent.ParallelGzipOutputStream +import sbt.io.IO +import sbt.io.Using + +import java.util.Arrays +import collection.parallel.CollectionConverters.* +import scala.util.Random +import scala.concurrent.{ Await, Future } +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.* + +class ParallelGzipOutputStreamSpecification extends AnyFlatSpec with Matchers { + val defaultSize: Int = 64 * 1024 + val sizes: Seq[Int] = Seq( + 0, + 1, + 3, + 32, + 127, + 1025, + defaultSize - 1, + defaultSize, + defaultSize + 1, + defaultSize * 8 - 1, + defaultSize * 8, + defaultSize * 8 + 1 + ) + val numberOfGzipStreams: Seq[Int] = Seq(1, 2, 4, 8, 15) + val parallelisms: Seq[Int] = 1 to 17 + + def decompress(data: Array[Byte]): Array[Byte] = { + Using.gzipInputStream(new ByteArrayInputStream(data))(IO.readBytes) + } + + def compress(data: Array[Byte], parallelism: Int, testSetup: String): Array[Byte] = { + val bout = new ByteArrayOutputStream() + val gout = new ParallelGzipOutputStream(bout, parallelism) + try { + gout.write(data) + } catch { + case e: Exception => + handleFailure(Array[Byte](), data, testSetup, "Compression Failed", Some(e)) + } finally { + gout.close() + } + bout.toByteArray + } + + def writeToFile(data: Array[Byte], fileName: String): Unit = { + val outputDir = Paths.get("../../../test-gzip-output") + if (!Files.exists(outputDir)) { + Files.createDirectories(outputDir) + } + val path = outputDir.resolve(fileName) + Files.write(path, data, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING) + } + + // Need this in windows to produce valid windows filename + def sanitizedFilename(fileName: String): String = { + fileName.replaceAll("[^a-zA-Z0-9-_.]", "_") + } + + def handleFailure( + compressed: Array[Byte], + data: Array[Byte], + testSetup: String, + errorCause: String, + errorOpt: Option[Exception] = None, + ): Unit = { + val compressedFileName = sanitizedFilename(s"compressed_$testSetup.gz") + val dataFileName = sanitizedFilename(s"data_$testSetup.bin") + writeToFile(compressed, compressedFileName) + writeToFile(data, dataFileName) + + errorOpt match { + case Some(error) => + fail(s"$errorCause. See $compressedFileName and $dataFileName", error) + case _ => fail(s"$errorCause. See $compressedFileName and $dataFileName") + } + + } + + def verifyRoundTrip(data: Array[Byte], parallelism: Int, testSetup: String): Unit = { + val compressed = compress(data, parallelism, testSetup) + try { + val decompressed = decompress(compressed) + if (!Arrays.equals(data, decompressed)) { + handleFailure(compressed, data, testSetup, "Compression and decompression mismatch.") + } + } catch { + case e: Exception => + handleFailure( + compressed, + data, + testSetup, + "Decompression failed", + Some(e), + ) + } + } + + def randomArray(size: Int): Array[Byte] = { + val rnd = new Random(0L) + val data = new Array[Byte](size) + rnd.nextBytes(data) + data + } + + it should "compress and decompress data correctly" in { + for { + parallelism <- parallelisms + size <- sizes + } { + val data = randomArray(size) + verifyRoundTrip(data, parallelism, s"parallelism = $parallelism, size = $size") + } + } + + it should "handle highly redundant data correctly" in { + for { + parallelism <- parallelisms + size <- sizes + } { + val data = Array.fill(size)(0.toByte) + verifyRoundTrip(data, parallelism, s"parallelism = $parallelism, size = $size, redundant") + } + } + + it should "handle large data sizes" in { + val largeData = randomArray(64 * 1024 * 1024) // 64 MB + for (parallelism <- parallelisms) { + verifyRoundTrip(largeData, parallelism, s"parallelism = $parallelism, large data size") + } + } + + it should "handle very large parallelism" in { + val data = randomArray(defaultSize * 16) + val maxNumberOfThreads = 200 + verifyRoundTrip(data, maxNumberOfThreads, s"parallelism = $maxNumberOfThreads, large data") + } + + it should "handle multiple ParallelGzipOutputStream concurrently" in { + for { + numberOfGzipStream <- numberOfGzipStreams + parallelism <- parallelisms + size <- sizes + } { + val verifications = Future.traverse(1 to numberOfGzipStream)(numberOfGzipStream => + Future { + val data = randomArray(size) + verifyRoundTrip( + data, + parallelism, + s"numberOfStreams: $numberOfGzipStream, parallelism = $parallelism, size = $size, multiple" + ) + } + ) + Await.result(verifications, 60.seconds) + } + } + + it should "handle multiple ParallelGzipOutputStream with varying config concurrently" in { + val verifications = Future.traverse(for { + parallelism <- parallelisms.take(10) + size <- sizes + } yield (parallelism, size)) { case (parallelism, size) => + Future { + val data = randomArray(size) + verifyRoundTrip( + data, + parallelism, + s"parallelism = $parallelism, size = $size, varying" + ) + } + } + Await.result(verifications, 60.seconds) + } +} diff --git a/zinc/src/main/scala/sbt/internal/inc/MixedAnalyzingCompiler.scala b/zinc/src/main/scala/sbt/internal/inc/MixedAnalyzingCompiler.scala index 2d37b6658..b16b0fecf 100644 --- a/zinc/src/main/scala/sbt/internal/inc/MixedAnalyzingCompiler.scala +++ b/zinc/src/main/scala/sbt/internal/inc/MixedAnalyzingCompiler.scala @@ -509,7 +509,6 @@ object MixedAnalyzingCompiler { useConsistent = false, mappers = ReadWriteMappers.getEmptyMappers(), sort = true, - ec = ExecutionContext.global, parallelism = Runtime.getRuntime.availableProcessors(), ) @@ -519,7 +518,6 @@ object MixedAnalyzingCompiler { useConsistent: Boolean, mappers: ReadWriteMappers, sort: Boolean, - ec: ExecutionContext, parallelism: Int, ): AnalysisStore = { val fileStore = (useTextAnalysis, useConsistent) match { @@ -530,7 +528,6 @@ object MixedAnalyzingCompiler { file = analysisFile.toFile, mappers = mappers, sort = sort, - ec = ec, parallelism = parallelism, ) case (true, false) => @@ -540,7 +537,6 @@ object MixedAnalyzingCompiler { file = analysisFile.toFile, mappers = mappers, sort = sort, - ec = ec, parallelism = parallelism, ) }