Skip to content

Commit

Permalink
Reduce sample matching to batches of 5
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Sep 6, 2023
1 parent e0bc9ca commit 9627e24
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 56 deletions.
22 changes: 0 additions & 22 deletions tests/test_sgkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,28 +415,6 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
# site.metadata for site in inf_ts.sites()
# ]

def test_sgkit_variant_bad_mask(self, tmp_path):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.arange(ds.sizes["variants"], dtype=int)
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
# with pytest.raises(
# ValueError,
# match="The variant_mask array contains values " "other than 0 or 1",
# ):
# tsinfer.SgkitSampleData(zarr_path)

def test_sgkit_variant_bad_mask_negative(self, tmp_path):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.arange(0, -ds.sizes["variants"], -1, dtype=int)
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
with pytest.raises(
ValueError,
match="The variant_mask array contains values " "other than 0 or 1",
):
tsinfer.SgkitSampleData(zarr_path)

def test_sgkit_variant_bad_mask_length(self, tmp_path):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
Expand Down
9 changes: 1 addition & 8 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2397,14 +2397,7 @@ def sites_mask(self):
raise ValueError(
"Mask must be the same length as the number of unmasked sites"
)
# Often xarray will save a bool array as int8, so we need to cast,
# but check that a mistake hasn't been made by checking
# that the values are either 0 or 1
# mask = self.data["variant_mask"].astype(np.int8)
# if da.max(mask).compute() > 1 or da.min(mask).compute() < 0:
# raise ValueError(
# "The variant_mask array contains values other than 0 or 1"
# )

return self.data["variant_mask"].astype(bool)
except KeyError:
return np.full(self.data["variant_position"].shape, True, dtype=bool)
Expand Down
37 changes: 11 additions & 26 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import humanize
import lmdb
import numpy as np
import psutil
import tskit

import _tsinfer
Expand Down Expand Up @@ -1910,9 +1909,7 @@ def dask_find_path(
site_indexes=None,
sample_id_map=None,
):
sys.stderr.write(
f"Free RAM: {psutil.virtual_memory().free / 1024 ** 3:.2f} GB\n"
)
t = time_.time()
ancestor_matcher_class = (
_tsinfer.AncestorMatcher
if engine == constants.C_ENGINE
Expand All @@ -1925,42 +1922,28 @@ def dask_find_path(
precision=precision,
extended_checks=extended_checks,
)
sys.stderr.write(f"Loading sample data from {data_path}\n")
sys.stderr.write(
f"Free RAM: {psutil.virtual_memory().free / 1024 ** 3:.2f} GB\n"
f"Loading sample data from {data_path} at {time_.time() - t:.2f} seconds\n"
)
t = time_.time()
sample_data = formats.SgkitSampleData(data_path)
sys.stderr.write(f"Loaded sample data in {time_.time() - t:.2f} seconds\n")
sys.stderr.write(
f"Free RAM: {psutil.virtual_memory().free / 1024 ** 3:.2f} GB\n"
)
t = time_.time()
sys.stderr.write(f"Loaded sample data at {time_.time() - t:.2f} seconds\n")
haplotypes = sample_data._slice_haplotypes(
sites=site_indexes, recode_ancestral=True, samples_slice=samples_slice
)
sys.stderr.write(f"Init haplotypes in {time_.time() - t:.2f} seconds\n")
sys.stderr.write(
f"Free RAM: {psutil.virtual_memory().free / 1024 ** 3:.2f} GB\n"
)
sys.stderr.write(f"Init haplotypes at {time_.time() - t:.2f} seconds\n")
# Pickle here rather than let dask deal with it so we can log sizes
results = []
for sample_id, haplotype in haplotypes:
sys.stderr.write(f"Finding path for sample {sample_id}\n")
sys.stderr.write(
f"Free RAM: {psutil.virtual_memory().free / 1024 ** 3:.2f} GB\n"
f"Finding path for {sample_id} at {time_.time() - t:.2f} seconds\n"
)
t = time_.time()
results.append(
AncestorMatcher.find_path(
matcher, sample_id_map[sample_id], haplotype, 0, len(site_indexes)
)
)
sys.stderr.write(
f"Found path for sample {sample_id} in {time_.time() - t:.2f} seconds\n"
)
sys.stderr.write(
f"Free RAM: {psutil.virtual_memory().free / 1024 ** 3:.2f} GB\n"
f"Found path for {sample_id} at {time_.time() - t:.2f} seconds\n"
)
return pickle.dumps(results)

Expand Down Expand Up @@ -2071,11 +2054,13 @@ def match_with_dask(
):
path = os.path.abspath(self.sample_data.path)
# sample slice is a tuple of (start, stop),
# convert to a list of (start, stop) tuples that contain 10 each.
# convert to a list of (start, stop) tuples that contain 5 each.
# The tradeoff here is between the extra cost of loading the samples,
# vs the RAM usage and long-running tasks.
start, stop = sample_slice
sample_slices = [
(sub_start, sub_start + 10 if sub_start + 10 < stop else stop)
for sub_start in range(start, stop, 10)
(sub_start, sub_start + 5 if sub_start + 5 < stop else stop)
for sub_start in range(start, stop, 5)
]
samples_indexes_bag = db.from_sequence(
sample_slices, npartitions=len(sample_slices)
Expand Down

0 comments on commit 9627e24

Please sign in to comment.