Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update perf plots #379

Merged
merged 5 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 90 additions & 67 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def tally_lineages(ts, metadata_db, show_progress=False):
md = ts.metadata["sc2ts"]
date = md["date"]
# Take the exact matches into account also.
counter = collections.Counter(md["num_exact_matches"])
counter = collections.Counter(md["exact_matches"]["pango"])
key = "Viridian_pangolin"
iterator = tqdm.tqdm(
ts.samples()[1:],
Expand Down Expand Up @@ -873,6 +873,21 @@ def site_summary(self, position):
{"property": [d[0] for d in data], "value": [d[1] for d in data]}
)

def samples_summary(self):
data = []
md = self.ts.metadata["sc2ts"]
for days_ago in np.arange(self.num_samples_per_day.shape[0]):
date = str(self.time_zero_as_date - days_ago)
data.append(
{
"date": self.time_zero_as_date - days_ago,
"samples_in_arg": self.num_samples_per_day[days_ago],
"samples_processed": md["num_samples_processed"].get(date, 0),
"exact_matches": md["exact_matches"]["date"].get(date, 0),
}
)
return pd.DataFrame(data)

def recombinants_summary(self):
data = []
for u in self.recombinants:
Expand Down Expand Up @@ -1501,86 +1516,91 @@ def plot_deletion_overlaps(self, annotate_threshold=0.9):
ax.set_ylabel("Overlapping deletions")
return fig, [ax]

def plot_samples_per_day(self):
fig, ax = self._wide_plot(1, 1)
t = np.arange(self.num_samples_per_day.shape[0])
ax.plot(self.time_zero_as_date - t, self.num_samples_per_day)
ax.set_xlabel("Date")
ax.set_ylabel("Number of samples")
return fig, [ax]
def plot_samples_per_day(self, start_date="2020-04-01"):
df = self.samples_summary()
df = df[df.date >= start_date]
fig, (ax1, ax2) = self._wide_plot(2, height=6, sharex=True)

ax1.plot(df.date, df.samples_in_arg, label="In ARG")
ax1.plot(df.date, df.samples_processed, label="Processed")
ax1.plot(df.date, df.exact_matches, label="Exact matches")

ax2.plot(
df.date,
df.samples_in_arg / df.samples_processed,
label="Fraction processed in ARG",
)
ax2.plot(
df.date,
df.exact_matches / df.samples_processed,
label="Fraction processed exact matches",
)
excluded = df.samples_processed - df.exact_matches - df.samples_in_arg
ax2.plot(df.date, excluded / df.samples_processed, label="Fraction excluded")
ax2.set_xlabel("Date")
ax1.set_ylabel("Number of samples")
ax1.legend()
ax2.legend()
return fig, [ax1, ax2]

def plot_resources(self, start_date="2020-04-01"):
ts = self.ts
fig, ax = self._wide_plot(3, height=8, sharex=True)
elapsed_time = np.zeros(ts.num_provenances)
cpu_time = np.zeros(ts.num_provenances)
max_mem = np.zeros(ts.num_provenances)
date = np.zeros(ts.num_provenances, dtype="datetime64[D]")
num_samples = np.zeros(ts.num_provenances, dtype=int)
for j in range(1, ts.num_provenances):
p = ts.provenance(j)
record = json.loads(p.record)
text_date = record["parameters"]["args"][2]
date[j] = text_date
try:
resources = record["resources"]
elapsed_time[j] = resources["elapsed_time"]
cpu_time[j] = resources["user_time"] + resources["sys_time"]
max_mem[j] = resources["max_memory"]
except KeyError:
warnings.warn("Missing required provenance fields")
# The +3 is from lining up peaks by eye, not sure how it happens
days_ago = self.time_zero_as_date - date[j] + 3
# Avoid division by zero
num_samples[j] = max(1, self.num_samples_per_day[days_ago.astype(int)])

keep = date >= np.array([start_date], dtype="datetime64[D]")
total_elapsed = datetime.timedelta(seconds=np.sum(elapsed_time))
total_cpu = datetime.timedelta(seconds=np.sum(cpu_time))

dfs = self.samples_summary().set_index("date")
df = self.resources_summary().set_index("date")
# Should be able to do this with join, but I failed
df["samples_in_arg"] = dfs.loc[df.index]["samples_in_arg"]
df["samples_processed"] = dfs.loc[df.index]["samples_processed"]

