Skip to content

Commit

Permalink
Use low-level tsinfer components directly
Browse files Browse the repository at this point in the history
Closes #381
  • Loading branch information
jeromekelleher committed Nov 5, 2024
1 parent 05461de commit 7c6ec48
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 49 deletions.
214 changes: 168 additions & 46 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import tqdm
import tskit
import tsinfer
import _tsinfer
import numpy as np
import zarr
import numba
Expand Down Expand Up @@ -49,21 +49,21 @@ def get_progress(iterable, title, phase, show_progress, total=None):
)


class TsinferProgressMonitor(tsinfer.progress.ProgressMonitor):
def __init__(self, date, phase, *args, **kwargs):
self.date = date
self.phase = phase
super().__init__(*args, **kwargs)
# class TsinferProgressMonitor(tsinfer.progress.ProgressMonitor):
# def __init__(self, date, phase, *args, **kwargs):
# self.date = date
# self.phase = phase
# super().__init__(*args, **kwargs)

def get(self, key, total):
self.current_instance = get_progress(
None,
title=self.date,
phase=self.phase,
show_progress=self.enabled,
total=total,
)
return self.current_instance
# def get(self, key, total):
# self.current_instance = get_progress(
# None,
# title=self.date,
# phase=self.phase,
# show_progress=self.enabled,
# total=total,
# )
# return self.current_instance


class MatchDb:
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def resize_copy(array, new_size):
return sd


