diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 7b169e8..0e9514b 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -26,6 +26,7 @@ def single_tree_example_ts(): tables.mutations.add_row(site=j, derived_state="T", node=j) tables.sites.add_row(position=7, ancestral_state="FOOBAR") tables.mutations.add_row(site=6, derived_state="FOOBARD", node=6) + tables.compute_mutation_times() return tables.tree_sequence() @@ -299,6 +300,23 @@ def test_sample_and_other_flags(self, bit): assert (node.flags & (1 << bit)) != 0 +def test_preprocess_calculate_mutation_times(tmpdir): + ts = msprime.sim_ancestry(5, sequence_length=1e4, random_seed=42) + ts = msprime.sim_mutations(ts, rate=0.1, random_seed=43) + assert ts.num_mutations > 0 + # Wipe out mutation times + tables = ts.dump_tables() + tables.mutations.time = np.full_like(tables.mutations.time, tskit.UNKNOWN_TIME) + ts = tables.tree_sequence() + input_path = os.path.join(tmpdir, "test.trees") + ts.dump(input_path) + output_path = os.path.join(tmpdir, "test.trees") + with pytest.warns(UserWarning, match="All mutation times are unknown"): + preprocess.preprocess(input_path, output_path) + ts = tszip.load(output_path) + assert not np.any(tskit.is_unknown_time(ts.tables.mutations.time)) + + @pytest.mark.parametrize("use_tszip", [True, False]) def test_preprocess(tmpdir, use_tszip): input_path = os.path.join(tmpdir, "test_input.tszip") diff --git a/tsbrowse/preprocess.py b/tsbrowse/preprocess.py index fd7672b..26c9d28 100644 --- a/tsbrowse/preprocess.py +++ b/tsbrowse/preprocess.py @@ -1,6 +1,7 @@ import dataclasses import json import pathlib +import warnings import daiquiri import numba @@ -514,6 +515,17 @@ def preprocess(tszip_path, output_path, show_progress=False): total=2 + len(preprocessors), desc="Processing", disable=not show_progress ) as pbar: ts = tszip.load(tszip_path) + # Check if all mutation times are unknown, calulate them if so + if np.all(tskit.is_unknown_time(ts.mutations_time)): + warnings.warn( + "All mutation times are unknown. Calculating mutation times" + " from tree sequence", + stacklevel=1, + ) + tables = ts.dump_tables() + tables.compute_mutation_times() + ts = tables.tree_sequence() + pbar.update(1) tszip.compress(ts, output_path)