Skip to content

Commit

Permalink
Add almost-working version with Python simplifier
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher authored and molpopgen committed Feb 5, 2024
1 parent e19a5a6 commit c559f64
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 23 deletions.
20 changes: 13 additions & 7 deletions python/tests/simplify.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2019-2022 Tskit Developers
# Copyright (c) 2019-2023 Tskit Developers
# Copyright (c) 2015-2018 University of Oxford
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -114,6 +114,8 @@ def __init__(
filter_nodes=True,
update_sample_flags=True,
):
# DELETE ME
self.parent_edges_processed = 0
self.ts = ts
self.n = len(sample)
self.reduce_to_site_topology = reduce_to_site_topology
Expand Down Expand Up @@ -397,6 +399,7 @@ def process_parent_edges(self, edges):
"""
Process all of the edges for a given parent.
"""
self.parent_edges_processed += len(edges)
assert len({e.parent for e in edges}) == 1
parent = edges[0].parent
S = []
Expand Down Expand Up @@ -535,6 +538,14 @@ def insert_input_roots(self):
offset += 1
self.sort_offset = offset

def finalise(self):
if self.keep_input_roots:
self.insert_input_roots()
self.finalise_sites()
self.finalise_references()
if self.sort_offset != -1:
self.tables.sort(edge_start=self.sort_offset)

def simplify(self):
if self.ts.num_edges > 0:
all_edges = list(self.ts.edges())
Expand All @@ -545,12 +556,7 @@ def simplify(self):
edges = []
edges.append(e)
self.process_parent_edges(edges)
if self.keep_input_roots:
self.insert_input_roots()
self.finalise_sites()
self.finalise_references()
if self.sort_offset != -1:
self.tables.sort(edge_start=self.sort_offset)
self.finalise()
ts = self.tables.tree_sequence()
return ts, self.node_id_map

Expand Down
174 changes: 158 additions & 16 deletions python/tests/test_forward_sims.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,147 @@
"""
Python implementation of the low-level supporting code for forward simulations.
"""
import collections
import itertools
import random

import numpy as np
import pytest

import tskit
from tests import simplify


def simplify_with_buffer(tables, parent_buffer, samples, verbose):
# Pretend this was done efficiently internally without any sorting
# by creating a simplifier object and adding the ancstry for the
# new parents appropriately before flushing through the rest of the
# edges.
for parent, edges in parent_buffer.items():
for left, right, child in edges:
class BirthBuffer:
def __init__(self):
self.edges = {}
self.parents = []

def add_edge(self, left, right, parent, child):
if parent not in self.edges:
self.parents.append(parent)
self.edges[parent] = []
self.edges[parent].append((child, left, right))

def clear(self):
self.edges = {}
self.parents = []

def __str__(self):
s = ""
for parent in self.parents:
for child, left, right in self.edges[parent]:
s += f"{parent}\t{child}\t{left:0.3f}\t{right:0.3f}\n"
return s


def add_younger_edges_to_simplifier(simplifier, t, tables, edge_offset):
parent_edges = []
while (
edge_offset < len(tables.edges)
and tables.nodes.time[tables.edges.parent[edge_offset]] <= t
):
print("edge offset = ", edge_offset)
if len(parent_edges) == 0:
last_parent = tables.edges.parent[edge_offset]
else:
last_parent = parent_edges[-1].parent
if last_parent == tables.edges.parent[edge_offset]:
parent_edges.append(tables.edges[edge_offset])
else:
print(
"Flush ", tables.nodes.time[parent_edges[-1].parent], len(parent_edges)
)
simplifier.process_parent_edges(parent_edges)
parent_edges = []
edge_offset += 1
if len(parent_edges) > 0:
print("Flush ", tables.nodes.time[parent_edges[-1].parent], len(parent_edges))
simplifier.process_parent_edges(parent_edges)
return edge_offset


def simplify_with_births(tables, births, alive, verbose):
total_edges = len(tables.edges)
for edges in births.edges.values():
total_edges += len(edges)
if verbose > 0:
print("Simplify with births")
# print(births)
print("total_input edges = ", total_edges)
print("alive = ", alive)
print("\ttable edges:", len(tables.edges))
print("\ttable nodes:", len(tables.nodes))

