Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add disk caching #90

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import functools
import pathlib

import appdirs
import daiquiri
import diskcache

logger = daiquiri.getLogger("cache")


def get_cache_dir():
cache_dir = pathlib.Path(appdirs.user_cache_dir("tsqc", "tsqc"))
cache_dir.mkdir(exist_ok=True, parents=True)
return cache_dir


cache = diskcache.Cache(get_cache_dir())


def disk_cache(version):
def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
uuid = self.file_uuid
if uuid is None:
logger.info(f"No uuid, not caching {func.__name__}")
return func(self, *args, **kwargs)
key = f"{self.file_uuid}-{func.__name__}-{version}"
if key in cache:
logger.info(f"Fetching {key} from cache")
return cache[key]
logger.info(f"Calculating {key} and caching")
result = func(self, *args, **kwargs)
cache[key] = result
return result

return wrapper

return decorator
12 changes: 12 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import pandas as pd
import tskit

from cache import disk_cache


logger = daiquiri.getLogger("model")

spec = [
Expand Down Expand Up @@ -271,7 +274,12 @@ def __init__(self, ts, name=None):
self.ts.mutations_node, minlength=self.ts.num_nodes
)

@property
def file_uuid(self):
return self.ts.file_uuid

@cached_property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, so there's two layers of caching here. I had to think a bit about what this means - so the first time the property is called, we fall through to the disk_cache version, and then afterwards (within that process) it'll be cached by functools.

@disk_cache("v1")
def summary_df(self):
nodes_with_zero_muts = np.sum(self.nodes_num_mutations == 0)
sites_with_zero_muts = np.sum(self.sites_num_mutations == 0)
Expand Down Expand Up @@ -312,6 +320,7 @@ def child_bounds(num_nodes, edges_left, edges_right, edges_child):
return child_left, child_right

@cached_property
@disk_cache("v1")
def mutations_df(self):
# FIXME use tskit's impute mutations time
ts = self.ts
Expand Down Expand Up @@ -386,6 +395,7 @@ def mutations_df(self):
)

@cached_property
@disk_cache("v1")
def edges_df(self):
ts = self.ts
left = ts.edges_left
Expand Down Expand Up @@ -426,6 +436,7 @@ def edges_df(self):
)

@cached_property
@disk_cache("v1")
def nodes_df(self):
ts = self.ts
child_left, child_right = self.child_bounds(
Expand All @@ -452,6 +463,7 @@ def nodes_df(self):
)

@cached_property
@disk_cache("v1")
def trees_df(self):
ts = self.ts
num_trees = ts.num_trees
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
click
daiquiri
panel
diskcache
hvplot
xarray
datashader
tskit
seaborn
pre-commit
pytest
tszip
appdirs
28 changes: 28 additions & 0 deletions tests/test_data_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import numpy as np
import numpy.testing as nt
import tskit
Expand Down Expand Up @@ -212,3 +214,29 @@ def test_multi_tree_with_polytomies_example(self):
nt.assert_array_equal(df.total_branch_length, [11.0, 12.0])
# nt.assert_array_equal(df.mean_internal_arity, [2.25, 2.25])
nt.assert_array_equal(df.max_internal_arity, [3.0, 3.0])


def test_cache(caplog, tmpdir):
caplog.set_level(logging.INFO)
ts = multiple_trees_example_ts()
tsm = model.TSModel(ts)
# Use the logging out put to determine cache usage
t1 = tsm.trees_df
t2 = tsm.trees_df
assert t1.equals(t2)
assert "No uuid, not caching trees_df" in caplog.text

ts.dump(tmpdir / "cache.trees")
ts = tskit.load(tmpdir / "cache.trees")
tsm = model.TSModel(ts)
# Use the logging out put to determine cache usage
caplog.clear()
t1 = tsm.trees_df
assert "Calculating" in caplog.text
caplog.clear()

ts2 = tskit.load(tmpdir / "cache.trees")
tsm2 = model.TSModel(ts2)
t2 = tsm2.trees_df
assert "Fetching" in caplog.text
assert t1.equals(t2)
Loading