Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out distributed SMC logic #2303

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 22 additions & 101 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,57 +1016,12 @@ def initialise(self, ts):
lineage = root_lineages[node]
if lineage is not None:
seg = lineage.head
left_end = seg.left
while seg is not None:
self.set_segment_mass(seg)
lineage.tail = seg
seg = seg.next
self.add_lineage(lineage)

if self.model == "smc_k":
for node in range(ts.num_nodes):
lineage = root_lineages[node]
if lineage is not None:
seg = lineage.head
left_end = seg.left
pop = lineage.population
label = lineage.label
right_end = root_segments_tail[node].right
new_hull = self.alloc_hull(left_end, right_end, lineage)
# insert Hull
floor = self.P[pop].hulls_left[label].floor_key(new_hull)
insertion_order = 0
if floor is not None:
if floor.left == new_hull.left:
insertion_order = floor.insertion_order + 1
new_hull.insertion_order = insertion_order
self.P[pop].hulls_left[label][new_hull] = -1

# initialise the correct coalesceable pairs count
for pop in self.P:
for label, ost_left in enumerate(pop.hulls_left):
avl = ost_left.avl
ost_right = pop.hulls_right[label]
count = 0
for hull in avl.keys():
floor = ost_right.floor_key(HullEnd(hull.left))
num_ending_before_hull = 0
if floor is not None:
num_ending_before_hull = ost_right.rank[floor] + 1
num_pairs = count - num_ending_before_hull
avl[hull] = num_pairs
pop.coal_mass_index[label].set_value(hull.index, num_pairs)
# insert HullEnd
hull_end = HullEnd(hull.right)
floor = ost_right.floor_key(hull_end)
insertion_order = 0
if floor is not None:
if floor.x == hull.right:
insertion_order = floor.insertion_order + 1
hull_end.insertion_order = insertion_order
ost_right[hull_end] = -1
count += 1

def ancestors_remain(self):
"""
Returns True if the simulation is not finished, i.e., there is some ancestral
Expand Down Expand Up @@ -1198,6 +1153,15 @@ def store_edge(self, left, right, parent, child):
tskit.Edge(left=left, right=right, parent=parent, child=child)
)

def update_lineage_right(self, lineage):
if self.model == "smc_k":
# modify original hull
pop = lineage.population
hull = lineage.hull
old_right = hull.right
hull.right = min(lineage.tail.right + self.hull_offset, self.L)
self.P[pop].reset_hull_right(lineage.label, hull, old_right, hull.right)

def add_lineage(self, lineage):
pop = lineage.population
self.P[pop].add(lineage, lineage.label)
Expand All @@ -1208,6 +1172,15 @@ def add_lineage(self, lineage):
assert x.lineage == lineage
x = x.next

if self.model == "smc_k":
head = lineage.head
assert head.prev is None
hull = self.alloc_hull(head.left, head.right, lineage)
right = lineage.tail.right
hull.right = min(right + self.hull_offset, self.L)
pop = self.P[lineage.population]
pop.add_hull(lineage.label, hull)

def finalise(self):
"""
Finalises the simulation returns an msprime tree sequence object.
Expand Down Expand Up @@ -1836,22 +1809,11 @@ def hudson_recombination_event(self, label):
left_lineage.tail = x
lhs_tail = x

self.update_lineage_right(left_lineage)
right_lineage = self.alloc_lineage(alpha, left_lineage.population, label=label)
self.set_segment_mass(alpha)
self.add_lineage(right_lineage)

if self.model == "smc_k":
# modify original hull
pop = left_lineage.population
lhs_hull = lhs_tail.get_hull()
rhs_right = lhs_hull.right
lhs_hull.right = min(lhs_tail.right + self.hull_offset, self.L)
self.P[pop].reset_hull_right(label, lhs_hull, rhs_right, lhs_hull.right)

# create hull for alpha
alpha_hull = self.alloc_hull(alpha.left, rhs_right, right_lineage)
self.P[pop].add_hull(label, alpha_hull)

