diff --git a/cache.py b/cache.py new file mode 100644 index 0000000..c4d859a --- /dev/null +++ b/cache.py @@ -0,0 +1,39 @@ +import functools +import pathlib + +import appdirs +import daiquiri +import diskcache + +logger = daiquiri.getLogger("cache") + + +def get_cache_dir(): + cache_dir = pathlib.Path(appdirs.user_cache_dir("tsqc", "tsqc")) + cache_dir.mkdir(exist_ok=True, parents=True) + return cache_dir + + +cache = diskcache.Cache(get_cache_dir()) + + +def disk_cache(version): + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + uuid = self.file_uuid + if uuid is None: + logger.info(f"No uuid, not caching {func.__name__}") + return func(self, *args, **kwargs) + key = f"{self.file_uuid}-{func.__name__}-{version}" + if key in cache: + logger.info(f"Fetching {key} from cache") + return cache[key] + logger.info(f"Calculating {key} and caching") + result = func(self, *args, **kwargs) + cache[key] = result + return result + + return wrapper + + return decorator diff --git a/model.py b/model.py index bb81dfd..b5b5a5f 100644 --- a/model.py +++ b/model.py @@ -7,6 +7,9 @@ import pandas as pd import tskit +from cache import disk_cache + + logger = daiquiri.getLogger("model") spec = [ @@ -271,7 +274,12 @@ def __init__(self, ts, name=None): self.ts.mutations_node, minlength=self.ts.num_nodes ) + @property + def file_uuid(self): + return self.ts.file_uuid + @cached_property + @disk_cache("v1") def summary_df(self): nodes_with_zero_muts = np.sum(self.nodes_num_mutations == 0) sites_with_zero_muts = np.sum(self.sites_num_mutations == 0) @@ -312,6 +320,7 @@ def child_bounds(num_nodes, edges_left, edges_right, edges_child): return child_left, child_right @cached_property + @disk_cache("v1") def mutations_df(self): # FIXME use tskit's impute mutations time ts = self.ts @@ -386,6 +395,7 @@ def mutations_df(self): ) @cached_property + @disk_cache("v1") def edges_df(self): ts = self.ts left = ts.edges_left @@ -426,6 +436,7 @@ def edges_df(self): ) @cached_property + @disk_cache("v1") def nodes_df(self): ts = self.ts child_left, child_right = self.child_bounds( @@ -452,6 +463,7 @@ def nodes_df(self): ) @cached_property + @disk_cache("v1") def trees_df(self): ts = self.ts num_trees = ts.num_trees diff --git a/requirements.txt b/requirements.txt index 3d5fb46..376a01e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ click daiquiri panel +diskcache hvplot xarray datashader @@ -8,3 +9,5 @@ tskit seaborn pre-commit pytest +tszip +appdirs diff --git a/tests/test_data_model.py b/tests/test_data_model.py index 2501bbb..a88a211 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import numpy.testing as nt import tskit @@ -212,3 +214,29 @@ def test_multi_tree_with_polytomies_example(self): nt.assert_array_equal(df.total_branch_length, [11.0, 12.0]) # nt.assert_array_equal(df.mean_internal_arity, [2.25, 2.25]) nt.assert_array_equal(df.max_internal_arity, [3.0, 3.0]) + + +def test_cache(caplog, tmpdir): + caplog.set_level(logging.INFO) + ts = multiple_trees_example_ts() + tsm = model.TSModel(ts) + # Use the logging out put to determine cache usage + t1 = tsm.trees_df + t2 = tsm.trees_df + assert t1.equals(t2) + assert "No uuid, not caching trees_df" in caplog.text + + ts.dump(tmpdir / "cache.trees") + ts = tskit.load(tmpdir / "cache.trees") + tsm = model.TSModel(ts) + # Use the logging out put to determine cache usage + caplog.clear() + t1 = tsm.trees_df + assert "Calculating" in caplog.text + caplog.clear() + + ts2 = tskit.load(tmpdir / "cache.trees") + tsm2 = model.TSModel(ts2) + t2 = tsm2.trees_df + assert "Fetching" in caplog.text + assert t1.equals(t2)