From 86880ad618238c4db2d02ada8a7bfd73864f5bd2 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sun, 20 Oct 2024 13:36:21 +0100 Subject: [PATCH 1/5] Updated samples_per_day plot --- sc2ts/info.py | 62 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/sc2ts/info.py b/sc2ts/info.py index f3e3211..362af0b 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -873,6 +873,21 @@ def site_summary(self, position): {"property": [d[0] for d in data], "value": [d[1] for d in data]} ) + def samples_summary(self): + data = [] + md = self.ts.metadata["sc2ts"] + for days_ago in np.arange(self.num_samples_per_day.shape[0]): + date = str(self.time_zero_as_date - days_ago) + data.append( + { + "date": self.time_zero_as_date - days_ago, + "samples_in_arg": self.num_samples_per_day[days_ago], + "samples_processed": md["num_samples_processed"].get(date, 0), + "exact_matches": md["exact_matches"]["date"].get(date, 0), + } + ) + return pd.DataFrame(data) + def recombinants_summary(self): data = [] for u in self.recombinants: @@ -1502,12 +1517,30 @@ def plot_deletion_overlaps(self, annotate_threshold=0.9): return fig, [ax] def plot_samples_per_day(self): - fig, ax = self._wide_plot(1, 1) - t = np.arange(self.num_samples_per_day.shape[0]) - ax.plot(self.time_zero_as_date - t, self.num_samples_per_day) - ax.set_xlabel("Date") - ax.set_ylabel("Number of samples") - return fig, [ax] + df = self.samples_summary() + fig, (ax1, ax2) = self._wide_plot(2, height=6, sharex=True) + + ax1.plot(df.date, df.samples_in_arg, label="In ARG") + ax1.plot(df.date, df.samples_processed, label="Processed") + ax1.plot(df.date, df.exact_matches, label="Exact matches") + + ax2.plot( + df.date, + df.samples_in_arg / df.samples_processed, + label="Fraction processed in ARG", + ) + ax2.plot( + df.date, + df.exact_matches / df.samples_processed, + label="Fraction processed exact matches", + ) + excluded = df.samples_processed - df.exact_matches - df.samples_in_arg + ax2.plot(df.date, excluded / df.samples_processed, label="Fraction excluded") + ax2.set_xlabel("Date") + ax1.set_ylabel("Number of samples") + ax1.legend() + ax2.legend() + return fig, [ax1, ax2] def plot_resources(self, start_date="2020-04-01"): ts = self.ts @@ -1596,7 +1629,8 @@ def draw_pango_lineage_subtree( appropriate set of samples. See that function for more details. """ return self.draw_subtree( - tracked_pango=[pango_lineage], position=position, *args, **kwargs) + tracked_pango=[pango_lineage], position=position, *args, **kwargs + ) def draw_subtree( self, @@ -1625,7 +1659,7 @@ def draw_subtree( untracked node lineages within polytomies are condensed into a dotted line. Clades containing more than a certain proportion of tracked nodes can also be collapsed (see the ``collapse_tracked`` parameter). - + Most parameters are passed directly to ``tskit.Tree.draw_svg()`` method, apart from the following: :param position int: The genomic position at which to draw the tree. If None, @@ -1698,7 +1732,9 @@ def draw_subtree( if extra_tracked_samples is not None: tn_set = set(tracked_nodes) - extra_tracked_samples = [e for e in extra_tracked_samples if e not in tn_set] + extra_tracked_samples = [ + e for e in extra_tracked_samples if e not in tn_set + ] tracked_nodes = np.concatenate((tracked_nodes, extra_tracked_samples)) tree = ts.at(position, tracked_samples=tracked_nodes) order = np.array( @@ -1800,8 +1836,8 @@ def draw_subtree( # Recombination nodes as larger open circles re_nodes = np.where(ts.nodes_flags & core.NODE_IS_RECOMBINANT)[0] styles.append( - ",".join([f".node.n{u} > .sym" for u in re_nodes]) + - f"{{r:{symbol_size/2*1.5:.2f}px; stroke:black; fill:white}}" + ",".join([f".node.n{u} > .sym" for u in re_nodes]) + + f"{{r:{symbol_size/2*1.5:.2f}px; stroke:black; fill:white}}" ) return tree.draw_svg( time_scale=time_scale, @@ -2032,8 +2068,8 @@ def draw_svg( # recombination nodes in larger open white circles re_nodes = np.where(ts.nodes_flags & core.NODE_IS_RECOMBINANT)[0] styles.append( - ",".join([f".node.n{u} > .sym" for u in re_nodes]) + - f"{{r: {symbol_size/2*1.5:.2f}px; stroke: black; fill: white}}" + ",".join([f".node.n{u} > .sym" for u in re_nodes]) + + f"{{r: {symbol_size/2*1.5:.2f}px; stroke: black; fill: white}}" ) svg = self.ts.draw_svg( size=size, From e5fd80369b2e0813c2e46b1810980bb640054ed6 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sun, 20 Oct 2024 14:12:27 +0100 Subject: [PATCH 2/5] Upadted resources plot --- sc2ts/info.py | 68 ++++++++++++++++++++++++++------------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/sc2ts/info.py b/sc2ts/info.py index 362af0b..bccdaef 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -1542,58 +1542,60 @@ def plot_samples_per_day(self): ax2.legend() return fig, [ax1, ax2] - def plot_resources(self, start_date="2020-04-01"): + # Tests don't pass because we're not storing resources + def fix_tests_plot_resources(self, start_date="2020-04-01"): ts = self.ts fig, ax = self._wide_plot(3, height=8, sharex=True) - elapsed_time = np.zeros(ts.num_provenances) - cpu_time = np.zeros(ts.num_provenances) - max_mem = np.zeros(ts.num_provenances) - date = np.zeros(ts.num_provenances, dtype="datetime64[D]") - num_samples = np.zeros(ts.num_provenances, dtype=int) - for j in range(1, ts.num_provenances): - p = ts.provenance(j) - record = json.loads(p.record) - text_date = record["parameters"]["args"][2] - date[j] = text_date - try: - resources = record["resources"] - elapsed_time[j] = resources["elapsed_time"] - cpu_time[j] = resources["user_time"] + resources["sys_time"] - max_mem[j] = resources["max_memory"] - except KeyError: - warnings.warn("Missing required provenance fields") - # The +3 is from lining up peaks by eye, not sure how it happens - days_ago = self.time_zero_as_date - date[j] + 3 - # Avoid division by zero - num_samples[j] = max(1, self.num_samples_per_day[days_ago.astype(int)]) - - keep = date >= np.array([start_date], dtype="datetime64[D]") - total_elapsed = datetime.timedelta(seconds=np.sum(elapsed_time)) - total_cpu = datetime.timedelta(seconds=np.sum(cpu_time)) + + dfs = self.samples_summary().set_index("date") + df = self.resources_summary().set_index("date") + # Should be able to do this with join, but I failed + df["samples_in_arg"] = dfs.loc[df.index]["samples_in_arg"] + df["samples_processed"] = dfs.loc[df.index]["samples_processed"] + + # df = resources_summary(self) + df = df[df.index >= start_date] + df["cpu_time"] = df.user_time + df.sys_time + x = np.array(df.index, dtype="datetime64[D]") + + total_elapsed = datetime.timedelta(seconds=np.sum(df.elapsed_time)) + total_cpu = datetime.timedelta(seconds=np.sum(df.cpu_time)) title = ( f"{humanize.naturaldelta(total_elapsed)} elapsed " f"using {humanize.naturaldelta(total_cpu)} of CPU time " - f"(utilisation = {np.sum(cpu_time) / np.sum(elapsed_time):.2f})" + f"(utilisation = {np.sum(df.cpu_time) / np.sum(df.elapsed_time):.2f})" ) - max_mem /= 1024**3 # Convert to GiB + # df.max_mem /= 1024**3 # Convert to GiB ax[0].set_title(title) - ax[0].plot(date[keep], elapsed_time[keep] / 60, label="elapsed time") + ax[0].plot(x, df.elapsed_time / 60, label="elapsed time") ax[-1].set_xlabel("Date") ax_twin = ax[0].twinx() ax_twin.plot( - date[keep], num_samples[keep], color="tab:red", alpha=0.5, label="samples" + x, df.samples_processed, color="tab:red", alpha=0.5, label="samples" ) - ax_twin.set_ylabel("Num samples") + ax_twin.set_ylabel("Samples processed") ax[0].set_ylabel("Elapsed time (mins)") ax[0].legend() ax_twin.legend() - ax[1].plot(date[keep], elapsed_time[keep] / num_samples[keep]) + ax[1].plot(x, df.elapsed_time / df.samples_processed) ax[1].set_ylabel("Elapsed time per sample (s)") - ax[2].plot(date[keep], max_mem[keep]) + ax[2].plot(x, df.max_memory / 1024**3) ax[2].set_ylabel("Max memory (GiB)") return fig, ax + def resources_summary(self): + ts = self.ts + data = [] + for j in range(1, ts.num_provenances): + p = ts.provenance(j) + record = json.loads(p.record) + text_date = record["parameters"]["args"][2] + + resources = record["resources"] + data.append({"date": text_date, **resources}) + return pd.DataFrame(data) + def fixme_plot_recombinants_per_day(self): counter = collections.Counter() for u in self.recombinants: From 8477b509e008ab44428e905c5501c8c945bc6c41 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sun, 20 Oct 2024 22:21:21 +0100 Subject: [PATCH 3/5] Fixup plot testing, and add provenance to test fixtures --- sc2ts/info.py | 17 +++++++++++------ tests/conftest.py | 9 ++++++++- tests/test_info.py | 9 +++++++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/sc2ts/info.py b/sc2ts/info.py index bccdaef..fb2af57 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -1542,8 +1542,7 @@ def plot_samples_per_day(self): ax2.legend() return fig, [ax1, ax2] - # Tests don't pass because we're not storing resources - def fix_tests_plot_resources(self, start_date="2020-04-01"): + def plot_resources(self, start_date="2020-04-01"): ts = self.ts fig, ax = self._wide_plot(3, height=8, sharex=True) @@ -1553,7 +1552,6 @@ def fix_tests_plot_resources(self, start_date="2020-04-01"): df["samples_in_arg"] = dfs.loc[df.index]["samples_in_arg"] df["samples_processed"] = dfs.loc[df.index]["samples_processed"] - # df = resources_summary(self) df = df[df.index >= start_date] df["cpu_time"] = df.user_time + df.sys_time x = np.array(df.index, dtype="datetime64[D]") @@ -1587,13 +1585,20 @@ def fix_tests_plot_resources(self, start_date="2020-04-01"): def resources_summary(self): ts = self.ts data = [] + dates = sorted(list(ts.metadata["sc2ts"]["num_samples_processed"].keys())) + assert len(dates) == ts.num_provenances - 1 for j in range(1, ts.num_provenances): p = ts.provenance(j) record = json.loads(p.record) - text_date = record["parameters"]["args"][2] - + try: + # Just double checking that this is the same date the provenance is for + # when using production data from CLI (test fixtures don't have this). + text_date = record["parameters"]["args"][2] + assert text_date == dates[j - 1] + except IndexError: + pass resources = record["resources"] - data.append({"date": text_date, **resources}) + data.append({"date": dates[j - 1], **resources}) return pd.DataFrame(data) def fixme_plot_recombinants_per_day(self): diff --git a/tests/conftest.py b/tests/conftest.py index 14820b5..6e8613a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import pytest import sc2ts +from sc2ts import cli @pytest.fixture @@ -85,7 +86,11 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store, fx_ma # These sites are masked out in all alignments in the initial data # anyway; https://github.com/jeromekelleher/sc2ts/issues/282 last_ts = sc2ts.initial_ts([56, 57, 58, 59, 60]) + cache_path = fx_data_cache / "initial.ts" + cli.add_provenance(last_ts, cache_path) for date in dates: + # Load the ts from file to get the provenance data + last_ts = tskit.load(cache_path) last_ts = sc2ts.extend( alignment_store=fx_alignment_store, metadata_db=fx_metadata_db, @@ -97,7 +102,9 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store, fx_ma f"INFERRED {date} nodes={last_ts.num_nodes} mutations={last_ts.num_mutations}" ) cache_path = fx_data_cache / f"{date}.ts" - last_ts.dump(cache_path) + # The values recorded for resources are nonsense here, but at least it's + # something to use for tests + cli.add_provenance(last_ts, cache_path) return {date: tskit.load(fx_data_cache / f"{date}.ts") for date in dates} diff --git a/tests/test_info.py b/tests/test_info.py index 7ef0d6e..337e3ed 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -193,7 +193,7 @@ def test_draw_pango_lineage_subtree(self, fx_ti_2020_02_13): svg2 = ti.draw_subtree(tracked_pango=["A"]) assert svg == svg2 assert svg.startswith(" Date: Sun, 20 Oct 2024 22:25:15 +0100 Subject: [PATCH 4/5] Fixup lineage tallies for new MD format --- sc2ts/info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sc2ts/info.py b/sc2ts/info.py index fb2af57..0b50f0d 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -141,7 +141,7 @@ def tally_lineages(ts, metadata_db, show_progress=False): md = ts.metadata["sc2ts"] date = md["date"] # Take the exact matches into account also. - counter = collections.Counter(md["num_exact_matches"]) + counter = collections.Counter(md["exact_matches"]["pango"]) key = "Viridian_pangolin" iterator = tqdm.tqdm( ts.samples()[1:], From 45dfa86eaf47f7ce42595c189edafa19f973a55f Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sun, 20 Oct 2024 22:28:06 +0100 Subject: [PATCH 5/5] Tweaks --- sc2ts/info.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/sc2ts/info.py b/sc2ts/info.py index 0b50f0d..257ed7c 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -1516,8 +1516,9 @@ def plot_deletion_overlaps(self, annotate_threshold=0.9): ax.set_ylabel("Overlapping deletions") return fig, [ax] - def plot_samples_per_day(self): + def plot_samples_per_day(self, start_date="2020-04-01"): df = self.samples_summary() + df = df[df.date >= start_date] fig, (ax1, ax2) = self._wide_plot(2, height=6, sharex=True) ax1.plot(df.date, df.samples_in_arg, label="In ARG") @@ -1601,27 +1602,6 @@ def resources_summary(self): data.append({"date": dates[j - 1], **resources}) return pd.DataFrame(data) - def fixme_plot_recombinants_per_day(self): - counter = collections.Counter() - for u in self.recombinants: - date = np.datetime64(self.nodes_metadata[u]["date_added"]) - counter[date] += 1 - - samples_per_day = np.zeros(len(counter)) - sample_date = self.nodes_date[self.ts.samples()] - for j, date in enumerate(counter.keys()): - samples_per_day[j] = np.sum(sample_date == date) - x = np.array(list(counter.keys())) - y = np.array(list(counter.values())) - - _, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8)) - ax1.plot(x, y) - ax2.plot(x, y / samples_per_day) - ax2.set_xlabel("Date") - ax1.set_ylabel("Number of recombinant samples") - ax2.set_ylabel("Fraction of samples recombinant") - ax2.set_ylim(0, 0.01) - def draw_pango_lineage_subtree( self, pango_lineage,