diff --git a/README.md b/README.md index 56a8e99..6ac028f 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,19 @@ It is particularly useful to help evaluate ARGs that have been inferred using to [KwARG](https://github.com/a-ignatieva/kwarg), [Threads](https://pypi.org/project/threads-arg/), etc. +To view a tskit tree sequence or tszip file first pre-process it: + +`python -m tsbrowse preprocess /path/to/trees-file` + +This will write a `.tsbrowse` file + To launch the app use: -`python -m tsbrowse /path/to/trees-file` +`python -m tsbrowse serve /path/to/tsbrowse-file` On WSL, it may be necessary to disable Numba's CUDA support: -`NUMBA_DISABLE_CUDA=1 python -m tsbrowse /path/to/trees-file` +`NUMBA_DISABLE_CUDA=1 python -m tsbrowse serve /path/to/tsbrowse-file` ## Installation diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..22e1631 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,31 @@ +import os + +import tszip +from click.testing import CliRunner + +from . import test_preprocess +from tsbrowse import __main__ as main + + +def test_preprocess_cli(tmpdir): + tszip_path = os.path.join(tmpdir, "test_input.tszip") + default_output_path = os.path.join(tmpdir, "test_input.tsbrowse") + custom_output_path = os.path.join(tmpdir, "custom_input.tsbrowse") + + ts = test_preprocess.single_tree_example_ts() + tszip.compress(ts, tszip_path) + + runner = CliRunner() + result = runner.invoke(main.cli, ["preprocess", tszip_path]) + assert result.exit_code == 0 + assert os.path.exists(default_output_path) + tszip.load(default_output_path).tables.assert_equals(ts.tables) + + result = runner.invoke( + main.cli, ["preprocess", tszip_path, "--output", custom_output_path] + ) + assert result.exit_code == 0 + assert os.path.exists(custom_output_path) + tszip.load(custom_output_path).tables.assert_equals(ts.tables) + + # TODO Load into model and check that the model is correct diff --git a/tests/test_data_model.py b/tests/test_data_model.py index 8d89f00..1bc7d90 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -1,327 +1,52 @@ -import logging - import msprime -import numpy as np -import numpy.testing as nt import pytest import tskit +import tszip +import zarr from tsbrowse import model +from tsbrowse import preprocess -def single_tree_example_ts(): - # 2.00┊ 6 ┊ - # ┊ ┏━┻━┓ ┊ - # 1.00┊ 4 5 ┊ - # ┊ ┏┻┓ ┏┻┓ ┊ - # 0.00┊ 0 1 2 3 ┊ - # 0 10 - ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence - tables = ts.dump_tables() - for j in range(6): - tables.sites.add_row(position=j + 1, ancestral_state="A") - tables.mutations.add_row(site=j, derived_state="T", node=j) - tables.sites.add_row(position=7, ancestral_state="FOOBAR") - tables.mutations.add_row(site=6, derived_state="FOOBARD", node=6) - return tables.tree_sequence() - - -def single_tree_recurrent_mutation_example_ts(): - # 2.00 ┊ 6 ┊ - # ┊ ┏━━━━━━━┻━━━━━━━┓ ┊ - # ┊ 4:A→T x x 5:A→T ┊ - # ┊ | x 6:A→G ┊ - # 1.00 ┊ 4 5 ┊ - # ┊ ┏━━━━┻━━━━┓ ┏━━━━┻━━━━┓ ┊ - # ┊ 0:A→T x 1:A→T x x 2:A→T x 3:A→T ┊ - # ┊ | | | | ┊ - # 0.00 ┊ 0 1 2 3 ┊ - # 0 10 - ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence - tables = ts.dump_tables() - for j in range(6): - tables.sites.add_row(position=j + 1, ancestral_state="A") - tables.mutations.add_row(site=j, derived_state="T", node=j) - tables.mutations.add_row(site=j, derived_state="G", node=j, parent=j) +def test_model(tmpdir): + ts = msprime.sim_ancestry( + recombination_rate=1e-3, samples=10, sequence_length=1_000, random_seed=42 + ) + ts = msprime.sim_mutations(ts, rate=1e-2, random_seed=43) + tables = ts.tables + tables.nodes.metadata_schema = tskit.MetadataSchema({"codec": "json"}) ts = tables.tree_sequence() - return tables.tree_sequence() - - -def multiple_trees_example_ts(): - # 2.00┊ 4 ┊ 4 ┊ - # ┊ ┏━┻┓ ┊ ┏┻━┓ ┊ - # 1.00┊ ┃ 3 ┊ 3 ┃ ┊ - # ┊ ┃ ┏┻┓ ┊ ┏┻┓ ┃ ┊ - # 0.00┊ 0 1 2 ┊ 0 1 2 ┊ - # 0 5 10 - ts = tskit.Tree.generate_balanced(3, span=10).tree_sequence - tables = ts.dump_tables() - tables.edges[1] = tables.edges[1].replace(right=5) - tables.edges[2] = tables.edges[2].replace(right=5) - tables.edges.add_row(5, 10, 3, 0) - tables.edges.add_row(5, 10, 4, 2) - tables.sort() - return tables.tree_sequence() - - -def single_tree_with_polytomies_example_ts(): - # 3.00┊ 8 ┊ - # ┊ ┏━━━━━━╋━━━━━━━┓ ┊ - # 2.00┊ ┃ 7 ┃ ┊ - # ┊ ┃ ┏━━━╋━━━━┓ ┃ ┊ - # 1.00┊ 5 ┃ 6 ┃ ┃ ┊ - # ┊ ┏┻┓ ┃ ┏━╋━━┓ ┃ ┃ ┊ - # 0.00┊ 0 1 2 3 4 11 9 10 ┊ - # 0 10 - ts = tskit.Tree.generate_balanced(5, span=10).tree_sequence - tables = ts.dump_tables() - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 10, 7, 9) - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 10, 8, 10) - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 10, 6, 11) - tables.sort() - return tables.tree_sequence() - - -def multi_tree_with_polytomies_example_ts(): - # 3.00┊ 8 ┊ 8 ┊ - # ┊ ┏━━┻━┓ ┊ ┏━━┻━━┓ ┊ - # 2.00┊ ┃ 7 ┊ ┃ 7 ┊ - # ┊ ┃ ┏━┻━┓ ┊ ┃ ┏━━╋━━┓ ┊ - # 1.00┊ 5 ┃ 6 ┊ 5 ┃ 6 ┃ ┊ - # ┊ ┏┻┓ ┃ ┏━╋━┓ ┊ ┏┻┓ ┃ ┏┻┓ ┃ ┊ - # 0.00┊ 0 1 2 3 4 9 ┊ 0 1 2 3 4 9 ┊ - # 0 5 10 - ts = tskit.Tree.generate_balanced(5, span=10).tree_sequence - tables = ts.dump_tables() - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 5, 6, 9) - tables.edges.add_row(5, 10, 7, 9) - tables.sort() - return tables.tree_sequence() - - -class TestMutationDataTable: - def test_single_tree_example(self): - ts = single_tree_example_ts() - tsm = model.TSModel(ts) - df = tsm.mutations_df - assert len(df) == 7 - nt.assert_array_equal(df.id, list(range(7))) - nt.assert_array_equal(df.node, list(range(7))) - nt.assert_array_equal(df.position, list(range(1, 8))) - nt.assert_array_equal(df.time, [0, 0, 0, 0, 1, 1, 2]) - nt.assert_array_equal(df.derived_state, ["T"] * 6 + ["FOOBARD"]) - nt.assert_array_equal(df.inherited_state, ["A"] * 6 + ["FOOBAR"]) - nt.assert_array_equal(df.num_parents, [0] * 7) - nt.assert_array_equal(df.num_descendants, [1] * 4 + [2] * 2 + [4]) - nt.assert_array_equal(df.num_inheritors, [1] * 4 + [2] * 2 + [4]) - - def test_single_tree_recurrent_mutation_example(self): - ts = single_tree_recurrent_mutation_example_ts() - tsm = model.TSModel(ts) - df = tsm.mutations_df - assert len(df) == 7 - nt.assert_array_equal(df.id, list(range(7))) - nt.assert_array_equal(df.node, [0, 1, 2, 3, 4, 5, 5]) - nt.assert_array_equal(df.position, [1, 2, 3, 4, 5, 6, 6]) - nt.assert_array_equal(df.time, [0, 0, 0, 0, 1, 1, 1]) - nt.assert_array_equal(df.derived_state, ["T"] * 6 + ["G"]) - nt.assert_array_equal(df.inherited_state, ["A"] * 6 + ["T"]) - nt.assert_array_equal(df.num_parents, [0] * 6 + [1]) - nt.assert_array_equal(df.num_descendants, [1] * 4 + [2] * 3) - nt.assert_array_equal(df.num_inheritors, [1] * 4 + [2, 0, 2]) - - -class TestEdgeDataTable: - def test_single_tree_example(self): - ts = single_tree_example_ts() - tsm = model.TSModel(ts) - df = tsm.edges_df - assert len(df) == 6 - nt.assert_array_equal(df.left, [0, 0, 0, 0, 0, 0]) - nt.assert_array_equal(df.right, [10, 10, 10, 10, 10, 10]) - nt.assert_array_equal(df.parent, [4, 4, 5, 5, 6, 6]) - nt.assert_array_equal(df.child, [0, 1, 2, 3, 4, 5]) - nt.assert_array_equal(df.child_time, [0, 0, 0, 0, 1, 1]) - nt.assert_array_equal(df.parent_time, [1, 1, 1, 1, 2, 2]) - - def test_multiple_trees_example(self): - ts = multiple_trees_example_ts() - tsm = model.TSModel(ts) - df = tsm.edges_df - assert len(df) == 6 - nt.assert_array_equal(df.left, [5, 0, 0, 0, 5, 0]) - nt.assert_array_equal(df.right, [10, 10, 5, 5, 10, 10]) - nt.assert_array_equal(df.parent, [3, 3, 3, 4, 4, 4]) - nt.assert_array_equal(df.child, [0, 1, 2, 0, 2, 3]) - nt.assert_array_equal(df.child_time, [0, 0, 0, 0, 0, 1]) - nt.assert_array_equal(df.parent_time, [1, 1, 1, 2, 2, 2]) - - -class TestNodeDataTable: - def test_single_tree_example(self): - ts = single_tree_example_ts() - tsm = model.TSModel(ts) - df = tsm.nodes_df - assert len(df) == 7 - nt.assert_array_equal(df.time, [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 2.0]) - nt.assert_array_equal(df.num_mutations, [1, 1, 1, 1, 1, 1, 1]) - nt.assert_array_equal(df.ancestors_span, [10, 10, 10, 10, 10, 10, -np.inf]) - nt.assert_array_equal(df.node_flags, [1, 1, 1, 1, 0, 0, 0]) - - def test_multiple_tree_example(self): - ts = multiple_trees_example_ts() - tsm = model.TSModel(ts) - df = tsm.nodes_df - assert len(df) == 5 - nt.assert_array_equal(df.time, [0.0, 0.0, 0.0, 1.0, 2.0]) - nt.assert_array_equal(df.num_mutations, [0, 0, 0, 0, 0]) - nt.assert_array_equal(df.ancestors_span, [10, 10, 10, 10, -np.inf]) - nt.assert_array_equal(df.node_flags, [1, 1, 1, 0, 0]) - - -def compute_mutation_counts(ts): - pop_mutation_count = np.zeros((ts.num_populations, ts.num_mutations), dtype=int) - for pop in ts.populations(): - for tree in ts.trees(tracked_samples=ts.samples(population=pop.id)): - for mut in tree.mutations(): - count = tree.num_tracked_samples(mut.node) - pop_mutation_count[pop.id, mut.id] = count - return pop_mutation_count - - -class TestMutationFrequencies: - def example_ts(self): - demography = msprime.Demography() - demography.add_population(name="A", initial_size=10_000) - demography.add_population(name="B", initial_size=5_000) - demography.add_population(name="C", initial_size=1_000) - demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C") - return msprime.sim_ancestry( - samples={"A": 1, "B": 1}, - demography=demography, - random_seed=12, - sequence_length=10_000, - ) - - def check_ts(self, ts): - C1 = compute_mutation_counts(ts) - C2 = model.compute_population_mutation_counts(ts) - nt.assert_array_equal(C1, C2) - tsm = model.TSModel(ts) - df = tsm.mutations_df - nt.assert_array_equal(df["pop_A_freq"], C1[0] / ts.num_samples) - nt.assert_array_equal(df["pop_B_freq"], C1[1] / ts.num_samples) - nt.assert_array_equal(df["pop_C_freq"], C1[2] / ts.num_samples) - - def test_all_nodes(self): - ts = self.example_ts() - tables = ts.dump_tables() - for u in range(ts.num_nodes - 1): - site_id = tables.sites.add_row(u, "A") - tables.mutations.add_row(site=site_id, node=u, derived_state="T") - ts = tables.tree_sequence() - self.check_ts(ts) - - @pytest.mark.parametrize("seed", range(1, 7)) - def test_simulated_mutations(self, seed): - ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=seed) - assert ts.num_mutations > 0 - self.check_ts(ts) - - def test_no_metadata_schema(self): - ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=43) - assert ts.num_mutations > 0 - tables = ts.dump_tables() - tables.populations.metadata_schema = tskit.MetadataSchema(None) - self.check_ts(tables.tree_sequence()) - - def test_no_populations(self): - tables = single_tree_example_ts().dump_tables() - tables.populations.add_row(b"{}") - tsm = model.TSModel(tables.tree_sequence()) - with pytest.raises(ValueError, match="must be assigned to populations"): - tsm.mutations_df - - -class TestNodeIsSample: - def test_simple_example(self): - ts = single_tree_example_ts() - is_sample = model.node_is_sample(ts) - for node in ts.nodes(): - assert node.is_sample() == is_sample[node.id] - - @pytest.mark.parametrize("bit", [1, 2, 17, 31]) - def test_sample_and_other_flags(self, bit): - tables = single_tree_example_ts().dump_tables() - flags = tables.nodes.flags - tables.nodes.flags = flags | (1 << bit) - ts = tables.tree_sequence() - is_sample = model.node_is_sample(ts) - for node in ts.nodes(): - assert node.is_sample() == is_sample[node.id] - assert (node.flags & (1 << bit)) != 0 - - -class TestTreesDataTable: - def test_single_tree_example(self): - ts = single_tree_example_ts() - tsm = model.TSModel(ts) - df = tsm.trees_df - assert len(df) == 1 - nt.assert_array_equal(df.left, 0) - nt.assert_array_equal(df.right, 10) - nt.assert_array_equal(df.total_branch_length, 6.0) - # nt.assert_array_equal(df.mean_internal_arity, 2.0) - nt.assert_array_equal(df.max_internal_arity, 2.0) - - def test_single_tree_with_polytomies_example(self): - ts = single_tree_with_polytomies_example_ts() - tsm = model.TSModel(ts) - df = tsm.trees_df - assert len(df) == 1 - nt.assert_array_equal(df.left, 0) - nt.assert_array_equal(df.right, 10) - nt.assert_array_equal(df.total_branch_length, 16.0) - # nt.assert_array_equal(df.mean_internal_arity, 2.75) - nt.assert_array_equal(df.max_internal_arity, 3.0) - - def test_multi_tree_with_polytomies_example(self): - ts = multi_tree_with_polytomies_example_ts() - tsm = model.TSModel(ts) - df = tsm.trees_df - assert len(df) == 2 - nt.assert_array_equal(df.left, [0, 5]) - nt.assert_array_equal(df.right, [5, 10]) - nt.assert_array_equal(df.total_branch_length, [11.0, 12.0]) - # nt.assert_array_equal(df.mean_internal_arity, [2.25, 2.25]) - nt.assert_array_equal(df.max_internal_arity, [3.0, 3.0]) - - -def test_cache(caplog, tmpdir): - caplog.set_level(logging.INFO) - ts = multiple_trees_example_ts() - tsm = model.TSModel(ts) - # Use the logging out put to determine cache usage - t1 = tsm.trees_df - t2 = tsm.trees_df - assert t1.equals(t2) - assert "No uuid, not caching trees_df" in caplog.text - - ts.dump(tmpdir / "cache.trees") - ts = tskit.load(tmpdir / "cache.trees") - tsm = model.TSModel(ts) - # Use the logging out put to determine cache usage - caplog.clear() - t1 = tsm.trees_df - assert "Calculating" in caplog.text - caplog.clear() - ts2 = tskit.load(tmpdir / "cache.trees") - tsm2 = model.TSModel(ts2) - t2 = tsm2.trees_df - assert "Fetching" in caplog.text - assert t1.equals(t2) + tszip.compress(ts, tmpdir / "test.tszip") + preprocess.preprocess(tmpdir / "test.tszip", tmpdir / "test.tsbrowse") + tsm = model.TSModel(tmpdir / "test.tsbrowse") + + assert tsm.ts == ts + assert tsm.name == "test" + assert tsm.file_uuid == ts.file_uuid + assert len(tsm.summary_df) == 9 + assert len(tsm.edges_df) == ts.num_edges + assert len(tsm.trees_df) == ts.num_trees + assert len(tsm.mutations_df) == ts.num_mutations + assert len(tsm.nodes_df) == ts.num_nodes + assert len(tsm.sites_df) == ts.num_sites + + +def test_model_errors(tmpdir): + # Write an empty zarr ZipStore + with zarr.ZipStore(tmpdir / "test.tsbrowse", mode="w") as z: + g = zarr.group(store=z) + g.attrs["foo"] = "bar" + with pytest.raises(ValueError, match="File is not a tsbrowse file"): + model.TSModel(tmpdir / "test.tsbrowse") + + ts = msprime.sim_ancestry( + recombination_rate=1e-3, samples=2, sequence_length=1000, random_seed=42 + ) + tszip.compress(ts, tmpdir / "test.tszip") + preprocess.preprocess(tmpdir / "test.tszip", tmpdir / "test.tsbrowse") + with zarr.ZipStore(tmpdir / "test.tsbrowse", mode="w") as z: + g = zarr.group(store=z) + g.attrs["tsbrowse"] = {"data_version": 0} + with pytest.raises(ValueError, match="File .* has version .*"): + model.TSModel(tmpdir / "test.tsbrowse") diff --git a/tests/test_pages.py b/tests/test_pages.py index 6492f8d..060e4d8 100644 --- a/tests/test_pages.py +++ b/tests/test_pages.py @@ -1,19 +1,21 @@ import panel import pytest import tskit +import tszip -from tests import test_data_model +from tests import test_preprocess from tsbrowse import model from tsbrowse import pages +from tsbrowse import preprocess # TODO give these some pytest metadata so they are named. examples = [ # No sites tskit.Tree.generate_balanced(5).tree_sequence, - test_data_model.single_tree_example_ts(), - test_data_model.single_tree_recurrent_mutation_example_ts(), - test_data_model.multiple_trees_example_ts(), - test_data_model.single_tree_with_polytomies_example_ts(), + test_preprocess.single_tree_example_ts(), + test_preprocess.single_tree_recurrent_mutation_example_ts(), + test_preprocess.multiple_trees_example_ts(), + test_preprocess.single_tree_with_polytomies_example_ts(), ] display_pages = [ @@ -30,7 +32,9 @@ class TestPages: @pytest.mark.parametrize("ts", examples) @pytest.mark.parametrize("page", display_pages) - def test_is_panel_layout_instance(self, ts, page): - tsm = model.TSModel(ts) + def test_is_panel_layout_instance(self, ts, page, tmpdir): + tszip.compress(ts, tmpdir / "test.tszip") + preprocess.preprocess(tmpdir / "test.tszip", tmpdir / "test.tsbrowse") + tsm = model.TSModel(tmpdir / "test.tsbrowse") ui = page.page(tsm) assert isinstance(ui, panel.layout.base.Panel) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py new file mode 100644 index 0000000..7b169e8 --- /dev/null +++ b/tests/test_preprocess.py @@ -0,0 +1,344 @@ +import os + +import msprime +import numpy as np +import numpy.testing as nt +import pytest +import tskit +import tszip +import zarr + +from tsbrowse import preprocess +from tsbrowse import TSBROWSE_DATA_VERSION + + +def single_tree_example_ts(): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 10 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + tables = ts.dump_tables() + for j in range(6): + tables.sites.add_row(position=j + 1, ancestral_state="A") + tables.mutations.add_row(site=j, derived_state="T", node=j) + tables.sites.add_row(position=7, ancestral_state="FOOBAR") + tables.mutations.add_row(site=6, derived_state="FOOBARD", node=6) + return tables.tree_sequence() + + +def single_tree_recurrent_mutation_example_ts(): + # 2.00 ┊ 6 ┊ + # ┊ ┏━━━━━━━┻━━━━━━━┓ ┊ + # ┊ 4:A→T x x 5:A→T ┊ + # ┊ | x 6:A→G ┊ + # 1.00 ┊ 4 5 ┊ + # ┊ ┏━━━━┻━━━━┓ ┏━━━━┻━━━━┓ ┊ + # ┊ 0:A→T x 1:A→T x x 2:A→T x 3:A→T ┊ + # ┊ | | | | ┊ + # 0.00 ┊ 0 1 2 3 ┊ + # 0 10 + ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence + tables = ts.dump_tables() + for j in range(6): + tables.sites.add_row(position=j + 1, ancestral_state="A") + tables.mutations.add_row(site=j, derived_state="T", node=j) + tables.mutations.add_row(site=j, derived_state="G", node=j, parent=j) + ts = tables.tree_sequence() + return tables.tree_sequence() + + +def multiple_trees_example_ts(): + # 2.00┊ 4 ┊ 4 ┊ + # ┊ ┏━┻┓ ┊ ┏┻━┓ ┊ + # 1.00┊ ┃ 3 ┊ 3 ┃ ┊ + # ┊ ┃ ┏┻┓ ┊ ┏┻┓ ┃ ┊ + # 0.00┊ 0 1 2 ┊ 0 1 2 ┊ + # 0 5 10 + ts = tskit.Tree.generate_balanced(3, span=10).tree_sequence + tables = ts.dump_tables() + tables.edges[1] = tables.edges[1].replace(right=5) + tables.edges[2] = tables.edges[2].replace(right=5) + tables.edges.add_row(5, 10, 3, 0) + tables.edges.add_row(5, 10, 4, 2) + tables.sort() + return tables.tree_sequence() + + +def single_tree_with_polytomies_example_ts(): + # 3.00┊ 8 ┊ + # ┊ ┏━━━━━━╋━━━━━━━┓ ┊ + # 2.00┊ ┃ 7 ┃ ┊ + # ┊ ┃ ┏━━━╋━━━━┓ ┃ ┊ + # 1.00┊ 5 ┃ 6 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┏━╋━━┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 4 11 9 10 ┊ + # 0 10 + ts = tskit.Tree.generate_balanced(5, span=10).tree_sequence + tables = ts.dump_tables() + tables.nodes.add_row(flags=1, time=0) + tables.edges.add_row(0, 10, 7, 9) + tables.nodes.add_row(flags=1, time=0) + tables.edges.add_row(0, 10, 8, 10) + tables.nodes.add_row(flags=1, time=0) + tables.edges.add_row(0, 10, 6, 11) + tables.sort() + return tables.tree_sequence() + + +def multi_tree_with_polytomies_example_ts(): + # 3.00┊ 8 ┊ 8 ┊ + # ┊ ┏━━┻━┓ ┊ ┏━━┻━━┓ ┊ + # 2.00┊ ┃ 7 ┊ ┃ 7 ┊ + # ┊ ┃ ┏━┻━┓ ┊ ┃ ┏━━╋━━┓ ┊ + # 1.00┊ 5 ┃ 6 ┊ 5 ┃ 6 ┃ ┊ + # ┊ ┏┻┓ ┃ ┏━╋━┓ ┊ ┏┻┓ ┃ ┏┻┓ ┃ ┊ + # 0.00┊ 0 1 2 3 4 9 ┊ 0 1 2 3 4 9 ┊ + # 0 5 10 + ts = tskit.Tree.generate_balanced(5, span=10).tree_sequence + tables = ts.dump_tables() + tables.nodes.add_row(flags=1, time=0) + tables.edges.add_row(0, 5, 6, 9) + tables.edges.add_row(5, 10, 7, 9) + tables.sort() + return tables.tree_sequence() + + +class TestMutationDataTable: + def test_single_tree_example(self): + ts = single_tree_example_ts() + m = preprocess.mutations(ts) + for array in m.values(): + assert len(array) == 7 + nt.assert_array_equal(m["inherited_state"], ["A"] * 6 + ["FOOBAR"]) + nt.assert_array_equal(m["num_parents"], [0] * 7) + nt.assert_array_equal(m["num_descendants"], [1] * 4 + [2] * 2 + [4]) + nt.assert_array_equal(m["num_inheritors"], [1] * 4 + [2] * 2 + [4]) + + def test_single_tree_recurrent_mutation_example(self): + ts = single_tree_recurrent_mutation_example_ts() + m = preprocess.mutations(ts) + for array in m.values(): + assert len(array) == 7 + nt.assert_array_equal(m["inherited_state"], ["A"] * 6 + ["T"]) + nt.assert_array_equal(m["num_parents"], [0] * 6 + [1]) + nt.assert_array_equal(m["num_descendants"], [1] * 4 + [2] * 3) + nt.assert_array_equal(m["num_inheritors"], [1] * 4 + [2, 0, 2]) + + +class TestNodeDataTable: + def test_single_tree_example(self): + ts = single_tree_example_ts() + n = preprocess.nodes(ts) + for array in n.values(): + assert len(array) == 7 + nt.assert_array_equal(n["num_mutations"], [1, 1, 1, 1, 1, 1, 1]) + nt.assert_array_equal(n["ancestors_span"], [10, 10, 10, 10, 10, 10, -np.inf]) + + def test_multiple_tree_example(self): + ts = multiple_trees_example_ts() + n = preprocess.nodes(ts) + for array in n.values(): + assert len(array) == 5 + nt.assert_array_equal(n["num_mutations"], [0, 0, 0, 0, 0]) + nt.assert_array_equal(n["ancestors_span"], [10, 10, 10, 10, -np.inf]) + + +class TestEdgeDataTable: + def test_single_tree_example(self): + ts = single_tree_example_ts() + e = preprocess.edges(ts) + for array in e.values(): + assert len(array) == 6 + nt.assert_array_equal(e["child_time"], [0, 0, 0, 0, 1, 1]) + nt.assert_array_equal(e["parent_time"], [1, 1, 1, 1, 2, 2]) + nt.assert_array_equal(e["branch_length"], [1, 1, 1, 1, 1, 1]) + nt.assert_array_equal(e["span"], [10, 10, 10, 10, 10, 10]) + + def test_multiple_trees_example(self): + ts = multiple_trees_example_ts() + e = preprocess.edges(ts) + for array in e.values(): + assert len(array) == 6 + nt.assert_array_equal(e["child_time"], [0, 0, 0, 0, 0, 1]) + nt.assert_array_equal(e["parent_time"], [1, 1, 1, 2, 2, 2]) + nt.assert_array_equal(e["branch_length"], [1, 1, 1, 2, 2, 1]) + nt.assert_array_equal(e["span"], [5, 10, 5, 5, 5, 10]) + + +class TestSiteDataTable: + def test_single_tree_example(self): + ts = single_tree_example_ts() + s = preprocess.sites(ts) + for array in s.values(): + assert len(array) == 7 + nt.assert_array_equal(s["num_mutations"], [1, 1, 1, 1, 1, 1, 1]) + + def test_single_tree_recurrent_mutation_example(self): + ts = single_tree_recurrent_mutation_example_ts() + s = preprocess.sites(ts) + for array in s.values(): + assert len(array) == 6 + nt.assert_array_equal(s["num_mutations"], [1, 1, 1, 1, 1, 2]) + + +class TestTreesDataTable: + def test_single_tree_example(self): + ts = single_tree_example_ts() + t = preprocess.trees(ts) + for array in t.values(): + assert len(array) == 1 + nt.assert_array_equal(t["left"], 0) + nt.assert_array_equal(t["right"], 10) + nt.assert_array_equal(t["total_branch_length"], 6.0) + # nt.assert_array_equal(t['mean_internal_arity'], 2.0) + nt.assert_array_equal(t["max_internal_arity"], 2.0) + + def test_single_tree_with_polytomies_example(self): + ts = single_tree_with_polytomies_example_ts() + t = preprocess.trees(ts) + for array in t.values(): + assert len(array) == 1 + nt.assert_array_equal(t["left"], 0) + nt.assert_array_equal(t["right"], 10) + nt.assert_array_equal(t["total_branch_length"], 16.0) + # nt.assert_array_equal(t['mean_internal_arity'], 2.75) + nt.assert_array_equal(t["max_internal_arity"], 3.0) + + def test_multi_tree_with_polytomies_example(self): + ts = multi_tree_with_polytomies_example_ts() + t = preprocess.trees(ts) + for array in t.values(): + assert len(array) == 2 + nt.assert_array_equal(t["left"], [0, 5]) + nt.assert_array_equal(t["right"], [5, 10]) + nt.assert_array_equal(t["total_branch_length"], [11.0, 12.0]) + # nt.assert_array_equal(t['mean_internal_arity'], [2.25, 2.25]) + nt.assert_array_equal(t["max_internal_arity"], [3.0, 3.0]) + + +class TestMutationFrequencies: + def example_ts(self): + demography = msprime.Demography() + demography.add_population(name="A", initial_size=10_000) + demography.add_population(name="B", initial_size=5_000) + demography.add_population(name="C", initial_size=1_000) + demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C") + return msprime.sim_ancestry( + samples={"A": 1, "B": 1}, + demography=demography, + random_seed=12, + sequence_length=10_000, + ) + + def compute_mutation_counts(self, ts): + pop_mutation_count = np.zeros((ts.num_populations, ts.num_mutations), dtype=int) + for pop in ts.populations(): + for tree in ts.trees(tracked_samples=ts.samples(population=pop.id)): + for mut in tree.mutations(): + count = tree.num_tracked_samples(mut.node) + pop_mutation_count[pop.id, mut.id] = count + return pop_mutation_count + + def check_ts(self, ts): + C1 = self.compute_mutation_counts(ts) + C2 = preprocess.compute_population_mutation_counts(ts) + nt.assert_array_equal(C1, C2) + m = preprocess.mutations(ts) + nt.assert_array_equal(m["pop_A_freq"], C1[0] / ts.num_samples) + nt.assert_array_equal(m["pop_B_freq"], C1[1] / ts.num_samples) + nt.assert_array_equal(m["pop_C_freq"], C1[2] / ts.num_samples) + + def test_all_nodes(self): + ts = self.example_ts() + tables = ts.dump_tables() + for u in range(ts.num_nodes - 1): + site_id = tables.sites.add_row(u, "A") + tables.mutations.add_row(site=site_id, node=u, derived_state="T") + ts = tables.tree_sequence() + self.check_ts(ts) + + @pytest.mark.parametrize("seed", range(1, 7)) + def test_simulated_mutations(self, seed): + ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=seed) + assert ts.num_mutations > 0 + self.check_ts(ts) + + def test_no_metadata_schema(self): + ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=43) + assert ts.num_mutations > 0 + tables = ts.dump_tables() + tables.populations.metadata_schema = tskit.MetadataSchema(None) + self.check_ts(tables.tree_sequence()) + + def test_no_populations(self): + tables = single_tree_example_ts().dump_tables() + tables.populations.add_row(b"{}") + with pytest.raises(ValueError, match="must be assigned to populations"): + preprocess.mutations(tables.tree_sequence()) + + +class TestNodeIsSample: + def test_simple_example(self): + ts = single_tree_example_ts() + is_sample = preprocess.node_is_sample(ts) + for node in ts.nodes(): + assert node.is_sample() == is_sample[node.id] + + @pytest.mark.parametrize("bit", [1, 2, 17, 31]) + def test_sample_and_other_flags(self, bit): + tables = single_tree_example_ts().dump_tables() + flags = tables.nodes.flags + tables.nodes.flags = flags | (1 << bit) + ts = tables.tree_sequence() + is_sample = preprocess.node_is_sample(ts) + for node in ts.nodes(): + assert node.is_sample() == is_sample[node.id] + assert (node.flags & (1 << bit)) != 0 + + +@pytest.mark.parametrize("use_tszip", [True, False]) +def test_preprocess(tmpdir, use_tszip): + input_path = os.path.join(tmpdir, "test_input.tszip") + output_path = os.path.join(tmpdir, "test_output.tsbrowse") + + ts = single_tree_example_ts() + if use_tszip: + tszip.compress(ts, input_path) + else: + ts.dump(input_path) + + preprocess.preprocess(input_path, output_path) + + assert os.path.exists(output_path) + # Check that the output is still a valid tszip + tszip.load(output_path).tables.assert_equals(ts.tables) + + # Check that the file contains the expected arrays + with zarr.ZipStore(output_path, mode="r") as zarr_store: + root = zarr.group(store=zarr_store) + assert root.attrs["tsbrowse"]["data_version"] == TSBROWSE_DATA_VERSION + for array_name in [ + "mutations/position", + "mutations/inherited_state", + "mutations/num_descendants", + "mutations/num_inheritors", + "mutations/num_parents", + "nodes/num_mutations", + "nodes/ancestors_span", + "trees/left", + "trees/right", + "trees/total_branch_length", + "trees/mean_internal_arity", + "trees/max_internal_arity", + "trees/num_sites", + "trees/num_mutations", + "edges/parent_time", + "edges/child_time", + "edges/branch_length", + "edges/span", + "sites/num_mutations", + ]: + assert array_name in root diff --git a/tests/test_raster.py b/tests/test_raster.py index a09adb7..3bf62d2 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -1,7 +1,9 @@ import msprime import pytest +import tszip from tsbrowse import pages +from tsbrowse import preprocess from tsbrowse import raster @@ -17,11 +19,12 @@ def ts(): class TestRaster: @pytest.mark.parametrize("page", pages.PAGES_MAP.values()) - def test_mutation_scatter(self, page, ts, tmp_path): - print(ts) + def test_mutation_scatter(self, page, ts, tmp_path, tmpdir): + tszip.compress(ts, tmpdir / "test.tszip") + preprocess.preprocess(tmpdir / "test.tszip", tmpdir / "test.tsbrowse") raster.raster_component( page.page, - ts, + tmpdir / "test.tsbrowse", tmp_path / "image.png", ) assert (tmp_path / "image.png").exists() diff --git a/tsbrowse/__init__.py b/tsbrowse/__init__.py index 0371aec..8ad60f8 100644 --- a/tsbrowse/__init__.py +++ b/tsbrowse/__init__.py @@ -5,3 +5,5 @@ __version__ = _version.version except ImportError: pass + +TSBROWSE_DATA_VERSION = "1" diff --git a/tsbrowse/__main__.py b/tsbrowse/__main__.py index 089d71c..e9ee7bb 100644 --- a/tsbrowse/__main__.py +++ b/tsbrowse/__main__.py @@ -1,12 +1,10 @@ -import pathlib import time import traceback +from pathlib import Path import click import daiquiri import holoviews as hv -import tskit -import tszip from holoviews import opts # Need to import daiquiri and set up logging before importing panel @@ -17,21 +15,11 @@ from . import model # noqa from . import pages # noqa from . import config # noqa +from . import preprocess as preprocess_ # noqa logger = daiquiri.getLogger("tsbrowse") -def load_data(path): - logger.info(f"Loading {path}") - try: - ts = tskit.load(path) - except tskit.FileFormatError: - ts = tszip.decompress(path) - - tsm = model.TSModel(ts, path.name) - return tsm - - def get_app(tsm): pn.extension(sizing_mode="stretch_width") pn.extension("tabulator") @@ -124,7 +112,13 @@ def setup_logging(log_level, no_log_filter): logger.setLevel("CRITICAL") -@click.command() +@click.group() +def cli(): + """Command line interface for tsbrowse.""" + pass + + +@cli.command() @click.argument("path", type=click.Path(exists=True, dir_okay=False)) @click.option("--annotations-file", type=click.Path(exists=True, dir_okay=False)) @click.option("--port", default=8080, help="Port to serve on") @@ -140,13 +134,13 @@ def setup_logging(log_level, no_log_filter): is_flag=True, help="Do not filter the output log (advanced debugging only)", ) -def main(path, port, show, log_level, no_log_filter, annotations_file): +def serve(path, port, show, log_level, no_log_filter, annotations_file): """ Run the tsbrowse server. """ setup_logging(log_level, no_log_filter) - tsm = load_data(pathlib.Path(path)) + tsm = model.TSModel(path) if annotations_file: config.ANNOTATIONS_FILE = annotations_file @@ -158,5 +152,26 @@ def app(): pn.serve(app, port=port, show=show, verbose=False) +@cli.command() +@click.argument("tszip_path", type=click.Path(exists=True, dir_okay=False)) +@click.option( + "--output", + type=click.Path(dir_okay=False), + default=None, + help="Optional output filename, defaults to tszip_path with .tsbrowse extension", +) +def preprocess(tszip_path, output): + """ + Preprocess a tskit tree sequence or tszip file, producing a .tsbrowse file. + """ + tszip_path = Path(tszip_path) + if output is None: + output = tszip_path.with_suffix(".tsbrowse") + + preprocess_.preprocess(tszip_path, output, show_progress=True) + logger.info(f"Preprocessing completed. Output saved to: {output}") + print(f"Preprocessing completed. You can now view with `tsbrowse serve {output}`") + + if __name__ == "__main__": - main() + cli() diff --git a/tsbrowse/cache.py b/tsbrowse/cache.py deleted file mode 100644 index d85264d..0000000 --- a/tsbrowse/cache.py +++ /dev/null @@ -1,39 +0,0 @@ -import functools -import pathlib - -import appdirs -import daiquiri -import diskcache - -logger = daiquiri.getLogger("cache") - - -def get_cache_dir(): - cache_dir = pathlib.Path(appdirs.user_cache_dir("tsbrowse", "tsbrowse")) - cache_dir.mkdir(exist_ok=True, parents=True) - return cache_dir - - -cache = diskcache.Cache(get_cache_dir()) - - -def disk_cache(version): - def decorator(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - uuid = self.file_uuid - if uuid is None: - logger.info(f"No uuid, not caching {func.__name__}") - return func(self, *args, **kwargs) - key = f"{self.file_uuid}-{func.__name__}-{version}" - if key in cache: - logger.info(f"Fetching {key} from cache") - return cache[key] - logger.info(f"Calculating {key} and caching") - result = func(self, *args, **kwargs) - cache[key] = result - return result - - return wrapper - - return decorator diff --git a/tsbrowse/jit.py b/tsbrowse/jit.py index 7ec9c6e..391359a 100644 --- a/tsbrowse/jit.py +++ b/tsbrowse/jit.py @@ -31,8 +31,8 @@ } -def numba_njit(**numba_kwargs): - def _numba_njit(func): +def numba_jit(**numba_kwargs): + def _numba_jit(func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) # pragma: no cover @@ -43,7 +43,7 @@ def wrapper(*args, **kwargs): else: return func - return _numba_njit + return _numba_jit def numba_jitclass(spec): diff --git a/tsbrowse/model.py b/tsbrowse/model.py index eb273a4..e81ab21 100644 --- a/tsbrowse/model.py +++ b/tsbrowse/model.py @@ -1,354 +1,16 @@ -import dataclasses -import json +import pathlib from functools import cached_property import daiquiri -import numba import numpy as np import pandas as pd -import tskit +import tszip +import zarr -from . import jit -from .cache import disk_cache +from . import TSBROWSE_DATA_VERSION logger = daiquiri.getLogger("tsbrowse") -spec = [ - ("num_edges", numba.int64), - ("sequence_length", numba.float64), - ("edges_left", numba.float64[:]), - ("edges_right", numba.float64[:]), - ("edge_insertion_order", numba.int32[:]), - ("edge_removal_order", numba.int32[:]), - ("edge_insertion_index", numba.int64), - ("edge_removal_index", numba.int64), - ("interval", numba.float64[:]), - ("in_range", numba.int64[:]), - ("out_range", numba.int64[:]), -] - - -@jit.numba_jitclass(spec) -class TreePosition: - def __init__( - self, - num_edges, - sequence_length, - edges_left, - edges_right, - edge_insertion_order, - edge_removal_order, - ): - self.num_edges = num_edges - self.sequence_length = sequence_length - self.edges_left = edges_left - self.edges_right = edges_right - self.edge_insertion_order = edge_insertion_order - self.edge_removal_order = edge_removal_order - self.edge_insertion_index = 0 - self.edge_removal_index = 0 - self.interval = np.zeros(2) - self.in_range = np.zeros(2, dtype=np.int64) - self.out_range = np.zeros(2, dtype=np.int64) - - def next(self): # noqa - left = self.interval[1] - j = self.in_range[1] - k = self.out_range[1] - self.in_range[0] = j - self.out_range[0] = k - M = self.num_edges - edges_left = self.edges_left - edges_right = self.edges_right - out_order = self.edge_removal_order - in_order = self.edge_insertion_order - - while k < M and edges_right[out_order[k]] == left: - k += 1 - while j < M and edges_left[in_order[j]] == left: - j += 1 - self.out_range[1] = k - self.in_range[1] = j - - right = self.sequence_length - if j < M: - right = min(right, edges_left[in_order[j]]) - if k < M: - right = min(right, edges_right[out_order[k]]) - self.interval[:] = [left, right] - return j < M or left < self.sequence_length - - -# Helper function to make it easier to communicate with the numba class -def alloc_tree_position(ts): - return TreePosition( - num_edges=ts.num_edges, - sequence_length=ts.sequence_length, - edges_left=ts.edges_left, - edges_right=ts.edges_right, - edge_insertion_order=ts.indexes_edge_insertion_order, - edge_removal_order=ts.indexes_edge_removal_order, - ) - - -@jit.numba_njit() -def _compute_per_tree_stats( - tree_pos, num_trees, num_nodes, nodes_time, edges_parent, edges_child -): - tbl = np.zeros(num_trees) - num_internal_nodes = np.zeros(num_trees) - max_arity = np.zeros(num_trees, dtype=np.int32) - num_children = np.zeros(num_nodes, dtype=np.int32) - nodes_with_arity = np.zeros(num_nodes, dtype=np.int32) - - current_tbl = 0 - tree_index = 0 - current_num_internal_nodes = 0 - current_max_arity = 0 - while tree_pos.next(): - for j in range(tree_pos.out_range[0], tree_pos.out_range[1]): - e = tree_pos.edge_removal_order[j] - p = edges_parent[e] - nodes_with_arity[num_children[p]] -= 1 - if ( - num_children[p] == current_max_arity - and nodes_with_arity[num_children[p]] == 1 - ): - current_max_arity -= 1 - - num_children[p] -= 1 - if num_children[p] == 0: - current_num_internal_nodes -= 1 - else: - nodes_with_arity[num_children[p]] += 1 - c = edges_child[e] - branch_length = nodes_time[p] - nodes_time[c] - current_tbl -= branch_length - - for j in range(tree_pos.in_range[0], tree_pos.in_range[1]): - e = tree_pos.edge_insertion_order[j] - p = edges_parent[e] - if num_children[p] == 0: - current_num_internal_nodes += 1 - else: - nodes_with_arity[num_children[p]] -= 1 - num_children[p] += 1 - nodes_with_arity[num_children[p]] += 1 - if num_children[p] > current_max_arity: - current_max_arity = num_children[p] - c = edges_child[e] - branch_length = nodes_time[p] - nodes_time[c] - current_tbl += branch_length - tbl[tree_index] = current_tbl - num_internal_nodes[tree_index] = current_num_internal_nodes - max_arity[tree_index] = current_max_arity - tree_index += 1 - # print("tree", tree_index, nodes_with_arity) - - return tbl, num_internal_nodes, max_arity - - -def compute_per_tree_stats(ts): - """ - Returns the per-tree statistics - """ - tree_pos = alloc_tree_position(ts) - return _compute_per_tree_stats( - tree_pos, - ts.num_trees, - ts.num_nodes, - ts.nodes_time, - ts.edges_parent, - ts.edges_child, - ) - - -@jit.numba_njit() -def _compute_mutation_parent_counts(mutations_parent): - N = mutations_parent.shape[0] - num_parents = np.zeros(N, dtype=np.int32) - - for j in range(N): - u = j - while mutations_parent[u] != -1: - num_parents[j] += 1 - u = mutations_parent[u] - return num_parents - - -@jit.numba_njit() -def _compute_mutation_inheritance_counts( - tree_pos, - num_nodes, - num_mutations, - edges_parent, - edges_child, - samples, - mutations_position, - mutations_node, - mutations_parent, -): - parent = np.zeros(num_nodes, dtype=np.int32) - 1 - num_samples = np.zeros(num_nodes, dtype=np.int32) - num_samples[samples] = 1 - mutations_num_descendants = np.zeros(num_mutations, dtype=np.int32) - mutations_num_inheritors = np.zeros(num_mutations, dtype=np.int32) - - mut_id = 0 - - while tree_pos.next(): - for j in range(tree_pos.out_range[0], tree_pos.out_range[1]): - e = tree_pos.edge_removal_order[j] - c = edges_child[e] - p = edges_parent[e] - parent[c] = -1 - u = p - while u != -1: - num_samples[u] -= num_samples[c] - u = parent[u] - - for j in range(tree_pos.in_range[0], tree_pos.in_range[1]): - e = tree_pos.edge_insertion_order[j] - p = edges_parent[e] - c = edges_child[e] - parent[c] = p - u = p - while u != -1: - num_samples[u] += num_samples[c] - u = parent[u] - left, right = tree_pos.interval - while mut_id < num_mutations and mutations_position[mut_id] < right: - assert mutations_position[mut_id] >= left - mutation_node = mutations_node[mut_id] - descendants = num_samples[mutation_node] - mutations_num_descendants[mut_id] = descendants - mutations_num_inheritors[mut_id] = descendants - # Subtract this number of descendants from the parent mutation. We are - # guaranteed to list parents mutations before their children - mut_parent = mutations_parent[mut_id] - if mut_parent != -1: - mutations_num_inheritors[mut_parent] -= descendants - mut_id += 1 - - return mutations_num_descendants, mutations_num_inheritors - - -@dataclasses.dataclass -class MutationCounts: - num_parents: np.ndarray - num_inheritors: np.ndarray - num_descendants: np.ndarray - - -def compute_mutation_counts(ts): - logger.info("Computing mutation inheritance counts") - tree_pos = alloc_tree_position(ts) - mutations_position = ts.sites_position[ts.mutations_site].astype(int) - num_descendants, num_inheritors = _compute_mutation_inheritance_counts( - tree_pos, - ts.num_nodes, - ts.num_mutations, - ts.edges_parent, - ts.edges_child, - ts.samples(), - mutations_position, - ts.mutations_node, - ts.mutations_parent, - ) - num_parents = _compute_mutation_parent_counts(ts.mutations_parent) - return MutationCounts(num_parents, num_inheritors, num_descendants) - - -@jit.numba_njit() -def _compute_population_mutation_counts( - tree_pos, - num_nodes, - num_mutations, - num_populations, - edges_parent, - edges_child, - nodes_is_sample, - nodes_population, - mutations_position, - mutations_node, - mutations_parent, -): - num_pop_samples = np.zeros((num_nodes, num_populations), dtype=np.int32) - - pop_mutation_count = np.zeros((num_populations, num_mutations), dtype=np.int32) - parent = np.zeros(num_nodes, dtype=np.int32) - 1 - - for u in range(num_nodes): - if nodes_is_sample[u]: - num_pop_samples[u, nodes_population[u]] = 1 - - mut_id = 0 - while tree_pos.next(): - for j in range(tree_pos.out_range[0], tree_pos.out_range[1]): - e = tree_pos.edge_removal_order[j] - c = edges_child[e] - p = edges_parent[e] - parent[c] = -1 - u = p - while u != -1: - for k in range(num_populations): - num_pop_samples[u, k] -= num_pop_samples[c, k] - u = parent[u] - - for j in range(tree_pos.in_range[0], tree_pos.in_range[1]): - e = tree_pos.edge_insertion_order[j] - p = edges_parent[e] - c = edges_child[e] - parent[c] = p - u = p - while u != -1: - for k in range(num_populations): - num_pop_samples[u, k] += num_pop_samples[c, k] - u = parent[u] - - left, right = tree_pos.interval - while mut_id < num_mutations and mutations_position[mut_id] < right: - assert mutations_position[mut_id] >= left - mutation_node = mutations_node[mut_id] - for pop in range(num_populations): - pop_mutation_count[pop, mut_id] = num_pop_samples[mutation_node, pop] - mut_id += 1 - - return pop_mutation_count - - -def node_is_sample(ts): - sample_flag = np.full_like(ts.nodes_flags, tskit.NODE_IS_SAMPLE) - return np.bitwise_and(ts.nodes_flags, sample_flag) != 0 - - -def compute_population_mutation_counts(ts): - """ - Return a (num_populations, num_mutations) array that gives the frequency - of each mutation in each of the populations in the specified tree sequence. - """ - logger.info( - f"Computing mutation frequencies within {ts.num_populations} populations" - ) - mutations_position = ts.sites_position[ts.mutations_site].astype(int) - - if np.any(ts.nodes_population[ts.samples()] == -1): - raise ValueError("Sample nodes must be assigned to populations") - - return _compute_population_mutation_counts( - alloc_tree_position(ts), - ts.num_nodes, - ts.num_mutations, - ts.num_populations, - ts.edges_parent, - ts.edges_child, - node_is_sample(ts), - ts.nodes_population, - mutations_position, - ts.mutations_node, - ts.mutations_parent, - ) - class TSModel: """ @@ -356,37 +18,64 @@ class TSModel: convenience methods for analysing the tree sequence. """ - def __init__(self, ts, name=None): - self.ts = ts - self.name = name - - self.sites_num_mutations = np.bincount( - self.ts.mutations_site, minlength=self.ts.num_sites - ) - self.nodes_num_mutations = np.bincount( - self.ts.mutations_node, minlength=self.ts.num_nodes - ) + def __init__(self, tsbrowse_path): + tsbrowse_path = pathlib.Path(tsbrowse_path) + root = zarr.open(zarr.ZipStore(tsbrowse_path, mode="r")) + if "tsbrowse" not in root.attrs or "data_version" not in root.attrs["tsbrowse"]: + raise ValueError("File is not a tsbrowse file, run tsbrowse preprocess") + if root.attrs["tsbrowse"]["data_version"] != TSBROWSE_DATA_VERSION: + raise ValueError( + f"File {tsbrowse_path} has version " + f"{root.attrs['tsbrowse']['data_version']}, " + f"but this version of tsbrowse expects version " + f"{TSBROWSE_DATA_VERSION} rerun tsbrowse preprocess" + ) + self.ts = tszip.load(tsbrowse_path) + self.name = tsbrowse_path.stem + for table_name in ["edges", "trees", "mutations", "nodes", "sites"]: + # filter out ragged arrays with offset + array_names = set(root[table_name].keys()) + ragged_array_names = { + "_".join(name.split("_")[:-1]) + for name in array_names + if "offset" in name + } + array_names -= set(ragged_array_names) + array_names -= {"metadata_schema"} + array_names -= {f"{name}_offset" for name in ragged_array_names} + arrays = {name: root[table_name][name][:] for name in array_names} + ragged_array_names -= {"metadata"} + # Not needed for now + # for name in ragged_array_names: + # array = root[table_name][name][:] + # offsets = root[table_name][f"{name}_offset"][:] + # arrays[name] = np.array( + # [ + # array[s].tobytes().decode("utf-8") + # for s in ( + # slice(start, end) for start, end in zip(offsets[:-1], offsets[1:]) + # ) + # ] + # ) + df = pd.DataFrame(arrays) + setattr(self, f"{table_name}_df", df) @property def file_uuid(self): return self.ts.file_uuid @cached_property - @disk_cache("v1") def summary_df(self): - nodes_with_zero_muts = np.sum(self.nodes_num_mutations == 0) - sites_with_zero_muts = np.sum(self.sites_num_mutations == 0) - data = [ ("samples", self.ts.num_samples), ("nodes", self.ts.num_nodes), ("mutations", self.ts.num_mutations), - ("nodes_with_zero_muts", nodes_with_zero_muts), - ("sites_with_zero_muts", sites_with_zero_muts), - ("max_mutations_per_site", np.max(self.sites_num_mutations)), - ("mean_mutations_per_site", np.mean(self.sites_num_mutations)), - ("median_mutations_per_site", np.median(self.sites_num_mutations)), - ("max_mutations_per_node", np.max(self.nodes_num_mutations)), + ("nodes_with_zero_muts", np.sum(self.nodes_df.num_mutations == 0)), + ("sites_with_zero_muts", np.sum(self.sites_df.num_mutations == 0)), + ("max_mutations_per_site", np.max(self.sites_df.num_mutations)), + ("mean_mutations_per_site", np.mean(self.sites_df.num_mutations)), + ("median_mutations_per_site", np.median(self.sites_df.num_mutations)), + ("max_mutations_per_node", np.max(self.nodes_df.num_mutations)), ] df = pd.DataFrame( {"property": [d[0] for d in data], "value": [d[1] for d in data]} @@ -397,215 +86,6 @@ def summary_df(self): def _repr_html_(self): return self.summary_df._repr_html_() - @staticmethod - @jit.numba_njit() - def child_bounds(num_nodes, edges_left, edges_right, edges_child): - num_edges = edges_left.shape[0] - child_left = np.zeros(num_nodes, dtype=np.float64) + np.inf - child_right = np.zeros(num_nodes, dtype=np.float64) - - for e in range(num_edges): - u = edges_child[e] - if edges_left[e] < child_left[u]: - child_left[u] = edges_left[e] - if edges_right[e] > child_right[u]: - child_right[u] = edges_right[e] - return child_left, child_right - - @cached_property - @disk_cache("v2") - def mutations_df(self): - # FIXME use tskit's impute mutations time - ts = self.ts - mutations_time = ts.mutations_time.copy() - mutations_node = ts.mutations_node.copy() - unknown = tskit.is_unknown_time(mutations_time) - mutations_time[unknown] = self.ts.nodes_time[mutations_node[unknown]] - - position = ts.sites_position[ts.mutations_site] - - tables = self.ts.tables - derived_state = tables.mutations.derived_state - offsets = tables.mutations.derived_state_offset - derived_state = np.array( - [ - derived_state[s].tobytes().decode("utf-8") - for s in ( - slice(start, end) for start, end in zip(offsets[:-1], offsets[1:]) - ) - ] - ) - ancestral_state = tables.sites.ancestral_state - offsets = tables.sites.ancestral_state_offset - ancestral_state = np.array( - [ - ancestral_state[s].tobytes().decode("utf-8") - for s in ( - slice(start, end) for start, end in zip(offsets[:-1], offsets[1:]) - ) - ] - ) - del tables - inherited_state = ancestral_state[ts.mutations_site] - mutations_with_parent = ts.mutations_parent != -1 - - parent = ts.mutations_parent[mutations_with_parent] - assert np.all(parent >= 0) - inherited_state[mutations_with_parent] = derived_state[parent] - self.mutations_derived_state = derived_state - self.mutations_inherited_state = inherited_state - - population_data = {} - if ts.num_populations > 0: - pop_mutation_count = compute_population_mutation_counts(ts) - for pop in ts.populations(): - name = f"pop{pop.id}" - if isinstance(pop.metadata, bytes): - metadata_dict = json.loads(pop.metadata.decode("utf-8")) - else: - metadata_dict = pop.metadata - if "name" in metadata_dict: - name = metadata_dict["name"] - col_name = f"pop_{name}_freq" - population_data[col_name] = pop_mutation_count[pop.id] / ts.num_samples - - counts = compute_mutation_counts(ts) - df = pd.DataFrame( - { - "id": np.arange(ts.num_mutations), - "position": position, - "node": ts.mutations_node, - "time": mutations_time, - "derived_state": self.mutations_derived_state, - "inherited_state": self.mutations_inherited_state, - "num_descendants": counts.num_descendants, - "num_inheritors": counts.num_inheritors, - "num_parents": counts.num_parents, - **population_data, - } - ) - logger.info("Computed mutations dataframe") - return df.astype( - { - "id": "int", - "position": "float64", - "node": "int", - "time": "float64", - "derived_state": "str", - "inherited_state": "str", - "num_descendants": "int", - "num_inheritors": "int", - "num_parents": "int", - } - ) - - @cached_property - @disk_cache("v1") - def edges_df(self): - ts = self.ts - left = ts.edges_left - right = ts.edges_right - edges_parent = ts.edges_parent - edges_child = ts.edges_child - nodes_time = ts.nodes_time - parent_time = nodes_time[edges_parent] - child_time = nodes_time[edges_child] - branch_length = parent_time - child_time - span = right - left - - df = pd.DataFrame( - { - "left": left, - "right": right, - "parent": edges_parent, - "child": edges_child, - "parent_time": parent_time, - "child_time": child_time, - "branch_length": branch_length, - "span": span, - } - ) - - logger.info("Computed edges dataframe") - return df.astype( - { - "left": "float64", - "right": "float64", - "parent": "int", - "child": "int", - "parent_time": "float64", - "child_time": "float64", - "branch_length": "float64", - "span": "float64", - } - ) - - @cached_property - @disk_cache("v2") - def nodes_df(self): - ts = self.ts - child_left, child_right = self.child_bounds( - ts.num_nodes, ts.edges_left, ts.edges_right, ts.edges_child - ) - df = pd.DataFrame( - { - "time": ts.nodes_time, - "num_mutations": self.nodes_num_mutations, - "ancestors_span": child_right - child_left, - "node_flags": ts.nodes_flags, - } - ) - logger.info("Computed nodes dataframe") - return df.astype( - { - "time": "float64", - "num_mutations": "int", - "ancestors_span": "float64", - "node_flags": "int", - } - ) - - @cached_property - @disk_cache("v1") - def trees_df(self): - ts = self.ts - num_trees = ts.num_trees - - total_branch_length, num_internal_nodes, max_arity = compute_per_tree_stats(ts) - - # FIXME - need to add this to the computation above - mean_internal_arity = np.zeros(num_trees) - - site_tree_index = self.calc_site_tree_index() - unique_values, counts = np.unique(site_tree_index, return_counts=True) - sites_per_tree = np.zeros(ts.num_trees, dtype=np.int64) - sites_per_tree[unique_values] = counts - breakpoints = ts.breakpoints(as_array=True) - df = pd.DataFrame( - { - "left": breakpoints[:-1], - "right": breakpoints[1:], - "total_branch_length": total_branch_length, - "mean_internal_arity": mean_internal_arity, - "max_internal_arity": max_arity, - "num_sites": sites_per_tree, - "num_mutations": self.calc_mutations_per_tree(), - } - ) - - logger.info("Computed trees dataframe") - return df.astype( - { - "left": "int", - "right": "int", - "total_branch_length": "float64", - "mean_internal_arity": "float64", - "max_internal_arity": "float64", - "num_sites": "int", - "num_mutations": "int", - } - ) - def genes_df(self, genes_file): genes_df = pd.read_csv(genes_file, sep=";") # TODO file checks! @@ -675,26 +155,3 @@ def calc_mean_node_arity(self): mode="node", )[:, 0] return span_sums / node_spans - - def calc_site_tree_index(self): - return ( - np.searchsorted( - self.ts.breakpoints(as_array=True), self.ts.sites_position, side="right" - ) - - 1 - ) - - def calc_sites_per_tree(self): - site_tree_index = self.calc_site_tree_index() - unique_values, counts = np.unique(site_tree_index, return_counts=True) - sites_per_tree = np.zeros(self.ts.num_trees, dtype=np.int64) - sites_per_tree[unique_values] = counts - return sites_per_tree - - def calc_mutations_per_tree(self): - site_tree_index = self.calc_site_tree_index() - mutation_tree_index = site_tree_index[self.ts.mutations_site] - unique_values, counts = np.unique(mutation_tree_index, return_counts=True) - mutations_per_tree = np.zeros(self.ts.num_trees, dtype=np.int64) - mutations_per_tree[unique_values] = counts - return mutations_per_tree diff --git a/tsbrowse/pages/frequency_spectra.py b/tsbrowse/pages/frequency_spectra.py index 15be79a..9eb1244 100644 --- a/tsbrowse/pages/frequency_spectra.py +++ b/tsbrowse/pages/frequency_spectra.py @@ -33,7 +33,7 @@ def make_afs_panel(afs_df, log_bins, mode): bin_edges = np.linspace(1, len_df, num_bins).astype(int) xrotation = 45 - labels = [f"{bin_edges[i]}-{bin_edges[i+1]}" for i in range(len(bin_edges) - 1)] + labels = [f"{bin_edges[i]} - {bin_edges[i + 1]}" for i in range(len(bin_edges) - 1)] afs_df["bins"] = pd.cut( afs_df["allele_count"], bins=bin_edges, diff --git a/tsbrowse/pages/nodes.py b/tsbrowse/pages/nodes.py index 7be76cb..853f1b7 100644 --- a/tsbrowse/pages/nodes.py +++ b/tsbrowse/pages/nodes.py @@ -33,7 +33,7 @@ def make_node_hist_panel(tsm, log_y): hist_panel = pn.bind(make_node_hist_panel, log_y=log_y_checkbox, tsm=tsm) def make_node_plot(data, node_types): - df = data[data.node_flags.isin(node_types)] + df = data[data["flags"].isin(node_types)] points = df.hvplot.scatter( x="ancestors_span", y="time", @@ -64,7 +64,7 @@ def make_node_panel(node_types): ) return pn.Row(nodes_spans_plot) - anc_options = list(df_nodes.node_flags.unique()) + anc_options = list(df_nodes["flags"].unique()) checkboxes = pn.widgets.CheckBoxGroup( name="Node Types", value=anc_options, options=anc_options ) diff --git a/tsbrowse/preprocess.py b/tsbrowse/preprocess.py new file mode 100644 index 0000000..fd7672b --- /dev/null +++ b/tsbrowse/preprocess.py @@ -0,0 +1,540 @@ +import dataclasses +import json +import pathlib + +import daiquiri +import numba +import numpy as np +import tskit +import tszip +import zarr +from tqdm import tqdm + +from . import jit +from tsbrowse import TSBROWSE_DATA_VERSION + +logger = daiquiri.getLogger("tsbrowse") + + +def node_is_sample(ts): + sample_flag = np.full_like(ts.nodes_flags, tskit.NODE_IS_SAMPLE) + return np.bitwise_and(ts.nodes_flags, sample_flag) != 0 + + +spec = [ + ("num_edges", numba.int64), + ("sequence_length", numba.float64), + ("edges_left", numba.float64[:]), + ("edges_right", numba.float64[:]), + ("edge_insertion_order", numba.int32[:]), + ("edge_removal_order", numba.int32[:]), + ("edge_insertion_index", numba.int64), + ("edge_removal_index", numba.int64), + ("interval", numba.float64[:]), + ("in_range", numba.int64[:]), + ("out_range", numba.int64[:]), +] + + +@jit.numba_jitclass(spec) +class TreePosition: + def __init__( + self, + num_edges, + sequence_length, + edges_left, + edges_right, + edge_insertion_order, + edge_removal_order, + ): + self.num_edges = num_edges + self.sequence_length = sequence_length + self.edges_left = edges_left + self.edges_right = edges_right + self.edge_insertion_order = edge_insertion_order + self.edge_removal_order = edge_removal_order + self.edge_insertion_index = 0 + self.edge_removal_index = 0 + self.interval = np.zeros(2) + self.in_range = np.zeros(2, dtype=np.int64) + self.out_range = np.zeros(2, dtype=np.int64) + + def next(self): # noqa + left = self.interval[1] + j = self.in_range[1] + k = self.out_range[1] + self.in_range[0] = j + self.out_range[0] = k + M = self.num_edges + edges_left = self.edges_left + edges_right = self.edges_right + out_order = self.edge_removal_order + in_order = self.edge_insertion_order + + while k < M and edges_right[out_order[k]] == left: + k += 1 + while j < M and edges_left[in_order[j]] == left: + j += 1 + self.out_range[1] = k + self.in_range[1] = j + + right = self.sequence_length + if j < M: + right = min(right, edges_left[in_order[j]]) + if k < M: + right = min(right, edges_right[out_order[k]]) + self.interval[:] = [left, right] + return j < M or left < self.sequence_length + + +# Helper function to make it easier to communicate with the numba class +def alloc_tree_position(ts): + return TreePosition( + num_edges=ts.num_edges, + sequence_length=ts.sequence_length, + edges_left=ts.edges_left, + edges_right=ts.edges_right, + edge_insertion_order=ts.indexes_edge_insertion_order, + edge_removal_order=ts.indexes_edge_removal_order, + ) + + +@jit.numba_jit() +def _compute_population_mutation_counts( + tree_pos, + num_nodes, + num_mutations, + num_populations, + edges_parent, + edges_child, + nodes_is_sample, + nodes_population, + mutations_position, + mutations_node, + mutations_parent, +): + num_pop_samples = np.zeros((num_nodes, num_populations), dtype=np.int32) + + pop_mutation_count = np.zeros((num_populations, num_mutations), dtype=np.int32) + parent = np.zeros(num_nodes, dtype=np.int32) - 1 + + for u in range(num_nodes): + if nodes_is_sample[u]: + num_pop_samples[u, nodes_population[u]] = 1 + + mut_id = 0 + while tree_pos.next(): + for j in range(tree_pos.out_range[0], tree_pos.out_range[1]): + e = tree_pos.edge_removal_order[j] + c = edges_child[e] + p = edges_parent[e] + parent[c] = -1 + u = p + while u != -1: + for k in range(num_populations): + num_pop_samples[u, k] -= num_pop_samples[c, k] + u = parent[u] + + for j in range(tree_pos.in_range[0], tree_pos.in_range[1]): + e = tree_pos.edge_insertion_order[j] + p = edges_parent[e] + c = edges_child[e] + parent[c] = p + u = p + while u != -1: + for k in range(num_populations): + num_pop_samples[u, k] += num_pop_samples[c, k] + u = parent[u] + + left, right = tree_pos.interval + while mut_id < num_mutations and mutations_position[mut_id] < right: + assert mutations_position[mut_id] >= left + mutation_node = mutations_node[mut_id] + for pop in range(num_populations): + pop_mutation_count[pop, mut_id] = num_pop_samples[mutation_node, pop] + mut_id += 1 + + return pop_mutation_count + + +def compute_population_mutation_counts(ts): + """ + Return a (num_populations, num_mutations) array that gives the frequency + of each mutation in each of the populations in the specified tree sequence. + """ + logger.info( + f"Computing mutation frequencies within {ts.num_populations} populations" + ) + mutations_position = ts.sites_position[ts.mutations_site].astype(int) + + if np.any(ts.nodes_population[ts.samples()] == -1): + raise ValueError("Sample nodes must be assigned to populations") + + return _compute_population_mutation_counts( + alloc_tree_position(ts), + ts.num_nodes, + ts.num_mutations, + ts.num_populations, + ts.edges_parent, + ts.edges_child, + node_is_sample(ts), + ts.nodes_population, + mutations_position, + ts.mutations_node, + ts.mutations_parent, + ) + + +@dataclasses.dataclass +class MutationCounts: + num_parents: np.ndarray + num_inheritors: np.ndarray + num_descendants: np.ndarray + + +def compute_mutation_counts(ts): + logger.info("Computing mutation inheritance counts") + tree_pos = alloc_tree_position(ts) + mutations_position = ts.sites_position[ts.mutations_site].astype(int) + num_descendants, num_inheritors = _compute_mutation_inheritance_counts( + tree_pos, + ts.num_nodes, + ts.num_mutations, + ts.edges_parent, + ts.edges_child, + ts.samples(), + mutations_position, + ts.mutations_node, + ts.mutations_parent, + ) + num_parents = _compute_mutation_parent_counts(ts.mutations_parent) + return MutationCounts(num_parents, num_inheritors, num_descendants) + + +@jit.numba_jit() +def _compute_mutation_parent_counts(mutations_parent): + N = mutations_parent.shape[0] + num_parents = np.zeros(N, dtype=np.int32) + + for j in range(N): + u = j + while mutations_parent[u] != -1: + num_parents[j] += 1 + u = mutations_parent[u] + return num_parents + + +@jit.numba_jit() +def _compute_mutation_inheritance_counts( + tree_pos, + num_nodes, + num_mutations, + edges_parent, + edges_child, + samples, + mutations_position, + mutations_node, + mutations_parent, +): + parent = np.zeros(num_nodes, dtype=np.int32) - 1 + num_samples = np.zeros(num_nodes, dtype=np.int32) + num_samples[samples] = 1 + mutations_num_descendants = np.zeros(num_mutations, dtype=np.int32) + mutations_num_inheritors = np.zeros(num_mutations, dtype=np.int32) + + mut_id = 0 + + while tree_pos.next(): + for j in range(tree_pos.out_range[0], tree_pos.out_range[1]): + e = tree_pos.edge_removal_order[j] + c = edges_child[e] + p = edges_parent[e] + parent[c] = -1 + u = p + while u != -1: + num_samples[u] -= num_samples[c] + u = parent[u] + + for j in range(tree_pos.in_range[0], tree_pos.in_range[1]): + e = tree_pos.edge_insertion_order[j] + p = edges_parent[e] + c = edges_child[e] + parent[c] = p + u = p + while u != -1: + num_samples[u] += num_samples[c] + u = parent[u] + left, right = tree_pos.interval + while mut_id < num_mutations and mutations_position[mut_id] < right: + assert mutations_position[mut_id] >= left + mutation_node = mutations_node[mut_id] + descendants = num_samples[mutation_node] + mutations_num_descendants[mut_id] = descendants + mutations_num_inheritors[mut_id] = descendants + # Subtract this number of descendants from the parent mutation. We are + # guaranteed to list parents mutations before their children + mut_parent = mutations_parent[mut_id] + if mut_parent != -1: + mutations_num_inheritors[mut_parent] -= descendants + mut_id += 1 + + return mutations_num_descendants, mutations_num_inheritors + + +def mutations(ts): + # FIXME use tskit's impute mutations time + mutations_time = ts.mutations_time.copy() + mutations_node = ts.mutations_node.copy() + unknown = tskit.is_unknown_time(mutations_time) + mutations_time[unknown] = ts.nodes_time[mutations_node[unknown]] + + position = ts.sites_position[ts.mutations_site] + + tables = ts.tables + derived_state = tables.mutations.derived_state + offsets = tables.mutations.derived_state_offset + derived_state = np.array( + [ + derived_state[s].tobytes().decode("utf-8") + for s in ( + slice(start, end) for start, end in zip(offsets[:-1], offsets[1:]) + ) + ] + ) + ancestral_state = tables.sites.ancestral_state + offsets = tables.sites.ancestral_state_offset + ancestral_state = np.array( + [ + ancestral_state[s].tobytes().decode("utf-8") + for s in ( + slice(start, end) for start, end in zip(offsets[:-1], offsets[1:]) + ) + ] + ) + del tables + inherited_state = ancestral_state[ts.mutations_site] + + mutations_with_parent = ts.mutations_parent != -1 + + parent = ts.mutations_parent[mutations_with_parent] + assert np.all(parent >= 0) + inherited_state[mutations_with_parent] = derived_state[parent] + mutations_inherited_state = inherited_state + + population_data = {} + if ts.num_populations > 0: + pop_mutation_count = compute_population_mutation_counts(ts) + for pop in ts.populations(): + name = f"pop{pop.id}" + if isinstance(pop.metadata, bytes): + try: + metadata_dict = json.loads(pop.metadata.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + metadata_dict = {} + else: + metadata_dict = pop.metadata + if "name" in metadata_dict: + name = metadata_dict["name"] + col_name = f"pop_{name}_freq" + population_data[col_name] = pop_mutation_count[pop.id] / ts.num_samples + + counts = compute_mutation_counts(ts) + logger.info("Preprocessed mutations") + return { + "position": position, + "inherited_state": mutations_inherited_state, + "num_descendants": counts.num_descendants, + "num_inheritors": counts.num_inheritors, + "num_parents": counts.num_parents, + **population_data, + } + + +def edges(ts): + parent_time = ts.nodes_time[ts.edges_parent] + child_time = ts.nodes_time[ts.edges_child] + branch_length = parent_time - child_time + span = ts.edges_right - ts.edges_left + + logger.info("Preprocessed edges") + return { + "parent_time": parent_time, + "child_time": child_time, + "branch_length": branch_length, + "span": span, + } + + +@jit.numba_jit() +def child_bounds(num_nodes, edges_left, edges_right, edges_child): + num_edges = edges_left.shape[0] + child_left = np.zeros(num_nodes, dtype=np.float64) + np.inf + child_right = np.zeros(num_nodes, dtype=np.float64) + + for e in range(num_edges): + u = edges_child[e] + if edges_left[e] < child_left[u]: + child_left[u] = edges_left[e] + if edges_right[e] > child_right[u]: + child_right[u] = edges_right[e] + return child_left, child_right + + +def nodes(ts): + child_left, child_right = child_bounds( + ts.num_nodes, ts.edges_left, ts.edges_right, ts.edges_child + ) + nodes_num_mutations = np.bincount(ts.mutations_node, minlength=ts.num_nodes) + logger.info("Preprocessed nodes") + return { + "num_mutations": nodes_num_mutations, + "ancestors_span": child_right - child_left, + } + + +def sites(ts): + sites_num_mutations = np.bincount(ts.mutations_site, minlength=ts.num_sites) + logger.info("Preprocessed sites") + return { + "num_mutations": sites_num_mutations, + } + + +@jit.numba_jit() +def _compute_per_tree_stats( + tree_pos, num_trees, num_nodes, nodes_time, edges_parent, edges_child +): + tbl = np.zeros(num_trees) + num_internal_nodes = np.zeros(num_trees) + max_arity = np.zeros(num_trees, dtype=np.int32) + num_children = np.zeros(num_nodes, dtype=np.int32) + nodes_with_arity = np.zeros(num_nodes, dtype=np.int32) + + current_tbl = 0 + tree_index = 0 + current_num_internal_nodes = 0 + current_max_arity = 0 + while tree_pos.next(): + for j in range(tree_pos.out_range[0], tree_pos.out_range[1]): + e = tree_pos.edge_removal_order[j] + p = edges_parent[e] + nodes_with_arity[num_children[p]] -= 1 + if ( + num_children[p] == current_max_arity + and nodes_with_arity[num_children[p]] == 1 + ): + current_max_arity -= 1 + + num_children[p] -= 1 + if num_children[p] == 0: + current_num_internal_nodes -= 1 + else: + nodes_with_arity[num_children[p]] += 1 + c = edges_child[e] + branch_length = nodes_time[p] - nodes_time[c] + current_tbl -= branch_length + + for j in range(tree_pos.in_range[0], tree_pos.in_range[1]): + e = tree_pos.edge_insertion_order[j] + p = edges_parent[e] + if num_children[p] == 0: + current_num_internal_nodes += 1 + else: + nodes_with_arity[num_children[p]] -= 1 + num_children[p] += 1 + nodes_with_arity[num_children[p]] += 1 + if num_children[p] > current_max_arity: + current_max_arity = num_children[p] + c = edges_child[e] + branch_length = nodes_time[p] - nodes_time[c] + current_tbl += branch_length + tbl[tree_index] = current_tbl + num_internal_nodes[tree_index] = current_num_internal_nodes + max_arity[tree_index] = current_max_arity + tree_index += 1 + # print("tree", tree_index, nodes_with_arity) + + return tbl, num_internal_nodes, max_arity + + +def compute_per_tree_stats(ts): + """ + Returns the per-tree statistics + """ + tree_pos = alloc_tree_position(ts) + return _compute_per_tree_stats( + tree_pos, + ts.num_trees, + ts.num_nodes, + ts.nodes_time, + ts.edges_parent, + ts.edges_child, + ) + + +def trees(ts): + num_trees = ts.num_trees + total_branch_length, num_internal_nodes, max_arity = compute_per_tree_stats(ts) + + # FIXME - need to add this to the computation above + mean_internal_arity = np.zeros(num_trees) + + site_tree_index = ( + np.searchsorted(ts.breakpoints(as_array=True), ts.sites_position, side="right") + - 1 + ) + unique_values, counts = np.unique(site_tree_index, return_counts=True) + sites_per_tree = np.zeros(ts.num_trees, dtype=np.int64) + sites_per_tree[unique_values] = counts + + mutation_tree_index = site_tree_index[ts.mutations_site] + unique_mutation_values, mutation_counts = np.unique( + mutation_tree_index, return_counts=True + ) + mutations_per_tree = np.zeros(ts.num_trees, dtype=np.int64) + mutations_per_tree[unique_mutation_values] = mutation_counts + + breakpoints = ts.breakpoints(as_array=True) + logger.info("Pre processed trees") + return { + "left": breakpoints[:-1], + "right": breakpoints[1:], + "total_branch_length": total_branch_length, + "mean_internal_arity": mean_internal_arity, + "max_internal_arity": max_arity, + "num_sites": sites_per_tree, + "num_mutations": mutations_per_tree, + } + + +def preprocess(tszip_path, output_path, show_progress=False): + tszip_path = pathlib.Path(tszip_path) + preprocessors = [mutations, nodes, trees, edges, sites] + with tqdm( + total=2 + len(preprocessors), desc="Processing", disable=not show_progress + ) as pbar: + ts = tszip.load(tszip_path) + pbar.update(1) + + tszip.compress(ts, output_path) + pbar.update(1) + + # Preprocess the data first so we can error out before writing to the file + data = {} + for preprocessor in preprocessors: + group_name = preprocessor.__name__.split(".")[-1] + pbar.set_description(f"Processing {group_name}") + data[group_name] = preprocessor(ts) + pbar.update(1) + + with zarr.ZipStore(output_path, mode="a") as zarr_store: + root = zarr.group(store=zarr_store) + total_arrays = sum(len(arrays) for arrays in data.values()) + with tqdm( + total=total_arrays, desc="Writing", disable=not show_progress + ) as pbar: + for table_name, arrays in data.items(): + for array_name, array in arrays.items(): + root[f"{table_name}/{array_name}"] = array + pbar.update(1) + root.attrs["tsbrowse"] = {"data_version": TSBROWSE_DATA_VERSION} diff --git a/tsbrowse/raster.py b/tsbrowse/raster.py index 2a88a8b..c8a8f03 100644 --- a/tsbrowse/raster.py +++ b/tsbrowse/raster.py @@ -1,8 +1,10 @@ from tsbrowse import model -def raster_component(page, ts, png_filename, *, width=None, height=None, **kwargs): - tsm = model.TSModel(ts) +def raster_component( + page, tsbrowse, png_filename, *, width=None, height=None, **kwargs +): + tsm = model.TSModel(tsbrowse) p = page(tsm, **kwargs) if width is not None: p.width = width