def match_tsinfer(
def match_tsinfer_old(
samples,
ts,
mu,
Expand Down Expand Up @@ -1293,6 +1293,128 @@ def match_tsinfer(
return get_match_info(input_ts, sample_paths, sample_mutations)


def make_tsb(ts, mirror_coordinates=False):
if mirror_coordinates:
ts = mirror_ts_coordinates(ts)

tables = ts.tables
assert np.all(tables.sites.ancestral_state_offset == np.arange(ts.num_sites + 1))
ancestral_state = alignments.encode_alignment(
tables.sites.ancestral_state.view("S1").astype(str)
)
assert np.all(
tables.mutations.derived_state_offset == np.arange(ts.num_mutations + 1)
)
derived_state = alignments.encode_alignment(
tables.mutations.derived_state.view("S1").astype(str)
)
del tables

num_alleles = np.full(ts.num_sites, len(core.ALLELES), dtype=np.uint64)
tsb = _tsinfer.TreeSequenceBuilder(
num_alleles=num_alleles,
max_nodes=ts.num_nodes,
max_edges=ts.num_edges,
ancestral_state=ancestral_state,
)

position_map = np.hstack([ts.sites_position, [ts.sequence_length]])
# Note - bracketing by 0 on the left here.
position_map[0] = 0
# Get the indexes into the position array.
left = np.searchsorted(position_map, ts.edges_left)
if np.any(position_map[left] != ts.edges_left):
raise ValueError("Invalid left coordinates")
right = np.searchsorted(position_map, ts.edges_right)
if np.any(position_map[right] != ts.edges_right):
raise ValueError("Invalid right coordinates")

# Need to sort by child ID here and left so that we can efficiently
# insert the child paths.
index = np.lexsort((left, ts.edges_child))
tsb.restore_nodes(ts.nodes_time, ts.nodes_flags)
tsb.restore_edges(
left[index].astype(np.int32),
right[index].astype(np.int32),
ts.edges_parent[index],
ts.edges_child[index],
)
assert tsb.num_match_nodes == ts.num_nodes

tsb.restore_mutations(
ts.mutations_site, ts.mutations_node, derived_state, ts.mutations_parent
)
return tsb, position_map


def match_tsinfer(
samples,
ts,
mu,
rho,
*,
likelihood_threshold=None,
deletions_as_missing=False,
num_threads=0,
show_progress=False,
progress_title=None,
progress_phase=None,
mirror_coordinates=False,
):
tsb, coord_map = make_tsb(ts, mirror_coordinates)
L = int(ts.sequence_length)

if likelihood_threshold is None:
# TMP
likelihood_threshold = rho**2 * mu**5

matcher = _tsinfer.AncestorMatcher(
tsb,
recombination=np.full(ts.num_sites, rho),
mismatch=np.full(ts.num_sites, mu),
likelihood_threshold=likelihood_threshold,
)
sample_paths = []
sample_mutations = []
for sample in samples:
h = sample.haplotype.copy()
if mirror_coordinates:
h = h[::-1]
if deletions_as_missing:
h[h == DELETION] = MISSING
is_missing = h == MISSING
m = np.full(len(h), MISSING, dtype=np.int8)
match_path = matcher.find_path(h, 0, len(h), m)
path = []
for left, right, parent in zip(*match_path):
if mirror_coordinates:
left_pos = mirror(int(coord_map[right]), L)
right_pos = mirror(int(coord_map[left]), L)
else:
left_pos = int(coord_map[left])
right_pos = int(coord_map[right])
path.append((left_pos, right_pos, int(parent)))
path.sort()
sample_paths.append(path)

# Mask out the imputed sites
m[is_missing] = MISSING
if mirror_coordinates:
h = h[::-1]
m = m[::-1]
mutations = []
for site_id in np.where(h != m)[0]:
site_pos = ts.sites_position[site_id]
derived_state = core.ALLELES[h[site_id]]
# TODO use this!
# inherited_state = core.ALLELES[m[site_id]]
mutations.append((site_pos, derived_state))
mutations.sort()
sample_mutations.append(mutations)

return get_match_info(ts, sample_paths, sample_mutations)


@dataclasses.dataclass(frozen=True)
class PathSegment:
left: int
Expand Down Expand Up @@ -1435,36 +1557,36 @@ def get_closest_mutation(node, site_id):
return matches


class Matcher(tsinfer.SampleMatcher):
"""
NOTE: this is using undocumented internal APIs as a way of accessing
tsinfer's Li and Stephens matching engine. There are some awkward
workaround involved in dealing with tsinfer's internal representation
of the data, which are tightly coupled to implementation details within
tsinfer.
This implementation will be swapped out for tskit's LS engine in the
near future, using fully documented and supported APIs.
"""

def _match_samples(self, sample_indexes):
# Some hacks here to work around the fact that tsinfer does a bunch
# of stuff we don't want here. All we want are the matched paths and
# mutations.
num_samples = len(sample_indexes)
self.match_progress = self.progress_monitor.get("ms_match", num_samples)
if self.num_threads <= 0:
self._SampleMatcher__match_samples_single_threaded(sample_indexes)
else:
self._SampleMatcher__match_samples_multi_threaded(sample_indexes)
self.match_progress.close()

def run_match(self, samples):
builder = self.tree_sequence_builder
for sd_id in samples:
self.sample_id_map[sd_id] = builder.add_node(0)
self._match_samples(samples)
return self.results
# class Matcher(tsinfer.SampleMatcher):
# """
# NOTE: this is using undocumented internal APIs as a way of accessing
# tsinfer's Li and Stephens matching engine. There are some awkward
# workaround involved in dealing with tsinfer's internal representation
# of the data, which are tightly coupled to implementation details within
# tsinfer.

# This implementation will be swapped out for tskit's LS engine in the
# near future, using fully documented and supported APIs.
# """

# def _match_samples(self, sample_indexes):
# # Some hacks here to work around the fact that tsinfer does a bunch
# # of stuff we don't want here. All we want are the matched paths and
# # mutations.
# num_samples = len(sample_indexes)
# self.match_progress = self.progress_monitor.get("ms_match", num_samples)
# if self.num_threads <= 0:
# self._SampleMatcher__match_samples_single_threaded(sample_indexes)
# else:
# self._SampleMatcher__match_samples_multi_threaded(sample_indexes)
# self.match_progress.close()

# def run_match(self, samples):
# builder = self.tree_sequence_builder
# for sd_id in samples:
# self.sample_id_map[sd_id] = builder.add_node(0)
# self._match_samples(samples)
# return self.results


def attach_tree(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def recombinant_example_1(ts_map):
# SRR11597163 51 [(15324, 'T'), (29303, 'T')]
H = ts.genotype_matrix(samples=nodes, alleles=tuple("ACGT-")).T
bp = 10_000
h = H[0].copy()
h = H[0].copy().astype(np.int8)
h[bp:] = H[1][bp:]

s = sc2ts.Sample("frankentype", "2020-02-14", haplotype=h)
Expand Down Expand Up @@ -1153,9 +1153,9 @@ def test_example_1(self, fx_ts_map):

m = s.hmm_reruns["no_recombination"]
assert len(m.mutations) == 3
assert m.mutation_summary() == "[11083T>G, 15324C>T, 29303C>T]"
assert m.mutation_summary() == "[871A>G, 3027A>G, 3787C>T]"
assert len(m.path) == 1
assert m.path[0].parent == left_parent
assert m.path[0].parent == right_parent
assert m.path[0].left == 0
assert m.path[0].right == ts.sequence_length

Expand Down

0 comments on commit 7c6ec48

Please sign in to comment.