Skip to content

Commit

Permalink
dask implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
JoOkuma committed May 29, 2024
1 parent 949eef4 commit fa065e6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 229 deletions.
272 changes: 77 additions & 195 deletions iohub/daxi.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,14 @@
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional

import dask
import dask.array as da
import numpy as np
import yaml

if TYPE_CHECKING:
from _typeshed import StrOrBytesPath

from iohub.clearcontrol import ArrayIndex


def _cached(f: Callable) -> Callable:
"""Decorator that caches the array data using its key."""

@wraps(f)
def _key_cache_wrapper(
self: "DaXiFOV", key: tuple[int | None, ...]
) -> np.ndarray:
if not self._cache:
return f(self, key)

elif key != self._cache_key:
self._cache_array = f(self, key)
self._cache_key = key

return cast(np.ndarray, self._cache_array)

return _key_cache_wrapper


class DaXiFOV:
"""
Expand All @@ -48,8 +28,6 @@ class DaXiFOV:
missing_value : Optional[int], optional
If provided this class won't raise an error when missing a volume and
it will return an array with the provided value.
cache : bool
When true caches the last array using the first two indices as key.
"""

_CHANNELS_KEY = "Laser wavelengths"
Expand All @@ -60,7 +38,6 @@ def __init__(
self,
data_path: "StrOrBytesPath",
missing_value: Optional[int] = None,
cache: bool = False,
):
super().__init__()

