From 849955a34a8e091c5da81932fae8fcab2c6d0f90 Mon Sep 17 00:00:00 2001 From: Martin Schorb <35071867+martinschorb@users.noreply.github.com> Date: Wed, 21 Feb 2024 14:57:48 +0100 Subject: [PATCH] Make rotate available (#213) * add rotate function (2D and 3D) * rotate doc * rotate ND support * newline at EOF * rotate tests * list functions * rotate tests * rotate doc and temporary fix for tests * include output_shape parameter * rotate: warning if both reshape and output_shape provided * Reformat docstring, remove whitespace & commented code line * rotate: add chunking when not explicitely demanded (suggested by @gcaria) * test_rotate: fix CI error * test_rotate: complete coverage * test_rotate: fix tests * complete parameters * rotate: clean duplicate code * rotate: clean code, fix tests * rotate: fix comparison bug * test_rotate: fix warning tests * rotate: formatted docstring and adapted axes type checking to scipy * rotate: Removed output and output_shape arguments. Updating tests is WIP * Outsourced nonspecific arguments to affine_transform and adapted tests. Improved docstring. Modified prefilter test in rotate (and removed one for affine_transform which was testing for an unrelated UserWarning) * Optimized imports --------- Co-authored-by: Martin Schorb Co-authored-by: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Co-authored-by: Marvin Albert --- dask_image/ndinterp/__init__.py | 159 ++++++++++- docs/coverage.rst | 3 +- .../test_affine_transformation.py | 7 - .../test_ndinterp/test_rotate.py | 253 ++++++++++++++++++ 4 files changed, 412 insertions(+), 10 deletions(-) create mode 100644 tests/test_dask_image/test_ndinterp/test_rotate.py diff --git a/dask_image/ndinterp/__init__.py b/dask_image/ndinterp/__init__.py index 8d6d6837..1d422005 100644 --- a/dask_image/ndinterp/__init__.py +++ b/dask_image/ndinterp/__init__.py @@ -3,7 +3,6 @@ import functools import math from itertools import product -import warnings import dask.array as da import numpy as np @@ -11,6 +10,7 @@ from dask.highlevelgraph import HighLevelGraph import scipy from scipy.ndimage import affine_transform as ndimage_affine_transform +from scipy.special import sindg, cosdg from ..dispatch._dispatch_ndinterp import ( dispatch_affine_transform, @@ -25,6 +25,9 @@ __all__ = [ "affine_transform", + "rotate", + "spline_filter", + "spline_filter1d", ] @@ -247,6 +250,160 @@ def affine_transform( return transformed +def rotate( + input_arr, + angle, + axes=(1, 0), + reshape=True, + output_chunks=None, + **kwargs, + ): + """Rotate an array using Dask. + + The array is rotated in the plane defined by the two axes given by the + `axes` parameter using spline interpolation of the requested order. + + Chunkwise processing is performed using `dask_image.ndinterp.affine_transform`, + for which further parameters supported by the ndimage functions can be + passed as keyword arguments. + + Notes + ----- + Differences to `ndimage.rotate`: + - currently, prefiltering is not supported + (affecting the output in case of interpolation `order > 1`) + - default order is 1 + - modes 'reflect', 'mirror' and 'wrap' are not supported + + Arguments are equal to `ndimage.rotate` except for + - `output` (not present here) + - `output_chunks` (relevant in the dask array context) + + Parameters + ---------- + input_arr : array_like (Numpy Array, Cupy Array, Dask Array...) + The image array. + angle : float + The rotation angle in degrees. + axes : tuple of 2 ints, optional + The two axes that define the plane of rotation. Default is the first + two axes. + reshape : bool, optional + If `reshape` is true, the output shape is adapted so that the input + array is contained completely in the output. Default is True. + output_chunks : tuple of ints, optional + The shape of the chunks of the output Dask Array. + **kwargs : dict, optional + Additional keyword arguments are passed to + `dask_image.ndinterp.affine_transform`. + + Returns + ------- + rotate : Dask Array + A dask array representing the rotated input. + + Examples + -------- + >>> from scipy import ndimage, misc + >>> import matplotlib.pyplot as plt + >>> import dask.array as da + >>> fig = plt.figure(figsize=(10, 3)) + >>> ax1, ax2, ax3 = fig.subplots(1, 3) + >>> img = da.from_array(misc.ascent(),chunks=(64,64)) + >>> img_45 = dask_image.ndinterp.rotate(img, 45, reshape=False) + >>> full_img_45 = dask_image.ndinterp.rotate(img, 45, reshape=True) + >>> ax1.imshow(img, cmap='gray') + >>> ax1.set_axis_off() + >>> ax2.imshow(img_45, cmap='gray') + >>> ax2.set_axis_off() + >>> ax3.imshow(full_img_45, cmap='gray') + >>> ax3.set_axis_off() + >>> fig.set_tight_layout(True) + >>> plt.show() + >>> print(img.shape) + (512, 512) + >>> print(img_45.shape) + (512, 512) + >>> print(full_img_45.shape) + (724, 724) + + """ + if not type(input_arr) == da.core.Array: + input_arr = da.from_array(input_arr) + + if output_chunks is None: + output_chunks = input_arr.chunksize + + ndim = input_arr.ndim + + if ndim < 2: + raise ValueError('input array should be at least 2D') + + axes = list(axes) + + if len(axes) != 2: + raise ValueError('axes should contain exactly two values') + + if not all([float(ax).is_integer() for ax in axes]): + raise ValueError('axes should contain only integer values') + + if axes[0] < 0: + axes[0] += ndim + if axes[1] < 0: + axes[1] += ndim + if axes[0] < 0 or axes[1] < 0 or axes[0] >= ndim or axes[1] >= ndim: + raise ValueError('invalid rotation plane specified') + + axes.sort() + + c, s = cosdg(angle), sindg(angle) + + rot_matrix = np.array([[c, s], + [-s, c]]) + + img_shape = np.asarray(input_arr.shape) + in_plane_shape = img_shape[axes] + + if reshape: + # Compute transformed input bounds + iy, ix = in_plane_shape + out_bounds = rot_matrix @ [[0, 0, iy, iy], + [0, ix, 0, ix]] + # Compute the shape of the transformed input plane + out_plane_shape = (out_bounds.ptp(axis=1) + 0.5).astype(int) + else: + out_plane_shape = img_shape[axes] + + output_shape = np.array(img_shape) + output_shape[axes] = out_plane_shape + output_shape = tuple(output_shape) + + out_center = rot_matrix @ ((out_plane_shape - 1) / 2) + in_center = (in_plane_shape - 1) / 2 + offset = in_center - out_center + + matrix_nd = np.eye(ndim) + offset_nd = np.zeros(ndim) + + for o_x,idx in enumerate(axes): + + matrix_nd[idx,axes[0]] = rot_matrix[o_x,0] + matrix_nd[idx,axes[1]] = rot_matrix[o_x,1] + + offset_nd[idx] = offset[o_x] + + output = affine_transform( + input_arr, + matrix=matrix_nd, + offset=offset_nd, + output_shape=output_shape, + output_chunks=output_chunks, + **kwargs, + ) + + return output + + # magnitude of the maximum filter pole for each order # (obtained from scipy/ndimage/src/ni_splines.c) _maximum_pole = { diff --git a/docs/coverage.rst b/docs/coverage.rst index 2692b75d..93ae06a6 100644 --- a/docs/coverage.rst +++ b/docs/coverage.rst @@ -257,7 +257,7 @@ This table shows which SciPy ndimage functions are supported by dask-image. - ✓ * - ``rotate`` - ✓ - - + - ✓ - * - ``shift`` - ✓ @@ -311,4 +311,3 @@ This table shows which SciPy ndimage functions are supported by dask-image. - ✓ - - - diff --git a/tests/test_dask_image/test_ndinterp/test_affine_transformation.py b/tests/test_dask_image/test_ndinterp/test_affine_transformation.py index 853a63e6..eeb4c4bd 100644 --- a/tests/test_dask_image/test_ndinterp/test_affine_transformation.py +++ b/tests/test_dask_image/test_ndinterp/test_affine_transformation.py @@ -284,13 +284,6 @@ def test_affine_transform_no_output_shape_or_chunks_specified(): assert image_t.chunks == tuple([(s,) for s in image.shape]) -def test_affine_transform_prefilter_warning(): - - with pytest.warns(UserWarning): - dask_image.ndinterp.affine_transform(da.ones(20), [1], [0], - order=3, prefilter=True) - - @pytest.mark.timeout(15) def test_affine_transform_large_input_small_output_cpu(): """ diff --git a/tests/test_dask_image/test_ndinterp/test_rotate.py b/tests/test_dask_image/test_ndinterp/test_rotate.py new file mode 100644 index 00000000..48003421 --- /dev/null +++ b/tests/test_dask_image/test_ndinterp/test_rotate.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import numpy as np +import dask.array as da +import pytest +from scipy import ndimage + +import dask_image.ndinterp as da_ndinterp + + +def validate_rotate(n=2, + axes=(0, 1), + reshape=False, + input_shape_per_dim=16, + interp_order=1, + interp_mode='constant', + input_output_chunksize_per_dim=(6, 6), + random_seed=0, + use_cupy=False, + ): + """ + Compare the outputs of `ndimage.rotate` + and `dask_image.ndinterp.rotate`. + """ + + # define test image + np.random.seed(random_seed) + image = np.random.random([input_shape_per_dim] * n) + + angle = np.random.random() * 360 - 180 + + # transform into dask array + chunksize = [input_output_chunksize_per_dim[0]] * n + image_da = da.from_array(image, chunks=chunksize) + + if use_cupy: + import cupy as cp + image_da = image_da.map_blocks(cp.asarray) + + + # define resampling options + output_chunks = [input_output_chunksize_per_dim[1]] * n + + # transform with dask-image + image_t_dask = da_ndinterp.rotate( + image, angle, + axes=axes, + reshape=reshape, + order=interp_order, + mode=interp_mode, + prefilter=False, + output_chunks=output_chunks + ) + + image_t_dask_computed = image_t_dask.compute() + + # transform with scipy + image_t_scipy = ndimage.rotate( + image, angle, + axes=axes, + reshape=reshape, + order=interp_order, + mode=interp_mode, + prefilter=False) + + assert np.allclose(image_t_scipy, image_t_dask_computed) + + +@pytest.mark.parametrize("n", + [2, 3]) +@pytest.mark.parametrize("input_shape_per_dim", + [25, 2]) +@pytest.mark.parametrize("interp_order", + [0,1]) +@pytest.mark.parametrize("input_output_chunksize_per_dim", + [(16, 16), (16, 7), (7, 16)]) +@pytest.mark.parametrize("random_seed", + [0, 1, 2]) +def test_rotate_general(n, + input_shape_per_dim, + interp_order, + input_output_chunksize_per_dim, + random_seed): + + kwargs = dict() + kwargs['n'] = n + kwargs['input_shape_per_dim'] = input_shape_per_dim + kwargs['interp_order'] = interp_order + kwargs['input_output_chunksize_per_dim'] = input_output_chunksize_per_dim + kwargs['random_seed'] = random_seed + + validate_rotate(**kwargs) + + +@pytest.mark.cupy +@pytest.mark.parametrize("n", + [2, 3]) +@pytest.mark.parametrize("input_shape_per_dim", + [25, 2]) +@pytest.mark.parametrize("interp_order", + [0, 1]) +@pytest.mark.parametrize("input_output_chunksize_per_dim", + [(16, 16), (16, 7)]) +@pytest.mark.parametrize("random_seed", + [0]) +def test_rotate_cupy(n, + input_shape_per_dim, + interp_order, + input_output_chunksize_per_dim, + random_seed): + + cupy = pytest.importorskip("cupy", minversion="6.0.0") + + kwargs = dict() + kwargs['n'] = n + kwargs['input_shape_per_dim'] = input_shape_per_dim + kwargs['interp_order'] = interp_order + kwargs['input_output_chunksize_per_dim'] = input_output_chunksize_per_dim + kwargs['random_seed'] = random_seed + kwargs['use_cupy'] = True + + validate_rotate(**kwargs) + + +@pytest.mark.parametrize("n", + [2, 3]) +@pytest.mark.parametrize("interp_mode", + ['constant', 'nearest']) +@pytest.mark.parametrize("input_shape_per_dim", + [20, 30]) +@pytest.mark.parametrize("input_output_chunksize_per_dim", + [(15, 10)]) +def test_rotate_modes(n, + interp_mode, + input_shape_per_dim, + input_output_chunksize_per_dim, + ): + + kwargs = dict() + kwargs['n'] = n + kwargs['interp_mode'] = interp_mode + kwargs['input_shape_per_dim'] = input_shape_per_dim + kwargs['input_output_chunksize_per_dim'] = input_output_chunksize_per_dim + kwargs['interp_order'] = 0 + + validate_rotate(**kwargs) + + +@pytest.mark.parametrize("interp_mode", + ['wrap', 'reflect', 'mirror']) +def test_rotate_unsupported_modes(interp_mode): + + kwargs = dict() + kwargs['interp_mode'] = interp_mode + + with pytest.raises(NotImplementedError): + validate_rotate(**kwargs) + + +def test_rotate_dimensions(): + with pytest.raises(ValueError): + validate_rotate(n=1) + + +@pytest.mark.parametrize("axes", + [[1], [1, 2, 3], + [-3, 0], [0, -3], [0, 3], [2, 0]]) +def test_rotate_axisdimensions(axes): + kwargs = dict() + kwargs['axes'] = axes + + with pytest.raises(ValueError): + validate_rotate(**kwargs) + + +@pytest.mark.parametrize("axes", + [[1, 2.2], [1, 'a'], [[0, 1], 1], [(0, 1), 1], [0, {}]]) +def test_rotate_axistypes(axes): + kwargs = dict() + kwargs['axes'] = axes + + with pytest.raises((ValueError, TypeError)): + validate_rotate(**kwargs) + + +@pytest.mark.parametrize( + "image", + [ + np.ones((3, 3)).astype(float), + np.ones((3, 3)).astype(int), + np.ones((3, 3)).astype(complex) + ] +) +def test_rotate_dtype(image): + image_t = da_ndinterp.rotate(image, 0, reshape=False) + assert image_t.dtype == image.dtype + + +def test_rotate_numpy_input(): + + image = np.ones((3, 3)) + image_t = da_ndinterp.rotate(image, 0, reshape =False) + + assert image_t.shape == image.shape + assert (da.from_array(image) == image_t).min() + + +def test_rotate_minimal_input(): + + image = np.ones((3, 3)) + image_t = da_ndinterp.rotate(np.ones((3, 3)), 0) + + assert image_t.shape == image.shape + + +def test_rotate_type_consistency(): + + image = da.ones((3, 3)) + image_t = da_ndinterp.rotate(image, 0) + + assert isinstance(image, type(image_t)) + assert isinstance(image[0, 0].compute(), type(image_t[0, 0].compute())) + + +@pytest.mark.cupy +def test_rotate_type_consistency_gpu(): + + cupy = pytest.importorskip("cupy", minversion="6.0.0") + + image = da.ones((3, 3)) + image_t = da_ndinterp.rotate(image, 0) + + image.map_blocks(cupy.asarray) + + assert isinstance(image, type(image_t)) + assert isinstance(image[0, 0].compute(), type(image_t[0, 0].compute())) + + +def test_rotate_no_chunks_specified(): + + image = da.ones((3, 3)) + image_t = da_ndinterp.rotate(image, 0) + + assert image_t.shape == image.shape + assert image_t.chunks == tuple([(s,) for s in image.shape]) + + +def test_rotate_prefilter_not_implemented_error(): + with pytest.raises(NotImplementedError): + da_ndinterp.rotate( + da.ones((15, 15)), 0, + order=3, prefilter=True, mode='nearest')