From 0cb08618165db5228128f5c71194ca0e2cf0ed24 Mon Sep 17 00:00:00 2001 From: David Benjamin Date: Tue, 27 Jun 2023 12:55:52 -0400 Subject: [PATCH] graph method for PDHMM event groups that unifies finding/merging and overlap/mutual exclusion (#8366) --- ...yDeterminedHaplotypeComputationEngine.java | 82 +++++++++---------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java index f4fb70d66a4..2306144c2f7 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/haplotypecaller/PartiallyDeterminedHaplotypeComputationEngine.java @@ -20,6 +20,10 @@ import org.broadinstitute.hellbender.utils.read.CigarBuilder; import org.broadinstitute.hellbender.utils.read.CigarUtils; import org.broadinstitute.hellbender.utils.smithwaterman.SmithWatermanAligner; +import org.jgrapht.Graph; +import org.jgrapht.alg.ConnectivityInspector; +import org.jgrapht.graph.DefaultEdge; +import org.jgrapht.graph.SimpleGraph; import java.util.*; import java.util.stream.Collectors; @@ -94,44 +98,11 @@ public static AssemblyResultSet generatePDHaplotypes(final AssemblyResultSet sou .collect(Collectors.groupingBy(e -> dragenStart(e), LinkedHashMap::new, Collectors.toList())); SortedMap> variantsByStartPos = eventsInOrder.stream() .collect(Collectors.groupingBy(Event::getStart, TreeMap::new, Collectors.toList())); - List eventGroups = new ArrayList<>(); - int lastEventEnd = -1; - for (Event vc : eventsInOrder) { - // Break everything into independent groups (don't worry about transitivitiy right now) - Double eventKey = dragenStart(vc) - referenceHaplotype.getStart(); - if (eventKey <= lastEventEnd + 0.5) { - eventGroups.get(eventGroups.size()-1).addEvent(vc); - } else { - eventGroups.add(new EventGroup(vc)); - } - int newEnd = (vc.getEnd() - referenceHaplotype.getStart()); - lastEventEnd = Math.max(newEnd, lastEventEnd); - } - eventGroupsMessage(referenceHaplotype, debug, eventsByDRAGENCoordinates); - - // Iterate over all events starting with all indels List> disallowedPairs = smithWatermanRealignPairsOfVariantsForEquivalentEvents(referenceHaplotype, aligner, args.getHaplotypeToReferenceSWParameters(), debug, eventsInOrder); dragenDisallowedGroupsMessage(referenceHaplotype.getStart(), debug, disallowedPairs); - Utils.printIf(debug, () -> "Event groups before merging:\n"+eventGroups.stream().map(eg -> eg.toDisplayString(referenceHaplotype.getStart())).collect(Collectors.joining("\n"))); - - //Now that we have the disallowed groups, lets merge any of them from separate groups: - //TODO this is not an efficient way of doing this - for (List pair : disallowedPairs) { - EventGroup eventGrpLeft = null; - for (Event event : pair) { - EventGroup grpForEvent = eventGroups.stream().filter(grp -> grp.contains(event)).findFirst().get(); - // If the event isn't in the same event group as its predecessor, merge this group with that one and - if (eventGrpLeft != grpForEvent) { - if (eventGrpLeft == null) { - eventGrpLeft = grpForEvent; - } else { - eventGrpLeft.mergeEvent(grpForEvent); - eventGroups.remove(grpForEvent); - } - } - } - } + + final List eventGroups = getEventGroupClusters(eventsInOrder, disallowedPairs); Utils.printIf(debug,() -> "Event groups after merging:\n"+eventGroups.stream().map(eg -> eg.toDisplayString(referenceHaplotype.getStart())).collect(Collectors.joining("\n"))); // if any of our merged event groups is too large, abort. @@ -414,6 +385,39 @@ private static List> smithWatermanRealignPairsOfVariantsForEquivalen return disallowedPairs; } + /** + * Partition events into clusters that must be considered together, either because they overlap or because they belong to the + * same mutually exclusive pair or trio. To find this clustering we calculate the connected components of an undirected graph + * with an edge connecting events that overlap or are mutually excluded. + */ + private static List getEventGroupClusters(List eventsInOrder, List> disallowedPairsAndTrios) { + final Graph graph = new SimpleGraph<>(DefaultEdge.class); + eventsInOrder.forEach(graph::addVertex); + + // edges due to overlapping position + for (int e1 = 0; e1 < eventsInOrder.size(); e1++) { + final Event event1 = eventsInOrder.get(e1); + for (int e2 = e1 + 1; e2 < eventsInOrder.size() && eventsInOrder.get(e2).getStart() <= event1.getEnd() + 1; e2++) { + final Event event2 = eventsInOrder.get(e2); + if (eventsOverlapForPDHapsCode(event1, event2)) { + graph.addEdge(event1, event2); + } + } + } + + // edges due to mutual exclusion + for (final List excludedGroup : disallowedPairsAndTrios) { + graph.addEdge(excludedGroup.get(0), excludedGroup.get(1)); + if (excludedGroup.size() == 3) { + graph.addEdge(excludedGroup.get(1), excludedGroup.get(2)); + } + } + + return new ConnectivityInspector<>(graph).connectedSets().stream() + .map(set -> set.stream().sorted(HAPLOTYPE_SNP_FIRST_COMPARATOR).toList()).map(EventGroup::new) + .toList(); + } + /** * Overlaps method to handle indels and snps correctly. Specifically for this branching codes purposes, * indels don't overlap on their anchor bases and insertions don't overlap anything except deletions spanning them or other insertions @@ -714,7 +718,7 @@ private static class EventGroup { // Optimization to save ourselves recomputing the subsets at every point its necessary to do so. List>> cachedEventLists = null; - public EventGroup(final Event ... events) { + public EventGroup(final Collection events) { eventsInBitmapOrder = new ArrayList<>(); eventSet = new HashSet<>(); @@ -910,12 +914,6 @@ private static String formatEventsLikeDragenLogs(final Collection events, .collect(Collectors.joining(delimiter)); } - private static void eventGroupsMessage(final Haplotype referenceHaplotype, final boolean debug, final Map> eventsByDRAGENCoordinates) { - Utils.printIf(debug, () -> eventsByDRAGENCoordinates.entrySet().stream() - .map(e -> String.format("%.1f", e.getKey()) + " -> " + formatEventsLikeDragenLogs(e.getValue(), referenceHaplotype.getStart(),",")) - .collect(Collectors.joining("\n"))); - } - private static void removeBadPileupEventsMessage(final boolean debug, final AssemblyResultSet assemblyResultSet, final Set badPileupEvents) { if (debug) { final Set intersection = Sets.intersection(assemblyResultSet.getVariationEvents(0), badPileupEvents);