Skip to content

Commit

Permalink
Updating the branch to reviewer comments and added some slightly more…
Browse files Browse the repository at this point in the history
… useful tests
  • Loading branch information
jamesemery committed Sep 17, 2024
1 parent cec1af1 commit 27beb55
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private StandardArgumentDefinitions(){}
public static final String INVALIDATE_PREVIOUS_FILTERS_LONG_NAME = "invalidate-previous-filters";
public static final String SORT_ORDER_LONG_NAME = "sort-order";
public static final String FLOW_ORDER_FOR_ANNOTATIONS = "flow-order-for-annotations";
public static final String VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME = "variant-output-interval-filtering-mode";
public static final String VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME = "variant-output-filtering";

public static final String INPUT_SHORT_NAME = "I";
public static final String OUTPUT_SHORT_NAME = "O";
Expand Down
19 changes: 6 additions & 13 deletions src/main/java/org/broadinstitute/hellbender/engine/GATKTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,6 @@ public abstract class GATKTool extends CommandLineProgram {
doc = "If true, don't emit genotype fields when writing vcf file output.", optional = true)
public boolean outputSitesOnlyVCFs = false;

@Argument(fullName = StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME,
doc = "Restrict the output variants to ones that match the specified intervals according to the specified matching mode.",
optional = true)
@Advanced
public IntervalFilteringVcfWriter.Mode outputVariantIntervalFilteringMode = getDefaultVariantOutputFilterMode();


