Skip to content

Commit

Permalink
Adding tests and further refactors, note that I've disovered a potent…
Browse files Browse the repository at this point in the history
…ial bug, see the todo in the tests
  • Loading branch information
lbergelson committed Dec 6, 2023
1 parent 63593fc commit 1df4aa2
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*
* @author Samuel Lee <[email protected]>
*/
public final class SimpleCount implements Locatable, Feature {
public class SimpleCount implements Locatable, Feature {

private final SimpleInterval interval;
private final int count;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import htsjdk.samtools.util.OverlapDetector;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.SimpleCount;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypingEngine;
import org.broadinstitute.hellbender.utils.Utils;

import java.util.LinkedHashMap;
import java.util.Map;
Expand All @@ -20,11 +21,18 @@ public final class MultiPloidyGenotyperCache<T extends GenotypingEngine> {
private final OverlapDetector<SimpleCount> ploidyRegions;
private final int defaultPloidy;

public MultiPloidyGenotyperCache(IntFunction<T> ploidyToGenotyper, int defaultPloidy, OverlapDetector<SimpleCount> ploidyRegions){

/**
* Create a new genotyper cache
* @param ploidyToGenotyper a function to generate a new GenotypingEngine given a ploidy
* @param defaultPloidy the default ploidy value
* @param alternatePloidyRegions a set of regions with alternate ploidys
*/
public MultiPloidyGenotyperCache(IntFunction<T> ploidyToGenotyper, int defaultPloidy, OverlapDetector<SimpleCount> alternatePloidyRegions){
this.ploidyRegions = Utils.nonNull(alternatePloidyRegions);
this.ploidyToGenotyper = Utils.nonNull(ploidyToGenotyper);
this.defaultPloidy = defaultPloidy;
this.ploidyRegions = ploidyRegions;
this.ploidyToGenotyperMap = new LinkedHashMap<>();
this.ploidyToGenotyper = ploidyToGenotyper;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ public class RampedHaplotypeCallerEngine extends HaplotypeCallerEngine {
// haplotype caller phases, as consumers
final private List<Consumer<CallRegionContext>> phases =
Arrays.asList(
this::prepare,
this::assemble,
this::computeReadLikelihoods,
this::uncollapse,
this::filter,
this::genotype
p -> prepare(p),
p -> assemble(p),
p -> computeReadLikelihoods(p),
p -> uncollapse(p),
p -> filter(p),
p -> genotype(p)
);

public RampedHaplotypeCallerEngine(final HaplotypeCallerArgumentCollection hcArgs, AssemblyRegionArgumentCollection assemblyRegionArgs, boolean createBamOutIndex,
Expand Down Expand Up @@ -145,7 +145,7 @@ private void tearRamps() {
throw new RuntimeException(e);
}
}
private static class CallRegionContext {
private class CallRegionContext {

// params
final AssemblyRegion region;
Expand Down Expand Up @@ -232,7 +232,7 @@ private void prepare(final CallRegionContext context) {

context.VCpriors = new ArrayList<>();
if (hcArgs.standardArgs.genotypeArgs.supportVariants != null) {
context.VCpriors.addAll(context.features.getValues(hcArgs.standardArgs.genotypeArgs.supportVariants));
context.features.getValues(hcArgs.standardArgs.genotypeArgs.supportVariants).stream().forEach(context.VCpriors::add);
}

if (hcArgs.sampleNameToUse != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.broadinstitute.hellbender.tools.walkers.haplotypecaller;

import htsjdk.samtools.util.OverlapDetector;
import org.broadinstitute.hellbender.GATKBaseTest;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.SimpleCount;
import org.broadinstitute.hellbender.tools.walkers.genotyper.MinimalGenotypingEngine;
import org.broadinstitute.hellbender.tools.walkers.genotyper.StandardCallerArgumentCollection;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.genotyper.SampleList;
import org.jetbrains.annotations.NotNull;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.List;

import static org.testng.Assert.*;

public class MultiPloidyGenotyperCacheUnitTest extends GATKBaseTest {


@DataProvider
public static Object[][] getGenotyperParams() {
return new Object[][]{
{"1:10-30", 2},
{"1:70-85", 3},
{"1:80-93", 3},
// {"1:89-92", 2}, TODO is this a bug?
{"1:90-95", 1},
{"1:85-120", 5},
{"2:200", 2},
{"2:1-500", 4}
};
}

@Test(dataProvider = "getGenotyperParams")
public void testGettingGenotypers(String interval, int expectedPloidy){
final var detector = getOverlapDetector();
final MultiPloidyGenotyperCache<MinimalGenotypingEngine> genotypingCache = new MultiPloidyGenotyperCache<>(this::getEngine, 2, detector);
final MinimalGenotypingEngine genotypingEngine = genotypingCache.getGenotypingEngine(new SimpleInterval(interval));
Assert.assertEquals(genotypingEngine.getGenotypeArgs().samplePloidy, expectedPloidy);
}

@NotNull
private OverlapDetector<SimpleCount> getOverlapDetector() {
final var counts = List.of(
getCount("1:80", 3),
getCount("1:90-95", 1),
getCount("1:100-110", 5),
getCount("2:85-90", 4));
return OverlapDetector.create(counts);
}

@Test
public void testReturnsTheSameGenotyper(){
final var detector = getOverlapDetector();
final MultiPloidyGenotyperCache<MinimalGenotypingEngine> genotypingCache = new MultiPloidyGenotyperCache<>(this::getEngine, 2, detector);
final MinimalGenotypingEngine genotypingEngine = genotypingCache.getGenotypingEngine(new SimpleInterval("1:10-20"));
final MinimalGenotypingEngine genotypingEngine2 = genotypingCache.getGenotypingEngine(new SimpleInterval("2:10-20"));
Assert.assertSame(genotypingEngine, genotypingEngine2);
}


private MinimalGenotypingEngine getEngine(int ploidy){
final var args = new StandardCallerArgumentCollection();
args.genotypeArgs.samplePloidy = ploidy;
final var samples = SampleList.singletonSampleList("sample1");
return new MinimalGenotypingEngine(args, samples);
}

private static SimpleCount getCount(String location, int count){
return new SimpleCount(new SimpleInterval(location), count);
}

}

0 comments on commit 1df4aa2

Please sign in to comment.