df = df[df.index >= start_date]
df["cpu_time"] = df.user_time + df.sys_time
x = np.array(df.index, dtype="datetime64[D]")

total_elapsed = datetime.timedelta(seconds=np.sum(df.elapsed_time))
total_cpu = datetime.timedelta(seconds=np.sum(df.cpu_time))
title = (
f"{humanize.naturaldelta(total_elapsed)} elapsed "
f"using {humanize.naturaldelta(total_cpu)} of CPU time "
f"(utilisation = {np.sum(cpu_time) / np.sum(elapsed_time):.2f})"
f"(utilisation = {np.sum(df.cpu_time) / np.sum(df.elapsed_time):.2f})"
)

max_mem /= 1024**3 # Convert to GiB
# df.max_mem /= 1024**3 # Convert to GiB
ax[0].set_title(title)
ax[0].plot(date[keep], elapsed_time[keep] / 60, label="elapsed time")
ax[0].plot(x, df.elapsed_time / 60, label="elapsed time")
ax[-1].set_xlabel("Date")
ax_twin = ax[0].twinx()
ax_twin.plot(
date[keep], num_samples[keep], color="tab:red", alpha=0.5, label="samples"
x, df.samples_processed, color="tab:red", alpha=0.5, label="samples"
)
ax_twin.set_ylabel("Num samples")
ax_twin.set_ylabel("Samples processed")
ax[0].set_ylabel("Elapsed time (mins)")
ax[0].legend()
ax_twin.legend()
ax[1].plot(date[keep], elapsed_time[keep] / num_samples[keep])
ax[1].plot(x, df.elapsed_time / df.samples_processed)
ax[1].set_ylabel("Elapsed time per sample (s)")
ax[2].plot(date[keep], max_mem[keep])
ax[2].plot(x, df.max_memory / 1024**3)
ax[2].set_ylabel("Max memory (GiB)")
return fig, ax

def fixme_plot_recombinants_per_day(self):
counter = collections.Counter()
for u in self.recombinants:
date = np.datetime64(self.nodes_metadata[u]["date_added"])
counter[date] += 1

samples_per_day = np.zeros(len(counter))
sample_date = self.nodes_date[self.ts.samples()]
for j, date in enumerate(counter.keys()):
samples_per_day[j] = np.sum(sample_date == date)
x = np.array(list(counter.keys()))
y = np.array(list(counter.values()))

_, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8))
ax1.plot(x, y)
ax2.plot(x, y / samples_per_day)
ax2.set_xlabel("Date")
ax1.set_ylabel("Number of recombinant samples")
ax2.set_ylabel("Fraction of samples recombinant")
ax2.set_ylim(0, 0.01)
def resources_summary(self):
ts = self.ts
data = []
dates = sorted(list(ts.metadata["sc2ts"]["num_samples_processed"].keys()))
assert len(dates) == ts.num_provenances - 1
for j in range(1, ts.num_provenances):
p = ts.provenance(j)
record = json.loads(p.record)
try:
# Just double checking that this is the same date the provenance is for
# when using production data from CLI (test fixtures don't have this).
text_date = record["parameters"]["args"][2]
assert text_date == dates[j - 1]
except IndexError:
pass
resources = record["resources"]
data.append({"date": dates[j - 1], **resources})
return pd.DataFrame(data)