/**
* Master sequence dictionary to be used instead of all other dictionaries (if provided).
*/
Expand Down Expand Up @@ -429,10 +422,10 @@ public int getDefaultCloudIndexPrefetchBufferSize() {
public String getProgressMeterRecordLabel() { return ProgressMeter.DEFAULT_RECORD_LABEL; }

/**
* @return Default interval filtering mode for variant output. Subclasses may override this to set a different default.
* @return null for no output filtering of variants to the variant writer. Subclasses may override this to enforce other filtering schemes.
*/
public IntervalFilteringVcfWriter.Mode getDefaultVariantOutputFilterMode(){
return null;
public IntervalFilteringVcfWriter.Mode getVariantFilteringOutputModeIfApplicable(){
return IntervalFilteringVcfWriter.Mode.ANYWHERE;
}

protected List<SimpleInterval> transformTraversalIntervals(final List<SimpleInterval> getIntervals, final SAMSequenceDictionary sequenceDictionary) {
Expand Down Expand Up @@ -751,7 +744,7 @@ protected void onStartup() {

checkToolRequirements();

if (outputVariantIntervalFilteringMode != null && userIntervals == null){
if ((getVariantFilteringOutputModeIfApplicable() != IntervalFilteringVcfWriter.Mode.ANYWHERE ) && userIntervals == null){
throw new CommandLineException.MissingArgument("-L or -XL", "Intervals are required if --" + StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME + " was specified.");
}

Expand Down Expand Up @@ -949,11 +942,11 @@ public VariantContextWriter createVCFWriter(final Path outPath) {
options.toArray(new Options[0]));
}

return outputVariantIntervalFilteringMode== null ?
return getVariantFilteringOutputModeIfApplicable() == IntervalFilteringVcfWriter.Mode.ANYWHERE ?
unfilteredWriter :
new IntervalFilteringVcfWriter(unfilteredWriter,
intervalArgumentCollection.getIntervals(getBestAvailableSequenceDictionary()),
outputVariantIntervalFilteringMode);
getVariantFilteringOutputModeIfApplicable());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.vcf.VCFHeader;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.engine.filters.CountingReadFilter;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.variant.writers.IntervalFilteringVcfWriter;

import java.util.Spliterator;

Expand All @@ -31,6 +33,28 @@ public abstract class VariantWalker extends VariantWalkerBase {
private FeatureDataSource<VariantContext> drivingVariants;
private FeatureInput<VariantContext> drivingVariantsFeatureInput;

@Argument(fullName = StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME,
doc = "Restrict the output variants to ones that match the specified intervals according to the specified matching mode.",
optional = true)
@Advanced
public IntervalFilteringVcfWriter.Mode outputVariantIntervalFilteringMode = getDefaultVariantOutputFilterMode();

/**
* @return Default interval filtering mode for variant output. Subclasses may override this to set a different default.
*/
public IntervalFilteringVcfWriter.Mode getDefaultVariantOutputFilterMode(){
return IntervalFilteringVcfWriter.Mode.ANYWHERE;
}

@Override
public IntervalFilteringVcfWriter.Mode getVariantFilteringOutputModeIfApplicable() {
if (outputVariantIntervalFilteringMode != null) {
return outputVariantIntervalFilteringMode;
} else {
return super.getVariantFilteringOutputModeIfApplicable();
}
}

@Override
protected SAMSequenceDictionary getSequenceDictionaryForDrivingVariants() { return drivingVariants.getSequenceDictionary(); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import htsjdk.variant.vcf.VCFHeader;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.barclay.argparser.CommandLineParser;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
Expand Down Expand Up @@ -32,6 +34,8 @@ boolean test(final OverlapDetector<? extends Locatable> detector, final VariantC
final SimpleInterval startPosition = new SimpleInterval(query.getContig(), query.getStart(), query.getStart());
return detector.overlapsAny(startPosition);
}
@Override
String getName() {return "STARTS_IN";}
},

/**
Expand All @@ -43,6 +47,8 @@ boolean test(final OverlapDetector<? extends Locatable> detector, final VariantC
final SimpleInterval endPosition = new SimpleInterval(query.getContig(), query.getEnd(), query.getEnd());
return detector.overlapsAny(endPosition);
}
@Override
String getName() {return "ENDS_IN";}
},

/**
Expand All @@ -53,12 +59,13 @@ boolean test(final OverlapDetector<? extends Locatable> detector, final VariantC
boolean test(final OverlapDetector<? extends Locatable> detector, final VariantContext query) {
return detector.overlapsAny(query);
}
@Override
String getName() {return "OVERLAPS";}
},

// TODO finish this exception here...
/**
* Matches if the entirety of the query is contained within one of the intervals. Note that adjacent intervals
* may be merged into a single interval depending on the values in
* may be merged into a single interval depending specified "--interval-merging-rule".
*/
CONTAINED("contained completely within a contiguous block of intervals without overlap") {
@Override
Expand All @@ -71,6 +78,8 @@ boolean test(final OverlapDetector<? extends Locatable> detector, final VariantC
}
return false;
}
@Override
String getName() {return "CONTAINED";}
},

/**
Expand All @@ -81,6 +90,8 @@ boolean test(final OverlapDetector<? extends Locatable> detector, final VariantC
boolean test(final OverlapDetector<? extends Locatable> detector, final VariantContext query) {
return true;
}
@Override
String getName() {return "ANYWHERE";}
};

private final String doc;
Expand All @@ -91,6 +102,7 @@ boolean test(final OverlapDetector<? extends Locatable> detector, final VariantC
* @return true iff the variant matches the given intervals
*/
abstract boolean test(final OverlapDetector<? extends Locatable> detector, final VariantContext query);
abstract String getName();

private Mode(String doc){
this.doc = doc;
Expand All @@ -103,9 +115,11 @@ public String getHelpDoc() {
}
}

private final VariantContextWriter writer;
private final VariantContextWriter underlyingWriter;
private final OverlapDetector<SimpleInterval> detector;
private final Mode mode;
private static int filteredCount = 0;
protected final Logger logger = LogManager.getLogger(this.getClass());

/**
* @param writer the writer to wrap
Expand All @@ -117,29 +131,32 @@ public IntervalFilteringVcfWriter(final VariantContextWriter writer, final List<
Utils.nonEmpty(intervals);
Utils.nonNull(mode);

this.writer = writer;
this.underlyingWriter = writer;
this.detector = OverlapDetector.create(intervals);
this.mode = mode;
}

@Override
public void writeHeader(final VCFHeader header) {
writer.writeHeader(header);
underlyingWriter.writeHeader(header);
}

@Override
public void setHeader(final VCFHeader header) {
writer.setHeader(header);
underlyingWriter.setHeader(header);
}

@Override
public void close() {
writer.close();
underlyingWriter.close();
if (filteredCount > 0) {
logger.info("Removed " + filteredCount + " variants from the output according to '"+mode.getName()+"' variant interval filtering rule.");
}
}

@Override
public boolean checkError() {
return writer.checkError();
return underlyingWriter.checkError();
}

/**
Expand All @@ -149,7 +166,9 @@ public boolean checkError() {
@Override
public void add(final VariantContext vc) {
if(mode.test(detector, vc)) {
writer.add(vc);
underlyingWriter.add(vc);
} else {
filteredCount++;
}
}

Expand Down
Loading

0 comments on commit 27beb55

Please sign in to comment.