Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting create_empty_zarr to iohub #234

Merged
merged 38 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
4386d32
initial commit with added docs prior to refactoring to simplify the n…
edyoshikun Jul 16, 2024
a5bca81
considering varying t_idx_in and out
edyoshikun Jul 16, 2024
6edd790
template for hypothesis
edyoshikun Jul 17, 2024
ee502b2
Refactor NGFF module and migrate to Pydantic v2 (#233)
ziw-liu Jul 29, 2024
7fc31a7
improving docstring for functions and renaming input and output path …
edyoshikun Jul 31, 2024
9ba133a
Fixing pyramid scaling factor (#238)
JoOkuma Aug 7, 2024
84f351c
Use annotation instead of field for tagged union (#244)
ziw-liu Sep 7, 2024
3304dac
Export `ImageArray` from the `ngff` module (#245)
edyoshikun Sep 10, 2024
755e941
renaming method arguments to have consistent naming structure
edyoshikun Sep 17, 2024
d391441
flake8
edyoshikun Sep 17, 2024
3844699
Merge branch 'main' into mp_utils_port
ieivanov Sep 19, 2024
5cc6ca8
refactor _calculate_zyx_chunk_size
ieivanov Sep 20, 2024
6a1252d
use input_store_path and output_store_path throughout
ieivanov Sep 20, 2024
0e1c141
style
ieivanov Sep 20, 2024
99e613e
rename and clean up time indices
ieivanov Sep 20, 2024
cde08f5
update time_indices documentation
ieivanov Sep 20, 2024
1af9254
add processing for channel indices
ieivanov Sep 20, 2024
e55686d
fix syntax and move ngff_utils.py to ngff/utils.py
talonchandler Sep 20, 2024
5bfe7c9
update import
talonchandler Sep 20, 2024
8315d4a
typing
talonchandler Sep 20, 2024
c6cd5e8
docs typos
talonchandler Sep 20, 2024
6621635
fix process_single_position iterator @talonchandler @edyoshikun
ieivanov Sep 20, 2024
671bf3d
update apply_transform... docstring
ieivanov Sep 20, 2024
3a0887e
compatibility with minimal deskew w/ @edyoshikun
talonchandler Sep 20, 2024
82dd335
pretty flat_iterable
ieivanov Sep 20, 2024
07dea4b
adding new tests
edyoshikun Sep 21, 2024
643665c
create_empty test without testing channel names
edyoshikun Sep 21, 2024
9ee6acb
fixing the create_empty_zarr extra indentation
edyoshikun Sep 21, 2024
961995c
-attempt to fix apply_transform_test. @ieivanov revert if needed
edyoshikun Sep 23, 2024
82950b3
fixed apply_transform_czyx
edyoshikun Sep 23, 2024
7e3f0af
debug pytest
ieivanov Sep 23, 2024
cf36dc2
fixing the test for create_empty_plate pytest
edyoshikun Sep 24, 2024
59b4bf2
synchronize log messages
talonchandler Sep 25, 2024
5b7f595
docs improvements
talonchandler Sep 26, 2024
c1246f1
improved docs and typing for Callable func
talonchandler Sep 26, 2024
f357184
remove commented slurmkit fix
talonchandler Sep 26, 2024
17da054
delete unused function
talonchandler Sep 26, 2024
ef0ed41
Merge branch 'main' into mp_utils_port
talonchandler Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 341 additions & 0 deletions iohub/ngff_utils.py
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
import contextlib
import inspect
import io
import itertools
import multiprocessing as mp
from functools import partial
from pathlib import Path
from typing import Tuple

import click
import numpy as np
from numpy.typing import DTypeLike

from iohub.ngff import Position, open_ome_zarr
from iohub.ngff_meta import TransformationMeta


def create_empty_hcs_zarr(
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
store_path: Path,
position_keys: list[Tuple[str]],
channel_names: list[str],
shape: Tuple[int],
chunks: Tuple[int] = None,
scale: Tuple[float] = (1, 1, 1, 1, 1),
dtype: DTypeLike = np.float32,
max_chunk_size_bytes=500e6,
) -> None:
"""
This function creates a new HCS Plate in OME-Zarr format if the plate does not exist.

If the plate exists, appends positions and channels if they are not
already in the plate.

Parameters
----------
store_path : Path
hcs plate path
position_keys : list[Tuple[str]]
Position keys, will append if not present in the plate.
e.g. [("A", "1", "0"), ("A", "1", "1")]
shape : Tuple[int]
chunks : Tuple[int]
scale : Tuple[float]
channel_names : list[str]
Channel names, will append if not present in metadata.
dtype : DTypeLike
"""
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
MAX_CHUNK_SIZE = max_chunk_size_bytes # in bytes
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
bytes_per_pixel = np.dtype(dtype).itemsize

# Limiting the chunking to 500MB
if chunks is None:
chunk_zyx_shape = list(shape[-3:])
# XY image is larger than MAX_CHUNK_SIZE
while (
chunk_zyx_shape[-3] > 1
and np.prod(chunk_zyx_shape) * bytes_per_pixel > MAX_CHUNK_SIZE
):
chunk_zyx_shape[-3] = np.ceil(chunk_zyx_shape[-3] / 2).astype(int)
chunk_zyx_shape = tuple(chunk_zyx_shape)

chunks = 2 * (1,) + chunk_zyx_shape

# Create plate
output_plate = open_ome_zarr(
str(store_path), layout="hcs", mode="a", channel_names=channel_names
)

# Create positions
for position_key in position_keys:
position_key_string = "/".join(position_key)
# Check if position is already in the store, if not create it
if position_key_string not in output_plate.zgroup:
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
position = output_plate.create_position(*position_key)
_ = position.create_zeros(
name="0",
shape=shape,
chunks=chunks,
dtype=dtype,
transform=[TransformationMeta(type="scale", scale=scale)],
)
else:
position = output_plate[position_key_string]

# Check if channel_names are already in the store, if not append them
for channel_name in channel_names:
# Read channel names directly from metadata to avoid race conditions
metadata_channel_names = [
channel.label for channel in position.metadata.omero.channels
]
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
if channel_name not in metadata_channel_names:
position.append_channel(channel_name, resize_arrays=True)


def apply_transform_to_zyx_and_save(
func,
position: Position,
output_path: Path,
input_channel_indices: list[int],
output_channel_indices: list[int],
talonchandler marked this conversation as resolved.
Show resolved Hide resolved
t_idx_in: int,
t_idx_out: int,
**kwargs,
) -> None:
"""
Load a CZYX array from a Position object, apply a transformation to CZYX.

Parameters
----------
func : CZYX -> CZYX function
The function to be applied to the data
position : Position
The position object to read from
output_path : Path
The path to output OME-Zarr Store
input_channel_indices : list
The channel indices to process.
If empty list, process all channels.
Must match output_channel_indices if not empty
output_channel_indices : list
The channel indices to write to.
If empty list, write to all channels.
Must match input_channel_indices if not empty
t_idx_in : int
The time index to process
t_idx_out : int
The time index to write to
kwargs : dict
Additional arguments to pass to the CZYX function
"""

# TODO: temporary fix to slumkit issue
if _is_nested(input_channel_indices):
input_channel_indices = [
int(x) for x in input_channel_indices if x.isdigit()
]
if _is_nested(output_channel_indices):
output_channel_indices = [
int(x) for x in output_channel_indices if x.isdigit()
]

# Check if t_idx_in should be added to the func kwargs
# This is needed when a different processing is needed for each time point, for example during stabilization
all_func_params = inspect.signature(func).parameters.keys()
if "t_idx_in" in all_func_params:
kwargs["t_idx_in"] = t_idx_in

# Process CZYX given with the given indeces
# if input_channel_indices is not None and len(input_channel_indices) > 0:
click.echo(f"Processing t={t_idx_in} and channels {input_channel_indices}")
czyx_data = position.data.oindex[t_idx_in, input_channel_indices]
if not _check_nan_n_zeros(czyx_data):
transformed_czyx = func(czyx_data, **kwargs)
# Write to file
with open_ome_zarr(output_path, mode="r+") as output_dataset:
output_dataset[0].oindex[
t_idx_out, output_channel_indices
] = transformed_czyx
click.echo(
f"Finished Writing.. t={t_idx_in} and channel output={output_channel_indices}"
)
else:
click.echo(f"Skipping t={t_idx_in} due to all zeros or nans")


# TODO: modify how we get the time and channesl like recOrder (isinstance(input, list) or instance(input,int) or all)
def process_single_position(
func,
input_data_path: Path,
output_path: Path,
time_indices_in: list = "all",
time_indices_out: list = [],
input_channel_idx: list = [],
output_channel_idx: list = [],
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
num_processes: int = mp.cpu_count(),
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
) -> None:
"""
Register a single position with multiprocessing parallelization over T and C
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
func : CZYX -> CZYX function
The function to be applied to the data
input_data_path : Path
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
The path to input position
output_path : Path
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
The path to output OME-Zarr Store
time_indices_in : list
The time indices to process.
Must match time_indices_out if not "all"
time_indices_out : list
The time indices to write to.
Must match time_indices_in if not empty.
Typically used for stabilization, which needs a per timepoint processing
talonchandler marked this conversation as resolved.
Show resolved Hide resolved
input_channel_idx : list
The channel indices to process.
If empty list, process all channels.
Must match output_channel_idx if not empty
output_channel_idx : list
The channel indices to write to.
If empty list, write to all channels.
Must match input_channel_idx if not empty
talonchandler marked this conversation as resolved.
Show resolved Hide resolved
num_processes : int
Number of simulatenous processes per position
kwargs : dict
Additional arguments to pass to the CZYX function

Usage:
------
Multiprocessing over T and C:
- input_channel_idx and output_channel_idx should be empty.
Multiprocessing over T only:
- input_channel_idx and output_channel_idx should be provided.
"""
# Function to be applied
click.echo(f"Function to be applied: \t{func}")

# Get the reader and writer
click.echo(f"Input data path:\t{input_data_path}")
click.echo(f"Output data path:\t{str(output_path)}")
input_dataset = open_ome_zarr(str(input_data_path))
stdout_buffer = io.StringIO()
with contextlib.redirect_stdout(stdout_buffer):
input_dataset.print_tree()
click.echo(f" Input data tree: {stdout_buffer.getvalue()}")

# Find time indices
if time_indices_in == "all":
time_indices_in = range(input_dataset.data.shape[0])
time_indices_out = time_indices_in
elif isinstance(time_indices_in, list):
# Check for invalid times
time_ubound = input_dataset.data.shape[0] - 1
if np.max(time_indices_in) > time_ubound:
raise ValueError(
f"time_indices_in = {time_indices_in} includes a time index beyond the maximum index of the dataset = {time_ubound}"
)
# Handle the case when time_indices out is not provided. It defaults to the t_indices_in
if len(time_indices_out) == 0:
time_indices_out = range(len(time_indices_in))

# Check the arguments for the function
all_func_params = inspect.signature(func).parameters.keys()
# Extract the relevant kwargs for the function 'func'
func_args = {}
non_func_args = {}

for k, v in kwargs.items():
if k in all_func_params:
func_args[k] = v
else:
non_func_args[k] = v

# Write the settings into the metadata if existing
if "extra_metadata" in non_func_args:
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
# For each dictionary in the nest
with open_ome_zarr(output_path, mode="r+") as output_dataset:
for params_metadata_keys in kwargs["extra_metadata"].keys():
output_dataset.zattrs["extra_metadata"] = non_func_args[
"extra_metadata"
]

# Loop through (T, C), deskewing and writing as we go
click.echo(f"\nStarting multiprocess pool with {num_processes} processes")

if input_channel_idx is None or len(input_channel_idx) == 0:
# If C is not empty, use itertools.product with both ranges
_, C, _, _, _ = input_dataset.data.shape
iterable = [
([c], [c], time_idx, time_idx_out)
for (time_idx, time_idx_out), c in itertools.product(
zip(time_indices_in, time_indices_out), range(C)
)
]
partial_apply_transform_to_zyx_and_save = partial(
apply_transform_to_zyx_and_save,
func,
input_dataset,
output_path / Path(*input_data_path.parts[-3:]),
**func_args,
)
else:
# If C is empty, use only the range for time_indices_in
talonchandler marked this conversation as resolved.
Show resolved Hide resolved
iterable = list(zip(time_indices_in, time_indices_out))
partial_apply_transform_to_zyx_and_save = partial(
apply_transform_to_zyx_and_save,
func,
input_dataset,
output_path / Path(*input_data_path.parts[-3:]),
input_channel_idx,
output_channel_idx,
**func_args,
)

click.echo(f"\nStarting multiprocess pool with {num_processes} processes")
with mp.Pool(num_processes) as p:
p.starmap(
partial_apply_transform_to_zyx_and_save,
iterable,
)


def _is_nested(lst):
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
"""
Check if the list is nested or not.

NOTE: this function was created for a bug in slumkit that nested input_channel_indices into a list of lists
TODO: check if this is still an issue in slumkit
"""
return any(isinstance(i, list) for i in lst) or any(
isinstance(i, str) for i in lst
)


def _check_nan_n_zeros(input_array):
"""
Checks if any of the channels are all zeros or nans and returns true
"""
if len(input_array.shape) == 3:
# Check if all the values are zeros or nans
if np.all(input_array == 0) or np.all(np.isnan(input_array)):
# Return true
return True
elif len(input_array.shape) == 4:
# Get the number of channels
num_channels = input_array.shape[0]
# Loop through the channels
for c in range(num_channels):
# Get the channel
zyx_array = input_array[c, :, :, :]

# Check if all the values are zeros or nans
if np.all(zyx_array == 0) or np.all(np.isnan(zyx_array)):
# Return true
return True
else:
raise ValueError("Input array must be 3D or 4D")

# Return false
return False
Loading
Loading