if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0:
self.store_node(left_lineage.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_arg_edges(lhs_tail)
Expand Down Expand Up @@ -1890,11 +1852,8 @@ def wiuf_gene_conversion_within_event(self, label):
# lbp rbp
return None
self.num_gc_events += 1
hull = y.get_hull()
assert (self.model == "smc_k") == (hull is not None)
lineage = y.lineage
pop = lineage.population
reset_right = -1

# Process left break
insert_alpha = True
Expand All @@ -1913,7 +1872,6 @@ def wiuf_gene_conversion_within_event(self, label):
insert_alpha = False
else:
x.next = None
reset_right = x.right
y.prev = None
alpha = y
tail = x
Expand All @@ -1935,15 +1893,11 @@ def wiuf_gene_conversion_within_event(self, label):
y.right = left_breakpoint
self.set_segment_mass(y)
tail = y
reset_right = left_breakpoint
self.set_segment_mass(alpha)

# Find the segment z that the right breakpoint falls in
z = alpha
hull_left = z.left
hull_right = -1
while z is not None and right_breakpoint >= z.right:
hull_right = z.right
z = z.next

head = None
Expand All @@ -1967,7 +1921,6 @@ def wiuf_gene_conversion_within_event(self, label):
z.right = right_breakpoint
z.next = None
self.set_segment_mass(z)
hull_right = right_breakpoint
else:
# tail z
# ======
Expand All @@ -1985,12 +1938,6 @@ def wiuf_gene_conversion_within_event(self, label):
tail.next = head
head.prev = tail
self.set_segment_mass(head)
else:
# rbp lies beyond segment chain, regular recombination logic applies
if insert_alpha and self.model == "smc_k":
assert reset_right > 0
reset_right = min(reset_right + self.hull_offset, self.L)
self.P[pop].reset_hull_right(label, hull, hull.right, reset_right)

# y z
# | ========== ... ===== |
Expand All @@ -2005,12 +1952,8 @@ def wiuf_gene_conversion_within_event(self, label):
if new_individual_head is not None:
# FIXME when doing the smc_k update
lineage.reset_segments()
self.update_lineage_right(lineage)
new_lineage = self.alloc_lineage(new_individual_head, pop)
if self.model == "smc_k":
assert hull_left < hull_right
hull_right = min(self.L, hull_right + self.hull_offset)
hull = self.alloc_hull(hull_left, hull_right, new_lineage)
self.P[new_lineage.population].add_hull(new_lineage.label, hull)
self.add_lineage(new_lineage)

def wiuf_gene_conversion_left_event(self, label):
Expand Down Expand Up @@ -2042,8 +1985,6 @@ def wiuf_gene_conversion_left_event(self, label):
x = y.prev
lineage = y.lineage
pop = lineage.population
lhs_hull = y.get_hull()
assert (self.model == "smc_k") == (lhs_hull is not None)
if y.left < bp:
# x y
# ===== =====|====
Expand All @@ -2061,7 +2002,6 @@ def wiuf_gene_conversion_left_event(self, label):
y.next = None
y.right = bp
self.set_segment_mass(y)
right = y.right
else:
# x y
# ===== | =========
Expand All @@ -2075,19 +2015,10 @@ def wiuf_gene_conversion_left_event(self, label):
x.next = None
y.prev = None
alpha = y
right = x.right

if self.model == "smc_k":
# lhs logic is identical to the lhs recombination event
lhs_old_right = lhs_hull.right
lhs_new_right = min(self.L, right + self.hull_offset)
self.P[pop].reset_hull_right(label, lhs_hull, lhs_old_right, lhs_new_right)

# rhs
hull = self.alloc_hull(alpha.left, lhs_old_right, alpha)
self.P[pop].add_hull(label, hull)

# FIXME
lineage.reset_segments()
self.update_lineage_right(lineage)

self.set_segment_mass(alpha)
assert alpha.prev is None
Expand Down Expand Up @@ -2574,16 +2505,6 @@ def insert_merged_lineage(
# assert tail == new_lineage.tail
self.add_lineage(new_lineage)

if self.model == "smc_k":
merged_head = new_lineage.head
assert merged_head.prev is None
hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage)
while merged_head is not None:
right = merged_head.right
merged_head = merged_head.next
hull.right = min(right + self.hull_offset, self.L)
pop = self.P[new_lineage.population]
pop.add_hull(new_lineage.label, hull)
return new_lineage

def print_state(self, verify=False):
Expand Down
Loading
Loading