def draw_pango_lineage_subtree(
self,
Expand All @@ -1596,7 +1616,8 @@ def draw_pango_lineage_subtree(
appropriate set of samples. See that function for more details.
"""
return self.draw_subtree(
tracked_pango=[pango_lineage], position=position, *args, **kwargs)
tracked_pango=[pango_lineage], position=position, *args, **kwargs
)

def draw_subtree(
self,
Expand Down Expand Up @@ -1625,7 +1646,7 @@ def draw_subtree(
untracked node lineages within polytomies are condensed into a dotted line.
Clades containing more than a certain proportion of tracked nodes can also be
collapsed (see the ``collapse_tracked`` parameter).

Most parameters are passed directly to ``tskit.Tree.draw_svg()`` method, apart
from the following:
:param position int: The genomic position at which to draw the tree. If None,
Expand Down Expand Up @@ -1698,7 +1719,9 @@ def draw_subtree(

if extra_tracked_samples is not None:
tn_set = set(tracked_nodes)
extra_tracked_samples = [e for e in extra_tracked_samples if e not in tn_set]
extra_tracked_samples = [
e for e in extra_tracked_samples if e not in tn_set
]
tracked_nodes = np.concatenate((tracked_nodes, extra_tracked_samples))
tree = ts.at(position, tracked_samples=tracked_nodes)
order = np.array(
Expand Down Expand Up @@ -1800,8 +1823,8 @@ def draw_subtree(
# Recombination nodes as larger open circles
re_nodes = np.where(ts.nodes_flags & core.NODE_IS_RECOMBINANT)[0]
styles.append(
",".join([f".node.n{u} > .sym" for u in re_nodes]) +
f"{{r:{symbol_size/2*1.5:.2f}px; stroke:black; fill:white}}"
",".join([f".node.n{u} > .sym" for u in re_nodes])
+ f"{{r:{symbol_size/2*1.5:.2f}px; stroke:black; fill:white}}"
)
return tree.draw_svg(
time_scale=time_scale,
Expand Down Expand Up @@ -2032,8 +2055,8 @@ def draw_svg(
# recombination nodes in larger open white circles
re_nodes = np.where(ts.nodes_flags & core.NODE_IS_RECOMBINANT)[0]
styles.append(
",".join([f".node.n{u} > .sym" for u in re_nodes]) +
f"{{r: {symbol_size/2*1.5:.2f}px; stroke: black; fill: white}}"
",".join([f".node.n{u} > .sym" for u in re_nodes])
+ f"{{r: {symbol_size/2*1.5:.2f}px; stroke: black; fill: white}}"
)
svg = self.ts.draw_svg(
size=size,
Expand Down
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

import sc2ts
from sc2ts import cli


@pytest.fixture
Expand Down Expand Up @@ -85,7 +86,11 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store, fx_ma
# These sites are masked out in all alignments in the initial data
# anyway; https://github.com/jeromekelleher/sc2ts/issues/282
last_ts = sc2ts.initial_ts([56, 57, 58, 59, 60])
cache_path = fx_data_cache / "initial.ts"
cli.add_provenance(last_ts, cache_path)
for date in dates:
# Load the ts from file to get the provenance data
last_ts = tskit.load(cache_path)
last_ts = sc2ts.extend(
alignment_store=fx_alignment_store,
metadata_db=fx_metadata_db,
Expand All @@ -97,7 +102,9 @@ def fx_ts_map(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store, fx_ma
f"INFERRED {date} nodes={last_ts.num_nodes} mutations={last_ts.num_mutations}"
)
cache_path = fx_data_cache / f"{date}.ts"
last_ts.dump(cache_path)
# The values recorded for resources are nonsense here, but at least it's
# something to use for tests
cli.add_provenance(last_ts, cache_path)
return {date: tskit.load(fx_data_cache / f"{date}.ts") for date in dates}


Expand Down
9 changes: 7 additions & 2 deletions tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_draw_pango_lineage_subtree(self, fx_ti_2020_02_13):
svg2 = ti.draw_subtree(tracked_pango=["A"])
assert svg == svg2
assert svg.startswith("<svg")
for u in ti.pango_lineage_samples['A']:
for u in ti.pango_lineage_samples["A"]:
assert f"node n{u}" in svg

def test_draw_subtree(self, fx_ti_2020_02_13):
Expand All @@ -204,11 +204,16 @@ def test_draw_subtree(self, fx_ti_2020_02_13):
for u in samples:
assert f"node n{u}" in svg

def test_resources_summary(self, fx_ti_2020_02_13):
df = fx_ti_2020_02_13.resources_summary()
assert df.shape[0] == 20
assert np.all(df.date.str.startswith("2020"))


class TestSampleGroupInfo:
def test_draw_svg(self, fx_ti_2020_02_13):
ti = fx_ti_2020_02_13
sg = list(ti.nodes_sample_group.keys())[0]
sg_info = ti.get_sample_group_info(sg)
svg = sg_info.draw_svg()
assert svg.startswith("<svg")
assert svg.startswith("<svg")