simplifier = simplify.Simplifier(tables.tree_sequence(), alive)
nodes_time = tables.nodes.time
# This should be almost sorted, because
parent_time = nodes_time[births.parents]
index = np.argsort(parent_time)
print(index)
offset = 0
for parent in np.array(births.parents)[index]:
offset = add_younger_edges_to_simplifier(
simplifier, nodes_time[parent], tables, offset
)
edges = [
tskit.Edge(left, right, parent, child)
for child, left, right in sorted(births.edges[parent])
]
# print("Adding parent from time", nodes_time[parent], len(edges))
# print("edges = ", edges)
simplifier.process_parent_edges(edges)
# simplifier.print_state()

# FIXME should probably reuse the add_younger_edges_to_simplifier function
# for this - doesn't quite seem to work though
for _, edges in itertools.groupby(tables.edges[offset:], lambda e: e.parent):
edges = list(edges)
simplifier.process_parent_edges(edges)

simplifier.check_state()
assert simplifier.parent_edges_processed == total_edges
# if simplifier.parent_edges_processed != total_edges:
# print("HERE!!!!", total_edges)
simplifier.finalise()

tables.nodes.replace_with(simplifier.tables.nodes)
tables.edges.replace_with(simplifier.tables.edges)

# This is needed because we call .tree_sequence here and later.
# Can be removed is we change the Simplifier to take a set of
# tables which it modifies, like the C version.
tables.drop_index()
# Just to check
tables.tree_sequence()

births.clear()
# Add back all the edges with an alive parent to the buffer, so that
# we store them contiguously
keep = np.ones(len(tables.edges), dtype=bool)
for u in alive:
u = simplifier.node_id_map[u]
for e in np.where(tables.edges.parent == u)[0]:
keep[e] = False
edge = tables.edges[e]
# print(edge)
births.add_edge(edge.left, edge.right, edge.parent, edge.child)

if verbose > 0:
print("Done")
print(births)
print("\ttable edges:", len(tables.edges))
print("\ttable nodes:", len(tables.nodes))


def simplify_with_births_easy(tables, births, alive, verbose):
for parent, edges in births.edges.items():
for child, left, right in edges:
tables.edges.add_row(left, right, parent, child)
tables.sort()
tables.simplify(samples)
# We've exhausted the parent buffer, so clear it out. In reality we'd
# do this more carefully, like KT does in the post_simplify step.
parent_buffer.clear()
tables.simplify(alive)
births.clear()

# print(tables.nodes.time[tables.edges.parent])


def wright_fisher(
Expand All @@ -52,7 +171,7 @@ def wright_fisher(
rng = random.Random(seed)
tables = tskit.TableCollection(L)
alive = [tables.nodes.add_row(time=T) for _ in range(N)]
parent_buffer = collections.defaultdict(list)
births = BirthBuffer()

t = T
while t > 0:
Expand All @@ -66,12 +185,16 @@ def wright_fisher(
a = rng.randint(0, N - 1)
b = rng.randint(0, N - 1)
x = rng.uniform(0, L)
parent_buffer[alive[a]].append((0, x, u))
parent_buffer[alive[b]].append((x, L, u))
# TODO Possibly more natural do this like
# births.add(u, parents=[a, b], breaks=[0, x, L])
births.add_edge(0, x, alive[a], u)
births.add_edge(x, L, alive[b], u)
alive = next_alive
if t % simplify_interval == 0 or t == 0:
simplify_with_buffer(tables, parent_buffer, alive, verbose=verbose)
simplify_with_births(tables, births, alive, verbose=verbose)
# simplify_with_births_easy(tables, births, alive, verbose=verbose)
alive = list(range(N))
# print(tables.tree_sequence())
return tables.tree_sequence()


Expand Down Expand Up @@ -115,3 +238,22 @@ def test_full_simulation(self):
ts = wright_fisher(N=5, T=500, death_proba=0.9, simplify_interval=1000)
for tree in ts.trees():
assert tree.num_roots == 1


class TestSimplifyIntervals:
@pytest.mark.parametrize("interval", [1, 10, 33, 100])
def test_non_overlapping_generations(self, interval):
N = 10
ts = wright_fisher(N, T=100, death_proba=1, simplify_interval=interval)
assert ts.num_samples == N

@pytest.mark.parametrize("interval", [1, 10, 33, 100])
@pytest.mark.parametrize("death_proba", [0.33, 0.5, 0.9])
def test_overlapping_generations(self, interval, death_proba):
N = 4
ts = wright_fisher(
N, T=20, death_proba=death_proba, simplify_interval=interval, verbose=1
)
assert ts.num_samples == N
print()
print(ts.draw_text())

0 comments on commit c559f64

Please sign in to comment.