Skip to content

Commit

Permalink
Rebase and add better tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lbergelson committed Aug 19, 2024
1 parent e96d439 commit 578a951
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private StandardArgumentDefinitions(){}
public static final String INVALIDATE_PREVIOUS_FILTERS_LONG_NAME = "invalidate-previous-filters";
public static final String SORT_ORDER_LONG_NAME = "sort-order";
public static final String FLOW_ORDER_FOR_ANNOTATIONS = "flow-order-for-annotations";

public static final String VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME = "variant-output-interval-filtering-mode";

public static final String INPUT_SHORT_NAME = "I";
public static final String OUTPUT_SHORT_NAME = "O";
Expand Down
10 changes: 2 additions & 8 deletions src/main/java/org/broadinstitute/hellbender/engine/GATKTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.barclay.argparser.CommandLinePluginDescriptor;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKAnnotationPluginDescriptor;
Expand Down Expand Up @@ -51,10 +50,6 @@
import org.broadinstitute.hellbender.utils.variant.writers.ShardingVCFWriter;
import org.broadinstitute.hellbender.utils.variant.writers.IntervalFilteringVcfWriter;

//TODO:
//UserException overloads
//VCF outs

/**
* Base class for all GATK tools. Tool authors that want to write a "GATK" tool but not use one of
* the pre-packaged Walker traversals should feel free to extend this class directly. All other
Expand Down Expand Up @@ -136,8 +131,7 @@ public abstract class GATKTool extends CommandLineProgram {
doc = "If true, don't emit genotype fields when writing vcf file output.", optional = true)
public boolean outputSitesOnlyVCFs = false;

public static final String VARIANT_OUTPUT_INTERVAL_FILTERING_MODE = "variant-output-interval-filtering-mode";
@Argument(fullName = VARIANT_OUTPUT_INTERVAL_FILTERING_MODE,
@Argument(fullName = StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME,
doc = "Restrict the output variants to ones that match the specified intervals according to the specified matching mode.",
optional = true)
@Advanced
Expand Down Expand Up @@ -758,7 +752,7 @@ protected void onStartup() {
checkToolRequirements();

if (outputVariantIntervalFilteringMode != null && userIntervals == null){
throw new CommandLineException.MissingArgument("-L or -XL", "Intervals are required if --" + VARIANT_OUTPUT_INTERVAL_FILTERING_MODE + " was specified.");
throw new CommandLineException.MissingArgument("-L or -XL", "Intervals are required if --" + StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME + " was specified.");
}

initializeProgressMeter(getProgressMeterRecordLabel());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.barclay.argparser.CommandLinePluginDescriptor;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKAnnotationPluginDescriptor;
Expand All @@ -18,7 +19,6 @@
import org.broadinstitute.hellbender.cmdline.argumentcollections.DbsnpArgumentCollection;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
Expand All @@ -41,13 +41,14 @@
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;
import org.broadinstitute.hellbender.utils.variant.writers.IntervalFilteringVcfWriter;
import org.broadinstitute.hellbender.tools.walkers.annotator.allelespecific.ReducibleAnnotation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -129,7 +130,7 @@ public final class GenotypeGVCFs extends VariantLocusWalker {
/**
* Import all data between specified intervals. Improves performance using large lists of intervals, as in exome
* sequencing, especially if GVCF data only exists for specified intervals. Use with
* --{@value GATKTool#VARIANT_OUTPUT_INTERVAL_FILTERING_MODE} if input GVCFs contain calls outside the specified intervals.
* --{@value StandardArgumentDefinitions#VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME} if input GVCFs contain calls outside the specified intervals.
*/
@Argument(fullName = GenomicsDBImport.MERGE_INPUT_INTERVALS_LONG_NAME,
shortName = GenomicsDBImport.MERGE_INPUT_INTERVALS_LONG_NAME,
Expand Down Expand Up @@ -271,7 +272,7 @@ public void onTraversalStart() {

final VCFHeader inputVCFHeader = getHeaderForVariants();

final Collection<Annotation> variantAnnotations = makeVariantAnnotations();
final Collection<Annotation> variantAnnotations = makeVariantAnnotations();
final Set<Annotation> annotationsToKeep = getAnnotationsToKeep();
annotationEngine = new VariantAnnotatorEngine(variantAnnotations, dbsnp.dbsnp, Collections.emptyList(), false, keepCombined, annotationsToKeep);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ public void apply(VariantContext variant, ReadsContext reads, ReferenceContext r
//return early if there's no non-symbolic ALT since GDB already did the merging
if ( !variant.isVariant() || !GATKVariantContextUtils.isProperlyPolymorphic(variant)
|| variant.getAttributeAsInt(VCFConstants.DEPTH_KEY,0) == 0 )
// todo this changes is a slight de-optimization since we will now process some sites whihc were previously ignored
{
if (keepAllSites) {
VariantContextBuilder builder = new VariantContextBuilder(mqCalculator.finalizeRawMQ(variant)); //don't fill in QUAL here because there's no alt data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public enum Mode implements CommandLineParser.ClpEnum {
STARTS_IN("starts within any of the given intervals"){

@Override
boolean test(OverlapDetector<? extends Locatable> detector, final VariantContext query) {
boolean test(final OverlapDetector<? extends Locatable> detector, final VariantContext query) {
final SimpleInterval startPosition = new SimpleInterval(query.getContig(), query.getStart(), query.getStart());
return detector.overlapsAny(startPosition);
}
Expand Down Expand Up @@ -90,12 +90,13 @@ boolean test(final OverlapDetector<? extends Locatable> detector, final VariantC
* @param query The variant being tested
* @return true iff the variant matches the given intervals
*/
abstract boolean test(OverlapDetector<? extends Locatable> detector, VariantContext query);
abstract boolean test(final OverlapDetector<? extends Locatable> detector, final VariantContext query);

private Mode(String doc){
this.doc = doc;

}

@Override
public String getHelpDoc() {
return doc;
Expand All @@ -111,7 +112,7 @@ public String getHelpDoc() {
* @param intervals the intervals to compare against, note that these are not merged so if they should be merged than the input list should be preprocessed
* @param mode the matching mode to use
*/
public IntervalFilteringVcfWriter(final VariantContextWriter writer, List<SimpleInterval> intervals, Mode mode) {
public IntervalFilteringVcfWriter(final VariantContextWriter writer, final List<SimpleInterval> intervals, final Mode mode) {
Utils.nonNull(writer);
Utils.nonEmpty(intervals);
Utils.nonNull(mode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.TestProgramGroup;
import org.broadinstitute.hellbender.testutils.ArgumentsBuilder;
import org.broadinstitute.hellbender.testutils.BaseTest;
import org.broadinstitute.hellbender.testutils.VariantContextTestUtils;
import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2;
import org.broadinstitute.hellbender.tools.walkers.variantutils.SelectVariants;
import org.broadinstitute.hellbender.testutils.VariantContextTestUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.variant.writers.IntervalFilteringVcfWriter;
Expand Down Expand Up @@ -107,6 +107,17 @@ public void testSharding() {
oneLineSummary = "Test tool",
programGroup = TestProgramGroup.class)
public static class VariantEmitter extends GATKTool{
static final SimpleInterval INT1 = new SimpleInterval("1",10, 15);
static final SimpleInterval INT2 = new SimpleInterval("1",100, 105);
static final SimpleInterval INT3 = new SimpleInterval("1",1000, 1005);
static final SimpleInterval INT4 = new SimpleInterval("1",10000, 10005);
static final SimpleInterval INT5 = new SimpleInterval("2",20, 25);
static final SimpleInterval INT6 = new SimpleInterval("2",200, 205);
static final SimpleInterval INT7 = new SimpleInterval("2",2000, 2005);
static final SimpleInterval INT8 = new SimpleInterval("2",20000, 20005);

static final List<Locatable> INTERVALS = List.of(INT1, INT2, INT3, INT4, INT5, INT6, INT7, INT8);

@Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME)
File output;

Expand All @@ -122,46 +133,48 @@ public void onTraversalStart() {
final VariantContextBuilder vcb = new VariantContextBuilder();
vcb.alleles("AAAAAA", "A").chr("1");

vcfWriter.add(vcb.start(10).stop(15).make());
vcfWriter.add(vcb.start(100).stop(105).make());
vcfWriter.add(vcb.start(1000).stop(1005).make());
vcfWriter.add(vcb.start(10000).stop(10005).make());

vcb.chr("2");
vcfWriter.add(vcb.start(20).stop(25).make());
vcfWriter.add(vcb.start(200).stop(205).make());
vcfWriter.add(vcb.start(2000).stop(2005).make());
vcfWriter.add(vcb.start(20000).stop(20005).make());
for(final Locatable interval : INTERVALS){
vcfWriter.add(vcb.loc(interval.getContig(),interval.getStart(), interval.getEnd()).make());
}
}
}
}

@DataProvider
public Object[][] getIntervalsAndOverlapMode(){
final SimpleInterval chr1Interval = new SimpleInterval("1", 101, 10001);
final SimpleInterval chr2Interval = new SimpleInterval("2", 201, 20001);
return new Object[][]{
{Arrays.asList(new SimpleInterval("1", 101, 10001), new SimpleInterval("2", 201, 20001)), IntervalFilteringVcfWriter.Mode.ANYWHERE, 8},
{Arrays.asList(new SimpleInterval("1", 101, 10001), new SimpleInterval("2", 201, 20001)), IntervalFilteringVcfWriter.Mode.OVERLAPS, 6},
{Arrays.asList(new SimpleInterval("1", 101, 10001), new SimpleInterval("2", 201, 20001)), IntervalFilteringVcfWriter.Mode.STARTS_IN, 4},
{Arrays.asList(new SimpleInterval("1", 101, 10001), new SimpleInterval("2", 201, 20001)), IntervalFilteringVcfWriter.Mode.ENDS_IN, 4},
{Arrays.asList(new SimpleInterval("1", 101, 10001), new SimpleInterval("2", 201, 20001)), IntervalFilteringVcfWriter.Mode.CONTAINED, 2},
{Arrays.asList(new SimpleInterval("1", 101, 10001), new SimpleInterval("2", 201, 20001)), null, 8},
{Arrays.asList(chr1Interval, chr2Interval), IntervalFilteringVcfWriter.Mode.ANYWHERE, VariantEmitter.INTERVALS },
{Arrays.asList(chr1Interval, chr2Interval), IntervalFilteringVcfWriter.Mode.OVERLAPS, List.of(VariantEmitter.INT2, VariantEmitter.INT3, VariantEmitter.INT4, VariantEmitter.INT6, VariantEmitter.INT7, VariantEmitter.INT8)},
{Arrays.asList(chr1Interval, chr2Interval), IntervalFilteringVcfWriter.Mode.STARTS_IN, List.of(VariantEmitter.INT3, VariantEmitter.INT4, VariantEmitter.INT7, VariantEmitter.INT8)},
{Arrays.asList(chr1Interval, chr2Interval), IntervalFilteringVcfWriter.Mode.ENDS_IN, List.of(VariantEmitter.INT2, VariantEmitter.INT3, VariantEmitter.INT6, VariantEmitter.INT7)},
{Arrays.asList(chr1Interval, chr2Interval), IntervalFilteringVcfWriter.Mode.CONTAINED, List.of(VariantEmitter.INT3, VariantEmitter.INT7)},
{Arrays.asList(chr1Interval, chr2Interval), null, VariantEmitter.INTERVALS},
};
}

@Test(dataProvider = "getIntervalsAndOverlapMode")
public void testVcfOutputFilterMode(List<? extends Locatable> intervals, IntervalFilteringVcfWriter.Mode mode, int variantsIncluded){
public void testVcfOutputFilterMode(List<? extends Locatable> intervals, IntervalFilteringVcfWriter.Mode mode, List<Locatable> expected){
final ArgumentsBuilder args = new ArgumentsBuilder();
final File out = createTempFile("out", ".vcf");
args.addOutput(out);
intervals.forEach(args::addInterval);
args.addReference(b37Reference);
if( mode != null) {
args.add(GATKTool.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE, mode);
args.add(StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME, mode);
}

runCommandLine(args, VariantEmitter.class.getSimpleName());
final Pair<VCFHeader, List<VariantContext>> vcfHeaderListPair = VariantContextTestUtils.readEntireVCFIntoMemory(out.toString());

Assert.assertEquals(vcfHeaderListPair.getRight().size(), variantsIncluded);
final List<VariantContext> actual = vcfHeaderListPair.getRight();
Assert.assertEquals(actual.size(), expected.size());
BaseTest.assertCondition(actual, expected, (left, right) -> {
Assert.assertEquals(left.getContig(), right.getContig());
Assert.assertEquals(left.getStart(), right.getStart());
Assert.assertEquals(left.getEnd(), right.getEnd());
} );

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.hellbender.CommandLineProgramTest;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.testutils.ArgumentsBuilder;
import org.broadinstitute.hellbender.testutils.GenomicsDBTestUtils;
Expand Down Expand Up @@ -85,7 +84,7 @@ public Object[][] gvcfsToGenotype() {
//combine not supported yet, see https://github.com/broadinstitute/gatk/issues/2429 and https://github.com/broadinstitute/gatk/issues/2584
//{"combine.single.sample.pipeline.1.vcf", null, Arrays.asList("-V", getTestFile("combine.single.sample.pipeline.2.vcf").toString() , "-V", getTestFile("combine.single.sample.pipeline.3.vcf").toString()), b37_reference_20_21},

{getTestFile("leadingDeletion.g.vcf"), getTestFile("leadingDeletionRestrictToStartExpected.vcf"), Arrays.asList("-L", "20:69512-69513", "--"+GATKTool.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE, IntervalFilteringVcfWriter.Mode.STARTS_IN.toString()), b37_reference_20_21},
{getTestFile("leadingDeletion.g.vcf"), getTestFile("leadingDeletionRestrictToStartExpected.vcf"), Arrays.asList("-L", "20:69512-69513", "--"+ StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME, IntervalFilteringVcfWriter.Mode.STARTS_IN.toString()), b37_reference_20_21},
{getTestFile("leadingDeletion.g.vcf"), getTestFile("leadingDeletionExpected.vcf"), Arrays.asList("-L", "20:69512-69513"), b37_reference_20_21},
{getTestFile(BASE_PAIR_GVCF), getTestFile( BASE_PAIR_EXPECTED), NO_EXTRA_ARGS, b37_reference_20_21}, //base pair level gvcf
{getTestFile("testUpdatePGT.gvcf"), getTestFile( "testUpdatePGT.gatk3.7_30_ga4f720357.output.vcf"), NO_EXTRA_ARGS, b37_reference_20_21}, //testUpdatePGT
Expand Down Expand Up @@ -435,7 +434,7 @@ public void testGenotypeGVCFsMultiIntervalGDBQuery(File input, File expected, Li
args.addOutput(output);
intervals.forEach(args::addInterval);
args.add(GenomicsDBImport.MERGE_INPUT_INTERVALS_LONG_NAME, true);
args.add(GATKTool.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE, IntervalFilteringVcfWriter.Mode.STARTS_IN); //note that this will restrict calls to just the specified intervals
args.add(StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME, IntervalFilteringVcfWriter.Mode.STARTS_IN); //note that this will restrict calls to just the specified intervals

runAndCheckGenomicsDBOutput(args, expected, output);

Expand Down Expand Up @@ -572,7 +571,7 @@ public void testIntervalsAndOnlyOutputCallsStartingInIntervalsAreMutuallyRequire
.addVCF(getTestFile("leadingDeletion.g.vcf"))
.addReference(new File(b37_reference_20_21))
.addOutput( createTempFile("tmp",".vcf"))
.add(GATKTool.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE, IntervalFilteringVcfWriter.Mode.STARTS_IN);
.add(StandardArgumentDefinitions.VARIANT_OUTPUT_INTERVAL_FILTERING_MODE_LONG_NAME, IntervalFilteringVcfWriter.Mode.STARTS_IN);

Assert.assertThrows(CommandLineException.MissingArgument.class, () -> runCommandLine(args));
args.add("L", "20:69512-69513");
Expand Down
Loading

0 comments on commit 578a951

Please sign in to comment.