diff --git a/model.py b/model.py index fa43ed3..bb81dfd 100644 --- a/model.py +++ b/model.py @@ -324,15 +324,26 @@ def mutations_df(self): position = ts.sites_position[ts.mutations_site] tables = self.ts.tables - assert np.all( - tables.mutations.derived_state_offset == np.arange(ts.num_mutations + 1) + derived_state = tables.mutations.derived_state + offsets = tables.mutations.derived_state_offset + derived_state = np.array( + [ + derived_state[s].tobytes().decode("utf-8") + for s in ( + slice(start, end) for start, end in zip(offsets[:-1], offsets[1:]) + ) + ] ) - derived_state = tables.mutations.derived_state.view("S1").astype(str) - - assert np.all( - tables.sites.ancestral_state_offset == np.arange(ts.num_sites + 1) + ancestral_state = tables.sites.ancestral_state + offsets = tables.sites.ancestral_state_offset + ancestral_state = np.array( + [ + ancestral_state[s].tobytes().decode("utf-8") + for s in ( + slice(start, end) for start, end in zip(offsets[:-1], offsets[1:]) + ) + ] ) - ancestral_state = tables.sites.ancestral_state.view("S1").astype(str) del tables inherited_state = ancestral_state[ts.mutations_site] mutations_with_parent = ts.mutations_parent != -1 diff --git a/tests/test_data_model.py b/tests/test_data_model.py index 7388e18..2501bbb 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -17,6 +17,8 @@ def single_tree_example_ts(): for j in range(6): tables.sites.add_row(position=j + 1, ancestral_state="A") tables.mutations.add_row(site=j, derived_state="T", node=j) + tables.sites.add_row(position=7, ancestral_state="FOOBAR") + tables.mutations.add_row(site=6, derived_state="FOOBARD", node=6) return tables.tree_sequence() @@ -102,16 +104,16 @@ def test_single_tree_example(self): ts = single_tree_example_ts() tsm = model.TSModel(ts) df = tsm.mutations_df - assert len(df) == 6 - nt.assert_array_equal(df.id, list(range(6))) - nt.assert_array_equal(df.node, list(range(6))) - nt.assert_array_equal(df.position, list(range(1, 7))) - nt.assert_array_equal(df.time, [0, 0, 0, 0, 1, 1]) - nt.assert_array_equal(df.derived_state, ["T"] * 6) - nt.assert_array_equal(df.inherited_state, ["A"] * 6) - nt.assert_array_equal(df.num_parents, [0] * 6) - nt.assert_array_equal(df.num_descendants, [1] * 4 + [2] * 2) - nt.assert_array_equal(df.num_inheritors, [1] * 4 + [2] * 2) + assert len(df) == 7 + nt.assert_array_equal(df.id, list(range(7))) + nt.assert_array_equal(df.node, list(range(7))) + nt.assert_array_equal(df.position, list(range(1, 8))) + nt.assert_array_equal(df.time, [0, 0, 0, 0, 1, 1, 2]) + nt.assert_array_equal(df.derived_state, ["T"] * 6 + ["FOOBARD"]) + nt.assert_array_equal(df.inherited_state, ["A"] * 6 + ["FOOBAR"]) + nt.assert_array_equal(df.num_parents, [0] * 7) + nt.assert_array_equal(df.num_descendants, [1] * 4 + [2] * 2 + [4]) + nt.assert_array_equal(df.num_inheritors, [1] * 4 + [2] * 2 + [4]) def test_single_tree_recurrent_mutation_example(self): ts = single_tree_recurrent_mutation_example_ts() @@ -162,7 +164,7 @@ def test_single_tree_example(self): df = tsm.nodes_df assert len(df) == 7 nt.assert_array_equal(df.time, [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 2.0]) - nt.assert_array_equal(df.num_mutations, [1, 1, 1, 1, 1, 1, 0]) + nt.assert_array_equal(df.num_mutations, [1, 1, 1, 1, 1, 1, 1]) nt.assert_array_equal(df.ancestors_span, [10, 10, 10, 10, 10, 10, -np.inf]) nt.assert_array_equal(df.is_sample, [1, 1, 1, 1, 0, 0, 0])