Skip to content

Commit

Permalink
Merge pull request #371 from jeromekelleher/add-max-muts-per-sample
Browse files Browse the repository at this point in the history
Add max muts per sample
  • Loading branch information
jeromekelleher authored Oct 17, 2024
2 parents 6c464cb + 2e26a29 commit 6900009
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 57 deletions.
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ dependencies = [
# FIXME
"tsinfer @ git+https://github.com/jeromekelleher/tsinfer.git@experimental-hmm",
"pyfaidx",
# FIXME - reinstate when 0.5.9 is released
# "tskit>=0.5.9",
"tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python",
"tskit>=0.6.0",
"tszip",
"pandas",
"numba",
Expand Down
3 changes: 2 additions & 1 deletion sc2ts/alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class AlignmentStore(collections.abc.Mapping):
def __init__(self, path, mode="r"):
map_size = 1024**4
self.path = path
readonly = mode == "r"
self.env = lmdb.Environment(
str(path), subdir=False, readonly=mode == "r", map_size=map_size
str(path), subdir=False, readonly=readonly, map_size=map_size, lock=not readonly
)
logger.debug(f"Opened AlignmentStore at {path} mode={mode}")

Expand Down
12 changes: 12 additions & 0 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,16 @@ def summarise_base(ts, date, progress):
type=int,
help="Minimum number of shared mutations for reconsidered sample groups",
)
@click.option(
"--max-mutations-per-sample",
default=10,
show_default=True,
type=int,
help=(
"Maximum average number of mutations per sample in an inferred retrospective "
"group tree"
),
)
@click.option(
"--retrospective-window",
default=30,
Expand Down Expand Up @@ -451,6 +461,7 @@ def extend(
hmm_cost_threshold,
min_group_size,
min_root_mutations,
max_mutations_per_sample,
retrospective_window,
deletions_as_missing,
max_daily_samples,
Expand Down Expand Up @@ -494,6 +505,7 @@ def extend(
hmm_cost_threshold=hmm_cost_threshold,
min_group_size=min_group_size,
min_root_mutations=min_root_mutations,
max_mutations_per_sample=max_mutations_per_sample,
retrospective_window=retrospective_window,
deletions_as_missing=deletions_as_missing,
max_daily_samples=max_daily_samples,
Expand Down
166 changes: 116 additions & 50 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def get(self, key, total):
class MatchDb:
def __init__(self, path):
uri = f"file:{path}"
self.path = path
self.uri = uri
self.conn = sqlite3.connect(uri, uri=True)
self.conn.row_factory = metadata.dict_factory
Expand Down Expand Up @@ -280,6 +281,7 @@ def initial_ts(problematic_sites=list()):
"date": core.REFERENCE_DATE,
"samples_strain": [core.REFERENCE_STRAIN],
"num_exact_matches": {},
"retro_groups": [],
}
}

Expand Down Expand Up @@ -421,9 +423,7 @@ def match_samples(
num_threads=None,
):
run_batch = samples

mu, rho = solve_num_mismatches(num_mismatches)

for k in range(2):
# To catch k mismatches we need a likelihood threshold of mu**k
likelihood_threshold = mu**k - 1e-15
Expand Down Expand Up @@ -486,8 +486,9 @@ def check_base_ts(ts):
md = ts.metadata
assert "sc2ts" in md
sc2ts_md = md["sc2ts"]
assert "date" in sc2ts_md
assert len(sc2ts_md["samples_strain"]) == ts.num_samples
# Avoid parsing the metadata again to get the date.
return sc2ts_md["date"]


def preprocess_worker(strains, alignment_store_path, keep_sites):
Expand Down Expand Up @@ -516,6 +517,8 @@ def preprocess(
show_progress=False,
num_workers=0,
):
if len(strains) == 0:
return []
num_workers = max(1, num_workers)
splits = min(len(strains), 2 * num_workers)
work = np.array_split(strains, splits)
Expand Down Expand Up @@ -545,6 +548,8 @@ def extend(
hmm_cost_threshold=None,
min_group_size=None,
min_root_mutations=None,
min_different_dates=None,
max_mutations_per_sample=None,
deletions_as_missing=None,
max_daily_samples=None,
show_progress=False,
Expand All @@ -561,17 +566,21 @@ def extend(
min_group_size = 10
if min_root_mutations is None:
min_root_mutations = 2
if max_mutations_per_sample is None:
max_mutations_per_sample = 100
if min_different_dates is None:
min_different_dates = 3
if retrospective_window is None:
retrospective_window = 30
if max_missing_sites is None:
max_missing_sites = np.inf
if deletions_as_missing is None:
deletions_as_missing = False

check_base_ts(base_ts)
previous_date = check_base_ts(base_ts)
logger.info(
f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};"
f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}"
f"mutations={base_ts.num_mutations};date={previous_date}"
)

metadata_matches = {md["strain"]: md for md in metadata_db.get(date)}
Expand Down Expand Up @@ -619,41 +628,38 @@ def extend(
logger.info(f"Subset from {len(samples)} to {max_daily_samples}")
samples = rng.sample(samples, max_daily_samples)

if len(samples) == 0:
logger.warning(f"Nothing to do for {date}")
return base_ts

logger.info(
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
)
ts = increment_time(date, base_ts)
if len(samples) > 0:
logger.info(
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
)

samples = match_samples(
date,
samples,
base_ts=base_ts,
num_mismatches=num_mismatches,
deletions_as_missing=deletions_as_missing,
show_progress=show_progress,
num_threads=num_threads,
)
samples = match_samples(
date,
samples,
base_ts=base_ts,
num_mismatches=num_mismatches,
deletions_as_missing=deletions_as_missing,
show_progress=show_progress,
num_threads=num_threads,
)

match_db.add(samples, date, num_mismatches)
match_db.create_mask_table(base_ts)
ts = increment_time(date, base_ts)
match_db.add(samples, date, num_mismatches)
match_db.create_mask_table(base_ts)

ts = add_exact_matches(ts=ts, match_db=match_db, date=date)
ts = add_exact_matches(ts=ts, match_db=match_db, date=date)

logger.info(f"Update ARG with low-cost samples for {date}")
ts, _ = add_matching_results(
f"match_date=='{date}' and hmm_cost>0 and hmm_cost<={hmm_cost_threshold}",
ts=ts,
match_db=match_db,
date=date,
min_group_size=1,
additional_node_flags=core.NODE_IN_SAMPLE_GROUP,
show_progress=show_progress,
phase="close",
)
logger.info(f"Update ARG with low-cost samples for {date}")
ts, _ = add_matching_results(
f"match_date=='{date}' and hmm_cost>0 and hmm_cost<={hmm_cost_threshold}",
ts=ts,
match_db=match_db,
date=date,
min_group_size=1,
additional_node_flags=core.NODE_IN_SAMPLE_GROUP,
show_progress=show_progress,
phase="close",
)

logger.info("Looking for retrospective matches")
assert min_group_size is not None
Expand All @@ -664,18 +670,19 @@ def extend(
match_db=match_db,
date=date,
min_group_size=min_group_size,
min_different_dates=3, # TODO parametrise
min_different_dates=min_different_dates,
min_root_mutations=min_root_mutations,
max_mutations_per_sample=max_mutations_per_sample,
additional_node_flags=core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP,
show_progress=show_progress,
phase="retro",
)
for group in groups:
logger.warning(f"Add retro group {dict(group.pango_count)}")
return update_top_level_metadata(ts, date)
logger.warning(f"Add retro group {dict(group.pango_count)}: {group.tree_quality_metrics.summary()}")
return update_top_level_metadata(ts, date, groups)


def update_top_level_metadata(ts, date):
def update_top_level_metadata(ts, date, retro_groups):
tables = ts.dump_tables()
md = tables.metadata
md["sc2ts"]["date"] = date
Expand All @@ -685,6 +692,12 @@ def update_top_level_metadata(ts, date):
node = ts.node(u)
samples_strain.append(node.metadata["strain"])
md["sc2ts"]["samples_strain"] = samples_strain
existing_retro_groups = md["sc2ts"].get("retro_groups", [])
for group in retro_groups:
d = group.tree_quality_metrics.asdict()
d["group_id"] = group.sample_hash
existing_retro_groups.append(d)
md["sc2ts"]["retro_groups"] = existing_retro_groups
tables.metadata = md
return tables.tree_sequence()

Expand Down Expand Up @@ -779,6 +792,41 @@ def add_exact_matches(match_db, ts, date):
return tables.tree_sequence()


@dataclasses.dataclass
class GroupTreeQualityMetrics:
"""
Set of metrics used to assess the quality of an in inferred sample group tree.
"""
strains: List[str]
pango_lineages: List[str]
dates: List[str]
num_nodes: int
num_root_mutations: int
num_mutations: int
num_recurrent_mutations: int
depth: int

def asdict(self):
return dataclasses.asdict(self)

@property
def num_samples(self):
return len(self.strains)

@property
def mean_mutations_per_sample(self):
return self.num_mutations / self.num_samples

def summary(self):
return (
f"samples={self.num_samples} "
f"depth={self.depth} total_muts={self.num_mutations} "
f"root_muts={self.num_root_mutations} "
f"muts_per_sample={self.mean_mutations_per_sample} "
f"recurrent_muts={self.num_recurrent_mutations} "
)


@dataclasses.dataclass
class SampleGroup:
"""
Expand All @@ -791,6 +839,7 @@ class SampleGroup:
immediate_reversions: List = None
additional_keys: Dict = None
sample_hash: str = None
tree_quality_metrics: GroupTreeQualityMetrics = None

def __post_init__(self):
m = hashlib.md5()
Expand Down Expand Up @@ -827,6 +876,21 @@ def summary(self):
f"strains={self.strains}"
)

def add_tree_quality_metrics(self, ts):
tree = ts.first()
assert ts.num_trees == 1
self.tree_quality_metrics = GroupTreeQualityMetrics(
strains=self.strains,
pango_lineages=[s.pango for s in self.samples],
dates=[s.date for s in self.samples],
num_nodes=ts.num_nodes,
num_mutations=ts.num_mutations,
num_root_mutations=int(np.sum(ts.mutations_node == tree.root)),
num_recurrent_mutations=int(np.sum(ts.mutations_parent != -1)),
depth=max(tree.depth(u) for u in ts.samples()),
)
return self.tree_quality_metrics


def add_matching_results(
where_clause,
Expand All @@ -836,6 +900,7 @@ def add_matching_results(
min_group_size=1,
min_different_dates=1,
min_root_mutations=0,
max_mutations_per_sample=np.inf,
additional_node_flags=None,
show_progress=False,
additional_group_metadata_keys=list(),
Expand Down Expand Up @@ -899,22 +964,23 @@ def add_matching_results(
binary_ts = tree_ops.infer_binary(flat_ts)
poly_ts = tree_ops.trim_branches(binary_ts)
assert poly_ts.num_samples == flat_ts.num_samples
tree = poly_ts.first()
num_root_mutations = np.sum(poly_ts.mutations_node == tree.root)
num_recurrent_mutations = np.sum(poly_ts.mutations_parent != -1)
if num_root_mutations < min_root_mutations:
tqm = group.add_tree_quality_metrics(poly_ts)
if tqm.num_root_mutations < min_root_mutations:
logger.debug(
f"Skipping root_mutations={tqm.num_root_mutations}: "
f"{group.summary()}"
)
continue
if tqm.mean_mutations_per_sample > max_mutations_per_sample:
logger.debug(
f"Skipping root_mutations={num_root_mutations}: "
f"Skipping mutation_per_sample={tqm.mutations_per_sample}: exceeds threshold "
f"{group.summary()}"
)
continue
attach_depth = max(tree.depth(u) for u in poly_ts.samples())
nodes = attach_tree(ts, tables, group, poly_ts, date, additional_node_flags)
logger.debug(
f"Attach {phase} "
f"depth={attach_depth} total_muts={poly_ts.num_mutations} "
f"root_muts={num_root_mutations} "
f"recurrent_muts={num_recurrent_mutations} attach_nodes={len(nodes)} "
f"Attach {phase} metrics:{tqm.summary()}"
f"attach_nodes={len(nodes)} "
f"group={group.summary()}"
)
attach_nodes.extend(nodes)
Expand Down
4 changes: 4 additions & 0 deletions sc2ts/tree_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def infer_binary(ts):
node=mut.node,
derived_state=mut.derived_state,
)
tables.compute_mutation_parents()
new_ts = tables.tree_sequence()
# print(new_ts.draw_text())
return new_ts
Expand Down Expand Up @@ -245,6 +246,9 @@ def trim_branches(ts):
tables.edges.add_row(0, ts.sequence_length, parent=p, child=c)

tables.sort()
# FIXME not sure this compute_mutation_parents is needed, check
tables.build_index()
tables.compute_mutation_parents()
# Get rid of unreferenced nodes
tables.simplify()
return tables.tree_sequence()
Expand Down
Loading

0 comments on commit 6900009

Please sign in to comment.