Expand All @@ -71,10 +48,6 @@ def __init__(

self._missing_value = missing_value
self._dtype = np.uint16
self._cache = cache
self._cache_key: tuple[int | None, ...] | None = None
self._cache_array: np.ndarray | None = None

self._channels = self._metadata[self._CHANNELS_KEY]

shape_dict = self._metadata[self._SHAPE_KEY]
Expand All @@ -86,178 +59,109 @@ def __init__(

self._shape = tuple(shape_dict[k] for k in self._SHAPE_IDX)

def _volume_path(self, t: int, v: int) -> Path:
z, y, x = self._raw_shape[2:]
volume_path = (
self._data_path / f"T_{int(t)}.V_{int(v)}.({z}x{y}x{x}).raw"
)
if not volume_path.exists():
raise ValueError(f"Volume not found: {volume_path}")
return volume_path

@property
def shape(self) -> tuple[int, ...]:
"""
Timelapse shape
"""
return self._shape

@property
def channels(self) -> list[str]:
"""Return sorted channels name."""
return self._channels

@_cached
def _load_array(
self, key: tuple[int, int, int | None, int | None]
) -> np.ndarray:
self._data = da.stack( # T
[
da.concatenate( # V
[
da.stack( # C
[
da.from_delayed(
self._load_volume(t, v, c),
shape=self._shape[2:],
dtype=self._dtype,
)
for c in range(len(self._channels))
]
)
for v in range(self._raw_shape[1])
]
)
for t in range(self._shape[0])
]
) # T, V * C, Y, X shape

@dask.delayed
def _load_volume(self, t: int, v: int, c: int) -> np.ndarray:
"""
Loads a single or multiple channels from DaXi raw.
Load a volume from disk.
If the volume is missing it returns an array with the missing value.
Parameters
----------
key : tuple[int, int, int | None, int | None]
Time point, view, channel_index, and z-index.
z-index and are optional
t : int
time index.
v : int
view index.
c : int
channel index.
Returns
-------
np.ndarray
Volume as an array can be single or multiple channels.
Raises
------
ValueError
When expected volume path not found.
Volume array.
"""
shape = list(self._shape[-2:])

time_point, view, channel, z_index = key

len_ch = len(self._channels)
view //= len_ch

if channel is None:
slicing = slice(None)
shape.insert(0, len_ch)
else:
slicing = slice(channel, None, len_ch)
len_ch = 1

if z_index is None:
shape.insert(0, self._shape[2]) # z-shape
else:
z = z_index * len(self._channels)
if channel is not None:
slicing = z + channel
else:
slicing = slice(z, z + len(self._channels))

map_arr = np.memmap(
self._volume_path(time_point, view),
path = self._volume_path(t, v)

if not path.exists():
return np.full(
self._shape[2:], self._missing_value, dtype=self._dtype
)

return np.memmap(
self._volume_path(t, v),
dtype=self._dtype,
shape=self._raw_shape[2:],
mode="r",
)

print("--------")
print(key)
print(slicing)

# TODO: catch when volume is not complete
arr = map_arr[slicing]
print(arr.shape)
arr = arr.reshape(shape)

if len(shape) == 4:
# ZCYX -> CZYX
arr = arr.swapaxes(0, 1)

return arr

@staticmethod
def _fix_indexing(indexing: "ArrayIndex", size: int) -> list[int] | int:
"""Converts numpy array to simple python type or list."""
# TODO: check if necessary
if isinstance(indexing, slice):
return list(range(size)[indexing])

elif np.isscalar(indexing):
try:
int_index: int = indexing.item()
except AttributeError:
int_index = indexing
)[c :: len(self._channels)]

if int_index < 0:
int_index += size

return int_index
def _volume_path(self, t: int, v: int) -> Path:
"""
Return the path for a volume.
elif isinstance(indexing, np.ndarray):
return indexing.tolist()
Parameters
----------
t : int
time index.
v : int
view index.
return indexing
Returns
-------
Path
Volume path.
"""
z, y, x = self._raw_shape[2:]
volume_path = (
self._data_path / f"T_{int(t)}.V_{int(v)}.({z}x{y}x{x}).raw"
)
return volume_path

def _load_array_from_key(
self, key: tuple[list[int] | int | None, ...]
) -> np.ndarray:
@property
def shape(self) -> tuple[int, ...]:
"""
Load array from a key with multiple channel indices.
This function is called recursively until int-only indices are found.
Timelapse shape
"""
for i, k in enumerate(key):
if isinstance(k, int):
continue

elif k is None:
if i >= 2:
continue

k = list(range(self._shape[i]))

arrs = []
for int_key in k:
new_key = key[:i] + (int_key,) + key[i + 1 :]
arrs.append(self._load_array_from_key(new_key))

return np.stack(arrs)
return self._shape

return self._load_array(key)
@property
def channels(self) -> list[str]:
"""Return sorted channels name."""
return self._channels

def __getitem__(
self, key: Union["ArrayIndex", tuple["ArrayIndex", ...]]
) -> np.ndarray:
def __getitem__(self, index) -> da.Array:
"""Lazily load array as indexed.
Parameters
----------
key : ArrayIndex | tuple[ArrayIndex, ...]
index : Array index.
An indexing key as in numpy, but a bit more limited.
Returns
-------
np.ndarray
Output array.
Raises
------
NotImplementedError
Not all numpy array of indexing are implemented.
"""
# standardizing indexing
yx_slicing = slice(None)
min_size = 4

if not isinstance(key, tuple):
key = (key,)

key = tuple(self._fix_indexing(k, s) for s, k in zip(self._shape, key))
args_key = key + (None,) * (min_size - len(key))

if len(args_key) > min_size: # min_size + 1 (z)
args_key, yx_slicing = args_key[:min_size], args_key[min_size:]

return self._load_array_from_key(args_key)[yx_slicing]
return self._data[index]

def __setitem__(self, key: Any, value: Any) -> None:
raise PermissionError("DaXiFOV is read-only.")
Expand All @@ -270,26 +174,14 @@ def ndim(self) -> int:
def dtype(self) -> np.dtype:
return self._dtype

@property
def cache(self) -> bool:
return self._cache

@cache.setter
def cache(self, value: bool) -> None:
"""Free current key/array cache."""
self._cache = value
if not value:
self._cache_array = None
self._cache_key = None

def metadata(self) -> dict[str, Any]:
"""Summarizes Clear Control metadata into a dictionary."""
return self._metadata

@property
def scale(self) -> list[float]:
"""Dataset temporal, channel and spacial scales."""
# TODO: use metadata
# TODO: use actual metadata, information is missing from file format
return [
1.0,
1.0,
Expand Down Expand Up @@ -317,7 +209,7 @@ def create_mock_daxi_dataset(path: "StrOrBytesPath") -> None:
)

metadata = {
"Laser wavelenghts": ["488", "561"],
"Laser wavelengths": ["488", "561"],
"Dataset dimension": {
"T": 2,
"V": 2,
Expand All @@ -337,13 +229,3 @@ def create_mock_daxi_dataset(path: "StrOrBytesPath") -> None:
)
raw_map[...] = array[t, v]
raw_map.flush()


if __name__ == "__main__":
import napari

path = Path("/mnt/royer.daxi2/Merlin/neurog1.h2afva_05_21_2024")
ds = DaXiFOV(path, cache=True)

napari.imshow(ds)
napari.run()
Loading

0 comments on commit fa065e6

Please sign in to comment.