-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tutorial: tree algorithms with numba #63
Comments
I've been working on this: it's absolutely awesome what numba can do, and it works beautifully with the array based tree representation. I'll post some comments here on what I've done. I'm using a tree sequence with 1 million samples here as the test case: ts1m = msprime.sim_ancestry(1e6, ploidy=1, random_seed=42) MRCAsFirst up, compute mrcas using (using the nice new algorithm, tskit-dev/tskit#1313) ts = ts1m
tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
`for u in range(ts.num_nodes):
parent[u] = tree.parent(u)
@numba.jit(nopython=True)
def get_mrca_numba(u, v):
tu = time[u]
tv = time[v]
while u != v:
if tu < tv:
u = parent[u]
if u == tskit.NULL:
return tskit.NULL
tu = time[u]
else:
v = parent[v]
if v == tskit.NULL:
return tskit.NULL
tv = time[v]
return u I'm putting the Timings:
Whoa! The numba jit version is a little bit faster than the library version (which is the updated, non malloc version)! This is a fast function though, so the overhead of the Python C interface is probably what's creating the difference. But still - I didn't see that coming. |
Total branch lengthLet's see how we do with a longer running function that takes a bit more computation. We'll compute the total branch length, but doing a simple top-down traversal. First we get the ts = ts1m
tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
left_child = np.zeros(ts.num_nodes, dtype=np.int32)
right_sib = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
for u in range(ts.num_nodes):
parent[u] = tree.parent(u)
left_child[u] = tree.left_child(u)
right_sib[u] = tree.right_sib(u)
@numba.njit()
def total_branch_length_numba(root):
tbl = 0
stack = [root]
while len(stack) > 0:
u = stack.pop()
v = left_child[u]
while v != tskit.NULL:
tbl += time[u] - time[v]
stack.append(v)
v = right_sib[v]
return tbl
@numba.njit()
def total_branch_length_numba_recursive(u):
tbl = 0
v = left_child[u]
while v != tskit.NULL:
tbl += (time[u] - time[v]) + total_branch_length_numba_recursive(v)
v = right_sib[v]
return tbl Timings:
So, the C version in For the record, here's what the C code looks like:) int
tsk_tree_get_total_branch_length(
const tsk_tree_t *self, tsk_id_t root, double *total_branch_length)
{
int ret = 0;
tsk_id_t u, v;
int stack_top;
double tbl = 0;
const double *restrict time = self->tree_sequence->tables->nodes.time;
const tsk_id_t *restrict right_child = self->right_child;
const tsk_id_t *restrict left_sib = self->left_sib;
tsk_id_t *stack = malloc(self->num_nodes * sizeof(*stack));
if (stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
ret = tsk_tree_check_node(self, root);
if (ret != 0) {
goto out;
}
stack_top = 0;
stack[stack_top] = root;
while (stack_top >= 0) {
u = stack[stack_top];
stack_top--;
for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) {
tbl += time[u] - time[v];
stack_top++;
stack[stack_top] = v;
}
}
*total_branch_length = tbl;
out:
tsk_safe_free(stack);
return ret;
} This isn't in the library currently - I might add the function, if we think it's worth while. |
Postorder sumSince preorder via recursion was fast, let's try postorder (which is much easier to do via recursion). Propagating a sum up the tree is a fundamental operation. @numba.njit()
def postorder_sum_numba(u, x):
v = left_child[u]
while v != tskit.NULL:
postorder_sum_numba(v, x)
x[u] += x[v]
v = right_sib[v]
def count_nodes():
a = np.zeros(ts.num_nodes)
a[ts.samples()] = 1
postorder_sum_numba(tree.root, a)
return a Here, we just count the number of nodes that are below each node and return this array, but it could be anything. Note in particular that we're using numpy indexing here, and this should work for nd arrays. Timings:
Holy jeebus, we did a recursive postorder traversal summing an array in less time than it took to do a top down sum of a simple value! Again, this is just over twice the time it took for the optimised C code to sum the total branch length. |
OK, let's try something a bit more complicated. We can compute the Sankoff parsimony score of an assigment of genotypes to a particular tree. @numba.njit()
def _sankoff_score_numba(parent, cost_matrix, S):
num_alleles = cost_matrix.shape[0]
child = left_child[parent]
while child != tskit.NULL:
_sankoff_score_numba(child, cost_matrix, S)
for j in range(num_alleles):
min_cost = np.inf
for k in range(num_alleles):
min_cost = min(min_cost, cost_matrix[k, j] + S[child, k])
S[parent, j] += min_cost
child = right_sib[child]
def sankoff_score_numba(genotypes, cost_matrix):
num_alleles = cost_matrix.shape[0]
S = np.zeros((tree.num_nodes, num_alleles))
samples = tree.tree_sequence.samples()
S[samples, :] = np.inf
for allele in range(num_alleles):
samples_with_allele = samples[genotypes == allele]
S[samples_with_allele, allele] = 0
_sankoff_score_numba(tree.root, cost_matrix, S)
return S
# Simple 2-allele cost matrix.
cost_matrix = np.array([[0, 0.5], [0.5, 0]])
genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# Assign something to the genotypes so we're not summing 0s
genotypes[::2] = 1 Timings:
Woo, 0.2 seconds! (This is a tree with 1 million samples, remember). It takes about 30 seconds to run the same calculation in Biopython - but the algorithm is implemented in Python, so it's not a fair comparison. |
Parsimonious assignmentsts = ts1m
tree = ts.first()
parent = np.zeros(ts.num_nodes, dtype=np.int32)
right_child = np.zeros(ts.num_nodes, dtype=np.int32)
left_sib = np.zeros(ts.num_nodes, dtype=np.int32)
time = ts.tables.nodes.time
flags = ts.tables.nodes.flags
for u in range(ts.num_nodes):
parent[u] = tree.parent(u)
right_child[u] = tree.right_child(u)
left_sib[u] = tree.left_sib(u)
@numba.njit()
def _hartigan_postorder(parent, optimal_set):
num_alleles = optimal_set.shape[1]
allele_count = np.zeros(num_alleles, dtype=np.int32)
child = right_child[parent]
while child != tskit.NULL:
_hartigan_postorder(child, optimal_set)
allele_count += optimal_set[child]
child = left_sib[child]
if flags[parent] == 0: # Bad! This should just be checking the sample bit.
max_allele_count = np.max(allele_count)
for j in range(num_alleles):
if allele_count[j] == max_allele_count:
optimal_set[parent, j] = 1
@numba.njit()
def _hartigan_preorder(node, state, optimal_set):
mutations = []
if optimal_set[node, state] == 0:
state = np.argmax(optimal_set[node])
mutations.append((node, state))
v = right_child[node]
while v != tskit.NULL:
v_muts = _hartigan_preorder(v, state, optimal_set)
mutations.extend(v_muts)
v = left_sib[v]
return mutations
def hartigan_map_mutations_numba(tree, genotypes, alleles):
# Simple version assuming non missing data and one root
num_alleles = np.max(genotypes) + 1
num_nodes = tree.tree_sequence.num_nodes
optimal_set = np.zeros((num_nodes + 1, num_alleles), dtype=np.int8)
for allele, u in zip(genotypes, tree.tree_sequence.samples()):
optimal_set[u, allele] = 1
_hartigan_postorder(tree.root, optimal_set)
ancestral_state = np.argmax(optimal_set[tree.root])
ll_mutations = _hartigan_preorder(tree.root, ancestral_state, optimal_set)
mutations = []
for node, derived_state in ll_mutations:
mutations.append(
tskit.Mutation(
node=node,
derived_state=alleles[derived_state],
# Note we're taking a short-cut here and not bothering with mutation parent.
# Could be done easily enough.
)
)
return alleles[ancestral_state], mutations
genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# This is an easy one so we won't be allocing a lot of memory.
genotypes[1] = 1 Timings: The tskit version takes about 1/3 of second on a million leaf tree
The numba version takes about a second, which is pretty great!
They was a very easy parismony job though, it only needed one mutation. Lets try something harder so we're doing memory allocations: genotypes = np.zeros(ts.num_samples, dtype=np.int8)
# Assign something to the genotypes so we're not summing 0s
genotypes[::2] = 1
genotypes
array([1, 0, 1, ..., 0, 1, 0], dtype=int8)
So, still within a factor of 2-3 of the highly optimised tskit C code! Wow, numba kicks ass! |
An update here: tskit-dev/tskit#1320 adds support for the tree arrays, which works very well. I'll need to do a bit of experimentation to see what's the best way of passing around these array references (i.e., to make sure they are considered "const" by numba), but it's all solid. |
Once we have direct numpy access to the tree arrays in Python (tskit-dev/tskit#1299) I think we should be able to do quite performant traversal algorithms in python using numba. We can illustrate this with a couple of examples:
node_time
array if we create it for each function call - we could work around this for the moment by using a copy of the node time array that we keep lying around). This is easy because we only go up the tree.The text was updated successfully, but these errors were encountered: