diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java index 0f6a8ad90b2..c96d51e4e12 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java @@ -1,8 +1,7 @@ package org.broadinstitute.hellbender.tools.walkers.haplotypecaller; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Range; -import com.google.common.collect.Sets; +import com.google.common.collect.*; import htsjdk.samtools.CigarElement; import htsjdk.samtools.CigarOperator; import htsjdk.samtools.util.Locatable; @@ -11,6 +10,8 @@ import org.apache.commons.lang3.ArrayUtils; import org.broadinstitute.gatk.nativebindings.smithwaterman.SWOverhangStrategy; import org.broadinstitute.gatk.nativebindings.smithwaterman.SWParameters; +import org.broadinstitute.hellbender.exceptions.GATKException; +import org.broadinstitute.hellbender.utils.SmallBitSet; import org.broadinstitute.hellbender.utils.Utils; import org.broadinstitute.hellbender.utils.haplotype.Event; import org.broadinstitute.hellbender.utils.haplotype.EventMap; @@ -97,18 +98,16 @@ public static AssemblyResultSet generatePDHaplotypes(final AssemblyResultSet sou SortedMap> variantsByStartPos = eventsInOrder.stream() .collect(Collectors.groupingBy(Event::getStart, TreeMap::new, Collectors.toList())); - List> disallowedPairs = smithWatermanRealignPairsOfVariantsForEquivalentEvents(referenceHaplotype, aligner, args.getHaplotypeToReferenceSWParameters(), debug, eventsInOrder); - dragenDisallowedGroupsMessage(referenceHaplotype.getStart(), debug, disallowedPairs); - - final List eventGroups = getEventGroupClusters(eventsInOrder, disallowedPairs); - Utils.printIf(debug,() -> "Event groups after merging:\n"+eventGroups.stream().map(eg -> eg.toDisplayString(referenceHaplotype.getStart())).collect(Collectors.joining("\n"))); + List> disallowedCombinations = smithWatermanRealignPairsOfVariantsForEquivalentEvents(referenceHaplotype, aligner, args.getHaplotypeToReferenceSWParameters(), debug, eventsInOrder); + dragenDisallowedGroupsMessage(referenceHaplotype.getStart(), debug, disallowedCombinations); + final List eventGroups = getEventGroupClusters(eventsInOrder, disallowedCombinations); // if any of our merged event groups is too large, abort. - if (eventGroups.stream().anyMatch(eg -> eg.size() > MAX_VAR_IN_EVENT_GROUP)) { + if (eventGroups == null) { Utils.printIf(debug, () -> "Found event group with too many variants! Aborting haplotype building"); return sourceSet; } - eventGroups.forEach(eg -> eg.populateBitset(disallowedPairs)); + Utils.printIf(debug,() -> "Event groups after merging:\n"+eventGroups.stream().map(eg -> eg.toDisplayString(referenceHaplotype.getStart())).collect(Collectors.joining("\n"))); Set outputHaplotypes = new LinkedHashSet<>(); if (pileupArgs.determinePDHaps) { @@ -137,6 +136,7 @@ We iterate over every ref position and select single alleles (including ref) fro */ for (int determinedAlleleIndex = (pileupArgs.determinePDHaps?0:-1); determinedAlleleIndex < allEventsHere.size(); determinedAlleleIndex++) { //note -1 for I here corresponds to the reference allele at this site final boolean isRef = determinedAlleleIndex == -1; + final Set determinedEvents = isRef ? Set.of() : Set.of(allEventsHere.get(determinedAlleleIndex)); final Event determinedEventToTest = allEventsHere.get(isRef ? 0 : determinedAlleleIndex); Utils.printIf(debug, () -> "Working with allele at site: "+(isRef? "[ref:"+(thisEventGroupStart-referenceHaplotype.getStart())+"]" : PartiallyDeterminedHaplotype.getDRAGENDebugEventString(referenceHaplotype.getStart()).apply(determinedEventToTest))); // This corresponds to the DRAGEN code for @@ -162,19 +162,17 @@ We iterate over every ref position and select single alleles (including ref) fro */ for(EventGroup group : eventGroups ) { if (group.causesBranching()) { - List>> groupVCs = group.getVariantGroupsForEvent(allEventsHere, determinedAlleleIndex, true); + List> branchingSets = group.setsForBranching(allEventsHere, determinedEvents, true); // Combinatorially expand the branches as necessary List> newBranchesToAdd = new ArrayList<>(); for (Set excludedVars : branchExcludeAlleles) { //For every exclude group, fork it by each subset we have: - for (int i = 1; i < groupVCs.size(); i++) { //NOTE: iterate starting at 1 here because we special case that branch at the end - Set newSet = new HashSet<>(excludedVars); - groupVCs.get(i).stream().filter(t -> !t.b).forEach(t -> newSet.add(t.a)); - newBranchesToAdd.add(newSet); + for (int i = 1; i < branchingSets.size(); i++) { //NOTE: iterate starting at 1 here because we special case that branch at the end + newBranchesToAdd.add(Sets.union(excludedVars, branchingSets.get(i)).immutableCopy()); } // Be careful since this event group might have returned nothing - if (!groupVCs.isEmpty()) { - groupVCs.get(0).stream().filter(t -> !t.b).forEach(t -> excludedVars.add(t.a)); + if (!branchingSets.isEmpty()) { + excludedVars.addAll(branchingSets.get(0)); } } branchExcludeAlleles.addAll(newBranchesToAdd); @@ -387,6 +385,9 @@ private static List> smithWatermanRealignPairsOfVariantsForEquivalen * Partition events into clusters that must be considered together, either because they overlap or because they belong to the * same mutually exclusive pair or trio. To find this clustering we calculate the connected components of an undirected graph * with an edge connecting events that overlap or are mutually excluded. + * + * return null if any event group exceeds the allowed size -- this tells the calling code to fall back on the original + * GATK assembly. */ private static List getEventGroupClusters(List eventsInOrder, List> disallowedPairsAndTrios) { final Graph graph = new SimpleGraph<>(DefaultEdge.class); @@ -411,9 +412,9 @@ private static List getEventGroupClusters(List eventsInOrder, } } - return new ConnectivityInspector<>(graph).connectedSets().stream() - .map(set -> set.stream().sorted(HAPLOTYPE_SNP_FIRST_COMPARATOR).toList()).map(EventGroup::new) - .toList(); + final List> components = new ConnectivityInspector<>(graph).connectedSets(); + return components.stream().anyMatch(comp -> comp.size() > MAX_VAR_IN_EVENT_GROUP) ? null : + components.stream().map(component -> new EventGroup(component, disallowedPairsAndTrios)).toList(); } /** @@ -649,22 +650,25 @@ static PartiallyDeterminedHaplotype createNewPDHaplotypeFromEvents(final Haploty // A helper class for managing mutually exclusive event clusters and the logic arround forming valid events vs eachother. private static class EventGroup { - List eventsInBitmapOrder; - HashSet eventSet; - //From Illumina (there is a LOT of math that will eventually go into these)/ - BitSet allowedEvents = null; + private final ImmutableList eventsInOrder; + private final ImmutableMap eventIndices; + private final BitSet allowedEvents; // Optimization to save ourselves recomputing the subsets at every point its necessary to do so. - List>> cachedEventLists = null; - - public EventGroup(final Collection events) { - eventsInBitmapOrder = new ArrayList<>(); - eventSet = new HashSet<>(); - - for (final Event event : events) { - eventsInBitmapOrder.add(event); - eventSet.add(event); + List> cachedEventSets = null; + + public EventGroup(final Collection events, List> disallowedCombinations) { + Utils.validate(events.size() <= MAX_VAR_IN_EVENT_GROUP, () -> "Too many events (" + events.size() + ") for populating bitset."); + eventsInOrder = events.stream().sorted(HAPLOTYPE_SNP_FIRST_COMPARATOR).collect(ImmutableList.toImmutableList()); + eventIndices = IntStream.range(0, events.size()).boxed().collect(ImmutableMap.toImmutableMap(eventsInOrder::get, n -> n)); + allowedEvents = new BitSet(1 << eventsInOrder.size()); + + final List> overlappingMutexes = disallowedCombinations.stream() + .filter(mutext -> mutext.stream().anyMatch(eventIndices::containsKey)).toList(); + for (final List mutex : overlappingMutexes) { + Utils.validate(mutex.stream().allMatch(eventIndices::containsKey), () -> "Mutex group " + mutex + " only partially overlaps event group " + this); } + populateBitset(overlappingMutexes); } /** @@ -679,61 +683,41 @@ public EventGroup(final Collection events) { * Iterate through pairs of Variants that overlap and mark off any pairings including this. * Iterate through the mutex variants and ensure pairs containing all mutex variant groups are marked as true * - * @param disallowedEvents Pairs of events disallowed + * @param mutexes Groups of mutually forbidden events. Note that when this is called we have already ensured + * that each mutex group comprises only events contained in this EventGroup. */ - public void populateBitset(List> disallowedEvents) { - Utils.validate(size() <= MAX_VAR_IN_EVENT_GROUP, () -> "Too many events (" + size() + ") for populating bitset."); - if (eventsInBitmapOrder.size() < 2) { + private void populateBitset(List> mutexes) { + if (eventsInOrder.size() < 2) { return; } - allowedEvents = new BitSet(eventsInBitmapOrder.size()); - allowedEvents.flip(1, 1 << eventsInBitmapOrder.size()); // initialize all events as being allowed and then disallow them in turn . + allowedEvents.set(1, 1 << eventsInOrder.size()); - // Ensure the list is in positional order before commencing. - eventsInBitmapOrder.sort(HAPLOTYPE_SNP_FIRST_COMPARATOR); - List bitmasks = new ArrayList<>(); + final List forbiddenCombinations = new ArrayList<>(); // Mark as disallowed all events that overlap each other, excluding pairs of SNPs - for (int i = 0; i < eventsInBitmapOrder.size(); i++) { - final Event first = eventsInBitmapOrder.get(i); - for (int j = i+1; j < eventsInBitmapOrder.size(); j++) { - final Event second = eventsInBitmapOrder.get(j); + for (int i = 0; i < eventsInOrder.size(); i++) { + final Event first = eventsInOrder.get(i); + for (int j = i+1; j < eventsInOrder.size(); j++) { + final Event second = eventsInOrder.get(j); if (!(first.isSNP() && second.isSNP()) && eventsOverlapForPDHapsCode(first, second)) { - bitmasks.add(1 << i | 1 << j); + forbiddenCombinations.add(new SmallBitSet(i,j)); } } } - // mark as disallowed any sets of variants from the bitmask. - for (List disallowed : disallowedEvents) { - // - if (disallowed.stream().anyMatch(v -> eventSet.contains(v))){ - int bitmask = 0; - for (Event v : disallowed) { - int indexOfV = eventsInBitmapOrder.indexOf(v); - if (indexOfV < 0) { - throw new RuntimeException("Something went wrong in event group merging, variant "+v+" is missing from the event group despite being in a mutex pair: "+disallowed+"\n"+this); - } - bitmask += 1 << eventsInBitmapOrder.indexOf(v); - } - bitmasks.add(bitmask); - } - } - // Now iterate through the list and disallow all events with every bitmask + // make SmallBitSet of the event indices of each mutex + mutexes.stream().map(mutex -> new SmallBitSet(mutex.stream().map(eventIndices::get).toList())).forEach(forbiddenCombinations::add); + + // Now forbid all subsets that contain forbidden combinations //TODO This method is potentially very inefficient! We don't technically have to iterate over every i, //TODO we know there is an optimization involving minimizing the number of checks necessary here by iterating //TODO using the bitmask values themselves for the loop - if (!bitmasks.isEmpty()) { - events: - for (int i = 1; i < allowedEvents.length(); i++) { - for (final int mask : bitmasks) { - if ((i & mask) == mask) { // are the bits form the mask true? - allowedEvents.set(i, false); - continue events; - // Once i is set we don't need to keep checking bitmasks - } + if (!forbiddenCombinations.isEmpty()) { + for (final SmallBitSet subset = new SmallBitSet().increment(); !subset.hasElementGreaterThan(eventsInOrder.size()); subset.increment()) { + if (forbiddenCombinations.stream().anyMatch(subset::contains)) { + allowedEvents.set(subset.index(), false); } } } @@ -745,79 +729,60 @@ public void populateBitset(List> disallowedEvents) { * @param disallowSubsets * @return */ - public List>> getVariantGroupsForEvent(final List allEventsHere, final int determinedAlleleIndex, final boolean disallowSubsets) { - // If we are dealing with an external to this list event - int eventMask = 0; - int maskValues = 0; - for(int i = 0; i < allEventsHere.size(); i++) { - if (eventSet.contains(allEventsHere.get(i))) { - int index = eventsInBitmapOrder.indexOf(allEventsHere.get(i)); - eventMask = eventMask | (1 << index); - maskValues = maskValues | ((i == determinedAlleleIndex ? 1 : 0) << index); - } - } + public List> setsForBranching(final List locusEvents, final Set determinedEvents, final boolean disallowSubsets) { + final SmallBitSet locusOverlapSet = overlapSet(locusEvents); + final SmallBitSet determinedOverlapSet = overlapSet(determinedEvents); + // Special case (if we are determining bases outside of this mutex cluster we can reuse the work from previous iterations) - if (eventMask == 0 && cachedEventLists != null) { - return cachedEventLists; + if (locusOverlapSet.isEmpty() && cachedEventSets != null) { + return cachedEventSets; } - List ints = new ArrayList<>(); - // Iterate from the BACK of the list (i.e. ~supersets -> subsets) + final List allowedAndDetermined = new ArrayList<>(); + // Iterate from the full set (containing every event) to the empty set (no events), which lets us output the largest possible subsets // NOTE: we skip over 0 here since that corresponds to ref-only events, handle those externally to this code - outerLoop: - for (int i = allowedEvents.length(); i > 0; i--) { - // If the event is allowed AND if we are looking for a particular event to be present or absent. - if (allowedEvents.get(i) && (eventMask == 0 || ((i & eventMask) == maskValues))) { + for (final SmallBitSet subset = SmallBitSet.fullSet(eventsInOrder.size()); !subset.isEmpty(); subset.decrement()) { + if (allowedEvents.get(subset.index()) && subset.intersection(locusOverlapSet).equals(determinedOverlapSet)) { // Only check for subsets if we need to - if (disallowSubsets) { - for (Integer group : ints) { - // if the current i is a subset of an existing group - if ((i | group) == group) { - continue outerLoop; - } - } + if (!disallowSubsets || allowedAndDetermined.stream().noneMatch(group -> group.contains(subset))) { + allowedAndDetermined.add(subset.copy()); // copy subset since the decrement() mutates it in-place } - ints.add(i); } } // Now that we have all the mutex groups, unpack them into lists of variants - List>> output = new ArrayList<>(); - for (Integer grp : ints) { - List> newGrp = new ArrayList<>(); - for (int i = 0; i < eventsInBitmapOrder.size(); i++) { - // if the corresponding bit is 1, set it as such, otherwise set it as 0. - newGrp.add(new Tuple<>(eventsInBitmapOrder.get(i), ((1<> output = new ArrayList<>(); + for (SmallBitSet grp : allowedAndDetermined) { + Set newGrp = new HashSet<>(); + for (int i = 0; i < eventsInOrder.size(); i++) { + if (!grp.get(i)) { + newGrp.add(eventsInOrder.get(i)); + } } output.add(newGrp); } // Cache the result - if(eventMask==0) { - cachedEventLists = Collections.unmodifiableList(output); + if(locusOverlapSet.isEmpty()) { + cachedEventSets = Collections.unmodifiableList(output); } return output; } + // create the SmallBitSet of those elements from some collection of events that overlap this EventGroup + private SmallBitSet overlapSet(final Collection events) { + return new SmallBitSet(events.stream().map(e -> eventIndices.getOrDefault(e, -1)).filter(n -> n != -1).toList()); + } + public boolean causesBranching() { - return eventsInBitmapOrder.size() > 1; + return eventsInOrder.size() > 1; } //Print The event group in Illumina indexed ordering: public String toDisplayString(int startPos) { - return "EventGroup: " + formatEventsLikeDragenLogs(eventsInBitmapOrder, startPos); - } - - public boolean contains(final Event event) { - return eventSet.contains(event); + return "EventGroup: " + formatEventsLikeDragenLogs(eventsInOrder, startPos); } - public int size() { return eventsInBitmapOrder.size(); } - - public void addEvent(final Event event) { - eventsInBitmapOrder.add(event); - eventSet.add(event); - allowedEvents = null; - } + public int size() { return eventsInOrder.size(); } } private static List growEventGroup(final List group, final Event event) { diff --git a/src/main/java/org/broadinstitute/hellbender/utils/SmallBitSet.java b/src/main/java/org/broadinstitute/hellbender/utils/SmallBitSet.java new file mode 100644 index 00000000000..80b75d3562b --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/utils/SmallBitSet.java @@ -0,0 +1,152 @@ +package org.broadinstitute.hellbender.utils; + +import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAlleleCounts; +import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeIndexCalculator; +import org.broadinstitute.hellbender.utils.param.ParamUtils; + +import java.util.Collection; +import java.util.Iterator; + +/** + * Small BitSet with a capacity of 30 elements, corresponding to the number of bits in an int. + * + * Union and intersection are implemented as extremely fast bitwise | and & operators + * + * This class is very much like the standard library BitSet class in java.util, but instead of using + * a long[] array it uses a single (32-bit) int, which is much faster. This limits the capacity but + * for small sets it makes sense. + * + * This class uses the binary representation of sets, in which set {i1, i2, . . .in} maps to the + * integer 2^(i_1) + 2^(i_2) . . . + 2^(i_n). + * + * This is efficient because 2^n = 1 << n is simply the integer 1 with n bit-shifts, and because + * addition of different powers of 2 is equivalent to the bitwise OR. + * + * Equivalently, the singleton set n in binary is represented as the integer where the nth bit is 1 + * and all others are zero, and the set {i1, i2, . . .in} is represented as the integer where + * bits i1, i2. . .in are 1 and all others are zero. + * + */ +public class SmallBitSet { + public static final int MAX_ELEMENTS = 30; + private int bits; + + public SmallBitSet() { bits = 0;} + + public SmallBitSet copy() { + final SmallBitSet result = new SmallBitSet(); + result.bits = this.bits; + return result; + } + + // construct a singleton set + public SmallBitSet(final int element) { + bits = elementIndex(validateElement(element)); + } + + // construct a two-element set + public SmallBitSet(final int element1, final int element2) { + bits = elementIndex(validateElement(element1)) | elementIndex(validateElement(element2)); + } + + // construct a three-element set + public SmallBitSet(final int element1, final int element2, final int element3) { + bits = elementIndex(validateElement(element1)) | elementIndex(validateElement(element2)) | elementIndex(validateElement(element3)); + } + + public SmallBitSet(final Collection elements) { + bits = 0; + for (final int element : elements) { + bits |= elementIndex(validateElement(element)); + } + } + + // create a full bit set of all 1s in binary up to a certain number of elements i.e. 00000000000111111.... + public static SmallBitSet fullSet(final int numElements) { + validateElement(numElements); + final SmallBitSet result = new SmallBitSet(); + result.bits = (1 << numElements) - 1; + return result; + } + + // convert to the next bitset in the canonical ordering, which conveniently is just adding 1 to the underlying int. + // Useful for iterating over all possible subsets in order from empty to full. + // Calling code is responsible for starting iteration at 0 (empty bitset) and stopping iteration at 2^n - 1 for a full bitset of n elements. + public SmallBitSet increment() { + bits++; + return this; + } + + // same as above, but in the reverse order. Useful for iterating from a full bitset to the empty bitset. + public SmallBitSet decrement() { + bits--; + return this; + } + + // the bits as an integer define a unique index within the set of bitsets + // that is, bitsets can be enumerated as {}, {0}, {1}, {0, 1}, {2}, {0, 2} . . . + public int index() { return bits; } + + // intersection is equivalent to bitwise AND + public SmallBitSet intersection(final SmallBitSet other) { + final SmallBitSet result = new SmallBitSet(); + result.bits = this.bits & other.bits; + return result; + } + + // union is equivalent to bitwise OR + public SmallBitSet union(final SmallBitSet other) { + final SmallBitSet result = new SmallBitSet(); + result.bits = this.bits | other.bits; + return result; + } + + public boolean contains(final SmallBitSet other) { + return (this.bits & other.bits) == other.bits; + } + + public void add(final int element) { + bits |= elementIndex(element); + } + + public void remove(final int element) { + bits &= ~(elementIndex(element)); + } + + public void flip(final int element) { + bits ^= elementIndex(element); + } + + public boolean get(final int element) { + return (bits & elementIndex(element)) != 0; + } + + public boolean isEmpty() { return bits == 0; } + + public boolean hasElementGreaterThan(final int element) { return bits >= 1 << element; } + + private static int elementIndex(final int element) { + return 1 << element; + } + + @Override + public boolean equals(Object o) { + if (o == this) + return true; + if (!(o instanceof SmallBitSet)) + return false; + SmallBitSet other = (SmallBitSet) o; + return other.bits == this.bits; + } + + @Override + public int hashCode() { + return bits; + } + + private static int validateElement(final int element) { + ParamUtils.inRange(element, 0, MAX_ELEMENTS - 1, "Element indices must be non-negative and less than max capacity of SmallBitSet."); + return element; + } + +} \ No newline at end of file diff --git a/src/test/java/org/broadinstitute/hellbender/utils/SmallBitSetUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/SmallBitSetUnitTest.java new file mode 100644 index 00000000000..beb4ad6c924 --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/utils/SmallBitSetUnitTest.java @@ -0,0 +1,100 @@ +package org.broadinstitute.hellbender.utils; + +import com.google.common.collect.Streams; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.List; + +public class SmallBitSetUnitTest { + private static final SmallBitSet EMPTY = new SmallBitSet(); + private static final int ELEMENT_1 = 5; + private static final int ELEMENT_2 = 13; + private static final int ELEMENT_3 = 7; + private static final int DIFFERENT_ELEMENT = 20; + private static final SmallBitSet ONE_ELEMENT = new SmallBitSet(ELEMENT_1); + private static final SmallBitSet TWO_ELEMENTS = new SmallBitSet(ELEMENT_1, ELEMENT_2); + private static final SmallBitSet THREE_ELEMENTS = new SmallBitSet(ELEMENT_1, ELEMENT_2, ELEMENT_3); + + @Test + public void testConstructors() { + for (int i = 0; i < SmallBitSet.MAX_ELEMENTS; i++) { + Assert.assertFalse(EMPTY.get(i)); + Assert.assertEquals(i == ELEMENT_1, ONE_ELEMENT.get(i)); + Assert.assertEquals(i == ELEMENT_1 || i == ELEMENT_2, TWO_ELEMENTS.get(i)); + Assert.assertEquals(i == ELEMENT_1 || i == ELEMENT_2 || i == ELEMENT_3, THREE_ELEMENTS.get(i)); + } + } + + @Test + public void testIntersection() { + Assert.assertEquals((new SmallBitSet(1,3,5)).intersection(new SmallBitSet(7,9,11)), EMPTY); + + Assert.assertEquals((new SmallBitSet(1,3,5)).intersection(new SmallBitSet(5,7,9)), + (new SmallBitSet(5))); + + Assert.assertEquals((new SmallBitSet(List.of(0, 20, 5, 15, 10))) + .intersection(new SmallBitSet(List.of(0, 7, 21, 14, 20))), new SmallBitSet(0, 20)); + } + + @Test + public void testUnion() { + Assert.assertEquals((new SmallBitSet(1,3,5)).union(new SmallBitSet(7,9,11)), + new SmallBitSet(List.of(1,3,5,7,9,11))); + + Assert.assertEquals((new SmallBitSet(1,3,5)).union(new SmallBitSet(5,7,9)), + new SmallBitSet(List.of(1,3,5,7,9))); + } + + @Test + public void testContains() { + Assert.assertTrue(THREE_ELEMENTS.contains(TWO_ELEMENTS)); + Assert.assertTrue(THREE_ELEMENTS.contains(ONE_ELEMENT)); + Assert.assertTrue(TWO_ELEMENTS.contains(ONE_ELEMENT)); + Assert.assertTrue(THREE_ELEMENTS.contains(THREE_ELEMENTS)); + Assert.assertTrue(TWO_ELEMENTS.contains(TWO_ELEMENTS)); + Assert.assertTrue(ONE_ELEMENT.contains(ONE_ELEMENT)); + Assert.assertFalse(TWO_ELEMENTS.contains(THREE_ELEMENTS)); + Assert.assertFalse(ONE_ELEMENT.contains(THREE_ELEMENTS)); + Assert.assertFalse(ONE_ELEMENT.contains(TWO_ELEMENTS)); + + Assert.assertFalse(THREE_ELEMENTS.contains(new SmallBitSet(ELEMENT_1, ELEMENT_2, DIFFERENT_ELEMENT))); + } + + @Test + public void testAdd() { + final SmallBitSet copy = EMPTY.copy(); + copy.add(ELEMENT_1); + Assert.assertEquals(copy, ONE_ELEMENT); + copy.add(ELEMENT_2); + Assert.assertEquals(copy, TWO_ELEMENTS); + copy.add(ELEMENT_3); + Assert.assertEquals(copy, THREE_ELEMENTS); + } + + @Test + public void testRemove() { + final SmallBitSet copy = THREE_ELEMENTS.copy(); + copy.remove(ELEMENT_3); + Assert.assertEquals(copy, TWO_ELEMENTS); + copy.remove(ELEMENT_2); + Assert.assertEquals(copy, ONE_ELEMENT); + copy.remove(ELEMENT_1); + Assert.assertEquals(copy, EMPTY); + } + + @Test + public void testFlip() { + final SmallBitSet bitset = new SmallBitSet(List.of(5, 10, 15, 20)); + bitset.flip(10); + Assert.assertEquals(bitset, new SmallBitSet(5, 15, 20)); + bitset.flip(7); + Assert.assertEquals(bitset, new SmallBitSet(List.of(5, 7, 15, 20))); + } + + @Test + public void testGet() { + List.of(ELEMENT_1, ELEMENT_2, ELEMENT_3).forEach(el -> Assert.assertTrue(THREE_ELEMENTS.get(el))); + Assert.assertFalse(THREE_ELEMENTS.get(DIFFERENT_ELEMENT)); + } +} \ No newline at end of file