From 9addf1b0444416a4fba7477713b0f5404de6fb13 Mon Sep 17 00:00:00 2001 From: Chris Norman Date: Wed, 13 Jun 2018 16:31:33 -0400 Subject: [PATCH] Engine read caching. --- .../ReadInputArgumentCollection.java | 11 + .../hellbender/engine/FeatureCache.java | 248 ---------- .../hellbender/engine/FeatureDataSource.java | 59 +-- .../hellbender/engine/GATKTool.java | 14 +- .../hellbender/engine/ReadsDataSource.java | 35 +- .../DrivingFeatureInputCacheStrategy.java | 143 ++++++ .../engine/cache/LocatableCache.java | 239 ++++++++++ .../engine/cache/LocatableCacheStrategy.java | 59 +++ .../cache/SideReadInputCacheStrategy.java | 176 +++++++ .../FilterAlignmentArtifacts.java | 5 + .../tools/walkers/vqsr/CNNScoreVariants.java | 5 + .../engine/FeatureDataSourceUnitTest.java | 266 ----------- .../engine/VariantWalkerIntegrationTest.java | 106 +++++ .../engine/cache/LocatableCacheUnitTest.java | 430 ++++++++++++++++++ 14 files changed, 1235 insertions(+), 561 deletions(-) delete mode 100644 src/main/java/org/broadinstitute/hellbender/engine/FeatureCache.java create mode 100644 src/main/java/org/broadinstitute/hellbender/engine/cache/DrivingFeatureInputCacheStrategy.java create mode 100644 src/main/java/org/broadinstitute/hellbender/engine/cache/LocatableCache.java create mode 100644 src/main/java/org/broadinstitute/hellbender/engine/cache/LocatableCacheStrategy.java create mode 100644 src/main/java/org/broadinstitute/hellbender/engine/cache/SideReadInputCacheStrategy.java create mode 100644 src/test/java/org/broadinstitute/hellbender/engine/cache/LocatableCacheUnitTest.java diff --git a/src/main/java/org/broadinstitute/hellbender/cmdline/argumentcollections/ReadInputArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/cmdline/argumentcollections/ReadInputArgumentCollection.java index 843f0fbf611..ddebd35b725 100644 --- a/src/main/java/org/broadinstitute/hellbender/cmdline/argumentcollections/ReadInputArgumentCollection.java +++ b/src/main/java/org/broadinstitute/hellbender/cmdline/argumentcollections/ReadInputArgumentCollection.java @@ -1,6 +1,7 @@ package org.broadinstitute.hellbender.cmdline.argumentcollections; import htsjdk.samtools.ValidationStringency; +import org.broadinstitute.barclay.argparser.Advanced; import org.broadinstitute.barclay.argparser.Argument; import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; import org.broadinstitute.hellbender.utils.io.IOUtils; @@ -37,6 +38,11 @@ public abstract class ReadInputArgumentCollection implements Serializable { optional = true) protected List readIndices; + @Advanced + @Argument(fullName = "reads-look-ahead-window-size", shortName="reads-look-ahead-window-size", + doc = "Window size for reads look-ahead query caching") + protected int readsLookaheadWindowSize = 100 * 1024; + /** * Get the list of BAM/SAM/CRAM files specified at the command line. * Paths are the preferred format, as this can handle both local disk and NIO direct access to cloud storage. @@ -73,4 +79,9 @@ public List getReadIndexPaths() { * at the command line. */ public ValidationStringency getReadValidationStringency() { return readValidationStringency; }; + + /** + * Get the look ahead buffer size to be used for read queries + */ + public int getReadsLookaheadWindowSize() { return readsLookaheadWindowSize; }; } diff --git a/src/main/java/org/broadinstitute/hellbender/engine/FeatureCache.java b/src/main/java/org/broadinstitute/hellbender/engine/FeatureCache.java deleted file mode 100644 index b33b35ecd0f..00000000000 --- a/src/main/java/org/broadinstitute/hellbender/engine/FeatureCache.java +++ /dev/null @@ -1,248 +0,0 @@ -package org.broadinstitute.hellbender.engine; - -import htsjdk.tribble.Feature; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.broadinstitute.hellbender.exceptions.GATKException; -import org.broadinstitute.hellbender.utils.SimpleInterval; - -import java.util.*; - -/** - * FeatureCache: helper class for {@link FeatureDataSource} to manage the cache of Feature records used - * during query operations initiated via {@link FeatureDataSource#query(org.broadinstitute.hellbender.utils.SimpleInterval)} - * and/or {@link FeatureDataSource#queryAndPrefetch(org.broadinstitute.hellbender.utils.SimpleInterval)}. - * - * Strategy is to pre-fetch a large number of records AFTER each query interval that produces - * a cache miss. This optimizes for the use case of intervals with gradually increasing start - * positions, as many subsequent queries will find their records wholly contained in the cache - * before we have another cache miss. Performance will be poor for random/non-localized access - * patterns, or intervals with decreasing start positions. - * - * Usage: - * -Test whether each query interval is a cache hit via {@link #cacheHit(org.broadinstitute.hellbender.utils.SimpleInterval)} - * - * -If it is a cache hit, trim the cache to the start position of the interval (discarding records that - * end before the start of the new interval) via {@link #trimToNewStartPosition(int)}, then retrieve - * records up to the desired endpoint using {@link #getCachedFeaturesUpToStopPosition(int)}. - * - * -If it is a cache miss, reset the cache using {@link #fill(java.util.Iterator, org.broadinstitute.hellbender.utils.SimpleInterval)}, pre-fetching - * a large number of records after the query interval in addition to those actually requested. - * - * @param Type of Feature record we are caching - */ -class FeatureCache { - private static final Logger logger = LogManager.getLogger(FeatureCache.class); - - /** - * Our cache of Features, optimized for insertion/removal at both ends. - */ - private final Deque cache; - - /** - * Our cache currently contains Feature records overlapping this interval - */ - private SimpleInterval cachedInterval; - - /** - * Number of times we called {@link #cacheHit(SimpleInterval)} and it returned true - */ - private int numCacheHits = 0; - - /** - * Number of times we called {@link #cacheHit(SimpleInterval)} and it returned false - */ - private int numCacheMisses = 0; - - /** - * Initial capacity of our cache (will grow by doubling if needed) - */ - private static final int INITIAL_CAPACITY = 1024; - - /** - * When we trim our cache to a new start position, this is the maximum number of - * Features we expect to need to place into temporary storage for the duration of - * the trim operation. Performance only suffers slightly if our estimate is wrong. - */ - private static final int EXPECTED_MAX_OVERLAPPING_FEATURES_DURING_CACHE_TRIM = 128; - - /** - * Create an initially-empty FeatureCache with default initial capacity - */ - public FeatureCache() { - cache = new ArrayDeque<>(INITIAL_CAPACITY); - } - - /** - * Get the name of the contig on which the Features in our cache are located - * - * @return the name of the contig on which the Features in our cache are located - */ - public String getContig() { - return cachedInterval.getContig(); - } - - /** - * Get the start position of the interval that all Features in our cache overlap - * - * @return the start position of the interval that all Features in our cache overlap - */ - public int getCacheStart() { - return cachedInterval.getStart(); - } - - /** - * Get the stop position of the interval that all Features in our cache overlap - * - * @return the stop position of the interval that all Features in our cache overlap - */ - public int getCacheEnd() { - return cachedInterval.getEnd(); - } - - /** - * Does our cache currently contain no Features? - * - * @return true if our cache contains no Features, otherwise false - */ - public boolean isEmpty() { - return cache.isEmpty(); - } - - /** - * @return Number of times we called {@link #cacheHit(SimpleInterval)} and it returned true - */ - public int getNumCacheHits() { - return numCacheHits; - } - - /** - * @return Number of times we called {@link #cacheHit(SimpleInterval)} and it returned false - */ - public int getNumCacheMisses() { - return numCacheMisses; - } - - /** - * Clear our cache and fill it with the records from the provided iterator, preserving their - * relative ordering, and update our contig/start/stop to reflect the new interval that all - * records in our cache overlap. - * - * Typically each fill operation should involve significant lookahead beyond the region - * requested so that future queries will be cache hits. - * - * @param featureIter iterator from which to pull Features with which to populate our cache - * (replacing existing cache contents) - * @param interval all Features from featureIter overlap this interval - */ - public void fill( final Iterator featureIter, final SimpleInterval interval ) { - cache.clear(); - while ( featureIter.hasNext() ) { - cache.add(featureIter.next()); - } - - cachedInterval = interval; - } - - /** - * Determines whether all records overlapping the provided interval are already contained in our cache. - * - * @param interval the interval to check against the contents of our cache - * @return true if all records overlapping the provided interval are already contained in our cache, otherwise false - */ - public boolean cacheHit( final SimpleInterval interval ) { - final boolean cacheHit = cachedInterval != null && cachedInterval.contains(interval); - - if ( cacheHit ) { - ++numCacheHits; - } - else { - ++numCacheMisses; - } - - return cacheHit; - } - - /** - * Trims the cache to the specified new start position by discarding all records that end before it - * while preserving relative ordering of records. - * - * @param newStart new start position on the current contig to which to trim the cache - */ - public void trimToNewStartPosition( final int newStart ) { - if ( newStart > cachedInterval.getEnd() ) { - throw new GATKException(String.format("BUG: attempted to trim Feature cache to an improper new start position (%d). Cache stop = %d", - newStart, cachedInterval.getEnd())); - } - - List overlappingFeaturesBeforeNewStart = new ArrayList<>(EXPECTED_MAX_OVERLAPPING_FEATURES_DURING_CACHE_TRIM); - - // In order to trim the cache to the new start position, we need to find - // all Features in the cache that start before the new start position, - // and discard those that don't overlap the new start while keeping those - // that do overlap. We can stop once we find a Feature that starts on or - // after the new start position, since the Features are assumed to be sorted - // by start position. - while ( ! cache.isEmpty() && cache.getFirst().getStart() < newStart ) { - CACHED_FEATURE featureBeforeNewStart = cache.removeFirst(); - - if ( featureBeforeNewStart.getEnd() >= newStart ) { - overlappingFeaturesBeforeNewStart.add(featureBeforeNewStart); - } - } - - // Add back the Features that started before the new start but overlapped it - // in the reverse of the order in which we encountered them so that their original - // relative ordering in the cache is restored. - for ( int i = overlappingFeaturesBeforeNewStart.size() - 1; i >= 0; --i ) { - cache.addFirst(overlappingFeaturesBeforeNewStart.get(i)); - } - - // Record our new start boundary - cachedInterval = new SimpleInterval(cachedInterval.getContig(), newStart, cachedInterval.getEnd()); - } - - /** - * Returns (but does not remove) all cached Features that overlap the region from the start - * of our cache (cacheStart) to the specified stop position. - * - * @param stopPosition Endpoint of the interval that returned Features must overlap - * @return all cached Features that overlap the region from the start of our cache to the specified stop position - */ - public List getCachedFeaturesUpToStopPosition( final int stopPosition ) { - List matchingFeatures = new ArrayList<>(cache.size()); - - // Find (but do not remove from our cache) all Features that start before or on the provided stop position - for ( CACHED_FEATURE candidateFeature : cache ) { - if ( candidateFeature.getStart() > stopPosition ) { - break; // No more possible matches among the remaining cached Features, so stop looking - } - matchingFeatures.add(candidateFeature); - } - return matchingFeatures; - } - - /** - * Print statistics about the cache hit rate for debugging. - */ - public void printCacheStatistics() { - printCacheStatistics(""); - } - - /** - * Print statistics about the cache hit rate for debugging. - * @param sourceName The source for the features in this cache. - */ - public void printCacheStatistics(final String sourceName) { - - final String sourceNameString = sourceName.isEmpty() ? "" : "for data source " + sourceName; - - final int totalQueries = getNumCacheHits() + getNumCacheMisses(); - logger.debug(String.format("Cache hit rate %s was %.2f%% (%d out of %d total queries)", - sourceNameString, - totalQueries > 0 ? ((double)getNumCacheHits() / totalQueries) * 100.0 : 0.0, - getNumCacheHits(), - totalQueries)); - } -} - diff --git a/src/main/java/org/broadinstitute/hellbender/engine/FeatureDataSource.java b/src/main/java/org/broadinstitute/hellbender/engine/FeatureDataSource.java index c0b8e8be231..6ad733967e7 100644 --- a/src/main/java/org/broadinstitute/hellbender/engine/FeatureDataSource.java +++ b/src/main/java/org/broadinstitute/hellbender/engine/FeatureDataSource.java @@ -11,6 +11,8 @@ import htsjdk.variant.vcf.VCFHeader; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.broadinstitute.hellbender.engine.cache.DrivingFeatureInputCacheStrategy; +import org.broadinstitute.hellbender.engine.cache.LocatableCache; import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; import org.broadinstitute.hellbender.tools.IndexFeatureFile; @@ -94,16 +96,7 @@ public final class FeatureDataSource implements GATKDataSourc * improve performance of the common access pattern involving multiple queries across nearby intervals * with gradually increasing start positions. */ - private final FeatureCache queryCache; - - /** - * When we experience a cache miss (ie., a query interval not fully contained within our cache) and need - * to re-populate the Feature cache from disk to satisfy a query, this controls the number of extra bases - * AFTER the end of our interval to fetch. Should be sufficiently large so that typically a significant number - * of subsequent queries will be cache hits (ie., query intervals fully contained within our cache) before - * we have another cache miss and need to go to disk again. - */ - private final int queryLookaheadBases; + private final LocatableCache queryCache; /** * Holds information about the path this datasource reads from. @@ -278,8 +271,7 @@ public FeatureDataSource(final FeatureInput featureInput, final int queryLook this.currentIterator = null; this.intervalsForTraversal = null; - this.queryCache = new FeatureCache<>(); - this.queryLookaheadBases = queryLookaheadBases; + this.queryCache = new LocatableCache<>(getName(), new DrivingFeatureInputCacheStrategy<>(queryLookaheadBases, this::refillQueryCache)); } /** @@ -553,50 +545,28 @@ public List queryAndPrefetch( final SimpleInterval interval ) { "If it's a file, please index it using the bundled tool " + IndexFeatureFile.class.getSimpleName()); } - // If the query can be satisfied using existing cache contents, prepare for retrieval - // by discarding all Features at the beginning of the cache that end before the start - // of our query interval. - if ( queryCache.cacheHit(interval) ) { - queryCache.trimToNewStartPosition(interval.getStart()); - } - // Otherwise, we have a cache miss, so go to disk to refill our cache. - else { - refillQueryCache(interval); - } - - // Return the subset of our cache that overlaps our query interval - return queryCache.getCachedFeaturesUpToStopPosition(interval.getEnd()); + return queryCache.queryAndPrefetch(interval); } /** - * Refill our cache from disk after a cache miss. Will prefetch Features overlapping an additional - * queryLookaheadBases bases after the end of the provided interval, in addition to those overlapping + * Called by the cache strategy to refill our cache from disk after a cache miss. Will prefetch Features overlapping + * an additional queryLookaheadBases bases after the end of the provided interval, in addition to those overlapping * the interval itself. * - * Calling this has the side effect of invalidating (closing) any currently-open iteration over + * This has the side effect of invalidating (closing) any currently-open iteration over * this data source. * - * @param interval the query interval that produced a cache miss + * @param cacheInterval the query interval to be cached */ - private void refillQueryCache( final SimpleInterval interval ) { + private Iterator refillQueryCache(final SimpleInterval cacheInterval) { // Tribble documentation states that having multiple iterators open simultaneously over the same FeatureReader // results in undefined behavior closeOpenIterationIfNecessary(); - - // Expand the end of our query by the configured number of bases, in anticipation of probable future - // queries with slightly larger start/stop positions. - // - // Note that it doesn't matter if we go off the end of the contig in the process, since - // our reader's query operation is not aware of (and does not care about) contig boundaries. - // Note: we use addExact to blow up on overflow rather than propagate negative results downstream - final SimpleInterval queryInterval = new SimpleInterval(interval.getContig(), interval.getStart(), Math.addExact(interval.getEnd(), queryLookaheadBases)); - - // Query iterator over our reader will be immediately closed after re-populating our cache - try ( CloseableTribbleIterator queryIter = featureReader.query(queryInterval.getContig(), queryInterval.getStart(), queryInterval.getEnd()) ) { - queryCache.fill(queryIter, queryInterval); + try { + return featureReader.query(cacheInterval.getContig(), cacheInterval.getStart(), cacheInterval.getEnd()); } catch ( IOException e ) { - throw new GATKException("Error querying file " + featureInput + " over interval " + interval, e); + throw new GATKException("Error querying file " + featureInput + " over interval " + cacheInterval, e); } } @@ -626,8 +596,7 @@ public Object getHeader() { public void close() { closeOpenIterationIfNecessary(); - logger.debug(String.format("Cache statistics for FeatureInput %s:", featureInput)); - queryCache.printCacheStatistics(); + logger.debug(String.format("Cache statistics for FeatureInput %s: %s", featureInput, queryCache.getCacheStatistics())); try { if ( featureReader != null ) { diff --git a/src/main/java/org/broadinstitute/hellbender/engine/GATKTool.java b/src/main/java/org/broadinstitute/hellbender/engine/GATKTool.java index 07c65fadf93..bb91b21c8dd 100644 --- a/src/main/java/org/broadinstitute/hellbender/engine/GATKTool.java +++ b/src/main/java/org/broadinstitute/hellbender/engine/GATKTool.java @@ -386,6 +386,9 @@ else if (hasCramInput()) { reads = new ReadsDataSource(readArguments.getReadPaths(), readArguments.getReadIndexPaths(), factory, cloudPrefetchBuffer, (cloudIndexPrefetchBuffer < 0 ? cloudPrefetchBuffer : cloudIndexPrefetchBuffer)); + if (useReadCaching()) { + reads.enableReadCaching(readArguments.getReadsLookaheadWindowSize()); + } } else { reads = null; @@ -510,7 +513,16 @@ public boolean requiresReads() { } /** - * Does this tool require intervals? Traversals types and/or tools that do should override to return true. + * Should this tool use read caching to optimize forward-only queries on coordinate sorted read inputs ? + * + * @return true if read caching should be used, otherwise false + */ + public boolean useReadCaching() { + return false; + } + + /** + * Cache ? Traversals types and/or tools that do should override to return true. * * @return true if this tool requires intervals, otherwise false */ diff --git a/src/main/java/org/broadinstitute/hellbender/engine/ReadsDataSource.java b/src/main/java/org/broadinstitute/hellbender/engine/ReadsDataSource.java index d7b8c453771..27ebb7802d2 100644 --- a/src/main/java/org/broadinstitute/hellbender/engine/ReadsDataSource.java +++ b/src/main/java/org/broadinstitute/hellbender/engine/ReadsDataSource.java @@ -8,6 +8,8 @@ import java.util.function.Function; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.broadinstitute.hellbender.engine.cache.LocatableCache; +import org.broadinstitute.hellbender.engine.cache.SideReadInputCacheStrategy; import org.broadinstitute.hellbender.utils.IntervalUtils; import org.broadinstitute.hellbender.utils.SimpleInterval; import org.broadinstitute.hellbender.exceptions.GATKException; @@ -73,6 +75,11 @@ public final class ReadsDataSource implements GATKDataSource, AutoClos */ private final SamFileHeaderMerger headerMerger; + /** + * Used to cache reads around the current query interval when caching is enabled + */ + private LocatableCache queryCache = null; + /** * Are indices available for all files? */ @@ -247,6 +254,28 @@ public ReadsDataSource( final List samPaths, final List samIndices, headerMerger = samPaths.size() > 1 ? createHeaderMerger() : null; } + /** + * Enable look-ahead caching for reads using {@code windowSize} look-ahead bases. + * @param lookAheadBases number of additional bases to cache + */ + public void enableReadCaching(int lookAheadBases) { + Utils.validate(queryCache == null, "Can't reset the cache/window look ahead size"); + this.queryCache = new LocatableCache<>( + getName(), + new SideReadInputCacheStrategy<>( + lookAheadBases, + (SimpleInterval newCacheInterval) -> prepareIteratorsForTraversal(Arrays.asList(newCacheInterval), false) + ) + ); + } + + public String getName() { + final int DISPLAY_NAME_THRESHOLD = 3; // don't try to display more than this many paths + return backingPaths.values().stream().limit(DISPLAY_NAME_THRESHOLD) + .map(p -> p.toString()) + .collect(Collectors.joining("Reads: ",",", backingPaths.values().size() > DISPLAY_NAME_THRESHOLD ? "..." : "")); + } + /** * Are indices available for all files? */ @@ -343,7 +372,11 @@ public Iterator query( final SimpleInterval interval ) { raiseExceptionForMissingIndex("Cannot query reads data source by interval unless all files are indexed"); } - return prepareIteratorsForTraversal(Arrays.asList(interval)); + if (queryCache != null) { + return queryCache.queryAndPrefetch(interval).iterator(); + } else { + return prepareIteratorsForTraversal(Arrays.asList(interval)); + } } /** diff --git a/src/main/java/org/broadinstitute/hellbender/engine/cache/DrivingFeatureInputCacheStrategy.java b/src/main/java/org/broadinstitute/hellbender/engine/cache/DrivingFeatureInputCacheStrategy.java new file mode 100644 index 00000000000..44eca708d0a --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/engine/cache/DrivingFeatureInputCacheStrategy.java @@ -0,0 +1,143 @@ +package org.broadinstitute.hellbender.engine.cache; + +import htsjdk.tribble.Feature; +import org.broadinstitute.hellbender.exceptions.GATKException; +import org.broadinstitute.hellbender.utils.SimpleInterval; + +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; + +/** + * A {@code LocatableCacheStrategy} used to cache primary Feature inputs. + * + * Strategy is to pre-fetch a large number of records AFTER each query interval that produces + * a cache miss. This optimizes for the use case of intervals with gradually increasing start + * positions, as many subsequent queries will find their records wholly contained in the cache + * before we have another cache miss. Performance will be poor for random/non-localized access + * patterns, or intervals with decreasing start positions. + * + * @param type Feature being cached + */ +public class DrivingFeatureInputCacheStrategy implements LocatableCacheStrategy { + + /** + * When we trimCache our cache to a new start position, this is the maximum number of + * {@ocde CACHED_FEATURE} objects we expect to need to place into temporary storage for the duration of + * the trimCache operation. Performance only suffers slightly if our estimate is wrong. + */ + private static final int EXPECTED_MAX_OVERLAPPING_FEATURES_DURING_CACHE_TRIM = 128; + + /** + * When we experience a cache miss (ie., a query interval not fully contained within our cache) and need + * to re-populate the Feature cache from disk to satisfy a query, this controls the number of extra bases + * AFTER the end of our interval to fetch. Should be sufficiently large so that typically a significant number + * of subsequent queries will be cache hits (ie., query intervals fully contained within our cache) before + * we have another cache miss and need to go to disk again. + */ + private final int queryLookaheadBases; + + /** + * Function called on a cache miss to provide results for a cached interval to populate the cache. + */ + private final Function> queryResultsProvider; + + /** + * @param queryLookaheadBases number of a additional base positions beyond the requested query interval to cache + * @param queryResultsProvider a {@code Function} that can be called get an iterator for a given interval suitable + * for re-populating the cache + */ + public DrivingFeatureInputCacheStrategy( + final int queryLookaheadBases, + final Function> queryResultsProvider) { + this.queryLookaheadBases = queryLookaheadBases; + this.queryResultsProvider = queryResultsProvider; + } + + /** + * Given a requested queryInterval, return a new (expanded) interval representing the interval for which items + * should be cached. + * + * @param queryInterval the interval being queried + * @return the interval to be cached + */ + @Override + public SimpleInterval getCacheIntervalFromQueryInterval(final SimpleInterval queryInterval) { + return new SimpleInterval(queryInterval.getContig(), queryInterval.getStart(), Math.addExact(queryInterval.getEnd(), queryLookaheadBases)); + } + + /** + * Return a set of features from the cache that satisfy overlap the given interval + * @param cache the cache from which to pull items + * @param queryInterval the interval being queried + * @return {@ocde List} of {@ocde CACHED_FEATURE} objects that overlap {@ocde queryInterval} + */ + public List queryCache(final Deque cache, final SimpleInterval queryInterval) { + List matchingFeatures = new ArrayList<>(cache.size()); + + // Find (but do not remove from our cache) all Features that start before or on the provided stop position + for ( CACHED_FEATURE candidateFeature : cache ) { + if ( candidateFeature.getStart() > queryInterval.getEnd() ) { + break; // No more possible matches among the remaining cached Features, so stop looking + } + matchingFeatures.add(candidateFeature); + } + return matchingFeatures; + } + + /** + * Refill the cache with items overlapping {@ocde queryInterval}. + * @param queryInterval the query interval which returned items should overlap + * @return Iterator of items matching the query interval + */ + @Override + public Iterator refillCache(SimpleInterval queryInterval) { + return queryResultsProvider.apply(queryInterval); + } + + /** + * Trims the cache to the specified new start position by discarding all records that end before it + * while preserving relative ordering of records. + * + * @param cache the cache from which to pull items + * @param cachedInterval the currently cached interval + * @param interval new interval to which to trim the cache + * @return the newly cached interval + */ + @Override + public SimpleInterval trimCache(final Deque cache, final SimpleInterval cachedInterval, final SimpleInterval interval ) { + if ( interval.getStart() > cachedInterval.getEnd() ) { + throw new GATKException(String.format("BUG: attempted to trimCache Feature cache to an improper new start position (%d). Cache stop = %d", + interval.getStart(), cachedInterval.getEnd())); + } + + final List overlappingFeaturesBeforeNewStart = new ArrayList<>(EXPECTED_MAX_OVERLAPPING_FEATURES_DURING_CACHE_TRIM); + + // In order to trimCache the cache to the new start position, we need to find + // all Features in the cache that start before the new start position, + // and discard those that don't overlap the new start while keeping those + // that do overlap. We can stop once we find a Feature that starts on or + // after the new start position, since the Features are assumed to be sorted + // by start position. + while ( ! cache.isEmpty() && cache.getFirst().getStart() < interval.getStart() ) { + CACHED_FEATURE featureBeforeNewStart = cache.removeFirst(); + + if ( featureBeforeNewStart.getEnd() >= interval.getStart() ) { + overlappingFeaturesBeforeNewStart.add(featureBeforeNewStart); + } + } + + // Add back the Features that started before the new start but overlapped it + // in the reverse of the order in which we encountered them so that their original + // relative ordering in the cache is restored. + for ( int i = overlappingFeaturesBeforeNewStart.size() - 1; i >= 0; --i ) { + cache.addFirst(overlappingFeaturesBeforeNewStart.get(i)); + } + + // Record our new start boundary + return new SimpleInterval(cachedInterval.getContig(), interval.getStart(), cachedInterval.getEnd()); + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/engine/cache/LocatableCache.java b/src/main/java/org/broadinstitute/hellbender/engine/cache/LocatableCache.java new file mode 100644 index 00000000000..cc95bed3b85 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/engine/cache/LocatableCache.java @@ -0,0 +1,239 @@ +package org.broadinstitute.hellbender.engine.cache; + +import com.google.common.annotations.VisibleForTesting; +import htsjdk.samtools.util.Locatable; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.broadinstitute.hellbender.utils.SimpleInterval; +import org.broadinstitute.hellbender.utils.Utils; + +import java.util.*; + +/** + * Implementation of a locatable cache with customizable cache strategy. + * + * Usage: + * -Test whether each query interval is a cache hit via {@link #cacheHit(org.broadinstitute.hellbender.utils.SimpleInterval)} + * + * -If it is a cache hit, trim the cache to the start position of the interval (discarding records that + * end before the start of the new interval) via {@link # trimCache(SimpleInterval)}, then retrieve + * records up to the desired endpoint using {@link #getCachedLocatables(SimpleInterval)}. + * + * -If it is a cache miss, reset the cache using {@link #refillQueryCache(org.broadinstitute.hellbender.utils.SimpleInterval)}, + * pre-fetching a large number of records after the query interval in addition to those actually requested. + * + * @param Type of Locatable record we are caching + */ +public class LocatableCache { + private static final Logger logger = LogManager.getLogger(LocatableCache.class); + + /** + * Display name for this cache + */ + private final String sourceDisplayName; + + /** + * Our cache of Features, optimized for insertion/removal at both ends. + */ + private final Deque cache; + + private final LocatableCacheStrategy cachingStrategy; + + /** + * Our cache currently contains Feature records overlapping this interval + */ + private SimpleInterval cachedInterval; + + /** + * Number of times we called {@link #cacheHit(SimpleInterval)} and it returned true + */ + private int numCacheHits = 0; + + /** + * Number of times we called {@link #cacheHit(SimpleInterval)} and it returned false + */ + private int numCacheMisses = 0; + + /** + * Initial capacity of our cache (will grow by doubling if needed) + */ + private static final int INITIAL_CAPACITY = 1024; + + /** + * Create an initially-empty LocatableCache with default initial capacity + * + * @param sourceName display name for this cache + * @param strategy {@link LocatableCacheStrategy} for curating the cache + */ + public LocatableCache(final String sourceName, final LocatableCacheStrategy strategy) { + Utils.nonNull(sourceName); + Utils.nonNull(strategy); + + sourceDisplayName = sourceName; + cachingStrategy = strategy; + cache = new ArrayDeque<>(INITIAL_CAPACITY); + } + + /** + * Clear our cache and fill it with the records from the provided iterator, preserving their + * relative ordering, and update our contig/start/stop to reflect the new interval that all + * records in our cache overlap. + * + * Typically each fill operation should involve significant lookahead beyond the region + * requested so that future queries will be cache hits. + * + * @param locatableIter iterator from which to pull Locatables with which to populate our cache + * (replacing existing cache contents) + * @param interval all Locatables from locatableIter overlap this interval + */ + private void fill(final Iterator locatableIter, final SimpleInterval interval ) { + cache.clear(); + while ( locatableIter.hasNext() ) { + cache.add(locatableIter.next()); + } + cachedInterval = interval; + } + + /** + * Returns a List of all Locatables in this data source that overlap the provided interval. + * + * @param interval retrieve all Locatables overlapping this interval + * @return a {@code List} of all Locatables in this cache that overlap the provided interval + */ + public List queryAndPrefetch(final SimpleInterval interval) { + // If the query can be fully satisfied using existing cache contents, prepare for retrieval + // by discarding all Locatables at the beginning of the cache that end before the start + // of our query interval. + if (cacheHit(interval) ) { + cachedInterval = cachingStrategy.trimCache(cache, cachedInterval, interval); + } + // Otherwise, we have at least a partial cache miss, so go to disk to refill our cache. + else { + refillQueryCache(interval); + } + + // Return the subset of our cache that overlaps our query interval + return getCachedLocatables(interval); + } + + /** + * Refill our cache from disk after a cache miss. Will prefetch Locatables overlapping an additional + * queryLookaheadBases bases after the end of the provided interval, in addition to those overlapping + * the interval itself. + * + * Calling this has the side effect of invalidating (closing) any currently-open iteration over + * this data source. + * + * @param queryInterval the query interval that produced a cache miss + */ + private void refillQueryCache( final SimpleInterval queryInterval) + { + // Expand the end of our query by the configured number of bases, in anticipation of probable future + // queries with slightly larger start/stop positions. + // + // Note that it doesn't matter if we go off the end of the contig in the process, since + // our reader's query operation is not aware of (and does not care about) contig boundaries. + // Note: we use addExact to blow up on overflow rather than propagate negative results downstream + final SimpleInterval cacheInterval = cachingStrategy.getCacheIntervalFromQueryInterval(queryInterval); + + // Query iterator over our reader will be immediately closed after re-populating our cache + Iterator cacheableIterator = cachingStrategy.refillCache(cacheInterval); + fill(cacheableIterator, cacheInterval); + } + + /** + * Get the name of the contig on which the Locatables in our cache are located + * + * @return the name of the contig on which the Locatables in our cache are located + */ + public String getContig() { + return cachedInterval.getContig(); + } + + /** + * Get the start position of the interval that all Locatables in our cache overlap + * + * @return the start position of the interval that all Locatables in our cache overlap + */ + public int getCacheStart() { + return cachedInterval.getStart(); + } + + /** + * Get the stop position of the interval that all Locatables in our cache overlap + * + * @return the stop position of the interval that all Locatables in our cache overlap + */ + public int getCacheEnd() { + return cachedInterval.getEnd(); + } + + /** + * Does our cache currently contain no Locatables? + * + * @return true if our cache contains no Locatables, otherwise false + */ + public boolean isEmpty() { + return cache.isEmpty(); + } + + /** + * @return Number of times we called {@link #cacheHit(SimpleInterval)} and it returned true + */ + private int getNumCacheHits() { + return numCacheHits; + } + + /** + * @return Number of times we called {@link #cacheHit(SimpleInterval)} and it returned false + */ + private int getNumCacheMisses() { + return numCacheMisses; + } + + /** + * Determines whether all records overlapping the provided interval are already contained in our cache. + * + * @param interval the interval to check against the contents of our cache + * @return true if all records overlapping the provided interval are already contained in our cache, otherwise false + */ + @VisibleForTesting + boolean cacheHit( final SimpleInterval interval ) { + final boolean cacheHit = cachedInterval != null && cachedInterval.contains(interval); + + if ( cacheHit ) { + ++numCacheHits; + } + else { + ++numCacheMisses; + } + + return cacheHit; + } + + /** + * Returns (but does not remove) all cached Locatables that overlap the region from the start + * of our cache (cacheStart) to the specified stop position. + * + * @param interval Endpoint of the interval that returned Locatables must overlap + * @return all cached Locatables that overlap the region from the start of our cache to the specified stop position + */ + @VisibleForTesting + List getCachedLocatables(final SimpleInterval interval ) { + return cachingStrategy.queryCache(cache, interval); + } + + /** + * Print statistics about the cache hit rate for debugging. + */ + public String getCacheStatistics() { + + final int totalQueries = getNumCacheHits() + getNumCacheMisses(); + return String.format("Cache hit rate %s was %.2f%% (%d out of %d total queries)", + sourceDisplayName, + totalQueries > 0 ? ((double)getNumCacheHits() / totalQueries) * 100.0 : 0.0, + getNumCacheHits(), + totalQueries); + } +} + diff --git a/src/main/java/org/broadinstitute/hellbender/engine/cache/LocatableCacheStrategy.java b/src/main/java/org/broadinstitute/hellbender/engine/cache/LocatableCacheStrategy.java new file mode 100644 index 00000000000..584a0699559 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/engine/cache/LocatableCacheStrategy.java @@ -0,0 +1,59 @@ +package org.broadinstitute.hellbender.engine.cache; + +import htsjdk.samtools.util.Locatable; +import org.broadinstitute.hellbender.utils.SimpleInterval; + +import java.util.Deque; +import java.util.Iterator; +import java.util.List; + +/** + * Interface for implemented by cache strategy objects for {@code LocatableCache}. + * + * {@code LocatableCacheStrategy} implementations determine the following policies for the cache: + * + *
    + *
  • {@link #getCacheIntervalFromQueryInterval} how to map a requested interval to a (larger) interval to be cached
  • + *
  • {@link #refillCache}how to (re)populate the cache
  • + *
  • {@link #queryCache}how to query the cache
  • + *
  • {@link #trimCache}how to trim the cache
  • + *
+ */ +interface LocatableCacheStrategy { + + /** + * Given a query interval, return an expanded interval representing the new interval to be cached. + * @param queryInterval the interval being queried + * @return the new interval to be cached + */ + SimpleInterval getCacheIntervalFromQueryInterval(final SimpleInterval queryInterval); + + /** + * Return a {@code List} of objects from the current cache that overlap the requested interval. + * + * The overlapping objects should be returned but not removed from the cache (items are removed + * by the {@link #trimCache} method). + * + * @param cache the cache object + * @param queryInterval the requested query interval + * @return {@code List} of cached objects overlapping the query interval + */ + List queryCache(final Deque cache, final SimpleInterval queryInterval); + + /** + * Return an iterator of items that overlap {@code queryInterval}, to be used to fill the cache. + * @param queryInterval the query interval which returned items should overlap + * @return {@code Iterator} of items overlapping {@code queryInterval} + */ + Iterator refillCache(final SimpleInterval queryInterval); + + /** + * Remove items from that are no longer needed from the cache be removing items that overlap the {@code newInterval}. + * + * @param cache cache object + * @param cachedInterval the currenly cached interval + * @param newInterval the new cached interval being requested + * @return the new cached interval resulting from the trim operation + */ + SimpleInterval trimCache(final Deque cache, final SimpleInterval cachedInterval, final SimpleInterval newInterval); +} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/engine/cache/SideReadInputCacheStrategy.java b/src/main/java/org/broadinstitute/hellbender/engine/cache/SideReadInputCacheStrategy.java new file mode 100644 index 00000000000..df2d98aa89a --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/engine/cache/SideReadInputCacheStrategy.java @@ -0,0 +1,176 @@ +package org.broadinstitute.hellbender.engine.cache; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.broadinstitute.hellbender.exceptions.GATKException; +import org.broadinstitute.hellbender.utils.SimpleInterval; +import org.broadinstitute.hellbender.utils.read.GATKRead; + +import java.util.*; +import java.util.function.Function; + +/** + * A {@code LocatableCacheStrategy} used to cache {@code GATKRead} objects that are side inputs. Assumes serial + * queries will be over increasing intervals. Performance will suffer if a subsequent query interval covers territory + * from a position earlier than a previous query. + * + * NOTE: this implementation uses the same criteria as BAMFileReader to determine whether to return the unmapped + * read in a mate pair where one read is mapped and overlaps the query interval and one read is unmapped but placed + * (the unmapped mate is only returned if the *start position* overlaps the query; otherwise only the mapped read + * of the pair is returned). For cache trimming purposes however, the two reads will always be trimmed at the same + * time. + * + * @param Type of Locatable being cached. + */ +public class SideReadInputCacheStrategy implements LocatableCacheStrategy { + private static final Logger logger = LogManager.getLogger(LocatableCache.class); + private static final int EXPECTED_MAX_OVERLAPPING_READS_DURING_CACHE_TRIM = 128; + + private int lookAheadBases; + final Function> queryResultsProvider; + + /** + * @param lookAheadBases number of bases beyond the requested interval to cache + * @param queryResultsProvider @{code Function} that takes a query interval and returns an iterator of + * {@code Locatable} over that interval + */ + public SideReadInputCacheStrategy(final int lookAheadBases, final Function> queryResultsProvider) { + this.lookAheadBases = lookAheadBases; + this.queryResultsProvider = queryResultsProvider; + } + + @Override + public SimpleInterval getCacheIntervalFromQueryInterval(SimpleInterval queryInterval) { + return new SimpleInterval(queryInterval.getContig(), queryInterval.getStart(), Math.addExact(queryInterval.getEnd(), lookAheadBases)); + } + + /** + * Return reads from the cache that overlap a query interval. + * + * NOTE: this implementation uses the same criteria as BAMFileReader to determine whether to return the unmapped + * read in a mate pair where one read is mapped and overlaps the query interval and one read is unmapped but placed + * (the unmapped mate is only returned if the *start position* overlaps the query; otherwise only the mapped read + * of the pair is returned). For cache trimming purposes however, the two reads will always be trimmed at the same + * time. + * + * @param cache the cache object + * @param queryInterval interval being queried + * @return reads that overlap the query interval + */ + public List queryCache(final Deque cache, final SimpleInterval queryInterval) { + List matchingReads = new ArrayList<>(cache.size()); + + // Find (but do not remove from our cache) all reads that start before or on the provided stop position + for ( CACHED_READ candidateRead : cache ) { + if ( candidateRead.getAssignedStart() > queryInterval.getEnd() ) { + break; // No more possible matches among the remaining cached Reads, so stop looking + } + if (candidateRead.isPaired() && candidateRead.isUnmapped()) { + // in order to keep the results identical to those that would have been returned without caching, + // use the start position of the unmapped read + if (candidateRead.getAssignedStart() >= queryInterval.getStart()) { + matchingReads.add(candidateRead); + } + } else { + matchingReads.add(candidateRead); + } + } + return matchingReads; + } + + @Override + public Iterator refillCache(SimpleInterval queryInterval) { + return queryResultsProvider.apply(queryInterval); + } + + @Override + public SimpleInterval trimCache(final Deque cache, final SimpleInterval cachedInterval, final SimpleInterval interval) { + if ( interval.getStart() > cachedInterval.getEnd() ) { + throw new GATKException(String.format("BUG: attempted to trimCache cache to an improper new start position (%d). Cache stop = %d", + interval.getStart(), cachedInterval.getEnd())); + } + + List overlappingReadsBeforeNewStart = new ArrayList<>(EXPECTED_MAX_OVERLAPPING_READS_DURING_CACHE_TRIM); + + // In order to trim the cache to the new start position, we need to find all reads in the cache that start + // before the new start position, and discard those that don't overlap the new start while keeping those + // that do overlap. We can stop once we find a read that starts on or after the new start position, since + // the reads are assumed to be sorted by start position. + // + // For mate pairs where one read is unmapped, we need to keep the pairs together, so we use the territory + // covered by the mapped mate to determine whether or not to trim the pair. + while ( ! cache.isEmpty() && cache.getFirst().getAssignedStart() < interval.getStart() ) { + CACHED_READ readBeforeNewStart = cache.removeFirst(); + CACHED_READ matedRead = null; + if (readBeforeNewStart.isUnmapped()) { + // we found an unmapped read who's mate has not been seen yet + matedRead = findMappedMateForUnmappedRead(readBeforeNewStart, cache); + } + else if (readBeforeNewStart.isPaired() && readBeforeNewStart.mateIsUnmapped()) { + // we found a mapped read that has an unmapped mate + matedRead = findUnmappedMateForMappedRead(readBeforeNewStart, cache); + } + // Our trim criteria for a pair, one of which is unmapped, should be based on the alignment end from the + // MAPPED mate of the pair if there is one. Since the mates can appear in either order, determine which + // read is the mapped mapped and use it's value for the end calculation. + int mappedEnd = readBeforeNewStart.isUnmapped() && matedRead != null ? + matedRead.getEnd() : + readBeforeNewStart.getEnd(); + if ( mappedEnd >= interval.getStart() ) { + overlappingReadsBeforeNewStart.add(readBeforeNewStart); + if (matedRead != null) + // mated reads should travel together + overlappingReadsBeforeNewStart.add(matedRead); + } + } + + // Add back the reads that started before the new start but overlapped it in the reverse of the order in + // which we encountered them so that their original relative ordering in the cache is restored. + for ( int i = overlappingReadsBeforeNewStart.size() - 1; i >= 0; --i ) { + cache.addFirst(overlappingReadsBeforeNewStart.get(i)); + } + + // Record our new start boundary + return new SimpleInterval(cachedInterval.getContig(), interval.getStart(), cachedInterval.getEnd()); + } + + // Find the unmapped read's mapped mate (if its next in the cache) + private CACHED_READ findMappedMateForUnmappedRead( + final GATKRead unmappedRead, + final Deque cache) { + if (!unmappedRead.isPaired()) { + // We should never find an unmapped read in the cache that is not paired + // TODO: is throwing too strict here ? + throw new GATKException(String.format("Found unmapped, unpaired read: '%s' with no mate", unmappedRead)); + } + if (!cache.isEmpty()) { + final GATKRead nextRead = cache.getFirst(); + if (nextRead.isPaired() && !nextRead.isUnmapped() && nextRead.getName().equals(unmappedRead.getName())) { + return cache.removeFirst(); + } + } + // Its possible that the input source actually doesn't contain the mated read, which we tolerate + logger.info(String.format("An unmapped read '%s' with no corresponding, mapped, paired mate was found", unmappedRead)); + return null; + } + + // Find the read's unmapped mate (if its next in the cache). + private CACHED_READ findUnmappedMateForMappedRead( + final GATKRead readWithUnmappedMate, + final Deque cache) + { + if (cache.isEmpty()) { + logger.info(String.format("The unmapped mate of mapped read '%s' is missing from the input", readWithUnmappedMate)); + } else { + final GATKRead nextRead = cache.getFirst(); + if (!nextRead.isUnmapped() || !nextRead.getName().equals(readWithUnmappedMate.getName())) { + logger.info(String.format("A mapped read '%s' with no corresponding paired, unmapped mate was found", readWithUnmappedMate)); + } + else { + return cache.removeFirst(); + } + } + return null; + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/realignmentfilter/FilterAlignmentArtifacts.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/realignmentfilter/FilterAlignmentArtifacts.java index 64c2d3de055..de94651764f 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/realignmentfilter/FilterAlignmentArtifacts.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/realignmentfilter/FilterAlignmentArtifacts.java @@ -128,6 +128,11 @@ public List getDefaultReadFilters() { @Override public boolean requiresReads() { return true; } + @Override + public boolean useReadCaching() { + return true; + } + @Override public void onTraversalStart() { realignmentEngine = new RealignmentEngine(realignmentArgumentCollection); diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java index 366c906dc9d..39de7fff13e 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java @@ -197,6 +197,11 @@ protected String[] customCommandLineValidation() { return null; } + @Override + public boolean useReadCaching() { + return true; + } + @Override public boolean requiresReference() { return true; diff --git a/src/test/java/org/broadinstitute/hellbender/engine/FeatureDataSourceUnitTest.java b/src/test/java/org/broadinstitute/hellbender/engine/FeatureDataSourceUnitTest.java index b55ed87ac2f..0d4bfd0d780 100644 --- a/src/test/java/org/broadinstitute/hellbender/engine/FeatureDataSourceUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/engine/FeatureDataSourceUnitTest.java @@ -1,12 +1,10 @@ package org.broadinstitute.hellbender.engine; import htsjdk.samtools.SAMSequenceDictionary; -import htsjdk.tribble.Feature; import htsjdk.variant.variantcontext.VariantContext; import htsjdk.variant.vcf.VCFFileReader; import htsjdk.variant.vcf.VCFHeader; import org.apache.commons.lang3.tuple.Pair; -import org.broadinstitute.hellbender.exceptions.GATKException; import org.broadinstitute.hellbender.exceptions.UserException; import org.broadinstitute.hellbender.utils.SimpleInterval; import org.broadinstitute.hellbender.GATKBaseTest; @@ -412,268 +410,4 @@ public void testQueryGVCF( final SimpleInterval queryInterval, final List initializeFeatureCache( final List features, final String cacheContig, final int cacheStart, final int cacheEnd ) { - FeatureCache cache = new FeatureCache<>(); - - cache.fill(features.iterator(), new SimpleInterval(cacheContig, cacheStart, cacheEnd)); - return cache; - } - - @DataProvider(name = "FeatureCacheFillDataProvider") - public Object[][] getFeatureCacheFillData() { - return new Object[][] { - { Arrays.asList(new ArtificialTestFeature("1", 1, 100), new ArtificialTestFeature("1", 50, 150), - new ArtificialTestFeature("1", 200, 300), new ArtificialTestFeature("1", 350, 400)), - "1", 1, 400 }, - { Arrays.asList(new ArtificialTestFeature("1", 1, 100)), "1", 1, 100 }, - { Collections.emptyList(), "1", 1, 1 } - }; - } - - @Test(dataProvider = "FeatureCacheFillDataProvider") - public void testCacheFill( final List features, final String cacheContig, final int cacheStart, final int cacheEnd) { - FeatureCache cache = initializeFeatureCache(features, cacheContig, cacheStart, cacheEnd); - - List cachedFeatures = cache.getCachedFeaturesUpToStopPosition(cacheEnd); - Assert.assertEquals(cache.getContig(), cacheContig, "Wrong contig reported by cache after fill"); - Assert.assertEquals(cache.getCacheStart(), cacheStart, "Wrong start position reported by cache after fill"); - Assert.assertEquals(cache.getCacheEnd(), cacheEnd, "Wrong stop position reported by cache after fill"); - Assert.assertEquals(cachedFeatures, features, "Wrong Features in cache after fill()"); - } - - @DataProvider(name = "FeatureCacheHitDetectionDataProvider") - public Object[][] getFeatureCacheHitDetectionData() { - List features = Arrays.asList(new ArtificialTestFeature("1", 1, 100), - new ArtificialTestFeature("1", 50, 150), - new ArtificialTestFeature("1", 200, 300)); - FeatureCache cache = initializeFeatureCache(features, "1", 50, 250); - - return new Object[][] { - // Exact match for cache boundaries - { cache, new SimpleInterval("1", 50, 250), true }, - // Interval completely contained within cache boundaries - { cache, new SimpleInterval("1", 100, 200), true }, - // Interval left-aligned with cache boundaries - { cache, new SimpleInterval("1", 50, 100), true }, - // Interval right-aligned with cache boundaries - { cache, new SimpleInterval("1", 200, 250), true }, - // Interval overlaps, but is off the left edge of cache boundaries - { cache, new SimpleInterval("1", 49, 100), false }, - // Interval overlaps, but is off the right edge of cache boundaries - { cache, new SimpleInterval("1", 200, 251), false }, - // Interval does not overlap and is to the left of cache boundaries - { cache, new SimpleInterval("1", 1, 40), false }, - // Interval does not overlap and is to the right of cache boundaries - { cache, new SimpleInterval("1", 300, 350), false }, - // Interval is on different contig - { cache, new SimpleInterval("2", 50, 250), false } - }; - } - - @Test(dataProvider = "FeatureCacheHitDetectionDataProvider") - public void testCacheHitDetection( final FeatureCache cache, - final SimpleInterval testInterval, final boolean cacheHitExpectedResult ) { - Assert.assertEquals(cache.cacheHit(testInterval), cacheHitExpectedResult, - "Cache hit detection failed for interval " + testInterval); - } - - @DataProvider(name = "FeatureCacheTrimmingDataProvider") - public Object[][] getFeatureCacheTrimmingData() { - // Features are required to always be sorted by start position, but stop positions need not be sorted. - // This complicates cache trimming. - List feats = Arrays.asList( - new ArtificialTestFeature("1", 1, 1), // Feature 0 - new ArtificialTestFeature("1", 1, 100), // Feature 1 - new ArtificialTestFeature("1", 1, 1), // Feature 2 - new ArtificialTestFeature("1", 1, 50), // Feature 3 - new ArtificialTestFeature("1", 1, 3), // Feature 4 - new ArtificialTestFeature("1", 1, 5), // Feature 5 - new ArtificialTestFeature("1", 5, 5), // Feature 6 - new ArtificialTestFeature("1", 5, 50), // Feature 7 - new ArtificialTestFeature("1", 5, 10), // Feature 8 - new ArtificialTestFeature("1", 50, 100), // Feature 9 - new ArtificialTestFeature("1", 50, 50), // Feature 10 - new ArtificialTestFeature("1", 50, 200), // Feature 11 - new ArtificialTestFeature("1", 100, 100), // Feature 12 - new ArtificialTestFeature("1", 100, 110), // Feature 13 - new ArtificialTestFeature("1", 100, 200), // Feature 14 - new ArtificialTestFeature("1", 100, 150), // Feature 15 - new ArtificialTestFeature("1", 100, 199) // Feature 16 - ); - FeatureCache cache = initializeFeatureCache(feats, "1", 1, 200); - - // Pairing of start position to which to trim the cache with the List of Features we expect to see - // in the cache after trimming - List>> trimOperations = Arrays.asList( - Pair.of(1, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(2, Arrays.asList(feats.get(1), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(3, Arrays.asList(feats.get(1), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(4, Arrays.asList(feats.get(1), feats.get(3), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(5, Arrays.asList(feats.get(1), feats.get(3), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(6, Arrays.asList(feats.get(1), feats.get(3), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(10, Arrays.asList(feats.get(1), feats.get(3), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(11, Arrays.asList(feats.get(1), feats.get(3), feats.get(7), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(50, Arrays.asList(feats.get(1), feats.get(3), feats.get(7), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(51, Arrays.asList(feats.get(1), feats.get(9), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(100, Arrays.asList(feats.get(1), feats.get(9), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(101, Arrays.asList(feats.get(11), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(111, Arrays.asList(feats.get(11), feats.get(14), feats.get(15), feats.get(16))), - Pair.of(151, Arrays.asList(feats.get(11), feats.get(14), feats.get(16))), - Pair.of(151, Arrays.asList(feats.get(11), feats.get(14), feats.get(16))), - Pair.of(200, Arrays.asList(feats.get(11), feats.get(14))) - ); - - return new Object[][] { - { cache, trimOperations } - }; - } - - @Test(dataProvider = "FeatureCacheTrimmingDataProvider") - public void testCacheTrimming( final FeatureCache cache, final List>> trimOperations ) { - // Repeatedly trim the cache to ever-increasing start positions, and verify after each trim operation - // that the cache holds the correct Features in the correc order - for ( Pair> trimOperation : trimOperations ) { - final int trimPosition = trimOperation.getLeft(); - final List expectedFeatures = trimOperation.getRight(); - - cache.trimToNewStartPosition(trimPosition); - - final List actualFeatures = cache.getCachedFeaturesUpToStopPosition(cache.getCacheEnd()); - Assert.assertEquals(actualFeatures, expectedFeatures, "Wrong Features in cache after trimming start position to " + trimPosition); - } - } - - @DataProvider(name = "FeatureCacheRetrievalDataProvider") - public Object[][] getFeatureCacheRetrievalData() { - List feats = Arrays.asList( - new ArtificialTestFeature("1", 1, 1), // Feature 0 - new ArtificialTestFeature("1", 1, 100), // Feature 1 - new ArtificialTestFeature("1", 5, 5), // Feature 2 - new ArtificialTestFeature("1", 10, 10), // Feature 3 - new ArtificialTestFeature("1", 10, 100), // Feature 4 - new ArtificialTestFeature("1", 50, 50), // Feature 5 - new ArtificialTestFeature("1", 51, 55), // Feature 6 - new ArtificialTestFeature("1", 52, 52), // Feature 7 - new ArtificialTestFeature("1", 55, 60), // Feature 8 - new ArtificialTestFeature("1", 75, 75), // Feature 9 - new ArtificialTestFeature("1", 80, 100) // Feature 10 - ); - FeatureCache cache = initializeFeatureCache(feats, "1", 1, 100); - - // Pairing of end position with which to bound cache retrieval with the List of Features we expect to see - // after retrieval - List>> retrievalOperations = Arrays.asList( - Pair.of(100, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10))), - Pair.of(80, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10))), - Pair.of(79, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9))), - Pair.of(80, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10))), - Pair.of(75, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9))), - Pair.of(74, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8))), - Pair.of(54, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7))), - Pair.of(52, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7))), - Pair.of(51, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6))), - Pair.of(50, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5))), - Pair.of(49, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4))), - Pair.of(10, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4))), - Pair.of(9, Arrays.asList(feats.get(0), feats.get(1), feats.get(2))), - Pair.of(5, Arrays.asList(feats.get(0), feats.get(1), feats.get(2))), - Pair.of(4, Arrays.asList(feats.get(0), feats.get(1))), - Pair.of(1, Arrays.asList(feats.get(0), feats.get(1))) - ); - - return new Object[][] { - { cache, retrievalOperations } - }; - } - - @Test(dataProvider = "FeatureCacheRetrievalDataProvider") - public void testCacheFeatureRetrieval( final FeatureCache cache, final List>> retrievalOperations ) { - for ( Pair> retrievalOperation: retrievalOperations ) { - final int stopPosition = retrievalOperation.getLeft(); - final List expectedFeatures = retrievalOperation.getRight(); - - final List actualFeatures = cache.getCachedFeaturesUpToStopPosition(stopPosition); - Assert.assertEquals(actualFeatures, expectedFeatures, "Wrong Features returned in retrieval operation with stop position " + stopPosition); - } - } - - /** - * Test caching a region with no Features. This should work (we should avoid going to disk - * to look for new records when querying within such a region). - */ - @Test - public void testHandleCachingOfEmptyRegion() { - FeatureCache cache = new FeatureCache<>(); - List emptyRegion = new ArrayList<>(); - - cache.fill(emptyRegion.iterator(), new SimpleInterval("1", 1, 100)); - - Assert.assertTrue(cache.isEmpty(), "Cache should be empty"); - Assert.assertTrue(cache.cacheHit(new SimpleInterval("1", 1, 100)), "Unexpected cache miss"); - Assert.assertTrue(cache.cacheHit(new SimpleInterval("1", 2, 99)), "Unexpected cache miss"); - - Assert.assertEquals(cache.getCachedFeaturesUpToStopPosition(100), emptyRegion, "Should get back empty List for empty region"); - cache.trimToNewStartPosition(2); - Assert.assertTrue(cache.cacheHit(new SimpleInterval("1", 2, 100)), "Unexpected cache miss"); - Assert.assertEquals(cache.getCachedFeaturesUpToStopPosition(100), emptyRegion, "Should get back empty List for empty region"); - } - - /********************************************************* - * End of direct testing on the FeatureCache inner class - *********************************************************/ - } diff --git a/src/test/java/org/broadinstitute/hellbender/engine/VariantWalkerIntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/engine/VariantWalkerIntegrationTest.java index aefbd6c9e7b..5766edff267 100644 --- a/src/test/java/org/broadinstitute/hellbender/engine/VariantWalkerIntegrationTest.java +++ b/src/test/java/org/broadinstitute/hellbender/engine/VariantWalkerIntegrationTest.java @@ -11,13 +11,17 @@ import org.broadinstitute.hellbender.cmdline.TestProgramGroup; import org.broadinstitute.hellbender.exceptions.UserException; import org.broadinstitute.hellbender.tools.examples.ExampleVariantWalker; +import org.broadinstitute.hellbender.utils.SimpleInterval; import org.broadinstitute.hellbender.utils.io.IOUtils; import org.broadinstitute.hellbender.utils.read.GATKRead; import org.broadinstitute.hellbender.utils.test.ArgumentsBuilder; +import org.broadinstitute.hellbender.utils.test.IntegrationTestSpec; import org.testng.Assert; import org.testng.annotations.Test; import java.io.File; +import java.io.FileWriter; +import java.io.IOException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -157,4 +161,106 @@ public void testReadFilterOn() throws Exception { tool.instanceMain(args); } + @CommandLineProgramProperties( + summary = "TestGATKToolWithFeaturesAndCachedReads", + oneLineSummary = "TestGATKToolWithFeaturesAndCachedReads", + programGroup = TestProgramGroup.class + ) + private static final class TestGATKToolWithFeaturesAndCachedReads extends VariantWalker { + + @Argument(fullName="outputFileName",optional=false) + private String outputFile; + + @Argument(fullName="enable-reads-caching") + private boolean enableReadsCaching = false; + + private FileWriter outputWriter; + + @Override + public boolean requiresReads() { return true; } + + @Override + public boolean useReadCaching() { + return enableReadsCaching; + } + + @Override + public void onTraversalStart() { + try { + outputWriter = new FileWriter(new File(outputFile)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Object onTraversalSuccess() { + try { + outputWriter.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + return null; + } + + @Override + public void apply( + VariantContext variant, + ReadsContext readsContext, + ReferenceContext referenceContext, + FeatureContext featureContext ) + { + final Iterator it = readsContext.iterator(); + try { + outputWriter.write("Variant loc: " + (new SimpleInterval(variant)).toString() + "\n"); + while (it.hasNext()) { + final GATKRead read = it.next(); + final String readString = read.isUnmapped() ? + "(Unmapped): " + read.getName() : + read.getName(); + outputWriter.write("Read loc: " + readString + "\n"); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + @Test + public void testReadCachingWithUnmappedReads() throws IOException { + // Query every read for every variant in the input, with caching off and then with caching one, and compare + // the results to make sure we get the same results either way, including for unmapped, placed reads. + // + // NOTE: the part of this test that doesn't use caching is relatively slow, but covers lots of interesting + // cases that are handled by the cache, including different relative order of mapped/unmapped mate pairs, and a + // mapped/unmapped mate pair that is split by the query for the first variant (the first variant triggers + // the cache to be filled using a start pos that overlaps the mapped, but not the unmapped, read of the pair, + // causing them to be separated by leaving the unmapped mate out of the initial result set used + // to prime the cache). + final File vcfFile = new File(largeFileTestDir + "1000G.phase3.broad.withGenotypes.chr20.10100000.vcf"); + final File bamFile = new File(largeFileTestDir + "CEUTrio.HiSeq.WGS.b37.NA12878.20.21.bam"); + + final GATKTool toolWithOutCaching = new TestGATKToolWithFeaturesAndCachedReads(); + final File readsWithoutCaching = createTempFile("readsWithoutCaching", ".txt"); + final String[] noCachingArgs = { + "--variant", vcfFile.getCanonicalPath(), + "--input", bamFile.getCanonicalPath(), + "--outputFileName", readsWithoutCaching.getCanonicalPath(), + "--enable-reads-caching", "false" + }; + toolWithOutCaching.instanceMain(noCachingArgs); + + final File readsWithCaching = createTempFile("readsWithCaching", ".txt"); + final String[] cachingArgs = { + "--variant", vcfFile.getCanonicalPath(), + "--input", bamFile.getCanonicalPath(), + "--outputFileName", readsWithCaching.getCanonicalPath(), + "--enable-reads-caching", "true" + }; + final GATKTool toolWithCaching = new TestGATKToolWithFeaturesAndCachedReads(); + toolWithCaching.instanceMain(cachingArgs); + + IntegrationTestSpec.assertEqualTextFiles(readsWithCaching, readsWithoutCaching); + } + } diff --git a/src/test/java/org/broadinstitute/hellbender/engine/cache/LocatableCacheUnitTest.java b/src/test/java/org/broadinstitute/hellbender/engine/cache/LocatableCacheUnitTest.java new file mode 100644 index 00000000000..5239f6e8b4c --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/engine/cache/LocatableCacheUnitTest.java @@ -0,0 +1,430 @@ +package org.broadinstitute.hellbender.engine.cache; + +import htsjdk.samtools.SAMFileHeader; +import htsjdk.tribble.Feature; +import org.apache.commons.lang3.tuple.Pair; +import org.broadinstitute.hellbender.engine.FeatureDataSource; +import org.broadinstitute.hellbender.utils.SimpleInterval; +import org.broadinstitute.hellbender.utils.read.ArtificialReadUtils; +import org.broadinstitute.hellbender.utils.read.GATKRead; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class LocatableCacheUnitTest { + + /******************************************************************************************** + * LocatableCache tests using DrivingFeatureInputCacheStrategy to test basic cache operations + ********************************************************************************************/ + + @SuppressWarnings("overrides") // because I don't want to implement hashCode() but do need an equals() here + private static class ArtificialTestFeature implements Feature { + private String chr; + private int start; + private int end; + + public ArtificialTestFeature( final String chr, final int start, final int end ) { + this.chr = chr; + this.start = start; + this.end = end; + } + //suppressing deprecation; function required because it's part of the implemented class + @Override + @SuppressWarnings("deprecation") + @Deprecated + public String getChr() { + return getContig(); + } + + @Override + public String getContig() { + return chr; + } + + @Override + public int getStart() { + return start; + } + + @Override + public int getEnd() { + return end; + } + + @Override + public boolean equals( Object other ) { + if ( other == null || ! (other instanceof ArtificialTestFeature) ) { + return false; + } + + ArtificialTestFeature otherFeature = (ArtificialTestFeature)other; + return chr.equals(otherFeature.getContig()) && start == otherFeature.getStart() && end == otherFeature.getEnd(); + } + + @Override + public String toString() { + return chr + ":" + start + "-" + end; // (to improve output on test failures involving this class) + } + } + + private LocatableCache initializeFeatureCache(final List features, final String cacheContig, final int cacheStart, final int cacheEnd, final int cacheLookAhead ) { + LocatableCache cache = new LocatableCache<>( + "test", + new DrivingFeatureInputCacheStrategy<>( + cacheLookAhead, + (SimpleInterval newCacheInterval) -> features.iterator()) + ); + + cache.queryAndPrefetch(new SimpleInterval(cacheContig, cacheStart, cacheEnd)); + return cache; + } + + @DataProvider(name = "FeatureCacheInitialQueryDataProvider") + public Object[][] getFeatureCacheInitialQueryData() { + return new Object[][] { + { Arrays.asList( + new ArtificialTestFeature("1", 1, 100), + new ArtificialTestFeature("1", 50, 150), + new ArtificialTestFeature("1", 200, 300), + new ArtificialTestFeature("1", 350, 400)), + "1", 1, 400 }, + { Arrays.asList(new ArtificialTestFeature("1", 1, 100)), "1", 1, 100 }, + { Collections.emptyList(), "1", 1, 1 } + }; + } + + @Test(dataProvider = "FeatureCacheInitialQueryDataProvider") + public void testFeatureCacheInitialQuery( final List features, final String cacheContig, final int cacheStart, final int cacheEnd) { + LocatableCache cache = initializeFeatureCache(features, cacheContig, cacheStart, cacheEnd, FeatureDataSource.DEFAULT_QUERY_LOOKAHEAD_BASES); + + List cachedFeatures = cache.getCachedLocatables(new SimpleInterval(cache.getContig(), cache.getCacheStart(), cacheEnd)); + Assert.assertEquals(cache.getContig(), cacheContig, "Wrong contig reported by cache after fill"); + Assert.assertEquals(cache.getCacheStart(), cacheStart, "Wrong start position reported by cache after fill"); + Assert.assertEquals(cache.getCacheEnd(), cacheEnd + FeatureDataSource.DEFAULT_QUERY_LOOKAHEAD_BASES, "Wrong stop position reported by cache after fill"); + Assert.assertEquals(cachedFeatures, features, "Wrong Features in cache after fill()"); + } + + @DataProvider(name = "FeatureCacheHitDetectionDataProvider") + public Object[][] getFeatureCacheHitDetectionData() { + List features = Arrays.asList(new ArtificialTestFeature("1", 1, 100), + new ArtificialTestFeature("1", 50, 150), + new ArtificialTestFeature("1", 200, 300)); + // initialize cache with 0 lookahead bases + LocatableCache cache = initializeFeatureCache(features, "1", 50, 250, 0); + + return new Object[][] { + // Exact match for cache boundaries + { cache, new SimpleInterval("1", 50, 250), true }, + // Interval completely contained within cache boundaries + { cache, new SimpleInterval("1", 100, 200), true }, + // Interval left-aligned with cache boundaries + { cache, new SimpleInterval("1", 50, 100), true }, + // Interval right-aligned with cache boundaries + { cache, new SimpleInterval("1", 200, 250), true }, + // Interval overlaps, but is off the left edge of cache boundaries + { cache, new SimpleInterval("1", 49, 100), false }, + // Interval overlaps, but is off the right edge of cache boundaries + { cache, new SimpleInterval("1", 200, 251), false }, + // Interval does not overlap and is to the left of cache boundaries + { cache, new SimpleInterval("1", 1, 40), false }, + // Interval does not overlap and is to the right of cache boundaries + { cache, new SimpleInterval("1", 300, 350), false }, + // Interval is on different contig + { cache, new SimpleInterval("2", 50, 250), false } + }; + } + + @Test(dataProvider = "FeatureCacheHitDetectionDataProvider") + public void testFeatureCacheHitDetection( final LocatableCache cache, + final SimpleInterval testInterval, final boolean cacheHitExpectedResult ) { + Assert.assertEquals(cache.cacheHit(testInterval), cacheHitExpectedResult, + "Cache hit detection failed for interval " + testInterval); + } + + @DataProvider(name = "FeatureCacheTrimmingDataProvider") + public Object[][] getFeatureCacheTrimmingData() { + // Features are required to always be sorted by start position, but stop positions need not be sorted. + // This complicates cache trimming. + List feats = Arrays.asList( + new ArtificialTestFeature("1", 1, 1), // Feature 0 + new ArtificialTestFeature("1", 1, 100), // Feature 1 + new ArtificialTestFeature("1", 1, 1), // Feature 2 + new ArtificialTestFeature("1", 1, 50), // Feature 3 + new ArtificialTestFeature("1", 1, 3), // Feature 4 + new ArtificialTestFeature("1", 1, 5), // Feature 5 + new ArtificialTestFeature("1", 5, 5), // Feature 6 + new ArtificialTestFeature("1", 5, 50), // Feature 7 + new ArtificialTestFeature("1", 5, 10), // Feature 8 + new ArtificialTestFeature("1", 50, 100), // Feature 9 + new ArtificialTestFeature("1", 50, 50), // Feature 10 + new ArtificialTestFeature("1", 50, 200), // Feature 11 + new ArtificialTestFeature("1", 100, 100), // Feature 12 + new ArtificialTestFeature("1", 100, 110), // Feature 13 + new ArtificialTestFeature("1", 100, 200), // Feature 14 + new ArtificialTestFeature("1", 100, 150), // Feature 15 + new ArtificialTestFeature("1", 100, 199) // Feature 16 + ); + // initialize the feature cache with 0 lookahead bases + LocatableCache cache = initializeFeatureCache(feats, "1", 1, 200, 0); + + // Pairing of start position to which to trimCache the cache with the List of Features we expect to see + // in the cache after trimming + List>> trimOperations = Arrays.asList( + Pair.of(1, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(2, Arrays.asList(feats.get(1), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(3, Arrays.asList(feats.get(1), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(4, Arrays.asList(feats.get(1), feats.get(3), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(5, Arrays.asList(feats.get(1), feats.get(3), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(6, Arrays.asList(feats.get(1), feats.get(3), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(10, Arrays.asList(feats.get(1), feats.get(3), feats.get(7), feats.get(8), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(11, Arrays.asList(feats.get(1), feats.get(3), feats.get(7), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(50, Arrays.asList(feats.get(1), feats.get(3), feats.get(7), feats.get(9), feats.get(10), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(51, Arrays.asList(feats.get(1), feats.get(9), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(100, Arrays.asList(feats.get(1), feats.get(9), feats.get(11), feats.get(12), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(101, Arrays.asList(feats.get(11), feats.get(13), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(111, Arrays.asList(feats.get(11), feats.get(14), feats.get(15), feats.get(16))), + Pair.of(151, Arrays.asList(feats.get(11), feats.get(14), feats.get(16))), + Pair.of(151, Arrays.asList(feats.get(11), feats.get(14), feats.get(16))), + Pair.of(200, Arrays.asList(feats.get(11), feats.get(14))) + ); + + return new Object[][] { + { cache, trimOperations } + }; + } + + @Test(dataProvider = "FeatureCacheTrimmingDataProvider") + public void testFeatureCacheTrimming( final LocatableCache cache, final List>> trimOperations ) { + // Repeatedly trimCache the cache to ever-increasing start positions, and verify after each trimCache operation + // that the cache holds the correct Features in the correct order + for ( Pair> trimOperation : trimOperations ) { + final int trimPosition = trimOperation.getLeft(); + final List expectedFeatures = trimOperation.getRight(); + + cache.queryAndPrefetch(new SimpleInterval(cache.getContig(), trimPosition, cache.getCacheEnd())); + + final List actualFeatures = cache.getCachedLocatables(new SimpleInterval(cache.getContig(), cache.getCacheStart(), cache.getCacheEnd())); + Assert.assertEquals(actualFeatures, expectedFeatures, "Wrong Features in cache after trimming start position to " + trimPosition); + } + } + + @DataProvider(name = "FeatureCacheRetrievalDataProvider") + public Object[][] getFeatureCacheRetrievalData() { + List feats = Arrays.asList( + new ArtificialTestFeature("1", 1, 1), // Feature 0 + new ArtificialTestFeature("1", 1, 100), // Feature 1 + new ArtificialTestFeature("1", 5, 5), // Feature 2 + new ArtificialTestFeature("1", 10, 10), // Feature 3 + new ArtificialTestFeature("1", 10, 100), // Feature 4 + new ArtificialTestFeature("1", 50, 50), // Feature 5 + new ArtificialTestFeature("1", 51, 55), // Feature 6 + new ArtificialTestFeature("1", 52, 52), // Feature 7 + new ArtificialTestFeature("1", 55, 60), // Feature 8 + new ArtificialTestFeature("1", 75, 75), // Feature 9 + new ArtificialTestFeature("1", 80, 100) // Feature 10 + ); + LocatableCache cache = initializeFeatureCache(feats, "1", 1, 100, FeatureDataSource.DEFAULT_QUERY_LOOKAHEAD_BASES); + + // Pairing of end position with which to bound cache retrieval with the List of Features we expect to see + // after retrieval + List>> retrievalOperations = Arrays.asList( + Pair.of(100, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10))), + Pair.of(80, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10))), + Pair.of(79, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9))), + Pair.of(80, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9), feats.get(10))), + Pair.of(75, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8), feats.get(9))), + Pair.of(74, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7), feats.get(8))), + Pair.of(54, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7))), + Pair.of(52, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6), feats.get(7))), + Pair.of(51, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5), feats.get(6))), + Pair.of(50, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4), feats.get(5))), + Pair.of(49, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4))), + Pair.of(10, Arrays.asList(feats.get(0), feats.get(1), feats.get(2), feats.get(3), feats.get(4))), + Pair.of(9, Arrays.asList(feats.get(0), feats.get(1), feats.get(2))), + Pair.of(5, Arrays.asList(feats.get(0), feats.get(1), feats.get(2))), + Pair.of(4, Arrays.asList(feats.get(0), feats.get(1))), + Pair.of(1, Arrays.asList(feats.get(0), feats.get(1))) + ); + + return new Object[][] { + { cache, retrievalOperations } + }; + } + + @Test(dataProvider = "FeatureCacheRetrievalDataProvider") + public void testFeatureCacheFeatureRetrieval( final LocatableCache cache, final List>> retrievalOperations ) { + for ( Pair> retrievalOperation: retrievalOperations ) { + final int stopPosition = retrievalOperation.getLeft(); + final List expectedFeatures = retrievalOperation.getRight(); + + final List actualFeatures = cache.getCachedLocatables(new SimpleInterval(cache.getContig(), cache.getCacheStart(), stopPosition)); + Assert.assertEquals(actualFeatures, expectedFeatures, "Wrong Features returned in retrieval operation with stop position " + stopPosition); + } + } + + /** + * Test caching a region with no Features. This should work (we should avoid going to disk + * to look for new records when querying within such a region). + */ + @Test + public void testHandleCachingOfEmptyRegion() { + List emptyRegion = new ArrayList<>(); + // initialize cache with 0 lookahead bases + LocatableCache cache = new LocatableCache<>( + "test", + new DrivingFeatureInputCacheStrategy( + 0, (SimpleInterval newInterval) -> emptyRegion.iterator())); + + // prime the cache with an empty iterator + cache.queryAndPrefetch(new SimpleInterval("1", 1, 100)); + + Assert.assertTrue(cache.isEmpty(), "Cache should be empty"); + Assert.assertTrue(cache.cacheHit(new SimpleInterval("1", 1, 100)), "Unexpected cache miss"); + Assert.assertTrue(cache.cacheHit(new SimpleInterval("1", 2, 99)), "Unexpected cache miss"); + + // query the empty region to force a (no-op) trimCache, after which we still expect a cacheHit + Assert.assertEquals( + cache.queryAndPrefetch( + new SimpleInterval(cache.getContig(), 2, cache.getCacheEnd())), + emptyRegion); + Assert.assertTrue(cache.cacheHit(new SimpleInterval("1", 2, 100)), "Unexpected cache miss"); + + // now query again and make sure we don't do another query + Assert.assertEquals( + cache.queryAndPrefetch( + new SimpleInterval(cache.getContig(), cache.getCacheStart(), 100)), + emptyRegion, + "Should get back empty List for empty region"); + } + + /*************************************************************************************************** + * End of LocatableCache tests using DrivingFeatureInputCacheStrategy to test basic cache operations + ***************************************************************************************************/ + + /******************************************************** + * LocatableCache tests using SideReadInputCacheStrategy + ********************************************************/ + + @DataProvider(name = "ReadsWithUnmappedMates") + public Object[][] getReadsWithUnmappedMates() { + // hit every code path in SideReadInputCacheStrategy: + // + // mate pair appears at an early locus/late locus (i.e., before/after other reads) + // mate pairs with mapped first/unmapped first + // query that trims/doesn't trim the pair from the cache + + final SAMFileHeader samHeader = ArtificialReadUtils.createArtificialSamHeader(); + + final SimpleInterval earlyInterval = new SimpleInterval("1", 1, 10); + final SimpleInterval lateInterval = new SimpleInterval("1", 20, 30); + + final GATKRead earlyMappedMate = ArtificialReadUtils.createArtificialRead(samHeader, "earlyMatePair", 0, earlyInterval.getStart(), 10); + earlyMappedMate.setIsPaired(true); + final GATKRead earlyUnmappedMate = ArtificialReadUtils.createArtificialRead(samHeader, "earlyMatePair", 0, earlyInterval.getStart(), 10); + earlyUnmappedMate.setIsPaired(true); + earlyUnmappedMate.setIsUnmapped(); + + final GATKRead lateMappedMate = ArtificialReadUtils.createArtificialRead(samHeader, "lateMatePair", 0, lateInterval.getStart(), 10); + lateMappedMate.setIsPaired(true); + final GATKRead lateUnmappedMate = ArtificialReadUtils.createArtificialRead(samHeader, "lateMatePair", 0, lateInterval.getStart(), 10); + lateUnmappedMate.setIsPaired(true); + lateUnmappedMate.setIsUnmapped(); + + final GATKRead earlyOtherRead1 = ArtificialReadUtils.createArtificialRead(samHeader, "earlyOtherRead1", 0, earlyInterval.getStart(), 10); + final GATKRead earlyOtherRead2 = ArtificialReadUtils.createArtificialRead(samHeader, "earlyOtherRead2", 0, earlyInterval.getStart(), 10); + final GATKRead lateOtherRead1 = ArtificialReadUtils.createArtificialRead(samHeader, "lateOtherRead1", 0, lateInterval.getStart(), 10); + final GATKRead lateOtherRead2 = ArtificialReadUtils.createArtificialRead(samHeader, "lateOtherRead2", 0, lateInterval.getStart(), 10); + + // add in coord sorted order: + // unmapped mate, followed by mapped mate, followed by other reads + final List earlyUnmappedMateFirst = new ArrayList<>(); + earlyUnmappedMateFirst.add(earlyUnmappedMate); // unmapped mate + earlyUnmappedMateFirst.add(earlyMappedMate); + earlyUnmappedMateFirst.add(lateOtherRead1); + earlyUnmappedMateFirst.add(lateOtherRead2); + + // mapped mate, followed by unmapped mate, followed by other reads + final List earlyUnmappedMateSecond = new ArrayList<>(); + earlyUnmappedMateSecond.add(earlyMappedMate); + earlyUnmappedMateSecond.add(earlyUnmappedMate); // unmapped mate + earlyUnmappedMateSecond.add(lateOtherRead1); + earlyUnmappedMateSecond.add(lateOtherRead2); + + // mate pair appears after other reads, unmapped mate first + final List lateUnmappedMateFirst = new ArrayList<>(); + lateUnmappedMateFirst.add(earlyOtherRead1); + lateUnmappedMateFirst.add(earlyOtherRead2); + lateUnmappedMateFirst.add(lateUnmappedMate); // unmapped mate + lateUnmappedMateFirst.add(lateMappedMate); + + // mate pair appears after other reads, mapped mate first + final List lateUnmappedMateSecond = new ArrayList<>(); + lateUnmappedMateSecond.add(earlyOtherRead1); + lateUnmappedMateSecond.add(earlyOtherRead2); + lateUnmappedMateSecond.add(lateMappedMate); + lateUnmappedMateSecond.add(lateUnmappedMate); // unmapped mate + + return new Object[][]{ + // cache contents, lookaheadBases, initial cache triggering interval, + // query interval, triggering interval result size, index of unmapped, query result size, empty at end + + // issue a query past the end of the mate pair to force them both to be trimmed from the cache + { earlyUnmappedMateFirst, 50, earlyInterval, new SimpleInterval("1", 20, 50), 2, 0, 2, false}, + + // issue a query that overlaps both reads in the pair to cause them both to remain in the cache + { earlyUnmappedMateSecond, 50, earlyInterval, new SimpleInterval("1", 1, 50), 2, 1, 4, false }, + + // issue a query that overlaps only the MAPPED mate in the pair to cause them both to remain in the cache + { earlyUnmappedMateSecond, 50, earlyInterval, new SimpleInterval("1", 8, 50), 2, 1, 3, false }, + + // issue a query that includes the mate pair to force them both to be trimmed from the cache + { earlyUnmappedMateSecond, 50, earlyInterval, new SimpleInterval("1", 30, 40), 2, 1, 0, true }, + + // issue a query past the end of mate pair to force them both to be trimmed from the cache + { lateUnmappedMateFirst, 50, earlyInterval, new SimpleInterval("1", 40, 50), 2, -1, 0, true }, + + // issue a query that overlaps only the MAPPED mate, so both remain in the cache + { lateUnmappedMateSecond, 50, earlyInterval, new SimpleInterval("1", 25, 50), 2, -1, 1, false } + }; + } + + @Test(dataProvider = "ReadsWithUnmappedMates") + public void testReadCacheTrimmingWithUnmappedMates( + final List reads, + final int lookAheadBases, + final SimpleInterval triggeringQueryInterval, + final SimpleInterval queryInterval, + final int returnedByTriggeringInterval, + final int indexOfUnmapped, + final int returnedByQueryIntervalQuery, + final boolean expectedEmpty + ) { + final LocatableCache cache = + new LocatableCache<>( + "test", + new SideReadInputCacheStrategy<>( + lookAheadBases, + (SimpleInterval interval) -> reads.iterator() + ) + ); + + List resultReads = cache.queryAndPrefetch(triggeringQueryInterval); + Assert.assertEquals(resultReads.size(), returnedByTriggeringInterval); + Assert.assertTrue( + indexOfUnmapped == -1 || + (resultReads.get(0).isPaired() && resultReads.get(0).isUnmapped()) + || (resultReads.get(0).isPaired() && resultReads.get(1).mateIsUnmapped())); + + // issue the query and see how many records are returned + resultReads = cache.queryAndPrefetch(queryInterval); + Assert.assertEquals(resultReads.size(), returnedByQueryIntervalQuery); + + Assert.assertEquals(cache.isEmpty(), expectedEmpty); + } + +}