Skip to content

Commit

Permalink
Merge pull request #66 from thewtex/dask-image-label
Browse files Browse the repository at this point in the history
Dask image label
  • Loading branch information
thewtex authored Sep 22, 2022
2 parents 8fab9ad + 5253e9b commit c850167
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 21 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ image = input_images[dataset_name]
baseline_name = "2_4/XARRAY_COARSEN"
multiscale = to_multiscale(image, [2, 4], method=Methods.XARRAY_COARSEN)

store_new_image(web3_data, multiscale, dataset_name, baseline_name)
store_new_image(dataset_name, baseline_name, multiscale)

verify_against_baseline(web3_data, dataset_name, baseline_name, multiscale)
verify_against_baseline(dataset_name, baseline_name, multiscale)
```

Run the tests to generate the output. Remove the `store_new_image` call.
Expand Down
39 changes: 30 additions & 9 deletions multiscale_spatial_image/to_multiscale/_dask_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _get_truncate(xarray_image, sigma_values, truncate_start=4.0) -> float:

return truncate

def _downsample_dask_image_gaussian(current_input, default_chunks, out_chunks, scale_factors, data_objects, image):
def _downsample_dask_image(current_input, default_chunks, out_chunks, scale_factors, data_objects, image, label=False):
import dask_image.ndfilters
import dask_image.ndinterp

Expand All @@ -112,31 +112,52 @@ def _downsample_dask_image_gaussian(current_input, default_chunks, out_chunks, s
# Compute/discover region splitting parameters
input_spacing = _compute_input_spacing(current_input)
input_spacing = [input_spacing[dim] for dim in image_dims if dim in dim_factors]
sigma_values = _compute_sigma(input_spacing, shrink_factors)
truncate = _get_truncate(current_input, np.flip(sigma_values))

# Compute output shape and metadata
output_shape = [int(image_len / shrink_factor)
for image_len, shrink_factor in zip(current_input.shape, np.flip(shrink_factors))]
output_spacing = _compute_output_spacing(current_input, dim_factors)
output_origin = _compute_output_origin(current_input, dim_factors)

blurred_array = dask_image.ndfilters.gaussian_filter(
image=current_input.data,
sigma=np.flip(sigma_values), # tzyx order
mode='nearest',
truncate=truncate
)
if label == 'mode':
def largest_mode(arr):
values, counts = np.unique(arr, return_counts=True)
m = counts.argmax()
return values[m]
size = tuple(np.flip(shrink_factors))
blurred_array = dask_image.ndfilters.generic_filter(
image=current_input.data,
function=largest_mode,
size=size,
mode='nearest',
)
elif label == 'nearest':
blurred_array = current_input.data
else:
sigma_values = _compute_sigma(input_spacing, shrink_factors)
truncate = _get_truncate(current_input, np.flip(sigma_values))

blurred_array = dask_image.ndfilters.gaussian_filter(
image=current_input.data,
sigma=np.flip(sigma_values), # tzyx order
mode='nearest',
truncate=truncate
)

# Construct downsample parameters
image_dimension = len(dim_factors)
transform = np.eye(image_dimension)
for dim, shrink_factor in enumerate(np.flip(shrink_factors)):
transform[dim,dim] = shrink_factor
if label:
order = 0
else:
order = 1

downscaled_array = dask_image.ndinterp.affine_transform(
blurred_array,
matrix=transform,
order=order,
output_shape=output_shape # tzyx order
).compute()

Expand Down
20 changes: 13 additions & 7 deletions multiscale_spatial_image/to_multiscale/to_multiscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@

from ._xarray import _downsample_xarray_coarsen
from ._itk import _downsample_itk_bin_shrink, _downsample_itk_gaussian, _downsample_itk_label
from ._dask_image import _downsample_dask_image_gaussian
from ._dask_image import _downsample_dask_image

class Methods(Enum):
XARRAY_COARSEN = "xarray.DataArray.coarsen"
ITK_BIN_SHRINK = "itk.bin_shrink_image_filter"
ITK_GAUSSIAN = "itk.discrete_gaussian_image_filter"
ITK_LABEL_GAUSSIAN = "itk.discrete_gaussian_image_filter_label_interpolator"
DASK_IMAGE_GAUSSIAN = "dask_image.ndfilters.gaussian_filter"
XARRAY_COARSEN = "xarray_coarsen"
ITK_BIN_SHRINK = "itk_bin_shrink"
ITK_GAUSSIAN = "itk_gaussian"
ITK_LABEL_GAUSSIAN = "itk_label_gaussian"
DASK_IMAGE_GAUSSIAN = "dask_image_gaussian"
DASK_IMAGE_MODE = "dask_image_mode"
DASK_IMAGE_NEAREST = "dask_image_nearest"


def to_multiscale(
Expand Down Expand Up @@ -90,7 +92,11 @@ def to_multiscale(
elif method is Methods.ITK_LABEL_GAUSSIAN:
data_objects = _downsample_itk_label(current_input, default_chunks, out_chunks, scale_factors, data_objects, image)
elif method is Methods.DASK_IMAGE_GAUSSIAN:
data_objects = _downsample_dask_image_gaussian(current_input, default_chunks, out_chunks, scale_factors, data_objects, image)
data_objects = _downsample_dask_image(current_input, default_chunks, out_chunks, scale_factors, data_objects, image, label=False)
elif method is Methods.DASK_IMAGE_NEAREST:
data_objects = _downsample_dask_image(current_input, default_chunks, out_chunks, scale_factors, data_objects, image, label='nearest')
elif method is Methods.DASK_IMAGE_MODE:
data_objects = _downsample_dask_image(current_input, default_chunks, out_chunks, scale_factors, data_objects, image, label='mode')

multiscale = MultiscaleSpatialImage.from_dict(
d=data_objects
Expand Down
6 changes: 3 additions & 3 deletions test/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import xarray as xr
from datatree import open_datatree

test_data_ipfs_cid = 'bafybeidr5be65a67njdaiw4cm27gjqpcmxlnhor7wak5hgm3jbhcnikt4y'
test_data_sha256 = '95c5836b49c0f2a29b48a3865b3e5e23858d555c8dceebcd43f129052ee4525d'
test_data_ipfs_cid = 'bafybeia73oin2pi7hdbfquvrad5jctvcn3vubk3slvh47fvwtwlvbdxqfm'
test_data_sha256 = '29695d19bb6bac5b31b95bdbe451ff5535f202bdc9b43731f9a5fc8e0cfa1230'


test_dir = Path(__file__).resolve().parent
Expand Down Expand Up @@ -59,7 +59,7 @@ def verify_against_baseline(dataset_name, baseline_name, multiscale):
for scale in multiscale.children:
xr.testing.assert_equal(dt[scale].ds, multiscale[scale].ds)

def store_new_image(multiscale_image, dataset_name, baseline_name):
def store_new_image(dataset_name, baseline_name, multiscale_image):
'''Helper method for writing output results to disk
for later upload as test baseline'''
store = DirectoryStore(
Expand Down
49 changes: 49 additions & 0 deletions test/test_to_multiscale_dask_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,52 @@ def test_gaussian_anisotropic_scale_factors(input_images):
multiscale = to_multiscale(image, scale_factors, method=Methods.DASK_IMAGE_GAUSSIAN)
baseline_name = "x3y2z4_x2y2z2_x1y2z1/DASK_IMAGE_GAUSSIAN"
verify_against_baseline(dataset_name, baseline_name, multiscale)


def test_label_nearest_isotropic_scale_factors(input_images):
dataset_name = "2th_cthead1"
image = input_images[dataset_name]
baseline_name = "2_4/DASK_IMAGE_NEAREST"
multiscale = to_multiscale(image, [2, 4], method=Methods.DASK_IMAGE_NEAREST)
store_new_image(dataset_name, baseline_name, multiscale)
verify_against_baseline(dataset_name, baseline_name, multiscale)

dataset_name = "2th_cthead1"
image = input_images[dataset_name]
baseline_name = "2_3/DASK_IMAGE_NEAREST"
multiscale = to_multiscale(image, [2, 3], method=Methods.DASK_IMAGE_NEAREST)
store_new_image(dataset_name, baseline_name, multiscale)
verify_against_baseline(dataset_name, baseline_name, multiscale)


def test_label_nearest_anisotropic_scale_factors(input_images):
dataset_name = "2th_cthead1"
image = input_images[dataset_name]
scale_factors = [{"x": 2, "y": 4}, {"x": 1, "y": 2}]
multiscale = to_multiscale(image, scale_factors, method=Methods.DASK_IMAGE_NEAREST)
baseline_name = "x2y4_x1y2/DASK_IMAGE_NEAREST"
store_new_image(dataset_name, baseline_name, multiscale)
verify_against_baseline(dataset_name, baseline_name, multiscale)


def test_label_mode_isotropic_scale_factors(input_images):
dataset_name = "2th_cthead1"
image = input_images[dataset_name]
baseline_name = "2_4/DASK_IMAGE_MODE"
multiscale = to_multiscale(image, [2, 4], method=Methods.DASK_IMAGE_MODE)
verify_against_baseline(dataset_name, baseline_name, multiscale)

dataset_name = "2th_cthead1"
image = input_images[dataset_name]
baseline_name = "2_3/DASK_IMAGE_MODE"
multiscale = to_multiscale(image, [2, 3], method=Methods.DASK_IMAGE_MODE)
verify_against_baseline(dataset_name, baseline_name, multiscale)


def test_label_mode_anisotropic_scale_factors(input_images):
dataset_name = "2th_cthead1"
image = input_images[dataset_name]
scale_factors = [{"x": 2, "y": 4}, {"x": 1, "y": 2}]
multiscale = to_multiscale(image, scale_factors, method=Methods.DASK_IMAGE_MODE)
baseline_name = "x2y4_x1y2/DASK_IMAGE_MODE"
verify_against_baseline(dataset_name, baseline_name, multiscale)

0 comments on commit c850167

Please sign in to comment.