From 31c9489c8d145cf83bfb35cc6eea36b9d2e97f46 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 22 Mar 2022 09:13:15 -0600 Subject: [PATCH 01/37] Rough in a new BlockArray based on tuple --- scico/_generic_operators.py | 2 +- scico/blockarray.py | 1047 +++------------------------------ scico/numpy/__init__.py | 2 +- scico/numpy/linalg.py | 2 +- scico/random.py | 2 +- scico/test/test_blockarray.py | 28 +- 6 files changed, 102 insertions(+), 981 deletions(-) diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index 8d0c0d66b..d97bb6592 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -24,7 +24,7 @@ import scico.numpy as snp from scico._autograd import linear_adjoint from scico.array import is_complex_dtype, is_nested -from scico.blockarray import BlockArray, block_sizes +from scico.blockarray import BlockArray from scico.typing import BlockShape, DType, JaxArray, Shape diff --git a/scico/blockarray.py b/scico/blockarray.py index 601d5e4a7..7196b9363 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -448,777 +448,29 @@ """ - -from __future__ import annotations - from functools import wraps -from typing import Iterator, List, Optional, Tuple, Union - -import numpy as np import jax import jax.numpy as jnp -from jax import core -from jax.interpreters import xla -from jax.interpreters.xla import DeviceArray -from jax.tree_util import register_pytree_node, tree_flatten - -from jaxlib.xla_extension import Buffer - -from scico import array -from scico.typing import Axes, AxisIndex, BlockShape, DType, JaxArray, Shape - -_arraylikes = (Buffer, DeviceArray, np.ndarray) - - -def atleast_1d(*arys): - """Convert inputs to arrays with at least one dimension. - - A wrapper for :func:`jax.numpy.atleast_1d` that acts as usual on - ndarrays and DeviceArrays, and returns BlockArrays unmodified. - """ - - if len(arys) == 1: - arr = arys[0] - return arr if isinstance(arr, BlockArray) else jnp.atleast_1d(arr) - - out = [] - for arr in arys: - if isinstance(arr, BlockArray): - out.append(arr) - else: - out.append(jnp.atleast_1d(arr)) - return out - - -# Append docstring from original jax.numpy function -atleast_1d.__doc__ = ( - atleast_1d.__doc__.replace("\n ", "\n") # deal with indentation differences - + "\nDocstring for :func:`jax.numpy.atleast_1d`:\n\n" - + "\n".join(jax.numpy.atleast_1d.__doc__.split("\n")[2:]) -) - - -def reshape( - a: Union[JaxArray, BlockArray], newshape: Union[Shape, BlockShape] -) -> Union[JaxArray, BlockArray]: - """Change the shape of an array without changing its data. - - Args: - a: Array to be reshaped. - newshape: The new shape should be compatible with the original - shape. If an integer, then the result will be a 1-D array of - that length. One shape dimension can be -1. In this case, - the value is inferred from the length of the array and - remaining dimensions. If a tuple of tuple of ints, a - :class:`.BlockArray` is returned. - - Returns: - The reshaped array. Unlike :func:`numpy.reshape`, a copy is - always returned. - """ - - if array.is_nested(newshape): - # x is a blockarray - return BlockArray.array_from_flattened(a, newshape) - - return jnp.reshape(a, newshape) - - -def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: - r"""Compute the 'sizes' of (possibly nested) block shapes. - - This function computes `block_sizes(z.shape) == (_.size for _ in z)`. - - - Args: - shape: A shape tuple; possibly containing nested tuples. - - - Examples: - - .. doctest:: - >>> import scico.numpy as snp - - >>> x = BlockArray.ones( ( (4, 4), (2,))) - >>> x.size - 18 - - >>> y = snp.ones((3, 3)) - >>> y.size - 9 - >>> z = BlockArray.array([x, y]) - >>> block_sizes(z.shape) - (18, 9) +from jaxlib.xla_extension import DeviceArray - >>> zz = BlockArray.array([z, z]) - >>> block_sizes(zz.shape) - (27, 27) - """ - if isinstance(shape, BlockArray): - raise TypeError( - "Expected a `shape` (possibly nested tuple of ints); got :class:`.BlockArray`." - ) +class BlockArray(tuple): + """BlockArray""" - out = [] - if array.is_nested(shape): - # shape is nested -> at least one element came from a blockarray - for y in shape: - if array.is_nested(y): - # recursively calculate the block size until we arrive at - # a tuple (shape of a non-block array) - while array.is_nested(y): - y = block_sizes(y) - out.append(np.sum(y)) # adjacent block sizes are added together - else: - # this is a tuple; size given by product of elements - out.append(np.prod(y)) - return tuple(out) - - # shape is a non-nested tuple; return the product - return np.prod(shape) - - -def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple: - """Decompose a BlockArray indexing expression into components. - - Decompose a BlockArray indexing expression into block and array - components. - - Args: - idx: BlockArray indexing expression. - - Returns: - A tuple (idxblk, idxarr) with entries corresponding to the - integer block index and the indexing to be applied to the - selected block, respectively. The latter is ``None`` if the - indexing expression simply selects one of the blocks (i.e. - it consists of a single integer). - - Raises: - TypeError: If the block index is not an integer. - """ - if isinstance(idx, tuple): - idxblk = idx[0] - idxarr = idx[1:] - else: - idxblk = idx - idxarr = None - if not isinstance(idxblk, int): - raise TypeError("Block index must be an integer") - return idxblk, idxarr - - -def indexed_shape(shape: Shape, idx: Union[int, Tuple[AxisIndex, ...]]) -> Tuple[int, ...]: - """Determine the shape of the result of indexing a BlockArray. - - Args: - shape: Shape of BlockArray. - idx: BlockArray indexing expression. - - Returns: - Shape of the selected block, or slice of that block if `idx` is a tuple - rather than an integer. - """ - idxblk, idxarr = _decompose_index(idx) - if idxblk < 0: - idxblk = len(shape) + idxblk - if idxarr is None: - return shape[idxblk] - return array.indexed_shape(shape[idxblk], idxarr) - - -def _flatten_blockarrays(inp, *args, **kwargs): - """Flatten any blockarrays present in `inp`, `args`, or `kwargs`.""" - - def _flatten_if_blockarray(inp): - if isinstance(inp, BlockArray): - return inp._data - return inp - - inp_ = _flatten_if_blockarray(inp) - args_ = (_flatten_if_blockarray(_) for _ in args) - kwargs_ = {key: _flatten_if_blockarray(val) for key, val in kwargs.items()} - return inp_, args_, kwargs_ - - -def _block_array_ufunc_wrapper(func): - """Wrap a "ufunc" to allow for joint operation on `DeviceArray` and `BlockArray`.""" - - @wraps(func) - def wrapper(inp, *args, **kwargs): - all_args = (inp,) + args + tuple(kwargs.items()) - if any([isinstance(_, BlockArray) for _ in all_args]): - # If 'inp' is a BlockArray, call func on inp._data - # Then return a BlockArray of the same shape as inp - - inp_, args_, kwargs_ = _flatten_blockarrays(inp, *args, **kwargs) - flat_out = func(inp_, *args_, **kwargs_) - return BlockArray.array_from_flattened(flat_out, inp.shape) - - # Otherwise call the function normally - return func(inp, *args, **kwargs) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_reduction_wrapper(func): - """Wrap a reduction (eg. sum, norm) to allow for joint operation on - `DeviceArray` and `BlockArray`.""" - - @wraps(func) - def wrapper(inp, *args, axis=None, **kwargs): - - all_args = (inp,) + args + tuple(kwargs.items()) - if any([isinstance(_, BlockArray) for _ in all_args]): - if axis is None: - # Treat as a single long vector - inp_, args_, kwargs_ = _flatten_blockarrays(inp, *args, **kwargs) - return func(inp_, *args_, **kwargs_) - - if type(axis) == tuple: - raise Exception( - f"""Evaluating {func.__name__} on a BlockArray with a tuple argument to - axis is not currently supported""" - ) - - if axis == 0: # reduction along block axis - # reduction along axis=0 only makes sense if all blocks are the same shape - # so we can convert to a standard DeviceArray of shape (inp.num_blocks, ...) - # and reduce along axis = 0 - if all([bk_shape == inp.shape[0] for bk_shape in inp.shape]): - view_shape = (inp.num_blocks,) + inp.shape[0] - return func(inp._data.reshape(view_shape), *args, axis=0, **kwargs) - - raise ValueError( - f"Evaluating {func.__name__} of BlockArray along axis=0 requires " - f"all blocks to be same shape; got {inp.shape}" - ) - - # Reduce each block individually along axis-1 - out = [] - for bk in inp: - if isinstance(bk, BlockArray): - # This block is itself a blockarray, so call this wrapped reduction - # on axis-1 - tmp = _block_array_reduction_wrapper(func)(bk, *args, axis=axis - 1, **kwargs) - else: - if axis - 1 >= bk.ndim: - # Trying to reduce along a dim that doesn't exist for this block, - # so just return the block. - # i.e. broadcast to shape (..., 1) and reduce along axis=-1 - tmp = bk - else: - tmp = func(bk, *args, axis=axis - 1, **kwargs) - out.append(atleast_1d(tmp)) - return BlockArray.array(out) - - if axis is None: - # 'axis' might not be a valid kwarg (eg dot, vdot), so don't pass it - return func(inp, *args, **kwargs) - - return func(inp, *args, axis=axis, **kwargs) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_matmul_wrapper(func): - @wraps(func) - def wrapper(self, other): - if isinstance(self, BlockArray): - if isinstance(other, BlockArray): - # Both blockarrays, work block by block - return BlockArray.array([func(x, y) for x, y in zip(self, other)]) - raise TypeError( - f"Operation {func.__name__} not implemented between {type(self)} and {type(other)}" - ) - return func(self, other) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_binary_op_wrapper(func): - """Return a decorator that performs type and shape checking for - :class:`.BlockArray` arithmetic. - """ - - @wraps(func) - def wrapper(self, other): - if isinstance(other, BlockArray): - if other.shape == self.shape: - # Same shape blocks, can operate on flattened arrays - return BlockArray.array_from_flattened(func(self._data, other._data), self.shape) - if other.num_blocks == self.num_blocks: - # Will work as long as the shapes are broadcastable - return BlockArray.array([func(x, y) for x, y in zip(self, other)]) - raise ValueError( - f"operation not valid on operands with shapes {self.shape} {other.shape}" - ) - if any([isinstance(other, _) for _ in _arraylikes]): - if other.size == 1: - # Same as operating on a scalar - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - if other.size == self.size: - # A little fast and loose, treat the block array as a length self.size vector - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - if other.size == self.num_blocks: - return BlockArray.array([func(blk, other_) for blk, other_ in zip(self, other)]) - raise ValueError( - f"operation not valid on operands with shapes {self.shape} {other.shape}" - ) - if jnp.isscalar(other) or isinstance(other, core.Tracer): - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - raise TypeError( - f"Operation {func.__name__} not implemented between {type(self)} and {type(other)}" - ) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -class _AbstractBlockArray(core.ShapedArray): - """Abstract BlockArray class for JAX tracing. - - See https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html - """ - - array_abstraction_level = 0 # Same as jax.core.ConcreteArray - - def __init__(self, shapes, dtype): - - sizes = block_sizes(shapes) - size = np.sum(sizes) - - super(_AbstractBlockArray, self).__init__((size,), dtype) - - #: Abstract data value - self._data_aval: core.ShapedArray = core.ShapedArray((size,), dtype) - - #: Array dtype - self.dtype: DType = dtype - - #: Shape of each block - self.shapes: BlockShape = shapes - - #: Size of each block - self.sizes: Shape = sizes - - #: Array specifying boundaries of components as indices in base array - self.bndpos: np.ndarray = np.r_[0, np.cumsum(sizes)] - - -# The Jax class is heavily inspired by SparseArray/AbstractSparseArray here: -# https://github.com/google/jax/blob/7724322d1c08c13008815bfb52759a29c2a6823b/tests/custom_object_test.py -class BlockArray: - """A tuple of :class:`jax.interpreters.xla.DeviceArray` objects. - - A tuple of `DeviceArray` objects that all share their memory buffers - with non-overlapping, contiguous regions of a common one-dimensional - `DeviceArray`. It can be used as the common one-dimensional array via - the :func:`BlockArray.ravel` method, or individual component arrays - can be accessed individually. - """ - - # Ensure we use BlockArray.__radd__, __rmul__, etc for binary operations of the form - # op(np.ndarray, BlockArray) - # See https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ + # Ensure we use BlockArray.__radd__, __rmul__, etc for binary + # operations of the form op(np.ndarray, BlockArray) See + # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ __array_priority__ = 1 - def __init__(self, aval: _AbstractBlockArray, data: JaxArray): - """BlockArray init method. - - Args: - aval: `Abstract value`_ associated with this array (shape+dtype+weak_type) - data: The underlying contiguous, flattened `DeviceArray`. - - .. _Abstract value: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html - - """ - self._aval = aval - self._data = data - - def __repr__(self): - return "scico.blockarray.BlockArray: \n" + self._data.__repr__() - - def __getitem__(self, idx: Union[int, Tuple[AxisIndex, ...]]) -> JaxArray: - idxblk, idxarr = _decompose_index(idx) - if idxblk < 0: - idxblk = self.num_blocks + idxblk - blk = reshape(self._data[self.bndpos[idxblk] : self.bndpos[idxblk + 1]], self.shape[idxblk]) - if idxarr is not None: - blk = blk[idxarr] - return blk - - @_block_array_matmul_wrapper - def __matmul__(self, other: Union[np.ndarray, BlockArray, JaxArray]) -> JaxArray: - return self @ other - - @_block_array_matmul_wrapper - def __rmatmul__(self, other: Union[np.ndarray, BlockArray, JaxArray]) -> JaxArray: - return other @ self - - @_block_array_binary_op_wrapper - def __sub__(a, b): - return a - b - - @_block_array_binary_op_wrapper - def __rsub__(a, b): - return b - a - - @_block_array_binary_op_wrapper - def __mul__(a, b): - return a * b - - @_block_array_binary_op_wrapper - def __rmul__(a, b): - return a * b - - @_block_array_binary_op_wrapper - def __add__(a, b): - return a + b - - @_block_array_binary_op_wrapper - def __radd__(a, b): - return a + b - - @_block_array_binary_op_wrapper - def __truediv__(a, b): - return a / b - - @_block_array_binary_op_wrapper - def __rtruediv__(a, b): - return b / a - - @_block_array_binary_op_wrapper - def __floordiv__(a, b): - return a // b - - @_block_array_binary_op_wrapper - def __rfloordiv__(a, b): - return b // a - - @_block_array_binary_op_wrapper - def __pow__(a, b): - return a**b - - @_block_array_binary_op_wrapper - def __rpow__(a, b): - return b**a - - @_block_array_binary_op_wrapper - def __gt__(a, b): - return a > b - - @_block_array_binary_op_wrapper - def __ge__(a, b): - return a >= b - - @_block_array_binary_op_wrapper - def __lt__(a, b): - return a < b - - @_block_array_binary_op_wrapper - def __le__(a, b): - return a <= b - - @_block_array_binary_op_wrapper - def __eq__(a, b): - return a == b - - @_block_array_binary_op_wrapper - def __ne__(a, b): - return a != b - - def __iter__(self) -> Iterator[int]: - for i in range(self.num_blocks): - yield self[i] - - @property - def blocks(self) -> Iterator[int]: - """Return an iterator yielding component blocks.""" - return self.__iter__() - - @property - def bndpos(self) -> np.ndarray: - """Array specifying boundaries of components as indices in base array.""" - return self._aval.bndpos - - @property - def dtype(self) -> DType: - """Array dtype.""" - return self._data.dtype - - @property - def device_buffer(self) -> Buffer: - """The :class:`jaxlib.xla_extension.Buffer` that backs the - underlying data array.""" - return self._data.device_buffer - @property def size(self) -> int: """Total number of elements in the array.""" - return self._aval.size - - @property - def num_blocks(self) -> int: - """Number of :class:`.BlockArray` components.""" - - return len(self.shape) - - @property - def ndim(self) -> Shape: - """Tuple of component ndims.""" - - return tuple(len(c) for c in self.shape) - - @property - def shape(self) -> BlockShape: - """Tuple of component shapes.""" - - return self._aval.shapes - - @property - def split(self) -> Tuple[JaxArray, ...]: - """Tuple of component arrays.""" - - return tuple(self[k] for k in range(self.num_blocks)) - - def conj(self) -> BlockArray: - """Return a :class:`.BlockArray` with complex-conjugated elements.""" - - # Much faster than BlockArray.array([_.conj() for _ in self.blocks]) - return BlockArray.array_from_flattened(self.ravel().conj(), self.shape) - - @property - def real(self) -> BlockArray: - """Return a :class:`.BlockArray` with the real part of this array.""" - return BlockArray.array_from_flattened(self.ravel().real, self.shape) - - @property - def imag(self) -> BlockArray: - """Return a :class:`.BlockArray` with the imaginary part of this array.""" - return BlockArray.array_from_flattened(self.ravel().imag, self.shape) - - @classmethod - def array( - cls, alst: List[Union[np.ndarray, JaxArray]], dtype: Optional[np.dtype] = None - ) -> BlockArray: - """Construct a :class:`.BlockArray` from a list or tuple of existing array-like. - - Args: - alst: Initializers for array components. - Can be :class:`numpy.ndarray` or - :class:`jax.interpreters.xla.DeviceArray` - dtype: Data type of array. If ``None``, dtype is derived from - dtype of initializers. - - Returns: - :class:`.BlockArray` initialized from `alst` tuple. - """ + return sum(x_i.size for x_i in self) - if isinstance(alst, (tuple, list)) is False: - raise TypeError("Input to `array` must be a list or tuple of existing arrays") - - if dtype is None: - present_types = jax.tree_flatten(jax.tree_map(lambda x: x.dtype, alst))[0] - dtype = np.find_common_type(present_types, []) - - # alst can be a list/tuple of arrays, or a list/tuple containing list/tuples of arrays - # consider alst to be a tree where leaves are arrays (possibly abstract arrays) - # use tree_map to find the shape of each leaf - # `shapes` will be a tuple of ints and tuples containing ints (possibly nested further) - - # ensure any scalar leaves are converted to (1,) arrays - def shape_atleast_1d(x): - return x.shape if x.shape != () else (1,) - - shapes = tuple( - jax.tree_map(shape_atleast_1d, alst, is_leaf=lambda x: not isinstance(x, (list, tuple))) - ) - - _aval = _AbstractBlockArray(shapes, dtype) - data_ravel = jnp.hstack(jax.tree_map(lambda x: x.ravel(), jax.tree_flatten(alst)[0])) - return cls(_aval, data_ravel) - - @classmethod - def array_from_flattened( - cls, data_ravel: Union[np.ndarray, JaxArray], shape_tuple: BlockShape - ) -> BlockArray: - """Construct a :class:`.BlockArray` from a flattened array and tuple of shapes. - - Args: - data_ravel: Flattened data array. - shape_tuple: Tuple of tuples containing desired block shapes. - - Returns: - :class:`.BlockArray` initialized from `data_ravel` and `shape_tuple`. - """ - - if not isinstance(data_ravel, DeviceArray): - data_ravel = jax.device_put(data_ravel) - - shape_tuple_size = np.sum(block_sizes(shape_tuple)) - - if shape_tuple_size != data_ravel.size: - raise ValueError( - f"""The specified shape_tuple is incompatible with provided data_ravel - shape_tuple = {shape_tuple} - shape_tuple_size = {shape_tuple_size} - len(data_ravel) = {len(data_ravel)} - """ - ) - - _aval = _AbstractBlockArray(shape_tuple, dtype=data_ravel.dtype) - return cls(_aval, data_ravel) - - @classmethod - def ones(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with ones. - - Args: - shape_tuple: Tuple of shapes for component blocks. - dtype: Desired data-type for the :class:`.BlockArray`. - Default is ``numpy.float32``. - - Returns: - :class:`.BlockArray` of ones with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.ones(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def zeros(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. - - Args: - shape_tuple: Tuple of shapes for component blocks. - dtype: Desired data-type for the :class:`.BlockArray`. - Default is ``numpy.float32``. - - Returns: - :class:`.BlockArray` of zeros with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.zeros(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def empty(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. - - Note: like :func:`jax.numpy.empty`, this does not return an - uninitalized array. - - Args: - shape_tuple: Tuple of shapes for component blocks - dtype: Desired data-type for the :class:`.BlockArray`. - Default is ``numpy.float32``. - - Returns: - :class:`.BlockArray` of zeros with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.empty(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def full( - cls, - shape_tuple: BlockShape, - fill_value: Union[float, complex, int], - dtype: DType = np.float32, - ) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with - `fill_value`. - - Args: - shape_tuple: Tuple of shapes for component blocks. - fill_value: Fill value. - dtype: Desired data-type for the BlockArray. The default, - None, means `np.array(fill_value).dtype`. - - Returns: - :class:`.BlockArray` with the given component shapes and - dtype and all entries equal to `fill_value`. - """ - if dtype is None: - dtype = np.asarray(fill_value).dtype - - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.full(_aval.size, fill_value=fill_value, dtype=dtype) - return cls(_aval, data_ravel) - - def copy(self) -> BlockArray: - """Return a copy of this :class:`.BlockArray`. - - This method is not implemented for BlockArray. - - See Also: - :meth:`.to_numpy`: Convert a :class:`.BlockArray` into a - flattened NumPy array. - """ - # jax DeviceArray copies return a NumPy ndarray. This blockarray class must be backed - # by a DeviceArray, so cannot be converted to a NumPy-backed BlockArray. The BlockArray - # .to_numpy() method returns a flattened ndarray. - # - # This method may be implemented in the future if jax DeviceArray .copy() is modified to - # return another DeviceArray. - raise NotImplementedError - - def to_numpy(self) -> np.ndarray: - """Return a :class:`numpy.ndarray` containing the flattened form of this - :class:`.BlockArray`.""" - - if isinstance(self._data, DeviceArray): - host_arr = jax.device_get(self._data.copy()) - else: - host_arr = self._data.copy() - return host_arr - - def blockidx(self, idx: int) -> jax._src.ops.scatter._Indexable: - """Return :class:`jax.ops.index` for a given component block. - - Args: - idx: Desired block index. - - Returns: - :class:`jax.ops.index` pointing to desired block. - """ - return slice(self.bndpos[idx], self.bndpos[idx + 1]) - - def ravel(self) -> JaxArray: - """Return a copy of `self._data` as a contiguous, flattened `DeviceArray`. + def ravel(self) -> DeviceArray: + """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. Note that a copy, rather than a view, of the underlying array is returned. This is consistent with :func:`jax.numpy.ravel`. @@ -1227,240 +479,101 @@ def ravel(self) -> JaxArray: Copy of underlying flattened array. """ - return self._data[:] + return jnp.concatenate(tuple(x_i.ravel() for x_i in self)) - def flatten(self) -> JaxArray: - """Return a copy of `self._data` as a contiguous, flattened `DeviceArray`. + """ backwards compatibility methods, could be removed """ - Note that a copy, rather than a view, of the underlying array is - returned. This is consistent with :func:`jax.numpy.ravel`. + @staticmethod + def array(iterable): + """Construct a :class:`.BlockArray` from a list or tuple of existing array-like.""" + return BlockArray(iterable) - Returns: - Copy of underlying flattened array. - - """ - return self._data[:] - def sum(self, axis=None, keepdims=False): - """Return the sum of the blockarray elements over the given axis. - - Refer to :func:`scico.numpy.sum` for full documentation. - """ - # Can't just call scico.numpy.sum due to pesky circular import... - return _block_array_reduction_wrapper(jnp.sum)(self, axis=axis, keepdims=keepdims) - - -## Register BlockArray as a Jax type -# Our BlockArray is just a single large vector with some extra sugar -class _ConcreteBlockArray(_AbstractBlockArray): - pass - - -def _block_array_result_handler(device, _aval): - def build_block_array(data_buf): - data = xla.DeviceArray(_aval._data_aval, device, None, data_buf) - return BlockArray(_aval, data) - - return build_block_array - - -def _block_array_shape_handler(a): - return (xla.xc.Shape.array_shape(a._data_aval.dtype, a._data_aval.shape),) - - -def _block_array_device_put_handler(a, device): - return (xla.xb.get_device_backend(device).buffer_from_pyval(a._data, device),) - - -core.pytype_aval_mappings[BlockArray] = lambda x: x._aval -core.raise_to_shaped_mappings[_AbstractBlockArray] = lambda _aval, _: _aval -xla.pytype_aval_mappings[BlockArray] = lambda x: x._aval -xla.canonicalize_dtype_handlers[BlockArray] = lambda x: x -jax._src.dispatch.device_put_handlers[BlockArray] = _block_array_device_put_handler -jax._src.dispatch.result_handlers[_AbstractBlockArray] = _block_array_result_handler -xla.xla_shape_handlers[_AbstractBlockArray] = _block_array_shape_handler - - -## Handlers to use jax.device_put on BlockArray -def _block_array_tree_flatten(block_arr): - """Flatten a :class:`.BlockArray` pytree. - - See :func:`jax.tree_util.tree_flatten`. - - Args: - block_arr (:class:`.BlockArray`): :class:`.BlockArray` to flatten - - Returns: - children (tuple): :class:`.BlockArray` leaves. - aux_data (tuple): Extra metadata used to reconstruct BlockArray. - """ - - data_children, data_aux_data = tree_flatten(block_arr._data) - return (data_children, block_arr._aval) - - -def _block_array_tree_unflatten(aux_data, children): - """Construct a :class:`.BlockArray` from a flattened pytree. - - See jax.tree_utils.tree_unflatten - - Args: - aux_data (tuple): Metadata needed to construct block array. - children (tuple): Contains block array elements. - - Returns: - block_arr: Constructed :class:`.BlockArray`. - """ - return BlockArray(aux_data, children[0]) - - -register_pytree_node(BlockArray, _block_array_tree_flatten, _block_array_tree_unflatten) - -# Syntactic sugar for the .at operations -# see https://github.com/google/jax/blob/56e9f7cb92e3a099adaaca161cc14153f024047c/jax/_src/numpy/lax_numpy.py#L5900 -class _BlockArrayIndexUpdateHelper: - """The helper class for the `at` property to call indexed update functions. - - The `at` property is syntactic sugar for calling the indexed update - functions as is done in jax. The index must be of the form [ibk] or - [ibk,idx], where `ibk` is the index of the block to be updated, and - `idx` is a general index of the elements to be updated in that block. - - In particular: - - `x = x.at[ibk].set(y)` is an equivalent of `x[ibk] = y`. - - `x = x.at[ibk,idx].set(y)` is an equivalent of `x[ibk,idx] = y`. - - The methods `set, add, multiply, divide, power, maximum, minimum` - are supported. - """ - - __slots__ = ("_block_array",) - - def __init__(self, block_array): - self._block_array = block_array - - def __getitem__(self, index): - if isinstance(index, tuple): - if isinstance(index[0], slice): - raise TypeError(f"Slicing not supported along block index") - return _BlockArrayIndexUpdateRef(self._block_array, index) - - def __repr__(self): - print(f"_BlockArrayIndexUpdateHelper({repr(self._block_array)})") - - -class _BlockArrayIndexUpdateRef: - """Helper object to call indexed update functions for an (advanced) index. +# register BlockArray as a jax pytree, without this, jax autograd won't work +# taken from what is done with tuples in jax._src.tree_util +jax.tree_util.register_pytree_node( + BlockArray, + lambda xs: (xs, None), # to iter + lambda _, xs: BlockArray(xs), # from iter +) - This object references a source block array and a specific indexer, - with the first integer specifying the block being updated, and rest - being the indexer into the array of that block. Methods on this - object return copies of the source block array that have been - modified at the positions specified by the indexer in the given block. - """ +""" wrap binary ops like +, @ """ +binary_ops = ( + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__matmul__", + "__rmatmul__", + "__truediv__", + "__rtruediv__", + "__floordiv__", + "__rfloordiv__", + "__pow__", + "__rpow__", + "__gt__", + "__ge__", + "__lt__", + "__le__", + "__eq__", + "__ne__", +) - __slots__ = ("_block_array", "bk_index", "index") - def __init__(self, block_array, index): - self._block_array = block_array - if isinstance(index, int): - self.bk_index = index - self.index = Ellipsis - elif index == Ellipsis: - self.bk_index = Ellipsis - self.index = Ellipsis +def _binary_op_wrapper(func): + @wraps(func) + def func_ba(self, other): + if isinstance(other, BlockArray): + result = BlockArray(map(func, self, other)) else: - self.bk_index = index[0] - self.index = index[1:] - - def __repr__(self): - return f"_BlockArrayIndexUpdateRef({repr(self._block_array)}, {repr(self.bk_index)}, {repr(self.index)})" - - def _index_wrapper(self, func_str, values): - bk_index = self.bk_index - index = self.index - arr_tuple = self._block_array.split - if bk_index == Ellipsis: - # This may result in multiple copies: one per sub-blockarray, - # then one to combine into a nested BA. - retval = BlockArray.array([getattr(_.at[index], func_str)(values) for _ in arr_tuple]) + result = BlockArray(map(lambda self_n: func(self_n, other), self)) + if NotImplemented in result: + return NotImplemented else: - retval = BlockArray.array( - arr_tuple[:bk_index] - + (getattr(arr_tuple[bk_index].at[index], func_str)(values),) - + arr_tuple[bk_index + 1 :] - ) - return retval + return result - def set(self, values): - """Pure equivalent of `x[idx] = y`. + return func_ba - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] = y`. - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("set", values) +for op in binary_ops: + setattr(BlockArray, op, _binary_op_wrapper(getattr(DeviceArray, op))) - def add(self, values): - """Pure equivalent of `x[idx] += y`. - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] += y`. +""" wrap blockwise DeviceArray methods, like conj """ - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("add", values) +da_methods = ("conj",) - def multiply(self, values): - """Pure equivalent of `x[idx] *= y`. - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] *= y`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("multiply", values) - - def divide(self, values): - """Pure equivalent of `x[idx] /= y`. - - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] /= y`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("divide", values) - - def power(self, values): - """Pure equivalent of `x[idx] **= y`. +def _da_method_wrapper(func): + @wraps(func) + def func_ba(self): + return BlockArray(map(func, self)) - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] **= y`. + return func_ba - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("power", values) - def min(self, values): - """Pure equivalent of `x[idx] = minimum(x[idx], y)`. +for meth in da_methods: + setattr(BlockArray, meth, _da_method_wrapper(getattr(DeviceArray, meth))) - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] = minimum(x[idx], y)`. +""" wrap blockwise DeviceArray properties, like real """ - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("min", values) +da_props = ( + "real", + "imag", + "shape", + "ndim", +) - def max(self, values): - """Pure equivalent of `x[idx] = maximum(x[idx], y)`. - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] = maximum(x[idx], y)`. +def _da_prop_wrapper(prop): + @property + def prop_ba(self): + return BlockArray((getattr(x, prop) for x in self)) - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("max", values) + return prop_ba -setattr(BlockArray, "at", property(_BlockArrayIndexUpdateHelper)) +for prop in da_props: + setattr(BlockArray, prop, _da_prop_wrapper(prop)) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 01d9e435d..68bbc847c 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -27,7 +27,7 @@ # These functions rely on the definition of a BlockArray and must be in # scico.blockarray to avoid a circular import -from scico.blockarray import ( +from scico.blockarray_old import ( BlockArray, _block_array_matmul_wrapper, _block_array_reduction_wrapper, diff --git a/scico/numpy/linalg.py b/scico/numpy/linalg.py index 71974465e..a6729d0ea 100644 --- a/scico/numpy/linalg.py +++ b/scico/numpy/linalg.py @@ -23,7 +23,7 @@ import jax import jax.numpy.linalg as jla -from scico.blockarray import _block_array_reduction_wrapper +from scico.blockarray_old import _block_array_reduction_wrapper from scico.linop._matrix import MatrixOperator from ._util import _attach_wrapped_func, _not_implemented diff --git a/scico/random.py b/scico/random.py index ee45e46d5..ead5af813 100644 --- a/scico/random.py +++ b/scico/random.py @@ -60,7 +60,7 @@ import jax from scico.array import is_nested -from scico.blockarray import BlockArray, block_sizes +from scico.blockarray import BlockArray from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index d5aa30858..7ab581731 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -27,25 +27,25 @@ def __init__(self, dtype): self.a0, key = randn(shape=(2, 3), dtype=dtype, key=key) self.a1, key = randn(shape=(2, 3, 4), dtype=dtype, key=key) - self.a = ba.BlockArray.array((self.a0, self.a1), dtype=dtype) + self.a = ba.BlockArray.array((self.a0, self.a1)) self.b0, key = randn(shape=(2, 3), dtype=dtype, key=key) self.b1, key = randn(shape=(2, 3, 4), dtype=dtype, key=key) - self.b = ba.BlockArray.array((self.b0, self.b1), dtype=dtype) + self.b = ba.BlockArray.array((self.b0, self.b1)) self.d0, key = randn(shape=(3, 2), dtype=dtype, key=key) self.d1, key = randn(shape=(2, 4, 3), dtype=dtype, key=key) - self.d = ba.BlockArray.array((self.d0, self.d1), dtype=dtype) + self.d = ba.BlockArray.array((self.d0, self.d1)) c0, key = randn(shape=(2, 3), dtype=dtype, key=key) - self.c = ba.BlockArray.array((c0,), dtype=dtype) + self.c = ba.BlockArray.array((c0,)) # A flat device array with same size as self.a & self.b self.flat_da, key = randn(shape=(self.a.size,), dtype=dtype, key=key) self.flat_nd = np.array(self.flat_da) # A device array with length == self.a.num_blocks - self.block_da, key = randn(shape=(self.a.num_blocks,), dtype=dtype, key=key) + self.block_da, key = randn(shape=(len(self.a),), dtype=dtype, key=key) # block_da but as a numpy array self.block_nd = np.array(self.block_da) @@ -78,6 +78,8 @@ def test_operator_right(test_operator_obj, operator): # Operations between a blockarray and a flat DeviceArray +@pytest.mark.skip # do we want to allow ((3,4), (4, 5, 6)) + (132,) ? +# argument against: numpy doesn't allow (3, 4) + (12,) @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_ba_da_left(test_operator_obj, operator): flat_da = test_operator_obj.flat_da @@ -87,6 +89,7 @@ def test_ba_da_left(test_operator_obj, operator): np.testing.assert_allclose(x, y, rtol=5e-5) +@pytest.mark.skip # see previous @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_ba_da_right(test_operator_obj, operator): flat_da = test_operator_obj.flat_da @@ -97,44 +100,49 @@ def test_ba_da_right(test_operator_obj, operator): # Blockwise comparison between a BlockArray and Ndarray +@pytest.mark.skip # do we want to allow ((3,4), (4, 5, 6)) + (2,) ? +# argument against numpy doesn't allow (3, 4) + (3,), though leading dims match @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_ndarray_left(test_operator_obj, operator): a = test_operator_obj.a block_nd = test_operator_obj.block_nd x = operator(a, block_nd).ravel() - y = ba.BlockArray.array([operator(a[i], block_nd[i]) for i in range(a.num_blocks)]).ravel() + y = ba.BlockArray.array([operator(a[i], block_nd[i]) for i in range(len(a))]).ravel() np.testing.assert_allclose(x, y) +@pytest.mark.skip # see previous @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_ndarray_right(test_operator_obj, operator): a = test_operator_obj.a block_nd = test_operator_obj.block_nd x = operator(block_nd, a).ravel() - y = ba.BlockArray.array([operator(block_nd[i], a[i]) for i in range(a.num_blocks)]).ravel() - np.testing.assert_allclose(x, y, rtol=1e-6) + y = ba.BlockArray.array([operator(block_nd[i], a[i]) for i in range(len(a))]).ravel() + np.testing.assert_allclose(x, y) # Blockwise comparison between a BlockArray and DeviceArray +@pytest.mark.skip # see previous @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_devicearray_left(test_operator_obj, operator): a = test_operator_obj.a block_da = test_operator_obj.block_da x = operator(a, block_da).ravel() - y = ba.BlockArray.array([operator(a[i], block_da[i]) for i in range(a.num_blocks)]).ravel() + y = ba.BlockArray.array([operator(a[i], block_da[i]) for i in range(len(a))]).ravel() np.testing.assert_allclose(x, y) +@pytest.mark.skip # see previous @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_devicearray_right(test_operator_obj, operator): a = test_operator_obj.a block_da = test_operator_obj.block_da x = operator(block_da, a).ravel() - y = ba.BlockArray.array([operator(block_da[i], a[i]) for i in range(a.num_blocks)]).ravel() + y = ba.BlockArray.array([operator(block_da[i], a[i]) for i in range(len(a))]).ravel() np.testing.assert_allclose(x, y, atol=1e-7, rtol=0) From 39c68645d6177171654f0131f38badc8e3e75126 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 22 Mar 2022 09:27:30 -0600 Subject: [PATCH 02/37] Add old blockarray which is needed to even import scico --- scico/blockarray_old.py | 1466 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 1466 insertions(+) create mode 100644 scico/blockarray_old.py diff --git a/scico/blockarray_old.py b/scico/blockarray_old.py new file mode 100644 index 000000000..c611c1eb7 --- /dev/null +++ b/scico/blockarray_old.py @@ -0,0 +1,1466 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2020-2022 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SPORCO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r"""Extensions of numpy ndarray class. + + .. testsetup:: + + >>> import scico + >>> import scico.numpy as snp + >>> from scico.blockarray import BlockArray + >>> import numpy as np + >>> import jax.numpy + +The class :class:`.BlockArray` is a `jagged array +`_ that aims to mimic the +:class:`numpy.ndarray` interface where appropriate. + +A :class:`.BlockArray` object consists of a tuple of `DeviceArray` +objects that share their memory buffers with non-overlapping, contiguous +regions of a common one-dimensional `DeviceArray`. A :class:`.BlockArray` +contains the following size attributes: + +* `shape`: A tuple of tuples containing component dimensions. +* `size`: The sum of the size of each component block; this is the length + of the underlying one-dimensional `DeviceArray`. +* `num_blocks`: The number of components (blocks) that comprise the + :class:`.BlockArray`. + + +Motivating Example +================== + +Consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. + +We compute the discrete differences of :math:`\mb{x}` in the horizontal +and vertical directions, generating two new arrays: :math:`\mb{x}_h \in +\mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) +\times m}`. + +As these arrays are of different shapes, we cannot combine them into a +single `ndarray`. Instead, we might vectorize each array and concatenate +the resulting vectors, leading to :math:`\mb{\bar{x}} \in +\mathbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional +`ndarray`. Unfortunately, this makes it hard to access the individual +components :math:`\mb{x}_h` and :math:`\mb{x}_v`. + +Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B = +[\mb{x}_h, \mb{x}_v]` + + + :: + + >>> n = 32 + >>> m = 16 + >>> x_h, key = scico.random.randn((n, m-1)) + >>> x_v, _ = scico.random.randn((n-1, m), key=key) + + # Form the blockarray + >>> x_B = BlockArray.array([x_h, x_v]) + + # The blockarray shape is a tuple of tuples + >>> x_B.shape + ((32, 15), (31, 16)) + + # Each block component can be easily accessed + >>> x_B[0].shape + (32, 15) + >>> x_B[1].shape + (31, 16) + + +Constructing a BlockArray +========================= + +Construct from a tuple of arrays (either `ndarray` or `DeviceArray`) +-------------------------------------------------------------------- + + .. doctest:: + + >>> from scico.blockarray import BlockArray + >>> import numpy as np + >>> x0, key = scico.random.randn((32, 32)) + >>> x1, _ = scico.random.randn((16,), key=key) + >>> X = BlockArray.array((x0, x1)) + >>> X.shape + ((32, 32), (16,)) + >>> X.size + 1040 + >>> X.num_blocks + 2 + +While :func:`.BlockArray.array` will accept either `ndarray` or +`DeviceArray` as input, the resulting :class:`.BlockArray` will be backed +by a `DeviceArray` memory buffer. + +**Note**: constructing a :class:`.BlockArray` always involves a copy to +a new `DeviceArray` memory buffer. + +**Note**: by default, the resulting :class:`.BlockArray` is cast to +single precision and will have dtype `float32` or `complex64`. + + +Construct from a single vector and tuple of shapes +-------------------------------------------------- + + :: + + >>> x_flat, _ = scico.random.randn((1040,)) + >>> shape_tuple = ((32, 32), (16,)) + >>> X = BlockArray.array_from_flattened(x_flat, shape_tuple=shape_tuple) + >>> X.shape + ((32, 32), (16,)) + + + +Operating on a BlockArray +========================= + +.. _blockarray_indexing: + +Indexing +-------- + +The block index is required to be an integer, selecting a single block and +returning it as an array (*not* a singleton BlockArray). If the index +expression has more than one component, then the initial index indexes the +block, and the remainder of the indexing expression indexes within the +selected block, e.g. ``x[2, 3:4]`` is equivalent to ``y[3:4]`` after +setting ``y = x[2]``. + + +Indexed Updating +---------------- + +BlockArrays support the JAX DeviceArray `indexed update syntax +`_ + + +The index must be of the form [ibk] or [ibk, idx], where `ibk` is the +index of the block to be updated, and `idx` is a general index of the +elements to be updated in that block. In particular, `ibk` cannot be a +`slice`. The general index `idx` can be omitted, in which case an entire +block is updated. + + +============================== ============================================== +Alternate syntax Equivalent in-place expression +============================== ============================================== +``x.at[ibk, idx].set(y)`` ``x[ibk, idx] = y`` +``x.at[ibk, idx].add(y)`` ``x[ibk, idx] += y`` +``x.at[ibk, idx].multiply(y)`` ``x[ibk, idx] *= y`` +``x.at[ibk, idx].divide(y)`` ``x[ibk, idx] /= y`` +``x.at[ibk, idx].power(y)`` ``x[ibk, idx] **= y`` +``x.at[ibk, idx].min(y)`` ``x[ibk, idx] = np.minimum(x[idx], y)`` +``x.at[ibk, idx].max(y)`` ``x[ibk, idx] = np.maximum(x[idx], y)`` +============================== ============================================== + + +Arithmetic and Broadcasting +--------------------------- + +Suppose :math:`\mb{x}` is a BlockArray with shape :math:`((n, n), (m,))`. + + :: + + >>> x1, key = scico.random.randn((4, 4)) + >>> x2, _ = scico.random.randn((5,), key=key) + >>> x = BlockArray.array( (x1, x2) ) + >>> x.shape + ((4, 4), (5,)) + >>> x.num_blocks + 2 + >>> x.size # 4*4 + 5 + 21 + +Illustrated for the operation ``+``, but equally valid for operators +``+, -, *, /, //, **, <, <=, >, >=, ==`` + + +Operations with BlockArrays with same number of blocks +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Let :math:`\mb{y}` be a BlockArray with the same number of blocks as +:math:`\mb{x}`. + + .. math:: + \mb{x} + \mb{y} + = + \begin{bmatrix} + \mb{x}[0] + \mb{y}[0] \\ + \mb{x}[1] + \mb{y}[1] \\ + \end{bmatrix} + +This operation depends on pair of blocks from :math:`\mb{x}` and +:math:`\mb{y}` being broadcastable against each other. + + + +Operations with a scalar +^^^^^^^^^^^^^^^^^^^^^^^^ + +The scalar is added to each element of the :class:`.BlockArray`: + + .. math:: + \mb{x} + 1 + = + \begin{bmatrix} + \mb{x}[0] + 1 \\ + \mb{x}[1] + 1\\ + \end{bmatrix} + + + :: + + >>> y = x + 1 + >>> np.testing.assert_allclose(y[0], x[0] + 1) + >>> np.testing.assert_allclose(y[1], x[1] + 1) + + + +Operations with a 1D `ndarray` of size equal to `num_blocks` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The *i*\th scalar is added to the *i*\th block of the +:class:`.BlockArray`: + + .. math:: + \mb{x} + + + \begin{bmatrix} + 1 \\ + 2 + \end{bmatrix} + = + \begin{bmatrix} + \mb{x}[0] + 1 \\ + \mb{x}[1] + 2\\ + \end{bmatrix} + + + :: + + >>> y = x + np.array([1, 2]) + >>> np.testing.assert_allclose(y[0], x[0] + 1) + >>> np.testing.assert_allclose(y[1], x[1] + 2) + + +Operations with an ndarray of `size` equal to :class:`.BlockArray` size +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We first cast the `ndarray` to a BlockArray with same shape as +:math:`\mb{x}`, then apply the operation on the resulting BlockArrays. +With ``y.size = x.size``, we have: + + .. math:: + \mb{x} + + + \mb{y} + = + \begin{bmatrix} + \mb{x}[0] + \mb{y}[0] \\ + \mb{x}[1] + \mb{y}[1]\\ + \end{bmatrix} + +Equivalently, the BlockArray is first flattened, then added to the +flattened `ndarray`, and the result is reformed into a BlockArray with +the same shape as :math:`\mb{x}` + + + +MatMul +------ + +Between two BlockArrays +^^^^^^^^^^^^^^^^^^^^^^^ + +The matmul is computed between each block of the two BlockArrays. + +The BlockArrays must have the same number of blocks, and each pair of +blocks must be broadcastable. + + .. math:: + \mb{x} @ \mb{y} + = + \begin{bmatrix} + \mb{x}[0] @ \mb{y}[0] \\ + \mb{x}[1] @ \mb{y}[1]\\ + \end{bmatrix} + + + +Between BlockArray and Ndarray/DeviceArray +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This operation is not defined. + + +Between BlockArray and :class:`.LinearOperator` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :class:`.Operator` and :class:`.LinearOperator` classes are designed +to work on :class:`.BlockArray`\ s in addition to `DeviceArray`\ s. +For example + + + :: + + >>> x, key = scico.random.randn((3, 4)) + >>> A_1 = scico.linop.Identity(x.shape) + >>> A_1.shape # array -> array + ((3, 4), (3, 4)) + + >>> A_2 = scico.linop.FiniteDifference(x.shape) + >>> A_2.shape # array -> BlockArray + (((2, 4), (3, 3)), (3, 4)) + + >>> diag = BlockArray.array([np.array(1.0), np.array(2.0)]) + >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape)) + >>> A_3.shape # BlockArray -> BlockArray + (((2, 4), (3, 3)), ((2, 4), (3, 3))) + + +NumPy ufuncs +------------ + +`NumPy universal functions (ufuncs) `_ +are functions that operate on an `ndarray` on an element-by-element +fashion and support array broadcasting. Examples of ufuncs are ``abs``, +``sign``, ``conj``, and ``exp``. + +The JAX library implements most NumPy ufuncs in the :mod:`jax.numpy` +module. However, as JAX does not support subclassing of `DeviceArray`, +the JAX ufuncs cannot be used on :class:`.BlockArray`. As a workaround, +we have wrapped several JAX ufuncs for use on :class:`.BlockArray`; these +are defined in the :mod:`scico.numpy` module. + + +Reductions +^^^^^^^^^^ + +Reductions are functions that take an array-like as an input and return +an array of lower dimension. Examples include ``mean``, ``sum``, ``norm``. +BlockArray reductions are located in the :mod:`scico.numpy` module + +:class:`.BlockArray` tries to mirror `ndarray` reduction semantics where +possible, but cannot provide a one-to-one match as the block components +may be of different size. + +Consider the example BlockArray + + .. math:: + \mb{x} = \begin{bmatrix} + \begin{bmatrix} + 1 & 1 \\ + 1 & 1 + \end{bmatrix} \\ + \begin{bmatrix} + 2 \\ + 2 + \end{bmatrix} + \end{bmatrix}. + +We have + + .. doctest:: + + >>> import scico.numpy as snp + >>> x = BlockArray.array((np.ones((2,2)), 2*np.ones((2)))) + >>> x.shape + ((2, 2), (2,)) + >>> x.size + 6 + >>> x.num_blocks + 2 + + + + If no axis is specified, the reduction is applied to the flattened + array: + + .. doctest:: + + >>> snp.sum(x, axis=None).item() + 8.0 + + Reducing along the 0-th axis crushes the `BlockArray` down into a + single `DeviceArray` and requires all blocks to have the same shape + otherwise, an error is raised. + + .. doctest:: + + >>> snp.sum(x, axis=0) + Traceback (most recent call last): + ValueError: Evaluating sum of BlockArray along axis=0 requires all blocks to be same shape; got ((2, 2), (2,)) + + >>> y = BlockArray.array((np.ones((2,2)), 2*np.ones((2, 2)))) + >>> snp.sum(y, axis=0) + DeviceArray([[3., 3.], + [3., 3.]], dtype=float32) + + Reducing along axis :math:`n` is equivalent to reducing each component + along axis :math:`n-1`: + + .. math:: + \text{sum}(x, axis=1) = \begin{bmatrix} + \begin{bmatrix} + 2 \\ + 2 + \end{bmatrix} \\ + \begin{bmatrix} + 4 \\ + \end{bmatrix} + \end{bmatrix} + + + If a component does not have axis :math:`n-1`, the reduction is not + applied to that component. In this example, ``x[1].ndim == 1``, so no + reduction is applied to block ``x[1]``. + + .. math:: + \text{sum}(x, axis=2) = \begin{bmatrix} + \begin{bmatrix} + 2 \\ + 2 + \end{bmatrix} \\ + \begin{bmatrix} + 2 \\ + 2 + \end{bmatrix} + \end{bmatrix} + + +Code version + + .. doctest:: + + >>> snp.sum(x, axis=1) # doctest: +SKIP + BlockArray([[2, 2], + [4,] ]) + + >>> snp.sum(x, axis=2) # doctest: +SKIP + BlockArray([ [2, 2], + [2,] ]) + + +""" + +from __future__ import annotations + +from functools import wraps +from typing import Iterator, List, Optional, Tuple, Union + +import numpy as np + +import jax +import jax.numpy as jnp +from jax import core +from jax.interpreters import xla +from jax.interpreters.xla import DeviceArray +from jax.tree_util import register_pytree_node, tree_flatten + +from jaxlib.xla_extension import Buffer + +from scico import array +from scico.typing import Axes, AxisIndex, BlockShape, DType, JaxArray, Shape + +_arraylikes = (Buffer, DeviceArray, np.ndarray) + + +def atleast_1d(*arys): + """Convert inputs to arrays with at least one dimension. + + A wrapper for :func:`jax.numpy.atleast_1d` that acts as usual on + ndarrays and DeviceArrays, and returns BlockArrays unmodified. + """ + + if len(arys) == 1: + arr = arys[0] + return arr if isinstance(arr, BlockArray) else jnp.atleast_1d(arr) + + out = [] + for arr in arys: + if isinstance(arr, BlockArray): + out.append(arr) + else: + out.append(jnp.atleast_1d(arr)) + return out + + +# Append docstring from original jax.numpy function +atleast_1d.__doc__ = ( + atleast_1d.__doc__.replace("\n ", "\n") # deal with indentation differences + + "\nDocstring for :func:`jax.numpy.atleast_1d`:\n\n" + + "\n".join(jax.numpy.atleast_1d.__doc__.split("\n")[2:]) +) + + +def reshape( + a: Union[JaxArray, BlockArray], newshape: Union[Shape, BlockShape] +) -> Union[JaxArray, BlockArray]: + """Change the shape of an array without changing its data. + + Args: + a: Array to be reshaped. + newshape: The new shape should be compatible with the original + shape. If an integer, then the result will be a 1-D array of + that length. One shape dimension can be -1. In this case, + the value is inferred from the length of the array and + remaining dimensions. If a tuple of tuple of ints, a + :class:`.BlockArray` is returned. + + Returns: + The reshaped array. Unlike :func:`numpy.reshape`, a copy is + always returned. + """ + + if array.is_nested(newshape): + # x is a blockarray + return BlockArray.array_from_flattened(a, newshape) + + return jnp.reshape(a, newshape) + + +def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: + r"""Compute the 'sizes' of (possibly nested) block shapes. + + This function computes ``block_sizes(z.shape) == (_.size for _ in z)`` + + + Args: + shape: A shape tuple; possibly containing nested tuples. + + + Examples: + + .. doctest:: + >>> import scico.numpy as snp + + >>> x = BlockArray.ones( ( (4, 4), (2,))) + >>> x.size + 18 + + >>> y = snp.ones((3, 3)) + >>> y.size + 9 + + >>> z = BlockArray.array([x, y]) + >>> block_sizes(z.shape) + (18, 9) + + >>> zz = BlockArray.array([z, z]) + >>> block_sizes(zz.shape) + (27, 27) + """ + + if isinstance(shape, BlockArray): + raise TypeError( + "Expected a `shape` (possibly nested tuple of ints); got :class:`.BlockArray`." + ) + + out = [] + if array.is_nested(shape): + # shape is nested -> at least one element came from a blockarray + for y in shape: + if array.is_nested(y): + # recursively calculate the block size until we arrive at + # a tuple (shape of a non-block array) + while array.is_nested(y): + y = block_sizes(y) + out.append(np.sum(y)) # adjacent block sizes are added together + else: + # this is a tuple; size given by product of elements + out.append(np.prod(y)) + return tuple(out) + + # shape is a non-nested tuple; return the product + return np.prod(shape) + + +def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple: + """Decompose a BlockArray indexing expression into components. + + Decompose a BlockArray indexing expression into block and array + components. + + Args: + idx: BlockArray indexing expression. + + Returns: + A tuple (idxblk, idxarr) with entries corresponding to the + integer block index and the indexing to be applied to the + selected block, respectively. The latter is ``None`` if the + indexing expression simply selects one of the blocks (i.e. + it consists of a single integer). + + Raises: + TypeError: If the block index is not an integer. + """ + if isinstance(idx, tuple): + idxblk = idx[0] + idxarr = idx[1:] + else: + idxblk = idx + idxarr = None + if not isinstance(idxblk, int): + raise TypeError("Block index must be an integer") + return idxblk, idxarr + + +def indexed_shape(shape: Shape, idx: Union[int, Tuple[AxisIndex, ...]]) -> Tuple[int, ...]: + """Determine the shape of the result of indexing a BlockArray. + + Args: + shape: Shape of BlockArray. + idx: BlockArray indexing expression. + + Returns: + Shape of the selected block, or slice of that block if ``idx`` is a tuple + rather than an integer. + """ + idxblk, idxarr = _decompose_index(idx) + if idxblk < 0: + idxblk = len(shape) + idxblk + if idxarr is None: + return shape[idxblk] + return array.indexed_shape(shape[idxblk], idxarr) + + +def _flatten_blockarrays(inp, *args, **kwargs): + """Flatten any blockarrays present in inp, args, or kwargs.""" + + def _flatten_if_blockarray(inp): + if isinstance(inp, BlockArray): + return inp._data + return inp + + inp_ = _flatten_if_blockarray(inp) + args_ = (_flatten_if_blockarray(_) for _ in args) + kwargs_ = {key: _flatten_if_blockarray(val) for key, val in kwargs.items()} + return inp_, args_, kwargs_ + + +def _block_array_ufunc_wrapper(func): + """Wrap a "ufunc" to allow for joint operation on `DeviceArray` and `BlockArray`.""" + + @wraps(func) + def wrapper(inp, *args, **kwargs): + all_args = (inp,) + args + tuple(kwargs.items()) + if any([isinstance(_, BlockArray) for _ in all_args]): + # If 'inp' is a BlockArray, call func on inp._data + # Then return a BlockArray of the same shape as inp + + inp_, args_, kwargs_ = _flatten_blockarrays(inp, *args, **kwargs) + flat_out = func(inp_, *args_, **kwargs_) + return BlockArray.array_from_flattened(flat_out, inp.shape) + + # Otherwise call the function normally + return func(inp, *args, **kwargs) + + if not hasattr(func, "__doc__") or func.__doc__ is None: + return wrapper + + wrapper.__doc__ = ( + f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ + ) + return wrapper + + +def _block_array_reduction_wrapper(func): + """Wrap a reduction (eg. sum, norm) to allow for joint operation on + `DeviceArray` and `BlockArray`.""" + + @wraps(func) + def wrapper(inp, *args, axis=None, **kwargs): + + all_args = (inp,) + args + tuple(kwargs.items()) + if any([isinstance(_, BlockArray) for _ in all_args]): + if axis is None: + # Treat as a single long vector + inp_, args_, kwargs_ = _flatten_blockarrays(inp, *args, **kwargs) + return func(inp_, *args_, **kwargs_) + + if type(axis) == tuple: + raise Exception( + f"""Evaluating {func.__name__} on a BlockArray with a tuple argument to + axis is not currently supported""" + ) + + if axis == 0: # reduction along block axis + # reduction along axis=0 only makes sense if all blocks are the same shape + # so we can convert to a standard DeviceArray of shape (inp.num_blocks, ...) + # and reduce along axis = 0 + if all([bk_shape == inp.shape[0] for bk_shape in inp.shape]): + view_shape = (inp.num_blocks,) + inp.shape[0] + return func(inp._data.reshape(view_shape), *args, axis=0, **kwargs) + + raise ValueError( + f"Evaluating {func.__name__} of BlockArray along axis=0 requires " + f"all blocks to be same shape; got {inp.shape}" + ) + + # Reduce each block individually along axis-1 + out = [] + for bk in inp: + if isinstance(bk, BlockArray): + # This block is itself a blockarray, so call this wrapped reduction + # on axis-1 + tmp = _block_array_reduction_wrapper(func)(bk, *args, axis=axis - 1, **kwargs) + else: + if axis - 1 >= bk.ndim: + # Trying to reduce along a dim that doesn't exist for this block, + # so just return the block. + # ie broadcast to shape (..., 1) and reduce along axis=-1 + tmp = bk + else: + tmp = func(bk, *args, axis=axis - 1, **kwargs) + out.append(atleast_1d(tmp)) + return BlockArray.array(out) + + if axis is None: + # 'axis' might not be a valid kwarg (eg dot, vdot), so don't pass it + return func(inp, *args, **kwargs) + + return func(inp, *args, axis=axis, **kwargs) + + if not hasattr(func, "__doc__") or func.__doc__ is None: + return wrapper + + wrapper.__doc__ = ( + f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ + ) + return wrapper + + +def _block_array_matmul_wrapper(func): + @wraps(func) + def wrapper(self, other): + if isinstance(self, BlockArray): + if isinstance(other, BlockArray): + # Both blockarrays, work block by block + return BlockArray.array([func(x, y) for x, y in zip(self, other)]) + raise TypeError( + f"Operation {func.__name__} not implemented between {type(self)} and {type(other)}" + ) + return func(self, other) + + if not hasattr(func, "__doc__") or func.__doc__ is None: + return wrapper + wrapper.__doc__ = ( + f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ + ) + return wrapper + + +def _block_array_binary_op_wrapper(func): + """Return a decorator that performs type and shape checking for + :class:`.BlockArray` arithmetic. + """ + + @wraps(func) + def wrapper(self, other): + if isinstance(other, BlockArray): + if other.shape == self.shape: + # Same shape blocks, can operate on flattened arrays + return BlockArray.array_from_flattened(func(self._data, other._data), self.shape) + if other.num_blocks == self.num_blocks: + # Will work as long as the shapes are broadcastable + return BlockArray.array([func(x, y) for x, y in zip(self, other)]) + raise ValueError( + f"operation not valid on operands with shapes {self.shape} {other.shape}" + ) + if any([isinstance(other, _) for _ in _arraylikes]): + if other.size == 1: + # Same as operating on a scalar + return BlockArray.array_from_flattened(func(self._data, other), self.shape) + if other.size == self.size: + # A little fast and loose, treat the block array as a length self.size vector + return BlockArray.array_from_flattened(func(self._data, other), self.shape) + if other.size == self.num_blocks: + return BlockArray.array([func(blk, other_) for blk, other_ in zip(self, other)]) + raise ValueError( + f"operation not valid on operands with shapes {self.shape} {other.shape}" + ) + if jnp.isscalar(other) or isinstance(other, core.Tracer): + return BlockArray.array_from_flattened(func(self._data, other), self.shape) + raise TypeError( + f"Operation {func.__name__} not implemented between {type(self)} and {type(other)}" + ) + + if not hasattr(func, "__doc__") or func.__doc__ is None: + return wrapper + wrapper.__doc__ = ( + f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ + ) + return wrapper + + +class _AbstractBlockArray(core.ShapedArray): + """Abstract BlockArray class for JAX tracing. + + See https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html + """ + + array_abstraction_level = 0 # Same as jax.core.ConcreteArray + + def __init__(self, shapes, dtype): + + sizes = block_sizes(shapes) + size = np.sum(sizes) + + super(_AbstractBlockArray, self).__init__((size,), dtype) + + #: Abstract data value + self._data_aval: core.ShapedArray = core.ShapedArray((size,), dtype) + + #: Array dtype + self.dtype: DType = dtype + + #: Shape of each block + self.shapes: BlockShape = shapes + + #: Size of each block + self.sizes: Shape = sizes + + #: Array specifying boundaries of components as indices in base array + self.bndpos: np.ndarray = np.r_[0, np.cumsum(sizes)] + + +# The Jax class is heavily inspired by SparseArray/AbstractSparseArray here: +# https://github.com/google/jax/blob/7724322d1c08c13008815bfb52759a29c2a6823b/tests/custom_object_test.py +class BlockArray: + """A tuple of :class:`jax.interpreters.xla.DeviceArray` objects. + + A tuple of `DeviceArray` objects that all share their memory buffers + with non-overlapping, contiguous regions of a common one-dimensional + `DeviceArray`. It can be used as the common one-dimensional array via + the :func:`BlockArray.ravel` method, or individual component arrays + can be accessed individually. + """ + + # Ensure we use BlockArray.__radd__, __rmul__, etc for binary operations of the form + # op(np.ndarray, BlockArray) + # See https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ + __array_priority__ = 1 + + def __init__(self, aval: _AbstractBlockArray, data: JaxArray): + """BlockArray init method. + + Args: + aval: `Abstract value`_ associated with this array (shape+dtype+weak_type) + data: The underlying contiguous, flattened `DeviceArray`. + + .. _Abstract value: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html + + """ + self._aval = aval + self._data = data + + def __repr__(self): + return "scico.blockarray.BlockArray: \n" + self._data.__repr__() + + def __getitem__(self, idx: Union[int, Tuple[AxisIndex, ...]]) -> JaxArray: + idxblk, idxarr = _decompose_index(idx) + if idxblk < 0: + idxblk = self.num_blocks + idxblk + blk = reshape(self._data[self.bndpos[idxblk] : self.bndpos[idxblk + 1]], self.shape[idxblk]) + if idxarr is not None: + blk = blk[idxarr] + return blk + + @_block_array_matmul_wrapper + def __matmul__(self, other: Union[np.ndarray, BlockArray, JaxArray]) -> JaxArray: + return self @ other + + @_block_array_matmul_wrapper + def __rmatmul__(self, other: Union[np.ndarray, BlockArray, JaxArray]) -> JaxArray: + return other @ self + + @_block_array_binary_op_wrapper + def __sub__(a, b): + return a - b + + @_block_array_binary_op_wrapper + def __rsub__(a, b): + return b - a + + @_block_array_binary_op_wrapper + def __mul__(a, b): + return a * b + + @_block_array_binary_op_wrapper + def __rmul__(a, b): + return a * b + + @_block_array_binary_op_wrapper + def __add__(a, b): + return a + b + + @_block_array_binary_op_wrapper + def __radd__(a, b): + return a + b + + @_block_array_binary_op_wrapper + def __truediv__(a, b): + return a / b + + @_block_array_binary_op_wrapper + def __rtruediv__(a, b): + return b / a + + @_block_array_binary_op_wrapper + def __floordiv__(a, b): + return a // b + + @_block_array_binary_op_wrapper + def __rfloordiv__(a, b): + return b // a + + @_block_array_binary_op_wrapper + def __pow__(a, b): + return a ** b + + @_block_array_binary_op_wrapper + def __rpow__(a, b): + return b ** a + + @_block_array_binary_op_wrapper + def __gt__(a, b): + return a > b + + @_block_array_binary_op_wrapper + def __ge__(a, b): + return a >= b + + @_block_array_binary_op_wrapper + def __lt__(a, b): + return a < b + + @_block_array_binary_op_wrapper + def __le__(a, b): + return a <= b + + @_block_array_binary_op_wrapper + def __eq__(a, b): + return a == b + + @_block_array_binary_op_wrapper + def __ne__(a, b): + return a != b + + def __iter__(self) -> Iterator[int]: + for i in range(self.num_blocks): + yield self[i] + + @property + def blocks(self) -> Iterator[int]: + """Return an iterator yielding component blocks.""" + return self.__iter__() + + @property + def bndpos(self) -> np.ndarray: + """Array specifying boundaries of components as indices in base array.""" + return self._aval.bndpos + + @property + def dtype(self) -> DType: + """Array dtype.""" + return self._data.dtype + + @property + def device_buffer(self) -> Buffer: + """The :class:`jaxlib.xla_extension.Buffer` that backs the + underlying data array.""" + return self._data.device_buffer + + @property + def size(self) -> int: + """Total number of elements in the array.""" + return self._aval.size + + @property + def num_blocks(self) -> int: + """Number of :class:`.BlockArray` components.""" + + return len(self.shape) + + @property + def ndim(self) -> Shape: + """Tuple of component ndims.""" + + return tuple(len(c) for c in self.shape) + + @property + def shape(self) -> BlockShape: + """Tuple of component shapes.""" + + return self._aval.shapes + + @property + def split(self) -> Tuple[JaxArray, ...]: + """Tuple of component arrays.""" + + return tuple(self[k] for k in range(self.num_blocks)) + + def conj(self) -> BlockArray: + """Return a :class:`.BlockArray` with complex-conjugated elements.""" + + # Much faster than BlockArray.array([_.conj() for _ in self.blocks]) + return BlockArray.array_from_flattened(self.ravel().conj(), self.shape) + + @property + def real(self) -> BlockArray: + """Return a :class:`.BlockArray` with the real part of this array.""" + return BlockArray.array_from_flattened(self.ravel().real, self.shape) + + @property + def imag(self) -> BlockArray: + """Return a :class:`.BlockArray` with the imaginary part of this array.""" + return BlockArray.array_from_flattened(self.ravel().imag, self.shape) + + @classmethod + def array( + cls, alst: List[Union[np.ndarray, JaxArray]], dtype: Optional[np.dtype] = None + ) -> BlockArray: + """Construct a :class:`.BlockArray` from a list or tuple of existing array-like. + + Args: + alst: Initializers for array components. + Can be :class:`numpy.ndarray` or + :class:`jax.interpreters.xla.DeviceArray` + dtype: Data type of array. If none, dtype is derived from + dtype of initializers + + Returns: + :class:`.BlockArray` initialized from `alst` tuple + """ + + if isinstance(alst, (tuple, list)) is False: + raise TypeError("Input to `array` must be a list or tuple of existing arrays") + + if dtype is None: + present_types = jax.tree_flatten(jax.tree_map(lambda x: x.dtype, alst))[0] + dtype = np.find_common_type(present_types, []) + + # alst can be a list/tuple of arrays, or a list/tuple containing list/tuples of arrays + # consider alst to be a tree where leaves are arrays (possibly abstract arrays) + # use tree_map to find the shape of each leaf + # `shapes` will be a tuple of ints and tuples containing ints (possibly nested further) + + # ensure any scalar leaves are converted to (1,) arrays + def shape_atleast_1d(x): + return x.shape if x.shape != () else (1,) + + shapes = tuple( + jax.tree_map(shape_atleast_1d, alst, is_leaf=lambda x: not isinstance(x, (list, tuple))) + ) + + _aval = _AbstractBlockArray(shapes, dtype) + data_ravel = jnp.hstack(jax.tree_map(lambda x: x.ravel(), jax.tree_flatten(alst)[0])) + return cls(_aval, data_ravel) + + @classmethod + def array_from_flattened( + cls, data_ravel: Union[np.ndarray, JaxArray], shape_tuple: BlockShape + ) -> BlockArray: + """Construct a :class:`.BlockArray` from a flattened array and tuple of shapes. + + Args: + data_ravel: Flattened data array + shape_tuple: Tuple of tuples containing desired block shapes. + + Returns: + :class:`.BlockArray` initialized from `data_ravel` and `shape_tuple` + """ + + if not isinstance(data_ravel, DeviceArray): + data_ravel = jax.device_put(data_ravel) + + shape_tuple_size = np.sum(block_sizes(shape_tuple)) + + if shape_tuple_size != data_ravel.size: + raise ValueError( + f"""The specified shape_tuple is incompatible with provided data_ravel + shape_tuple = {shape_tuple} + shape_tuple_size = {shape_tuple_size} + len(data_ravel) = {len(data_ravel)} + """ + ) + + _aval = _AbstractBlockArray(shape_tuple, dtype=data_ravel.dtype) + return cls(_aval, data_ravel) + + @classmethod + def ones(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: + """ + Return a new :class:`.BlockArray` with given block shapes and type, filled with ones. + + Args: + shape_tuple: Tuple of shapes for component blocks + dtype: Desired data-type for the :class:`.BlockArray`. + Default is `numpy.float32`. + + Returns: + :class:`.BlockArray` of ones with the given component shapes + and dtype. + """ + _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) + data_ravel = jnp.ones(_aval.size, dtype=dtype) + return cls(_aval, data_ravel) + + @classmethod + def zeros(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: + """ + Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. + + Args: + shape_tuple: Tuple of shapes for component blocks. + dtype: Desired data-type for the :class:`.BlockArray`. + Default is `numpy.float32`. + + Returns: + :class:`.BlockArray` of zeros with the given component shapes + and dtype. + """ + _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) + data_ravel = jnp.zeros(_aval.size, dtype=dtype) + return cls(_aval, data_ravel) + + @classmethod + def empty(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: + """ + Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. + + Note: like :func:`jax.numpy.empty`, this does not return an + uninitalized array. + + Args: + shape_tuple: Tuple of shapes for component blocks + dtype: Desired data-type for the :class:`.BlockArray`. + Default is `numpy.float32`. + + Returns: + :class:`.BlockArray` of zeros with the given component shapes + and dtype. + """ + _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) + data_ravel = jnp.empty(_aval.size, dtype=dtype) + return cls(_aval, data_ravel) + + @classmethod + def full( + cls, + shape_tuple: BlockShape, + fill_value: Union[float, complex, int], + dtype: DType = np.float32, + ) -> BlockArray: + """ + Return a new :class:`.BlockArray` with given block shapes and type, filled with + `fill_value`. + + Args: + shape_tuple: Tuple of shapes for component blocks. + fill_value: Fill value + dtype: Desired data-type for the BlockArray. The default, + None, means `np.array(fill_value).dtype`. + + Returns: + :class:`.BlockArray` with the given component shapes and + dtype and all entries equal to `fill_value`. + """ + if dtype is None: + dtype = np.asarray(fill_value).dtype + + _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) + data_ravel = jnp.full(_aval.size, fill_value=fill_value, dtype=dtype) + return cls(_aval, data_ravel) + + def copy(self) -> BlockArray: + """Return a copy of this :class:`.BlockArray`. + + This method is not implemented for BlockArray. + + See Also: + :meth:`.to_numpy`: Convert a :class:`.BlockArray` into a + flattened NumPy array. + """ + # jax DeviceArray copies return a NumPy ndarray. This blockarray class must be backed + # by a DeviceArray, so cannot be converted to a NumPy-backed BlockArray. The BlockArray + # .to_numpy() method returns a flattened ndarray. + # + # This method may be implemented in the future if jax DeviceArray .copy() is modified to + # return another DeviceArray. + raise NotImplementedError + + def to_numpy(self) -> np.ndarray: + """Return a :class:`numpy.ndarray` containing the flattened form of this + :class:`.BlockArray`.""" + + if isinstance(self._data, DeviceArray): + host_arr = jax.device_get(self._data.copy()) + else: + host_arr = self._data.copy() + return host_arr + + def blockidx(self, idx: int) -> jax._src.ops.scatter._Indexable: + """Return :class:`jax.ops.index` for a given component block. + + Args: + idx: Desired block index. + + Returns: + :class:`jax.ops.index` pointing to desired block. + """ + return slice(self.bndpos[idx], self.bndpos[idx + 1]) + + def ravel(self) -> JaxArray: + """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. + + Note that a copy, rather than a view, of the underlying array is + returned. This is consistent with :func:`jax.numpy.ravel`. + + Returns: + Copy of underlying flattened array. + + """ + return self._data[:] + + def flatten(self) -> JaxArray: + """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. + + Note that a copy, rather than a view, of the underlying array is + returned. This is consistent with :func:`jax.numpy.ravel`. + + Returns: + Copy of underlying flattened array. + + """ + return self._data[:] + + def sum(self, axis=None, keepdims=False): + """Return the sum of the blockarray elements over the given axis. + + Refer to :func:`scico.numpy.sum` for full documentation. + """ + # Can't just call scico.numpy.sum due to pesky circular import... + return _block_array_reduction_wrapper(jnp.sum)(self, axis=axis, keepdims=keepdims) + + +## Register BlockArray as a Jax type +# Our BlockArray is just a single large vector with some extra sugar +class _ConcreteBlockArray(_AbstractBlockArray): + pass + + +def _block_array_result_handler(device, _aval): + def build_block_array(data_buf): + data = xla.DeviceArray(_aval._data_aval, device, None, data_buf) + return BlockArray(_aval, data) + + return build_block_array + + +def _block_array_shape_handler(a): + return (xla.xc.Shape.array_shape(a._data_aval.dtype, a._data_aval.shape),) + + +def _block_array_device_put_handler(a, device): + return (xla.xb.get_device_backend(device).buffer_from_pyval(a._data, device),) + + +core.pytype_aval_mappings[BlockArray] = lambda x: x._aval +core.raise_to_shaped_mappings[_AbstractBlockArray] = lambda _aval, _: _aval +xla.pytype_aval_mappings[BlockArray] = lambda x: x._aval +xla.canonicalize_dtype_handlers[BlockArray] = lambda x: x +jax._src.dispatch.device_put_handlers[BlockArray] = _block_array_device_put_handler +jax._src.dispatch.result_handlers[_AbstractBlockArray] = _block_array_result_handler +xla.xla_shape_handlers[_AbstractBlockArray] = _block_array_shape_handler + + +## Handlers to use jax.device_put on BlockArray +def _block_array_tree_flatten(block_arr): + """Flatten a :class:`.BlockArray` pytree. + + See :func:`jax.tree_util.tree_flatten`. + + Args: + block_arr (:class:`.BlockArray`): :class:`.BlockArray` to flatten + + Returns: + children (tuple): :class:`.BlockArray` leaves. + aux_data (tuple): Extra metadata used to reconstruct BlockArray. + """ + + data_children, data_aux_data = tree_flatten(block_arr._data) + return (data_children, block_arr._aval) + + +def _block_array_tree_unflatten(aux_data, children): + """Construct a :class:`.BlockArray` from a flattened pytree. + + See jax.tree_utils.tree_unflatten + + Args: + aux_data (tuple): Metadata needed to construct block array. + children (tuple): Contains block array elements. + + Returns: + block_arr: Constructed :class:`.BlockArray`. + """ + return BlockArray(aux_data, children[0]) + + +register_pytree_node(BlockArray, _block_array_tree_flatten, _block_array_tree_unflatten) + +# Syntactic sugar for the .at operations +# see https://github.com/google/jax/blob/56e9f7cb92e3a099adaaca161cc14153f024047c/jax/_src/numpy/lax_numpy.py#L5900 +class _BlockArrayIndexUpdateHelper: + """The helper class for the `at` property to call indexed update functions. + + The `at` property is syntactic sugar for calling the indexed update + functions as is done in jax. The index must be of the form [ibk] or + [ibk,idx], where `ibk` is the index of the block to be updated, and + `idx` is a general index of the elements to be updated in that block. + + In particular: + - ``x = x.at[ibk].set(y)`` is an equivalent of ``x[ibk] = y``. + - ``x = x.at[ibk,idx].set(y)`` is an equivalent of ``x[ibk,idx] = y``. + + The methods ``set, add, multiply, divide, power, maximum, minimum`` + are supported. + """ + + __slots__ = ("_block_array",) + + def __init__(self, block_array): + self._block_array = block_array + + def __getitem__(self, index): + if isinstance(index, tuple): + if isinstance(index[0], slice): + raise TypeError(f"Slicing not supported along block index") + return _BlockArrayIndexUpdateRef(self._block_array, index) + + def __repr__(self): + print(f"_BlockArrayIndexUpdateHelper({repr(self._block_array)})") + + +class _BlockArrayIndexUpdateRef: + """Helper object to call indexed update functions for an (advanced) index. + + This object references a source block array and a specific indexer, + with the first integer specifying the block being updated, and rest + being the indexer into the array of that block. Methods on this + object return copies of the source block array that have been + modified at the positions specified by the indexer in the given block. + """ + + __slots__ = ("_block_array", "bk_index", "index") + + def __init__(self, block_array, index): + self._block_array = block_array + if isinstance(index, int): + self.bk_index = index + self.index = Ellipsis + elif index == Ellipsis: + self.bk_index = Ellipsis + self.index = Ellipsis + else: + self.bk_index = index[0] + self.index = index[1:] + + def __repr__(self): + return f"_BlockArrayIndexUpdateRef({repr(self._block_array)}, {repr(self.bk_index)}, {repr(self.index)})" + + def _index_wrapper(self, func_str, values): + bk_index = self.bk_index + index = self.index + arr_tuple = self._block_array.split + if bk_index == Ellipsis: + # This may result in multiple copies: one per sub-blockarray, + # then one to combine into a nested BA. + retval = BlockArray.array([getattr(_.at[index], func_str)(values) for _ in arr_tuple]) + else: + retval = BlockArray.array( + arr_tuple[:bk_index] + + (getattr(arr_tuple[bk_index].at[index], func_str)(values),) + + arr_tuple[bk_index + 1 :] + ) + return retval + + def set(self, values): + """Pure equivalent of ``x[idx] = y``. + + Return the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] = y``. + + See :mod:`jax.ops` for details. + """ + return self._index_wrapper("set", values) + + def add(self, values): + """Pure equivalent of ``x[idx] += y``. + + Return the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] += y``. + + See :mod:`jax.ops` for details. + """ + return self._index_wrapper("add", values) + + def multiply(self, values): + """Pure equivalent of ``x[idx] *= y``. + + Return the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] *= y``. + + See :mod:`jax.ops` for details. + """ + return self._index_wrapper("multiply", values) + + def divide(self, values): + """Pure equivalent of ``x[idx] /= y``. + + Return the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] /= y``. + + See :mod:`jax.ops` for details. + """ + return self._index_wrapper("divide", values) + + def power(self, values): + """Pure equivalent of ``x[idx] **= y``. + + Return the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] **= y``. + + See :mod:`jax.ops` for details. + """ + return self._index_wrapper("power", values) + + def min(self, values): + """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. + + Return the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] = minimum(x[idx], y)``. + + See :mod:`jax.ops` for details. + """ + return self._index_wrapper("min", values) + + def max(self, values): + """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. + + Return the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] = maximum(x[idx], y)``. + + See :mod:`jax.ops` for details. + """ + return self._index_wrapper("max", values) + + +setattr(BlockArray, "at", property(_BlockArrayIndexUpdateHelper)) From 420ba24b8e22f572f6ee2684361addf09edb0cbd Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 25 Mar 2022 17:08:47 -0600 Subject: [PATCH 03/37] Start on snp --- scico/blockarray.py | 104 ++++++++++++++++++++-------------- scico/numpy/__init__.py | 46 ++++++++++++++- scico/random.py | 1 + scico/test/test_blockarray.py | 14 +++-- 4 files changed, 116 insertions(+), 49 deletions(-) diff --git a/scico/blockarray.py b/scico/blockarray.py index 7196b9363..f4bfc31f4 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -15,20 +15,34 @@ >>> import numpy as np >>> import jax.numpy -The class :class:`.BlockArray` is a `jagged array -`_ that aims to mimic the -:class:`numpy.ndarray` interface where appropriate. +The class :class:`.BlockArray` provides a way to combine several arrays +of different shapes and/or data types into a single object. +A :class:`.BlockArray` consists of a tuple of `DeviceArray` +objects, which we refer to as blocks. +:class:`.BlockArray`s differ from tuples in that mathematical operations +on :class:`.BlockArray`s automatically map along the blocks, returning +another :class:`.BlockArray` or tuple as appropriate. For example, + + :: + + >>> x = BlockArray(( + snp.array( + [[1, 3, 7], + [2, 2, 1],] + ), + snp.array( + [2, 4, 8] + ), + )) + >>> x.shape + ((2, 3), (3,)) # tuple + + >>> x + 1 + (DeviceArray([[2, 4, 8], + [3, 3, 2]], dtype=int32), + DeviceArray([3, 5, 9], dtype=int32)) # BlockArray -A :class:`.BlockArray` object consists of a tuple of `DeviceArray` -objects that share their memory buffers with non-overlapping, contiguous -regions of a common one-dimensional `DeviceArray`. A :class:`.BlockArray` -contains the following size attributes: -* `shape`: A tuple of tuples containing component dimensions. -* `size`: The sum of the size of each component block; this is the length - of the underlying one-dimensional `DeviceArray`. -* `num_blocks`: The number of components (blocks) that comprise the - :class:`.BlockArray`. Motivating Example @@ -455,6 +469,8 @@ from jaxlib.xla_extension import DeviceArray +import scico.numpy as snp + class BlockArray(tuple): """BlockArray""" @@ -464,11 +480,6 @@ class BlockArray(tuple): # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ __array_priority__ = 1 - @property - def size(self) -> int: - """Total number of elements in the array.""" - return sum(x_i.size for x_i in self) - def ravel(self) -> DeviceArray: """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. @@ -497,7 +508,7 @@ def array(iterable): lambda _, xs: BlockArray(xs), # from iter ) -""" wrap binary ops like +, @ """ +""" Wrap binary ops like +, @. """ binary_ops = ( "__add__", "__radd__", @@ -522,23 +533,23 @@ def array(iterable): ) -def _binary_op_wrapper(func): - @wraps(func) - def func_ba(self, other): +def _binary_op_wrapper(op): + @wraps(op) + def op_ba(self, other): if isinstance(other, BlockArray): - result = BlockArray(map(func, self, other)) + result = BlockArray(getattr(x, op)(y) for x, y in zip(self, other)) else: - result = BlockArray(map(lambda self_n: func(self_n, other), self)) + result = BlockArray(getattr(x, op)(other) for x in self) if NotImplemented in result: return NotImplemented else: return result - return func_ba + return op_ba for op in binary_ops: - setattr(BlockArray, op, _binary_op_wrapper(getattr(DeviceArray, op))) + setattr(BlockArray, op, _binary_op_wrapper(op)) """ wrap blockwise DeviceArray methods, like conj """ @@ -546,34 +557,41 @@ def func_ba(self, other): da_methods = ("conj",) -def _da_method_wrapper(func): - @wraps(func) - def func_ba(self): - return BlockArray(map(func, self)) +def _da_method_wrapper(method): + @wraps(method) + def method_ba(self): + return BlockArray(map(method, self)) + + return method_ba + - return func_ba +for method in da_methods: + setattr(BlockArray, method, _da_method_wrapper(getattr(DeviceArray, method))) -for meth in da_methods: - setattr(BlockArray, meth, _da_method_wrapper(getattr(DeviceArray, meth))) +""" Wrap DeviceArray methods and properties that are implemented in jnp so that they call snp. """ + + +def _da_method_wrapper(method, is_property=False): + def method_ba(self, *args, **kwargs): + return getattr(snp, method)(self, *args, **kwargs) + + if is_property: + return property(method_ba) + return method_ba -""" wrap blockwise DeviceArray properties, like real """ da_props = ( "real", "imag", + "size", "shape", "ndim", ) +for prop in da_props: + setattr(BlockArray, prop, _da_method_wrapper(prop, is_property=True)) -def _da_prop_wrapper(prop): - @property - def prop_ba(self): - return BlockArray((getattr(x, prop) for x in self)) - - return prop_ba - - -for prop in da_props: - setattr(BlockArray, prop, _da_prop_wrapper(prop)) +da_methods = ("sum",) +for method in da_methods: + setattr(BlockArray, method, _da_method_wrapper(method)) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 68bbc847c..8d3c5b18a 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -15,6 +15,7 @@ their documentation, below. """ + import sys from functools import wraps @@ -27,8 +28,8 @@ # These functions rely on the definition of a BlockArray and must be in # scico.blockarray to avoid a circular import +from scico.blockarray import BlockArray, da_methods, da_props from scico.blockarray_old import ( - BlockArray, _block_array_matmul_wrapper, _block_array_reduction_wrapper, _block_array_ufunc_wrapper, @@ -165,3 +166,46 @@ def vdot(a, b): # these must be imported towards the end to avoid a circular import with # linalg and _matrixop from . import fft, linalg + +""" new code, everything above should eventually go. """ +import inspect + +""" Make functions that map over BlockArray arguments""" +jnp_funcs_to_map = da_props + da_methods +jnp_funcs_to_map += ("dot",) + + +def _map_func_over_ba(func): + """Create a version of `func` that maps over all of its BlockArray + arguments.""" + + @wraps(func) + def mapped(*args, **kwargs): + bound_args = inspect.signature(func).bind(*args, **kwargs) + + ba_args = {} + for k, v in list(bound_args.arguments.items()): + if isinstance(v, BlockArray): + ba_args[k] = bound_args.arguments.pop(k) + + if len(ba_args): # if any BlockArray arguments, + return BlockArray( + map( # map over + lambda *args: ( # lambda x_1, x_2, ..., x_N + func( + *bound_args.args, + **bound_args.kwargs, # ... nonBlockArray args + **dict(zip(ba_args.keys(), args)), + ) # plus dict of block args + ), + *ba_args.values(), # map(f, ba_1, ba_2, ..., ba_N) + ) + ) + else: + return func(*args, **kwargs) + + return mapped + + +for func in jnp_funcs_to_map: + vars()[func] = _map_func_over_ba(getattr(jnp, func)) diff --git a/scico/random.py b/scico/random.py index ead5af813..fd68bdf08 100644 --- a/scico/random.py +++ b/scico/random.py @@ -61,6 +61,7 @@ from scico.array import is_nested from scico.blockarray import BlockArray +from scico.blockarray_old import block_sizes from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index 7ab581731..93fb4f40e 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -41,7 +41,7 @@ def __init__(self, dtype): self.c = ba.BlockArray.array((c0,)) # A flat device array with same size as self.a & self.b - self.flat_da, key = randn(shape=(self.a.size,), dtype=dtype, key=key) + self.flat_da, key = randn(shape=self.a.size, dtype=dtype, key=key) self.flat_nd = np.array(self.flat_da) # A device array with length == self.a.num_blocks @@ -227,6 +227,9 @@ def test_getitem(test_operator_obj): np.testing.assert_allclose(x[-1], b1) +@pytest.mark.skip() +# this is indexing block dimension and internal dimensions simultaneously +# supporting it adds complexity, are we okay with just x[0][1:3] instead of x[0, 1:3]? @pytest.mark.parametrize("index", (np.s_[0, 0], np.s_[0, 1:3], np.s_[0, :, 0:2], np.s_[0, ..., 2:])) def test_getitem_tuple(test_operator_obj, index): a = test_operator_obj.a @@ -234,6 +237,8 @@ def test_getitem_tuple(test_operator_obj, index): np.testing.assert_allclose(a[index], a0[index[1:]]) +@pytest.mark.skip() +# `.blockidx` was an index into the underlying 1D array that no longer exists def test_blockidx(test_operator_obj): a = test_operator_obj.a a0 = test_operator_obj.a0 @@ -248,9 +253,8 @@ def test_blockidx(test_operator_obj): def test_split(test_operator_obj): a = test_operator_obj.a - a_split = a.split - np.testing.assert_allclose(a_split[0], test_operator_obj.a0) - np.testing.assert_allclose(a_split[1], test_operator_obj.a1) + np.testing.assert_allclose(a[0], test_operator_obj.a0) + np.testing.assert_allclose(a[1], test_operator_obj.a1) def test_blockarray_from_one_array(): @@ -526,7 +530,7 @@ def test_nested_shape(nested_obj): np.testing.assert_allclose(a[1].ravel(), a1.ravel()) # basic test for block_sizes - assert ba.block_sizes(a.shape) == (a[0].size, a[1].size) + assert a.shape == (a[0].size, a[1].size) NESTED_REDUCTION_PARAMS = dict( From c22f67f1335c827640fd351b94583ba6614c643b Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 29 Mar 2022 09:51:08 -0600 Subject: [PATCH 04/37] Wrap all of jnp, add total reduction mechanism --- scico/blockarray.py | 23 ++-- scico/numpy/__init__.py | 205 +++++++++------------------------- scico/test/test_blockarray.py | 14 ++- 3 files changed, 74 insertions(+), 168 deletions(-) diff --git a/scico/blockarray.py b/scico/blockarray.py index f4bfc31f4..21dbdc891 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -480,7 +480,7 @@ class BlockArray(tuple): # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ __array_priority__ = 1 - def ravel(self) -> DeviceArray: + def full_ravel(self) -> DeviceArray: """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. Note that a copy, rather than a view, of the underlying array is @@ -581,17 +581,22 @@ def method_ba(self, *args, **kwargs): return method_ba -da_props = ( - "real", - "imag", - "size", - "shape", - "ndim", -) +# list and wrap properties +da_props = [ + x + for x in dir(DeviceArray) + if (isinstance(getattr(DeviceArray, x), property) and x[0] != "_" and x in dir(jnp)) +] + for prop in da_props: setattr(BlockArray, prop, _da_method_wrapper(prop, is_property=True)) +# list and wrap methods +da_methods = [ + x + for x in dir(DeviceArray) + if (not isinstance(getattr(DeviceArray, x), property) and x[0] != "_" and x in dir(jnp)) +] -da_methods = ("sum",) for method in da_methods: setattr(BlockArray, method, _da_method_wrapper(method)) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 8d3c5b18a..748c90d7d 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -16,178 +16,71 @@ """ -import sys from functools import wraps +from inspect import Parameter, signature +from types import FunctionType, ModuleType -import numpy as np +import jax.numpy as jnp -import jax -from jax import numpy as jnp +from jaxlib.xla_extension import CompiledFunction -from scico.array import is_nested +from scico.blockarray import BlockArray -# These functions rely on the definition of a BlockArray and must be in -# scico.blockarray to avoid a circular import -from scico.blockarray import BlockArray, da_methods, da_props -from scico.blockarray_old import ( - _block_array_matmul_wrapper, - _block_array_reduction_wrapper, - _block_array_ufunc_wrapper, - _flatten_blockarrays, - atleast_1d, - reshape, -) -from scico.typing import BlockShape, JaxArray, Shape - -from ._create import ( - empty, - empty_like, - full, - full_like, - ones, - ones_like, - zeros, - zeros_like, -) -from ._util import _attach_wrapped_func, _get_module_functions, _not_implemented - -# Numpy constants -pi = np.pi -e = np.e -euler_gamma = np.euler_gamma -inf = np.inf -NINF = np.NINF -PZERO = np.PZERO -NZERO = np.NZERO -nan = np.nan - -bool_ = jnp.bool_ -uint8 = jnp.uint8 -uint16 = jnp.uint16 -uint32 = jnp.uint32 -uint64 = jnp.uint64 -int8 = jnp.int8 -int16 = jnp.int16 -int32 = jnp.int32 -int64 = jnp.int64 -bfloat16 = jnp.bfloat16 -float16 = jnp.float16 -float32 = single = jnp.float32 -float64 = double = jnp.float64 -complex64 = csingle = jnp.complex64 -complex128 = cdouble = jnp.complex128 - -dtype = jnp.dtype -newaxis = None - -# Functions to which _block_array_ufunc_wrapper is to be applied -_ufunc_functions = [ - ("abs", jnp.abs), - jnp.maximum, - jnp.sign, - jnp.where, - jnp.true_divide, - jnp.floor_divide, - jnp.real, - jnp.imag, - jnp.conjugate, - jnp.angle, - jnp.exp, - jnp.sqrt, - jnp.log, - jnp.log10, -] -# Functions to which _block_array_reduction_wrapper is to be applied -_reduction_functions = [ - jnp.count_nonzero, - jnp.sum, - jnp.mean, - jnp.median, - jnp.any, - jnp.var, - ("max", jnp.max), - ("min", jnp.min), - jnp.amin, - jnp.amax, - jnp.all, - jnp.any, -] - -dot = _block_array_matmul_wrapper(jnp.dot) -matmul = _block_array_matmul_wrapper(jnp.matmul) - - -@wraps(jnp.vdot) -def vdot(a, b): - """Dot product of `a` and `b` (with first argument complex conjugated). - Wrapped to work on `BlockArray`s.""" - if isinstance(a, BlockArray): - a = a.ravel() - if isinstance(b, BlockArray): - b = b.ravel() - return jnp.vdot(a, b) - - -vdot.__doc__ = ":func:`vdot` wrapped to operate on :class:`.BlockArray`" + "\n\n" + jnp.vdot.__doc__ - -# Attach wrapped functions to this module -_attach_wrapped_func( - _ufunc_functions, - _block_array_ufunc_wrapper, - module_name=sys.modules[__name__], - fix_mod_name=True, -) -_attach_wrapped_func( - _reduction_functions, - _block_array_reduction_wrapper, - module_name=sys.modules[__name__], - fix_mod_name=True, -) -# divide is just an alias to true_divide -divide = true_divide -conj = conjugate - -# Find functions that exist in jax.numpy but not scico.numpy -# see jax.numpy.__init__.py -_not_implemented_functions = [] -for name, func in _get_module_functions(jnp).items(): - if name not in globals(): - _not_implemented_functions.append((name, func)) - -_attach_wrapped_func( - _not_implemented_functions, - _not_implemented, - module_name=sys.modules[__name__], - fix_mod_name=False, -) +def _copy_attributes(to_dict, from_dict, modules_to_recurse=None, reductions=None): + """Add attributes in `from_dict` to `to_dict`. + + Underscore methods are ignored. Functions are wrapped to allow for + `BlockArray` inputs. Modules are ignored, except those listed in + `modules_to_recurse`, which are added recursively. Functions with + names listed in `reductions` are given the `block_axis` argument, + allowing reduction down the block axis. + """ -# these must be imported towards the end to avoid a circular import with -# linalg and _matrixop -from . import fft, linalg + if modules_to_recurse is None: + modules_to_recurse = () -""" new code, everything above should eventually go. """ -import inspect + if reductions is None: + reductions = () -""" Make functions that map over BlockArray arguments""" -jnp_funcs_to_map = da_props + da_methods -jnp_funcs_to_map += ("dot",) + for name, obj in from_dict.items(): + if name[0] == "_": + continue + elif isinstance(obj, ModuleType) and name in modules_to_recurse: + to_dict[name] = ModuleType(name) + to_dict[name].__package__ = to_dict["__name__"] + to_dict[name].__doc__ = obj.__doc__ + _copy_attributes(to_dict[name].__dict__, obj.__dict__) + elif isinstance(obj, (FunctionType, CompiledFunction)): + obj = _map_func_over_ba(obj) + to_dict[name] = obj + else: + to_dict[name] = obj def _map_func_over_ba(func): - """Create a version of `func` that maps over all of its BlockArray - arguments.""" + """Create a version of `func` that maps over all of its `BlockArray` + arguments. + """ @wraps(func) def mapped(*args, **kwargs): - bound_args = inspect.signature(func).bind(*args, **kwargs) + sig = signature(func) + bound_args = sig.bind(*args, **kwargs) ba_args = {} for k, v in list(bound_args.arguments.items()): if isinstance(v, BlockArray): ba_args[k] = bound_args.arguments.pop(k) + ravel_blocks = "axis" not in bound_args.arguments and "axis" in sig.parameters + + if len(ba_args) and ravel_blocks: + ba_args = {k: v.full_ravel() for k, v in list(ba_args.items())} + print(ba_args) + return func(*bound_args.args, **bound_args.kwargs, **ba_args) + if len(ba_args): # if any BlockArray arguments, return BlockArray( map( # map over @@ -201,11 +94,15 @@ def mapped(*args, **kwargs): *ba_args.values(), # map(f, ba_1, ba_2, ..., ba_N) ) ) - else: - return func(*args, **kwargs) + + return func(*args, **kwargs) return mapped -for func in jnp_funcs_to_map: - vars()[func] = _map_func_over_ba(getattr(jnp, func)) +_copy_attributes( + vars(), + jnp.__dict__, + modules_to_recurse=("linalg",), + reductions=("sum",), +) diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index 93fb4f40e..9456de7a1 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -257,6 +257,10 @@ def test_split(test_operator_obj): np.testing.assert_allclose(a[1], test_operator_obj.a1) +@pytest.mark.skip() +# currently creation is exactly like a tuple, +# so BlockArray(np.jnp.zeros((32,32))) makes a block array +# with 32 1d blocks def test_blockarray_from_one_array(): with pytest.raises(TypeError): ba.BlockArray.array(np.random.randn(32, 32)) @@ -334,9 +338,9 @@ def __init__(self, dtype): c0, key = randn(shape=(2, 3), dtype=dtype, key=key) c1, key = randn(shape=(3,), dtype=dtype, key=key) - self.a = ba.BlockArray.array((a0, a1), dtype=dtype) - self.b = ba.BlockArray.array((b0, b1), dtype=dtype) - self.c = ba.BlockArray.array((c0, c1), dtype=dtype) + self.a = ba.BlockArray.array((a0, a1)) + self.b = ba.BlockArray.array((b0, b1)) + self.c = ba.BlockArray.array((c0, c1)) @pytest.fixture(scope="module") # so that random objects are cached @@ -360,8 +364,8 @@ def test_reduce(reduction_obj, func): x = func(reduction_obj.a) x_jit = jax.jit(func)(reduction_obj.a) y = func(reduction_obj.a.ravel()) - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function - np.testing.assert_allclose(x, y) # test for correctness + np.testing.assert_allclose(x, x_jit, rtol=1e-6) # test jitted function + np.testing.assert_allclose(x, y, rtol=1e-6) # test for correctness @pytest.mark.parametrize(**REDUCTION_PARAMS) From 97ecbdd587f4888aee95ecfa44f4effb9066b71d Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 29 Mar 2022 13:17:40 -0600 Subject: [PATCH 05/37] Finish reductions --- scico/blockarray.py | 62 ++++----- scico/numpy/__init__.py | 30 +++-- scico/random.py | 44 +----- scico/test/test_blockarray.py | 243 ++++++++++++++++------------------ 4 files changed, 163 insertions(+), 216 deletions(-) diff --git a/scico/blockarray.py b/scico/blockarray.py index 21dbdc891..8fc322073 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -43,6 +43,10 @@ DeviceArray([3, 5, 9], dtype=int32)) # BlockArray +TODO: not specifying axis to get a full reduction +TODO: using a BlockArray for axis or shape arguments + + Motivating Example @@ -462,15 +466,15 @@ """ +import inspect from functools import wraps +from typing import Callable import jax import jax.numpy as jnp from jaxlib.xla_extension import DeviceArray -import scico.numpy as snp - class BlockArray(tuple): """BlockArray""" @@ -552,50 +556,46 @@ def op_ba(self, other): setattr(BlockArray, op, _binary_op_wrapper(op)) -""" wrap blockwise DeviceArray methods, like conj """ - -da_methods = ("conj",) +""" Wrap DeviceArray properties. """ -def _da_method_wrapper(method): - @wraps(method) - def method_ba(self): - return BlockArray(map(method, self)) +def _da_prop_wrapper(prop): + @property + def prop_ba(self): + result = tuple(getattr(x, prop) for x in self) + if isinstance(result[0], jnp.ndarray): + return BlockArray(result) + return result - return method_ba + return prop_ba -for method in da_methods: - setattr(BlockArray, method, _da_method_wrapper(getattr(DeviceArray, method))) +da_props = [ + k + for k, v in dict(inspect.getmembers(DeviceArray)).items() + if isinstance(v, property) and k[0] != "_" +] +for prop in da_props: + setattr(BlockArray, prop, _da_prop_wrapper(prop)) -""" Wrap DeviceArray methods and properties that are implemented in jnp so that they call snp. """ +""" Wrap DeviceArray methods. """ -def _da_method_wrapper(method, is_property=False): +def _da_method_wrapper(method): def method_ba(self, *args, **kwargs): - return getattr(snp, method)(self, *args, **kwargs) + result = tuple(getattr(x, method)(*args, **kwargs) for x in self) + if isinstance(result[0], jnp.ndarray): + return BlockArray(result) + return result - if is_property: - return property(method_ba) return method_ba -# list and wrap properties -da_props = [ - x - for x in dir(DeviceArray) - if (isinstance(getattr(DeviceArray, x), property) and x[0] != "_" and x in dir(jnp)) -] - -for prop in da_props: - setattr(BlockArray, prop, _da_method_wrapper(prop, is_property=True)) - -# list and wrap methods da_methods = [ - x - for x in dir(DeviceArray) - if (not isinstance(getattr(DeviceArray, x), property) and x[0] != "_" and x in dir(jnp)) + k + for k, v in dict(inspect.getmembers(DeviceArray)).items() + if isinstance(v, Callable) and k[0] != "_" ] for method in da_methods: diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 748c90d7d..1460cd440 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -19,31 +19,31 @@ from functools import wraps from inspect import Parameter, signature from types import FunctionType, ModuleType +from typing import Iterable, Optional import jax.numpy as jnp from jaxlib.xla_extension import CompiledFunction +from scico.array import is_nested from scico.blockarray import BlockArray -def _copy_attributes(to_dict, from_dict, modules_to_recurse=None, reductions=None): +def _copy_attributes( + to_dict: dict, from_dict: dict, modules_to_recurse: Optional[Iterable[str]] = None +): """Add attributes in `from_dict` to `to_dict`. Underscore methods are ignored. Functions are wrapped to allow for `BlockArray` inputs. Modules are ignored, except those listed in - `modules_to_recurse`, which are added recursively. Functions with - names listed in `reductions` are given the `block_axis` argument, - allowing reduction down the block axis. + `modules_to_recurse`, which are added recursively. All others are + passed through unwrapped. """ if modules_to_recurse is None: modules_to_recurse = () - if reductions is None: - reductions = () - for name, obj in from_dict.items(): if name[0] == "_": continue @@ -62,6 +62,13 @@ def _copy_attributes(to_dict, from_dict, modules_to_recurse=None, reductions=Non def _map_func_over_ba(func): """Create a version of `func` that maps over all of its `BlockArray` arguments. + + Functions with an `axis` parameter are handled in a special way in + order to allow full reductions of `BlockArray`s. If the axis + parameter exists but is not specified, each `BlockArray` argument + is fully ravelled before the function is called and no mapping is + applied. + """ @wraps(func) @@ -71,18 +78,17 @@ def mapped(*args, **kwargs): ba_args = {} for k, v in list(bound_args.arguments.items()): - if isinstance(v, BlockArray): + if isinstance(v, BlockArray) or is_nested(v): ba_args[k] = bound_args.arguments.pop(k) ravel_blocks = "axis" not in bound_args.arguments and "axis" in sig.parameters if len(ba_args) and ravel_blocks: ba_args = {k: v.full_ravel() for k, v in list(ba_args.items())} - print(ba_args) return func(*bound_args.args, **bound_args.kwargs, **ba_args) if len(ba_args): # if any BlockArray arguments, - return BlockArray( + result = tuple( map( # map over lambda *args: ( # lambda x_1, x_2, ..., x_N func( @@ -94,6 +100,9 @@ def mapped(*args, **kwargs): *ba_args.values(), # map(f, ba_1, ba_2, ..., ba_N) ) ) + if isinstance(result[0], jnp.ndarray): # True for abstract arrays, too + return BlockArray(result) + return result return func(*args, **kwargs) @@ -104,5 +113,4 @@ def mapped(*args, **kwargs): vars(), jnp.__dict__, modules_to_recurse=("linalg",), - reductions=("sum",), ) diff --git a/scico/random.py b/scico/random.py index fd68bdf08..c405d7304 100644 --- a/scico/random.py +++ b/scico/random.py @@ -50,7 +50,6 @@ """ -import functools import inspect import sys from typing import Optional, Tuple, Union @@ -59,9 +58,8 @@ import jax -from scico.array import is_nested +import scico.numpy as snp from scico.blockarray import BlockArray -from scico.blockarray_old import block_sizes from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape @@ -123,45 +121,7 @@ def fun_alt(*args, key=None, seed=None, **kwargs): return fun_alt -def _allow_block_shape(fun): - """ - Decorate a jax.random function so that the `shape` argument may be a BlockShape. - """ - - # use inspect to find which argument number is `shape` - shape_ind = list(inspect.signature(fun).parameters.keys()).index("shape") - - @functools.wraps(fun) - def fun_alt(*args, **kwargs): - - # get the shape argument if it was passed - if len(args) > shape_ind: - shape = args[shape_ind] - elif "shape" in kwargs: - shape = kwargs["shape"] - else: # shape was not passed, call fun as normal - return fun(*args, **kwargs) - - # if shape is not nested, call fun as normal - if not is_nested(shape): - return fun(*args, **kwargs) - # shape is nested, so make a BlockArray! - - # call the wrapped fun with an shape=(size,) - subargs = list(args) - subkwargs = kwargs.copy() - size = np.sum(block_sizes(shape)) - - if len(subargs) > shape_ind: - subargs[shape_ind] = (size,) - else: # shape must be a kwarg if not a positional arg - subkwargs["shape"] = (size,) - - result_flat = fun(*subargs, **subkwargs) - - return BlockArray.array_from_flattened(result_flat, shape) - - return fun_alt +_allow_block_shape = snp._map_func_over_ba def _wrap(fun): diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index 9456de7a1..138c1ccb1 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -9,8 +9,8 @@ import pytest -import scico.blockarray as ba import scico.numpy as snp +from scico.blockarray import BlockArray from scico.random import randn math_ops = [op.add, op.sub, op.mul, op.truediv, op.pow] # op.floordiv doesn't work on complex @@ -27,18 +27,18 @@ def __init__(self, dtype): self.a0, key = randn(shape=(2, 3), dtype=dtype, key=key) self.a1, key = randn(shape=(2, 3, 4), dtype=dtype, key=key) - self.a = ba.BlockArray.array((self.a0, self.a1)) + self.a = BlockArray((self.a0, self.a1)) self.b0, key = randn(shape=(2, 3), dtype=dtype, key=key) self.b1, key = randn(shape=(2, 3, 4), dtype=dtype, key=key) - self.b = ba.BlockArray.array((self.b0, self.b1)) + self.b = BlockArray((self.b0, self.b1)) self.d0, key = randn(shape=(3, 2), dtype=dtype, key=key) self.d1, key = randn(shape=(2, 4, 3), dtype=dtype, key=key) - self.d = ba.BlockArray.array((self.d0, self.d1)) + self.d = BlockArray((self.d0, self.d1)) c0, key = randn(shape=(2, 3), dtype=dtype, key=key) - self.c = ba.BlockArray.array((c0,)) + self.c = BlockArray((c0,)) # A flat device array with same size as self.a & self.b self.flat_da, key = randn(shape=self.a.size, dtype=dtype, key=key) @@ -63,17 +63,17 @@ def test_operator_obj(request): def test_operator_left(test_operator_obj, operator): scalar = test_operator_obj.scalar a = test_operator_obj.a - x = operator(scalar, a).ravel() - y = operator(scalar, a.ravel()) - np.testing.assert_allclose(x, y, rtol=1e-6) + x = operator(scalar, a).full_ravel() + y = operator(scalar, a.full_ravel()) + np.testing.assert_allclose(x, y) @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_operator_right(test_operator_obj, operator): scalar = test_operator_obj.scalar a = test_operator_obj.a - x = operator(a, scalar).ravel() - y = operator(a.ravel(), scalar) + x = operator(a, scalar).full_ravel() + y = operator(a.full_ravel(), scalar) np.testing.assert_allclose(x, y) @@ -84,8 +84,8 @@ def test_operator_right(test_operator_obj, operator): def test_ba_da_left(test_operator_obj, operator): flat_da = test_operator_obj.flat_da a = test_operator_obj.a - x = operator(flat_da, a).ravel() - y = operator(flat_da, a.ravel()) + x = operator(flat_da, a).full_ravel() + y = operator(flat_da, a.full_ravel()) np.testing.assert_allclose(x, y, rtol=5e-5) @@ -94,8 +94,8 @@ def test_ba_da_left(test_operator_obj, operator): def test_ba_da_right(test_operator_obj, operator): flat_da = test_operator_obj.flat_da a = test_operator_obj.a - x = operator(a, flat_da).ravel() - y = operator(a.ravel(), flat_da) + x = operator(a, flat_da).full_ravel() + y = operator(a.full_ravel(), flat_da) np.testing.assert_allclose(x, y) @@ -107,8 +107,8 @@ def test_ndarray_left(test_operator_obj, operator): a = test_operator_obj.a block_nd = test_operator_obj.block_nd - x = operator(a, block_nd).ravel() - y = ba.BlockArray.array([operator(a[i], block_nd[i]) for i in range(len(a))]).ravel() + x = operator(a, block_nd).full_ravel() + y = BlockArray([operator(a[i], block_nd[i]) for i in range(len(a))]).full_ravel() np.testing.assert_allclose(x, y) @@ -118,8 +118,8 @@ def test_ndarray_right(test_operator_obj, operator): a = test_operator_obj.a block_nd = test_operator_obj.block_nd - x = operator(block_nd, a).ravel() - y = ba.BlockArray.array([operator(block_nd[i], a[i]) for i in range(len(a))]).ravel() + x = operator(block_nd, a).full_ravel() + y = BlockArray([operator(block_nd[i], a[i]) for i in range(len(a))]).full_ravel() np.testing.assert_allclose(x, y) @@ -130,8 +130,8 @@ def test_devicearray_left(test_operator_obj, operator): a = test_operator_obj.a block_da = test_operator_obj.block_da - x = operator(a, block_da).ravel() - y = ba.BlockArray.array([operator(a[i], block_da[i]) for i in range(len(a))]).ravel() + x = operator(a, block_da).full_ravel() + y = BlockArray([operator(a[i], block_da[i]) for i in range(len(a))]).full_ravel() np.testing.assert_allclose(x, y) @@ -141,8 +141,8 @@ def test_devicearray_right(test_operator_obj, operator): a = test_operator_obj.a block_da = test_operator_obj.block_da - x = operator(block_da, a).ravel() - y = ba.BlockArray.array([operator(block_da[i], a[i]) for i in range(len(a))]).ravel() + x = operator(block_da, a).full_ravel() + y = BlockArray([operator(block_da[i], a[i]) for i in range(len(a))]).full_ravel() np.testing.assert_allclose(x, y, atol=1e-7, rtol=0) @@ -151,8 +151,8 @@ def test_devicearray_right(test_operator_obj, operator): def test_ba_ba_operator(test_operator_obj, operator): a = test_operator_obj.a b = test_operator_obj.b - x = operator(a, b).ravel() - y = operator(a.ravel(), b.ravel()) + x = operator(a, b).full_ravel() + y = operator(a.full_ravel(), b.full_ravel()) np.testing.assert_allclose(x, y) @@ -169,9 +169,9 @@ def test_ba_ba_matmul(test_operator_obj): x = a @ b - y = ba.BlockArray.array([a0 @ d0, a1 @ d1]) + y = BlockArray([a0 @ d0, a1 @ d1]) assert x.shape == y.shape - np.testing.assert_allclose(x.ravel(), y.ravel()) + np.testing.assert_allclose(x.full_ravel(), y.full_ravel()) with pytest.raises(TypeError): z = a @ c @@ -182,7 +182,7 @@ def test_conj(test_operator_obj): ac = a.conj() assert a.shape == ac.shape - np.testing.assert_allclose(a.ravel().conj(), ac.ravel()) + np.testing.assert_allclose(a.full_ravel().conj(), ac.full_ravel()) def test_real(test_operator_obj): @@ -190,7 +190,7 @@ def test_real(test_operator_obj): ac = a.real assert a.shape == ac.shape - np.testing.assert_allclose(a.ravel().real, ac.ravel()) + np.testing.assert_allclose(a.full_ravel().real, ac.full_ravel()) def test_imag(test_operator_obj): @@ -198,7 +198,7 @@ def test_imag(test_operator_obj): ac = a.imag assert a.shape == ac.shape - np.testing.assert_allclose(a.ravel().imag, ac.ravel()) + np.testing.assert_allclose(a.full_ravel().imag, ac.full_ravel()) def test_ndim(test_operator_obj): @@ -212,7 +212,7 @@ def test_getitem(test_operator_obj): a1 = test_operator_obj.a1 b0 = test_operator_obj.b0 b1 = test_operator_obj.b1 - x = ba.BlockArray.array([a0, a1, b0, b1]) + x = BlockArray([a0, a1, b0, b1]) # Positive indexing np.testing.assert_allclose(x[0], a0) @@ -245,10 +245,10 @@ def test_blockidx(test_operator_obj): a1 = test_operator_obj.a1 # use the blockidx to index the flattened data - x0 = a.ravel()[a.blockidx(0)] - x1 = a.ravel()[a.blockidx(1)] - np.testing.assert_allclose(x0, a0.ravel()) - np.testing.assert_allclose(x1, a1.ravel()) + x0 = a.full_ravel()[a.blockidx(0)] + x1 = a.full_ravel()[a.blockidx(1)] + np.testing.assert_allclose(x0, a0.full_ravel()) + np.testing.assert_allclose(x1, a1.full_ravel()) def test_split(test_operator_obj): @@ -263,7 +263,7 @@ def test_split(test_operator_obj): # with 32 1d blocks def test_blockarray_from_one_array(): with pytest.raises(TypeError): - ba.BlockArray.array(np.random.randn(32, 32)) + BlockArray(np.random.randn(32, 32)) @pytest.mark.parametrize("axis", [None, 1]) @@ -271,13 +271,16 @@ def test_blockarray_from_one_array(): def test_sum_method(test_operator_obj, axis, keepdims): a = test_operator_obj.a - method_result = a.sum(axis=axis, keepdims=keepdims).ravel() - snp_result = snp.sum(a, axis=axis, keepdims=keepdims).ravel() + method_result = a.sum(axis=axis, keepdims=keepdims).full_ravel() + snp_result = snp.sum(a, axis=axis, keepdims=keepdims).full_ravel() assert method_result.shape == snp_result.shape np.testing.assert_allclose(method_result, snp_result) +@pytest.mark.skip() +# previously vdot returned a scalar, +# in this proposal, it acts blockwize def test_ba_ba_vdot(test_operator_obj): a = test_operator_obj.a d = test_operator_obj.d @@ -287,7 +290,7 @@ def test_ba_ba_vdot(test_operator_obj): d1 = test_operator_obj.d1 x = snp.vdot(a, d) - y = jnp.vdot(a.ravel(), d.ravel()) + y = jnp.vdot(a.full_ravel(), d.full_ravel()) np.testing.assert_allclose(x, y) @@ -301,8 +304,8 @@ def test_ba_ba_dot(test_operator_obj, operator): d1 = test_operator_obj.d1 x = operator(a, d) - y = ba.BlockArray.array([operator(a0, d0), operator(a1, d1)]) - np.testing.assert_allclose(x.ravel(), y.ravel()) + y = BlockArray([operator(a0, d0), operator(a1, d1)]) + np.testing.assert_allclose(x.full_ravel(), y.full_ravel()) ############################################################################### @@ -338,9 +341,9 @@ def __init__(self, dtype): c0, key = randn(shape=(2, 3), dtype=dtype, key=key) c1, key = randn(shape=(3,), dtype=dtype, key=key) - self.a = ba.BlockArray.array((a0, a1)) - self.b = ba.BlockArray.array((b0, b1)) - self.c = ba.BlockArray.array((c0, c1)) + self.a = BlockArray((a0, a1)) + self.b = BlockArray((b0, b1)) + self.c = BlockArray((c0, c1)) @pytest.fixture(scope="module") # so that random objects are cached @@ -363,13 +366,17 @@ def reduction_obj(request): def test_reduce(reduction_obj, func): x = func(reduction_obj.a) x_jit = jax.jit(func)(reduction_obj.a) - y = func(reduction_obj.a.ravel()) + y = func(reduction_obj.a.full_ravel()) np.testing.assert_allclose(x, x_jit, rtol=1e-6) # test jitted function np.testing.assert_allclose(x, y, rtol=1e-6) # test for correctness +@pytest.mark.skip +# this is reduction along the block axis, which (in the old version) +# requires all blocks to be the same shape. If you know all blocks are the same shape, +# why use a block array? @pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis0(reduction_obj, func): +def test_reduce_axis0_old(reduction_obj, func): f = lambda x: func(x, axis=0) x = f(reduction_obj.b) x_jit = jax.jit(f)(reduction_obj.b) @@ -387,55 +394,27 @@ def test_reduce_axis0(reduction_obj, func): @pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis1(reduction_obj, func): - """this is _not_ duplicated from test_reduce_axis0""" - f = lambda x: func(x, axis=1).ravel() +@pytest.mark.parametrize("axis", (0, 1)) +def test_reduce_axis(reduction_obj, func, axis): + f = lambda x: func(x, axis=axis) x = f(reduction_obj.a) x_jit = jax.jit(f)(reduction_obj.a) - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function + np.testing.assert_allclose( + x.full_ravel(), x_jit.full_ravel(), rtol=1e-4 + ) # test jitted function # test for correctness - y0 = func(reduction_obj.a[0], axis=0) - y1 = func(reduction_obj.a[1], axis=0) - y = ba.BlockArray.array((y0, y1), dtype=reduction_obj.a[0].dtype).ravel() - np.testing.assert_allclose(x, y) - - -@pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis2(reduction_obj, func): - """this is _not_ duplicated from test_reduce_axis0""" - f = lambda x: func(x, axis=2).ravel() - x = f(reduction_obj.a) - x_jit = jax.jit(f)(reduction_obj.a) - - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function - - y0 = func(reduction_obj.a[0], axis=1) - y1 = func(reduction_obj.a[1], axis=1) - y = ba.BlockArray.array((y0, y1), dtype=reduction_obj.a[0].dtype).ravel() - np.testing.assert_allclose(x, y) - - -@pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis3(reduction_obj, func): - """this is _not_ duplicated from test_reduce_axis0""" - f = lambda x: func(x, axis=3).ravel() - x = f(reduction_obj.a) - x_jit = jax.jit(f)(reduction_obj.a) - - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function - - y0 = reduction_obj.a[0] - y1 = func(reduction_obj.a[1], axis=2) - y = ba.BlockArray.array((y0, y1), dtype=reduction_obj.a[0].dtype).ravel() - np.testing.assert_allclose(x.ravel(), y) + y0 = func(reduction_obj.a[0], axis=axis) + y1 = func(reduction_obj.a[1], axis=axis) + y = BlockArray((y0, y1)) + np.testing.assert_allclose(x.full_ravel(), y.full_ravel()) @pytest.mark.parametrize(**REDUCTION_PARAMS) def test_reduce_singleton(reduction_obj, func): - # Case where a block is reduced to a singleton - f = lambda x: func(x, axis=1).ravel() + # Case where one block is reduced to a singleton + f = lambda x: func(x, axis=0).full_ravel() x = f(reduction_obj.c) x_jit = jax.jit(f)(reduction_obj.c) @@ -443,7 +422,7 @@ def test_reduce_singleton(reduction_obj, func): y0 = func(reduction_obj.c[0], axis=0) y1 = func(reduction_obj.c[1], axis=0)[None] # Ensure size (1,) - y = ba.BlockArray.array((y0, y1), dtype=reduction_obj.a[0].dtype).ravel() + y = BlockArray((y0, y1), dtype=reduction_obj.a[0].dtype).full_ravel() np.testing.assert_allclose(x, y) @@ -457,42 +436,44 @@ def setup_method(self, method): self.size = np.prod(self.a_shape) + np.prod(self.b_shape) + np.prod(self.c_shape) def test_zeros(self): - x = ba.BlockArray.zeros(self.shape, dtype=np.float32) + x = snp.zeros(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 0) def test_empty(self): - x = ba.BlockArray.empty(self.shape, dtype=np.float32) + x = snp.empty(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 0) def test_ones(self): - x = ba.BlockArray.ones(self.shape, dtype=np.float32) + x = snp.ones(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 1) def test_full(self): fill_value = np.float32(np.random.randn()) - x = ba.BlockArray.full(self.shape, fill_value=fill_value, dtype=np.float32) + x = snp.full(self.shape, fill_value=fill_value, dtype=np.float32) assert x.shape == self.shape - assert x.dtype == np.float32 + assert x.dtype == (np.float32, np.float32, np.float32) assert snp.all(x == fill_value) def test_full_nodtype(self): fill_value = np.float32(np.random.randn()) - x = ba.BlockArray.full(self.shape, fill_value=fill_value, dtype=None) + x = snp.full(self.shape, fill_value=fill_value, dtype=None) assert x.shape == self.shape - assert x.dtype == fill_value.dtype + assert x.dtype == (fill_value.dtype, fill_value.dtype, fill_value.dtype) assert snp.all(x == fill_value) +@pytest.mark.skip +# it no longer makes sense to make a BlockArray from a flattened array def test_incompatible_shapes(): # Verify that array_from_flattened raises exception when # len(data_ravel) != size determined by shape_tuple shape_tuple = ((32, 32), (16,)) # len == 1040 data_ravel = np.ones(1030) with pytest.raises(ValueError): - ba.BlockArray.array_from_flattened(data_ravel=data_ravel, shape_tuple=shape_tuple) + BlockArray.array_from_flattened(data_ravel=data_ravel, shape_tuple=shape_tuple) class NestedTestObj: @@ -501,13 +482,13 @@ class NestedTestObj: def __init__(self, dtype): key = None scalar, key = randn(shape=(1,), dtype=dtype, key=key) - self.scalar = scalar.copy().ravel()[0] # convert to float + self.scalar = scalar.copy().full_ravel()[0] # convert to float self.a00, key = randn(shape=(2, 2, 2), dtype=dtype, key=key) self.a01, key = randn(shape=(3, 2, 4), dtype=dtype, key=key) self.a1, key = randn(shape=(2, 4), dtype=dtype, key=key) - self.a = ba.BlockArray.array(((self.a00, self.a01), self.a1)) + self.a = BlockArray(((self.a00, self.a01), self.a1)) @pytest.fixture(scope="module") @@ -515,6 +496,7 @@ def nested_obj(request): yield NestedTestObj(request.param) +@pytest.mark.skip # deeply nested shapes no longer allowed @pytest.mark.parametrize("nested_obj", [np.float32, np.complex64], indirect=True) def test_nested_shape(nested_obj): a = nested_obj.a @@ -529,9 +511,9 @@ def test_nested_shape(nested_obj): assert a[0].shape == ((2, 2, 2), (3, 2, 4)) assert a[1].shape == (2, 4) - np.testing.assert_allclose(a[0][0].ravel(), a00.ravel()) - np.testing.assert_allclose(a[0][1].ravel(), a01.ravel()) - np.testing.assert_allclose(a[1].ravel(), a1.ravel()) + np.testing.assert_allclose(a[0][0].full_ravel(), a00.full_ravel()) + np.testing.assert_allclose(a[0][1].full_ravel(), a01.full_ravel()) + np.testing.assert_allclose(a[1].full_ravel(), a1.full_ravel()) # basic test for block_sizes assert a.shape == (a[0].size, a[1].size) @@ -548,14 +530,16 @@ def test_nested_shape(nested_obj): ) +@pytest.mark.skip # deeply nested shapes no longer allowed @pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) def test_nested_reduce_singleton(nested_obj, func): a = nested_obj.a x = func(a) - y = func(a.ravel()) + y = func(a.full_ravel()) np.testing.assert_allclose(x, y, rtol=5e-5) +@pytest.mark.skip # deeply nested shapes no longer allowed @pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) def test_nested_reduce_axis1(nested_obj, func): a = nested_obj.a @@ -565,6 +549,7 @@ def test_nested_reduce_axis1(nested_obj, func): x = func(a, axis=1) +@pytest.mark.skip # deeply nested shapes no longer allowed @pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) def test_nested_reduce_axis2(nested_obj, func): a = nested_obj.a @@ -572,12 +557,13 @@ def test_nested_reduce_axis2(nested_obj, func): x = func(a, axis=2) assert x.shape == (((2, 2), (2, 4)), (2,)) - y = ba.BlockArray.array((func(a[0], axis=1), func(a[1], axis=1))) + y = BlockArray((func(a[0], axis=1), func(a[1], axis=1))) assert x.shape == y.shape - np.testing.assert_allclose(x.ravel(), y.ravel(), rtol=5e-5) + np.testing.assert_allclose(x.full_ravel(), y.full_ravel(), rtol=5e-5) +@pytest.mark.skip # deeply nested shapes no longer allowed @pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) def test_nested_reduce_axis3(nested_obj, func): a = nested_obj.a @@ -585,12 +571,14 @@ def test_nested_reduce_axis3(nested_obj, func): x = func(a, axis=3) assert x.shape == (((2, 2), (3, 4)), (2, 4)) - y = ba.BlockArray.array((func(a[0], axis=2), a[1])) + y = BlockArray((func(a[0], axis=2), a[1])) assert x.shape == y.shape - np.testing.assert_allclose(x.ravel(), y.ravel(), rtol=5e-5) + np.testing.assert_allclose(x.full_ravel(), y.full_ravel(), rtol=5e-5) +@pytest.mark.skip +# no longer makes sense to make BlockArray from 1d array def test_array_from_flattened(): x = np.random.randn(19) x_b = ba.BlockArray.array_from_flattened(x, shape_tuple=((4, 4), (3,))) @@ -605,71 +593,62 @@ def setup_method(self): self.B, key = randn(shape=((3, 3), (4, 2, 3)), key=key) self.C, key = randn(shape=((3, 3), (4, 2), (4, 4)), key=key) - # nested - self.D, key = randn(shape=((self.A.shape, self.B.shape)), key=key) - def test_set_block(self): # Test assignment of an entire block - A2 = self.A.at[0].set(1) + A2 = self.A.at[0][:].set(1) np.testing.assert_allclose(A2[0], snp.ones_like(A2[0]), rtol=5e-5) np.testing.assert_allclose(A2[1], A2[1], rtol=5e-5) - D2 = self.D.at[1].set(1.45) - np.testing.assert_allclose(D2[0].ravel(), self.D[0].ravel(), rtol=5e-5) - np.testing.assert_allclose( - D2[1].ravel(), 1.45 * snp.ones_like(self.D[1]).ravel(), rtol=5e-5 - ) - def test_set(self): # Test assignment using (bkidx, idx) format - A2 = self.A.at[0, 2:, :-2].set(1.45) + A2 = self.A[0].at[2:, :-2].set(1.45) tmp = A2[0][2:, :-2] np.testing.assert_allclose(A2[0][2:, :-2], 1.45 * snp.ones_like(tmp), rtol=5e-5) - np.testing.assert_allclose(A2[1].ravel(), A2[1], rtol=5e-5) + np.testing.assert_allclose(A2[1].full_ravel(), A2[1], rtol=5e-5) def test_add(self): A2 = self.A.at[0, 2:, :-2].add(1.45) tmp = np.array(self.A[0]) tmp[2:, :-2] += 1.45 - y = ba.BlockArray.array([tmp, self.A[1]]) - np.testing.assert_allclose(A2.ravel(), y.ravel(), rtol=5e-5) + y = BlockArray([tmp, self.A[1]]) + np.testing.assert_allclose(A2.full_ravel(), y.full_ravel(), rtol=5e-5) D2 = self.D.at[1].add(1.45) - y = ba.BlockArray.array([self.D[0], self.D[1] + 1.45]) - np.testing.assert_allclose(D2.ravel(), y.ravel(), rtol=5e-5) + y = BlockArray([self.D[0], self.D[1] + 1.45]) + np.testing.assert_allclose(D2.full_ravel(), y.full_ravel(), rtol=5e-5) def test_multiply(self): A2 = self.A.at[0, 2:, :-2].multiply(1.45) tmp = np.array(self.A[0]) tmp[2:, :-2] *= 1.45 - y = ba.BlockArray.array([tmp, self.A[1]]) - np.testing.assert_allclose(A2.ravel(), y.ravel(), rtol=5e-5) + y = BlockArray([tmp, self.A[1]]) + np.testing.assert_allclose(A2.full_ravel(), y.full_ravel(), rtol=5e-5) D2 = self.D.at[1].multiply(1.45) - y = ba.BlockArray.array([self.D[0], self.D[1] * 1.45]) - np.testing.assert_allclose(D2.ravel(), y.ravel(), rtol=5e-5) + y = BlockArray([self.D[0], self.D[1] * 1.45]) + np.testing.assert_allclose(D2.full_ravel(), y.full_ravel(), rtol=5e-5) def test_divide(self): A2 = self.A.at[0, 2:, :-2].divide(1.45) tmp = np.array(self.A[0]) tmp[2:, :-2] /= 1.45 - y = ba.BlockArray.array([tmp, self.A[1]]) - np.testing.assert_allclose(A2.ravel(), y.ravel(), rtol=5e-5) + y = BlockArray([tmp, self.A[1]]) + np.testing.assert_allclose(A2.full_ravel(), y.full_ravel(), rtol=5e-5) D2 = self.D.at[1].divide(1.45) - y = ba.BlockArray.array([self.D[0], self.D[1] / 1.45]) - np.testing.assert_allclose(D2.ravel(), y.ravel(), rtol=5e-5) + y = BlockArray([self.D[0], self.D[1] / 1.45]) + np.testing.assert_allclose(D2.full_ravel(), y.full_ravel(), rtol=5e-5) def test_power(self): A2 = self.A.at[0, 2:, :-2].power(2) tmp = np.array(self.A[0]) tmp[2:, :-2] **= 2 - y = ba.BlockArray.array([tmp, self.A[1]]) - np.testing.assert_allclose(A2.ravel(), y.ravel(), rtol=5e-5) + y = BlockArray([tmp, self.A[1]]) + np.testing.assert_allclose(A2.full_ravel(), y.full_ravel(), rtol=5e-5) D2 = self.D.at[1].power(1.45) - y = ba.BlockArray.array([self.D[0], self.D[1] ** 1.45]) - np.testing.assert_allclose(D2.ravel(), y.ravel(), rtol=5e-5) + y = BlockArray([self.D[0], self.D[1] ** 1.45]) + np.testing.assert_allclose(D2.full_ravel(), y.full_ravel(), rtol=5e-5) def test_set_slice(self): with pytest.raises(TypeError): From 22dbce0890a479b7d4d65735263e09ee871ebc2a Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 29 Mar 2022 14:28:09 -0600 Subject: [PATCH 06/37] Finish first pass over the tests --- scico/blockarray.py | 15 +++++++++++---- scico/test/test_blockarray.py | 7 +++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/scico/blockarray.py b/scico/blockarray.py index 8fc322073..0b863a575 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -17,7 +17,7 @@ The class :class:`.BlockArray` provides a way to combine several arrays of different shapes and/or data types into a single object. -A :class:`.BlockArray` consists of a tuple of `DeviceArray` +A :class:`.BlockArray` consists of a list of `DeviceArray` objects, which we refer to as blocks. :class:`.BlockArray`s differ from tuples in that mathematical operations on :class:`.BlockArray`s automatically map along the blocks, returning @@ -476,7 +476,7 @@ from jaxlib.xla_extension import DeviceArray -class BlockArray(tuple): +class BlockArray(list): """BlockArray""" # Ensure we use BlockArray.__radd__, __rmul__, etc for binary @@ -570,10 +570,12 @@ def prop_ba(self): return prop_ba +skip_props = ("at",) + da_props = [ k for k, v in dict(inspect.getmembers(DeviceArray)).items() - if isinstance(v, property) and k[0] != "_" + if isinstance(v, property) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_props ] for prop in da_props: @@ -592,10 +594,15 @@ def method_ba(self, *args, **kwargs): return method_ba +skip_methods = () + da_methods = [ k for k, v in dict(inspect.getmembers(DeviceArray)).items() - if isinstance(v, Callable) and k[0] != "_" + if isinstance(v, Callable) + and k[0] != "_" + and k not in dir(BlockArray) + and k not in skip_methods ] for method in da_methods: diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index 138c1ccb1..a856820e4 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -585,6 +585,9 @@ def test_array_from_flattened(): assert isinstance(x_b._data, DeviceArray) +@pytest.mark.skip +# indexing now works just like a list of DeviceArrays: +# x[1] = x[1].at[:].set(0) class TestBlockArrayIndex: def setup_method(self): key = None @@ -595,14 +598,14 @@ def setup_method(self): def test_set_block(self): # Test assignment of an entire block - A2 = self.A.at[0][:].set(1) + A2 = self.A[0].at[:].set(1) np.testing.assert_allclose(A2[0], snp.ones_like(A2[0]), rtol=5e-5) np.testing.assert_allclose(A2[1], A2[1], rtol=5e-5) def test_set(self): # Test assignment using (bkidx, idx) format A2 = self.A[0].at[2:, :-2].set(1.45) - tmp = A2[0][2:, :-2] + tmp = A2[2:, :-2] np.testing.assert_allclose(A2[0][2:, :-2], 1.45 * snp.ones_like(tmp), rtol=5e-5) np.testing.assert_allclose(A2[1].full_ravel(), A2[1], rtol=5e-5) From f78a11f01f1e63e46dba80693478cb3df6de8829 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 29 Mar 2022 15:58:31 -0600 Subject: [PATCH 07/37] Start on entire test suite --- scico/_generic_operators.py | 2 +- scico/array.py | 71 +++++++++++++++++-- scico/blockarray_old.py | 58 +--------------- scico/numpy/__init__.py | 9 ++- scico/numpy/fft.py | 31 --------- scico/numpy/linalg.py | 108 ----------------------------- scico/scipy/special.py | 12 +++- scico/test/functional/test_core.py | 58 +++++++++++++++- scico/test/test_array.py | 6 +- 9 files changed, 143 insertions(+), 212 deletions(-) delete mode 100644 scico/numpy/fft.py delete mode 100644 scico/numpy/linalg.py diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index d97bb6592..e5f484082 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -23,7 +23,7 @@ import scico.numpy as snp from scico._autograd import linear_adjoint -from scico.array import is_complex_dtype, is_nested +from scico.array import block_sizes, is_complex_dtype, is_nested from scico.blockarray import BlockArray from scico.typing import BlockShape, DType, JaxArray, Shape diff --git a/scico/array.py b/scico/array.py index 17d582277..451fa737f 100644 --- a/scico/array.py +++ b/scico/array.py @@ -19,14 +19,14 @@ from jax.interpreters.pxla import ShardedDeviceArray from jax.interpreters.xla import DeviceArray -import scico.blockarray import scico.numpy as snp +from scico.blockarray import BlockArray from scico.typing import ArrayIndex, Axes, AxisIndex, DType, JaxArray, Shape def ensure_on_device( - *arrays: Union[np.ndarray, JaxArray, scico.blockarray.BlockArray] -) -> Union[JaxArray, scico.blockarray.BlockArray]: + *arrays: Union[np.ndarray, JaxArray, BlockArray] +) -> Union[JaxArray, BlockArray]: """Cast ndarrays to DeviceArrays. Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays, @@ -58,24 +58,25 @@ def ensure_on_device( stacklevel=2, ) - arrays[i] = jax.device_put(arrays[i]) elif not isinstance( array, - (DeviceArray, scico.blockarray.BlockArray, ShardedDeviceArray), + (DeviceArray, BlockArray, ShardedDeviceArray), ): raise TypeError( "Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or " f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}." ) + arrays[i] = jax.device_put(arrays[i]) + if len(arrays) == 1: return arrays[0] return arrays def no_nan_divide( - x: Union[scico.blockarray.BlockArray, JaxArray], y: Union[scico.blockarray.BlockArray, JaxArray] -) -> Union[scico.blockarray.BlockArray, JaxArray]: + x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] +) -> Union[BlockArray, JaxArray]: """Return `x/y`, with 0 instead of NaN where `y` is 0. Args: @@ -192,6 +193,62 @@ def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: return tuple(filter(lambda x: x is not None, idx_shape)) +def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: + r"""Compute the 'sizes' of block shapes. + + This function computes ``block_sizes(z.shape) == (_.size for _ in z)`` + + + Args: + shape: A shape tuple; possibly containing nested tuples. + + + Examples: + + .. doctest:: + >>> import scico.numpy as snp + + >>> x = BlockArray.ones( ( (4, 4), (2,))) + >>> x.size + 18 + + >>> y = snp.ones((3, 3)) + >>> y.size + 9 + + >>> z = BlockArray.array([x, y]) + >>> block_sizes(z.shape) + (18, 9) + + >>> zz = BlockArray.array([z, z]) + >>> block_sizes(zz.shape) + (27, 27) + """ + + if isinstance(shape, BlockArray): + raise TypeError( + "Expected a `shape` (possibly nested tuple of ints); got :class:`.BlockArray`." + ) + + out = [] + if is_nested(shape): + # shape is nested -> at least one element came from a blockarray + for y in shape: + if is_nested(y): + # recursively calculate the block size until we arrive at + # a tuple (shape of a non-block array) + while is_nested(y): + y = block_sizes(y) + out.append(np.sum(y)) # adjacent block sizes are added together + else: + # this is a tuple; size given by product of elements + out.append(np.prod(y)) + return tuple(out) + + # shape is a non-nested tuple; return the product + return np.prod(shape) + + def is_nested(x: Any) -> bool: """Check if input is a list/tuple containing at least one list/tuple. diff --git a/scico/blockarray_old.py b/scico/blockarray_old.py index c611c1eb7..ba71073bc 100644 --- a/scico/blockarray_old.py +++ b/scico/blockarray_old.py @@ -466,7 +466,7 @@ from jaxlib.xla_extension import Buffer from scico import array -from scico.typing import Axes, AxisIndex, BlockShape, DType, JaxArray, Shape +from scico.typing import AxisIndex, BlockShape, DType, JaxArray, Shape _arraylikes = (Buffer, DeviceArray, np.ndarray) @@ -525,62 +525,6 @@ def reshape( return jnp.reshape(a, newshape) -def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: - r"""Compute the 'sizes' of (possibly nested) block shapes. - - This function computes ``block_sizes(z.shape) == (_.size for _ in z)`` - - - Args: - shape: A shape tuple; possibly containing nested tuples. - - - Examples: - - .. doctest:: - >>> import scico.numpy as snp - - >>> x = BlockArray.ones( ( (4, 4), (2,))) - >>> x.size - 18 - - >>> y = snp.ones((3, 3)) - >>> y.size - 9 - - >>> z = BlockArray.array([x, y]) - >>> block_sizes(z.shape) - (18, 9) - - >>> zz = BlockArray.array([z, z]) - >>> block_sizes(zz.shape) - (27, 27) - """ - - if isinstance(shape, BlockArray): - raise TypeError( - "Expected a `shape` (possibly nested tuple of ints); got :class:`.BlockArray`." - ) - - out = [] - if array.is_nested(shape): - # shape is nested -> at least one element came from a blockarray - for y in shape: - if array.is_nested(y): - # recursively calculate the block size until we arrive at - # a tuple (shape of a non-block array) - while array.is_nested(y): - y = block_sizes(y) - out.append(np.sum(y)) # adjacent block sizes are added together - else: - # this is a tuple; size given by product of elements - out.append(np.prod(y)) - return tuple(out) - - # shape is a non-nested tuple; return the product - return np.prod(shape) - - def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple: """Decompose a BlockArray indexing expression into components. diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 1460cd440..bdeb06450 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -14,8 +14,7 @@ wrapped. Functions that have not been wrapped yet have WARNING text in their documentation, below. """ - - +import sys from functools import wraps from inspect import Parameter, signature from types import FunctionType, ModuleType @@ -112,5 +111,9 @@ def mapped(*args, **kwargs): _copy_attributes( vars(), jnp.__dict__, - modules_to_recurse=("linalg",), + modules_to_recurse=("linalg", "fft"), ) + +# enable `import scico.numpy.linalg` and `from scico.numpy.linalg import norm` +sys.modules["scico.numpy.linalg"] = linalg +sys.modules["scico.numpy.fft"] = fft diff --git a/scico/numpy/fft.py b/scico/numpy/fft.py deleted file mode 100644 index 0577f6acd..000000000 --- a/scico/numpy/fft.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Construct wrapped versions of :mod:`jax.numpy.fft` functions. - -This modules consists of functions from :mod:`jax.numpy.fft`. Some of -these functions are wrapped to support compatibility with -:class:`scico.blockarray.BlockArray` and are documented here. -The remaining functions are imported directly from :mod:`jax.numpy.fft`. -While they can be imported from the :mod:`scico.numpy.fft` namespace, -they are not documented here; please consult the documentation for the -source module :mod:`jax.numpy.fft`. -""" -import sys - -import jax.numpy.fft - -from ._util import _attach_wrapped_func, _not_implemented - -_not_implemented_functions = [] -for name, func in jax._src.util.get_module_functions(jax.numpy.fft).items(): - if name not in globals(): - _not_implemented_functions.append((name, func)) - -_attach_wrapped_func( - _not_implemented_functions, _not_implemented, module_name=sys.modules[__name__] -) diff --git a/scico/numpy/linalg.py b/scico/numpy/linalg.py deleted file mode 100644 index a6729d0ea..000000000 --- a/scico/numpy/linalg.py +++ /dev/null @@ -1,108 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Construct wrapped versions of :mod:`jax.numpy.linalg` functions. - -This modules consists of functions from :mod:`jax.numpy.linalg`. Some of -these functions are wrapped to support compatibility with -:class:`scico.blockarray.BlockArray` and are documented here. The -remaining functions are imported directly from :mod:`jax.numpy.linalg`. -While they can be imported from the :mod:`scico.numpy.linalg` namespace, -they are not documented here; please consult the documentation for the -source module :mod:`jax.numpy.linalg`. -""" - - -import sys -from functools import wraps - -import jax -import jax.numpy.linalg as jla - -from scico.blockarray_old import _block_array_reduction_wrapper -from scico.linop._matrix import MatrixOperator - -from ._util import _attach_wrapped_func, _not_implemented - - -def _extract_if_matrix(x): - if isinstance(x, MatrixOperator): - return x.A - return x - - -def _matrixop_linalg_wrapper(func): - """Wrap :mod:`jax.numpy.linalg` functions. - - Wrap :mod:`jax.numpy.linalg` functions for joint operation on - `MatrixOperator` and `DeviceArray`.""" - - @wraps(func) - def wrapper(*args, **kwargs): - all_args = args + tuple(kwargs.items()) - if any([isinstance(_, MatrixOperator) for _ in all_args]): - args = [_extract_if_matrix(_) for _ in args] - kwargs = {key: _extract_if_matrix(val) for key, val in kwargs.items()} - return func(*args, **kwargs) - - if hasattr(func, "__doc__"): - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`.MatrixOperator`" - + "\n\n" - + func.__doc__ - ) - return wrapper - - -# norm is a reduction and gets both block array and matrixop wrapping -norm = _block_array_reduction_wrapper(_matrixop_linalg_wrapper(jla.norm)) - -svd = _matrixop_linalg_wrapper(jla.svd) -cond = _matrixop_linalg_wrapper(jla.cond) -det = _matrixop_linalg_wrapper(jla.det) -eig = _matrixop_linalg_wrapper(jla.eig) -eigh = _matrixop_linalg_wrapper(jla.eigh) -eigvals = _matrixop_linalg_wrapper(jla.eigvals) -eigvalsh = _matrixop_linalg_wrapper(jla.eigvalsh) -inv = _matrixop_linalg_wrapper(jla.inv) -lstsq = _matrixop_linalg_wrapper(jla.lstsq) -matrix_power = _matrixop_linalg_wrapper(jla.matrix_power) -matrix_rank = _matrixop_linalg_wrapper(jla.matrix_rank) -pinv = _matrixop_linalg_wrapper(jla.pinv) -qr = _matrixop_linalg_wrapper(jla.qr) -slogdet = _matrixop_linalg_wrapper(jla.slogdet) -solve = _matrixop_linalg_wrapper(jla.solve) - - -# multidot is somewhat unique -def multi_dot(arrays, *, precision=None): - """Compute the dot product of two or more arrays. - - Compute the dot product of two or more arrays. - Wrapped to work with `MatrixOperator`s. - """ - arrays_ = [_extract_if_matrix(_) for _ in arrays] - return jla.multi_dot(arrays_, precision=precision) - - -multi_dot.__doc__ = ( - f":func:`multi_dot` wrapped to operate on :class:`.MatrixOperator`" - + "\n\n" - + jla.multi_dot.__doc__ -) - - -# Attach unwrapped functions -# jla.tensorinv, jla.tensorsolve use n-dim arrays; not supported by MatrixOperator -_not_implemented_functions = [] -for name, func in jax._src.util.get_module_functions(jla).items(): - if name not in globals(): - _not_implemented_functions.append((name, func)) - -_attach_wrapped_func( - _not_implemented_functions, _not_implemented, module_name=sys.modules[__name__] -) diff --git a/scico/scipy/special.py b/scico/scipy/special.py index b65004cd7..f30cbcb4b 100644 --- a/scico/scipy/special.py +++ b/scico/scipy/special.py @@ -16,7 +16,7 @@ :mod:`jax.scipy.special`. """ - +""" import sys import jax @@ -70,3 +70,13 @@ _attach_wrapped_func( _not_implemented_functions, _not_implemented, module_name=sys.modules[__name__] ) +""" + +import jax.scipy.special as js + +import scico.numpy as snp + +snp._copy_attributes( + vars(), + js.__dict__, +) diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 96e3bf7f9..d69087ba8 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -33,6 +33,60 @@ def test_prox_obj(request): return ProxTestObj(request.param) +class SeparableTestObject: + def __init__(self, dtype): + self.f = functional.L1Norm() + self.g = functional.SquaredL2Norm() + self.fg = functional.SeparableFunctional([self.f, self.g]) + + n = 4 + m = 6 + key = None + + self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval + self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval + self.vb = BlockArray.array([self.v1, self.v2]) + + +@pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128]) +def test_separable_obj(request): + return SeparableTestObject(request.param) + + +def test_separable_eval(test_separable_obj): + fv1 = test_separable_obj.f(test_separable_obj.v1) + gv2 = test_separable_obj.g(test_separable_obj.v2) + fgv = test_separable_obj.fg(test_separable_obj.vb) + np.testing.assert_allclose(fv1 + gv2, fgv, rtol=5e-2) + + +def test_separable_prox(test_separable_obj): + alpha = 0.1 + fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha) + gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) + fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) + out = BlockArray.array((fv1, gv2)).full_ravel() + np.testing.assert_allclose(out, fgv.full_ravel(), rtol=5e-2) + + +def test_separable_grad(test_separable_obj): + # Used to restore the warnings after the context is used + with warnings.catch_warnings(): + # Ignores warning raised by ensure_on_device + warnings.filterwarnings(action="ignore", category=UserWarning) + + # Verifies that there is a warning on f.grad and fg.grad + np.testing.assert_warns(test_separable_obj.f.grad(test_separable_obj.v1)) + np.testing.assert_warns(test_separable_obj.fg.grad(test_separable_obj.vb)) + + # Tests the separable grad with warnings being supressed + fv1 = test_separable_obj.f.grad(test_separable_obj.v1) + gv2 = test_separable_obj.g.grad(test_separable_obj.v2) + fgv = test_separable_obj.fg.grad(test_separable_obj.vb) + out = BlockArray.array((fv1, gv2)).full_ravel() + np.testing.assert_allclose(out, fgv.full_ravel(), rtol=5e-2) + + class TestNormProx: alphalist = [1e-2, 1e-1, 1e0, 1e1] @@ -73,9 +127,9 @@ def test_prox_blockarray(self, norm, alpha, test_prox_obj): nrmobj = norm() nrm = nrmobj.__call__ prx = nrmobj.prox - pf = nrmobj.prox(test_prox_obj.vb.ravel(), alpha) + pf = nrmobj.prox(test_prox_obj.vb.full_ravel(), alpha) pf_b = nrmobj.prox(test_prox_obj.vb, alpha) - np.testing.assert_allclose(pf, pf_b.ravel()) + np.testing.assert_allclose(pf, pf_b.full_ravel()) @pytest.mark.parametrize("norm", normlist) def test_prox_zeros(self, norm, test_prox_obj): diff --git a/scico/test/test_array.py b/scico/test/test_array.py index b838ce0c0..1b44b595a 100644 --- a/scico/test/test_array.py +++ b/scico/test/test_array.py @@ -40,7 +40,9 @@ def test_ensure_on_device(): assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer() assert isinstance(BA_, BlockArray) - assert BA._data.unsafe_buffer_pointer() == BA_._data.unsafe_buffer_pointer() + assert isinstance(BA_[0], DeviceArray) + assert isinstance(BA_[1], DeviceArray) + assert BA[1].unsafe_buffer_pointer() == BA_[1].unsafe_buffer_pointer() np.testing.assert_raises(TypeError, ensure_on_device, [1, 1, 1]) @@ -64,7 +66,7 @@ def test_no_nan_divide_blockarray(): x, key = randn(((3, 3), (4,)), dtype=np.float32) y, key = randn(x.shape, dtype=np.float32, key=key) - y = y.at[1].set(0 * y[1]) + y[1] = y[1].at[:].set(0 * y[1]) res = no_nan_divide(x, y) From 807098457fd602935a65d4e7c4877979b0e607d8 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 7 Apr 2022 09:37:24 -0600 Subject: [PATCH 08/37] Remove automatic reduction wrapping --- scico/blockarray.py | 4 +- scico/numpy/__init__.py | 114 +++++-------------------- scico/numpy/_util.py | 181 ++++++++++++++++++++++------------------ 3 files changed, 119 insertions(+), 180 deletions(-) diff --git a/scico/blockarray.py b/scico/blockarray.py index 0b863a575..933ee2524 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -561,6 +561,7 @@ def op_ba(self, other): def _da_prop_wrapper(prop): @property + @wraps(prop) def prop_ba(self): result = tuple(getattr(x, prop) for x in self) if isinstance(result[0], jnp.ndarray): @@ -571,7 +572,6 @@ def prop_ba(self): skip_props = ("at",) - da_props = [ k for k, v in dict(inspect.getmembers(DeviceArray)).items() @@ -585,6 +585,7 @@ def prop_ba(self): def _da_method_wrapper(method): + @wraps(method) def method_ba(self, *args, **kwargs): result = tuple(getattr(x, method)(*args, **kwargs) for x in self) if isinstance(result[0], jnp.ndarray): @@ -595,7 +596,6 @@ def method_ba(self, *args, **kwargs): skip_methods = () - da_methods = [ k for k, v in dict(inspect.getmembers(DeviceArray)).items() diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index bdeb06450..b974072ba 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -14,106 +14,30 @@ wrapped. Functions that have not been wrapped yet have WARNING text in their documentation, below. """ -import sys -from functools import wraps -from inspect import Parameter, signature -from types import FunctionType, ModuleType -from typing import Iterable, Optional +import numpy as np import jax.numpy as jnp -from jaxlib.xla_extension import CompiledFunction +from . import _util +from .blockarray import BlockArray -from scico.array import is_nested -from scico.blockarray import BlockArray - - -def _copy_attributes( - to_dict: dict, from_dict: dict, modules_to_recurse: Optional[Iterable[str]] = None -): - """Add attributes in `from_dict` to `to_dict`. - - Underscore methods are ignored. Functions are wrapped to allow for - `BlockArray` inputs. Modules are ignored, except those listed in - `modules_to_recurse`, which are added recursively. All others are - passed through unwrapped. - - """ - - if modules_to_recurse is None: - modules_to_recurse = () - - for name, obj in from_dict.items(): - if name[0] == "_": - continue - elif isinstance(obj, ModuleType) and name in modules_to_recurse: - to_dict[name] = ModuleType(name) - to_dict[name].__package__ = to_dict["__name__"] - to_dict[name].__doc__ = obj.__doc__ - _copy_attributes(to_dict[name].__dict__, obj.__dict__) - elif isinstance(obj, (FunctionType, CompiledFunction)): - obj = _map_func_over_ba(obj) - to_dict[name] = obj - else: - to_dict[name] = obj - - -def _map_func_over_ba(func): - """Create a version of `func` that maps over all of its `BlockArray` - arguments. - - Functions with an `axis` parameter are handled in a special way in - order to allow full reductions of `BlockArray`s. If the axis - parameter exists but is not specified, each `BlockArray` argument - is fully ravelled before the function is called and no mapping is - applied. - - """ - - @wraps(func) - def mapped(*args, **kwargs): - sig = signature(func) - bound_args = sig.bind(*args, **kwargs) - - ba_args = {} - for k, v in list(bound_args.arguments.items()): - if isinstance(v, BlockArray) or is_nested(v): - ba_args[k] = bound_args.arguments.pop(k) - - ravel_blocks = "axis" not in bound_args.arguments and "axis" in sig.parameters - - if len(ba_args) and ravel_blocks: - ba_args = {k: v.full_ravel() for k, v in list(ba_args.items())} - return func(*bound_args.args, **bound_args.kwargs, **ba_args) - - if len(ba_args): # if any BlockArray arguments, - result = tuple( - map( # map over - lambda *args: ( # lambda x_1, x_2, ..., x_N - func( - *bound_args.args, - **bound_args.kwargs, # ... nonBlockArray args - **dict(zip(ba_args.keys(), args)), - ) # plus dict of block args - ), - *ba_args.values(), # map(f, ba_1, ba_2, ..., ba_N) - ) - ) - if isinstance(result[0], jnp.ndarray): # True for abstract arrays, too - return BlockArray(result) - return result - - return func(*args, **kwargs) +# wrap jnp +_util.wrap_attributes( + to_dict=vars(), + from_dict=jnp.__dict__, + modules_to_recurse=("linalg", "fft"), + reductions=("sum",), +) - return mapped +# wrap np.testing +_util.wrap_attributes( + to_dict=vars(), + from_dict={k: v for k, v in np.__dict__.items() if k == "testing"}, + modules_to_recurse=("testing"), +) -_copy_attributes( - vars(), - jnp.__dict__, - modules_to_recurse=("linalg", "fft"), -) +__all__ = ["BlockArray"] -# enable `import scico.numpy.linalg` and `from scico.numpy.linalg import norm` -sys.modules["scico.numpy.linalg"] = linalg -sys.modules["scico.numpy.fft"] = fft +# clean up +del np, jnp, _util diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index af38cdb1e..99106462c 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -1,93 +1,108 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Tools to construct wrapped versions of :mod:`jax.numpy` functions.""" +""" +Utilities for wrapping jnp functions to handle BlockArray inputs. +""" -import re -import types +import sys from functools import wraps +from inspect import signature +from types import ModuleType +from typing import Callable, Iterable, Optional -import numpy as np +import jax.numpy as jnp -from jaxlib.xla_extension import CompiledFunction +from scico.array import is_nested -# wrapper for not-implemented jax.numpy functions -# stripped down version of jax._src.lax_numpy._not_implemented and jax.utils._wraps +from .blockarray import BlockArray -_NOT_IMPLEMENTED_DESC = """ -**WARNING**: This function is not yet implemented by :mod:`scico.numpy` and -may raise an error when operating on :class:`scico.blockarray.BlockArray`. -""" +def wrap_attributes( + to_dict: dict, + from_dict: dict, + modules_to_recurse: Optional[Iterable[str]] = None, + reductions: Optional[Iterable[str]] = None, +): + """Add attributes in `from_dict` to `to_dict`. + + Underscore attributes are ignored. Functions are wrapped to allow for + `BlockArray` inputs. Modules are ignored, except those listed in + `modules_to_recurse`, which are added recursively. All others are + passed through unwrapped. -def _not_implemented(fun): - @wraps(fun) - def wrapped(*args, **kwargs): - return fun(*args, **kwargs) - - if not hasattr(fun, "__doc__") or fun.__doc__ is None: - return wrapped - - # wrapped.__doc__ = fun.__doc__ + "\n\n" + _NOT_IMPLEMENTED_DESC - wrapped.__doc__ = re.sub( - r"^\*Original docstring below\.\*", - _NOT_IMPLEMENTED_DESC + r"\n\n" + "*Original docstring below.*", - wrapped.__doc__, - flags=re.M, - ) - return wrapped - - -def _attach_wrapped_func(funclist, wrapper, module_name, fix_mod_name=False): - # funclist is either a function, or a tuple (name-in-this-module, function) - # wrapper is a function that is applied to each function in funclist, with - # the output being assigned as an attribute of the module `module_name` - for func in funclist: - # Test required because func.__name__ isn't always the name we want - # e.g. jnp.abs.__name__ resolves to 'absolute', not 'abs' - if isinstance(func, tuple): - fname = func[0] - fref = func[1] - else: - fname = func.__name__ - fref = func - # Set wrapped function as an attribute in module_name - setattr(module_name, fname, wrapper(fref)) - # Set __module__ attribute of wrapped function to - # module_name.__name__ (i.e., scico.numpy) so that it does not - # appear to autodoc be an imported function - if fix_mod_name: - getattr(module_name, fname).__module__ = module_name.__name__ - - -def _get_module_functions(module): - """Finds functions in module. - - This function is a slightly modified version of - :func:`jax._src.util.get_module_functions`. Unlike the JAX version, - this version will also return any - :class:`jaxlib.xla_extension.CompiledFunction`s that exist in the - module. - - Args: - module: A Python module. - Returns: - module_fns: A dict of names mapped to functions, builtins or - ufuncs in `module`. """ - module_fns = {} - for key in dir(module): - # Omitting module level __getattr__, __dir__ which was added in Python 3.7 - # https://www.python.org/dev/peps/pep-0562/ - if key in ("__getattr__", "__dir__"): + + if modules_to_recurse is None: + modules_to_recurse = () + + if reductions is None: + reductions = () + + for name, obj in from_dict.items(): + if name[0] == "_": continue - attr = getattr(module, key) - if isinstance( - attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc, CompiledFunction) - ): - module_fns[key] = attr - return module_fns + if isinstance(obj, ModuleType) and name in modules_to_recurse: + to_dict[name] = ModuleType(name) + to_dict[name].__package__ = to_dict["__name__"] + to_dict[name].__doc__ = obj.__doc__ + # enable `import scico.numpy.linalg` and `from scico.numpy.linalg import norm` + sys.modules[to_dict["__name__"] + "." + name] = to_dict[name] + wrap_attributes(to_dict[name].__dict__, obj.__dict__) + elif isinstance(obj, Callable): + obj = map_func_over_ba(obj, is_reduction=name in reductions) + to_dict[name] = obj + else: + to_dict[name] = obj + + +def map_func_over_ba(func, is_reduction=False): + """Create a version of `func` that maps over all of its `BlockArray` + arguments. + + is_reduction: function is handled in a special way in order to allow + full reductions of `BlockArray`s. If the axis parameter exists but + is not specified, the function is mapped over the blocks, then + called again on the stacked result. + """ + + @wraps(func) + def mapped(*args, **kwargs): + sig = signature(func) + bound_args = sig.bind(*args, **kwargs) + + ba_args = {} + for k, v in list(bound_args.arguments.items()): + if isinstance(v, BlockArray) or is_nested(v): + ba_args[k] = bound_args.arguments.pop(k) + + if not len(ba_args): # no BlockArray arguments + return func(*args, **kwargs) + + result = tuple( + map( # map over + lambda *args: ( # lambda x_1, x_2, ..., x_N + func( + *bound_args.args, + **bound_args.kwargs, # ... nonBlockArray args + **dict(zip(ba_args.keys(), args)), + ) # plus dict of block args + ), + *ba_args.values(), # map(f, ba_1, ba_2, ..., ba_N) + ) + ) + + if is_reduction and "axis" not in bound_args.arguments: + if len(ba_args) > 1: + raise ValueError( + "Cannot perform a full reduction with multiple BlockArray arguments." + ) + return func( + *bound_args.args, + **bound_args.kwargs, + **{list(ba_args.keys())[0]: jnp.stack(result)}, + ) + + if isinstance(result[0], jnp.ndarray): # True for abstract arrays, too + return BlockArray(result) + + return result + + return mapped From 2d547270c0a8f7ea6b1fa7a339294fb4d3ddb1c4 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 8 Apr 2022 14:22:52 -0600 Subject: [PATCH 09/37] Add files --- scico/numpy/_util_old.py | 93 +++++++++++++++++++++++++++++++ scico/numpy/blockarray.py | 9 +++ scico/test/test_blockarray_new.py | 24 ++++++++ 3 files changed, 126 insertions(+) create mode 100644 scico/numpy/_util_old.py create mode 100644 scico/numpy/blockarray.py create mode 100644 scico/test/test_blockarray_new.py diff --git a/scico/numpy/_util_old.py b/scico/numpy/_util_old.py new file mode 100644 index 000000000..af38cdb1e --- /dev/null +++ b/scico/numpy/_util_old.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2020-2022 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Tools to construct wrapped versions of :mod:`jax.numpy` functions.""" + +import re +import types +from functools import wraps + +import numpy as np + +from jaxlib.xla_extension import CompiledFunction + +# wrapper for not-implemented jax.numpy functions +# stripped down version of jax._src.lax_numpy._not_implemented and jax.utils._wraps + +_NOT_IMPLEMENTED_DESC = """ +**WARNING**: This function is not yet implemented by :mod:`scico.numpy` and +may raise an error when operating on :class:`scico.blockarray.BlockArray`. +""" + + +def _not_implemented(fun): + @wraps(fun) + def wrapped(*args, **kwargs): + return fun(*args, **kwargs) + + if not hasattr(fun, "__doc__") or fun.__doc__ is None: + return wrapped + + # wrapped.__doc__ = fun.__doc__ + "\n\n" + _NOT_IMPLEMENTED_DESC + wrapped.__doc__ = re.sub( + r"^\*Original docstring below\.\*", + _NOT_IMPLEMENTED_DESC + r"\n\n" + "*Original docstring below.*", + wrapped.__doc__, + flags=re.M, + ) + return wrapped + + +def _attach_wrapped_func(funclist, wrapper, module_name, fix_mod_name=False): + # funclist is either a function, or a tuple (name-in-this-module, function) + # wrapper is a function that is applied to each function in funclist, with + # the output being assigned as an attribute of the module `module_name` + for func in funclist: + # Test required because func.__name__ isn't always the name we want + # e.g. jnp.abs.__name__ resolves to 'absolute', not 'abs' + if isinstance(func, tuple): + fname = func[0] + fref = func[1] + else: + fname = func.__name__ + fref = func + # Set wrapped function as an attribute in module_name + setattr(module_name, fname, wrapper(fref)) + # Set __module__ attribute of wrapped function to + # module_name.__name__ (i.e., scico.numpy) so that it does not + # appear to autodoc be an imported function + if fix_mod_name: + getattr(module_name, fname).__module__ = module_name.__name__ + + +def _get_module_functions(module): + """Finds functions in module. + + This function is a slightly modified version of + :func:`jax._src.util.get_module_functions`. Unlike the JAX version, + this version will also return any + :class:`jaxlib.xla_extension.CompiledFunction`s that exist in the + module. + + Args: + module: A Python module. + Returns: + module_fns: A dict of names mapped to functions, builtins or + ufuncs in `module`. + """ + module_fns = {} + for key in dir(module): + # Omitting module level __getattr__, __dir__ which was added in Python 3.7 + # https://www.python.org/dev/peps/pep-0562/ + if key in ("__getattr__", "__dir__"): + continue + attr = getattr(module, key) + if isinstance( + attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc, CompiledFunction) + ): + module_fns[key] = attr + return module_fns diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py new file mode 100644 index 000000000..26a0d83e3 --- /dev/null +++ b/scico/numpy/blockarray.py @@ -0,0 +1,9 @@ +""" +BlockArray. +""" + + +class BlockArray(list): + """BlockArray.""" + + ... diff --git a/scico/test/test_blockarray_new.py b/scico/test/test_blockarray_new.py new file mode 100644 index 000000000..d4c9afb91 --- /dev/null +++ b/scico/test/test_blockarray_new.py @@ -0,0 +1,24 @@ +import scico.numpy as snp + +# from scico.random import randn + + +# from scico.blockarray import BlockArray + +x = snp.BlockArray( + ( + snp.ones((3, 4)), + snp.arange(4), + ) +) + +y = snp.BlockArray( + ( + 2 * snp.ones((3, 4)), + snp.arange(4), + ) +) + +snp.sum(x) + +snp.testing.assert_allclose(x, y) From 6fbe32f2164885fa4b56497fc2e75864308bcf47 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 8 Apr 2022 17:56:59 -0600 Subject: [PATCH 10/37] Move BlockArray to scico.numpy, like DeviceArray in jax.numpy --- examples/scripts/denoise_tv_iso_pgm.py | 2 +- examples/scripts/sparsecode_poisson_pgm.py | 4 +- scico/_flax.py | 2 +- scico/_generic_operators.py | 7 +- scico/array.py | 166 +----- scico/blockarray.py | 609 --------------------- scico/blockarray_old.py | 10 +- scico/functional/_functional.py | 2 +- scico/functional/_indicator.py | 2 +- scico/functional/_norm.py | 4 +- scico/linop/_circconv.py | 2 +- scico/linop/_linop.py | 2 +- scico/linop/_stack.py | 2 +- scico/linop/optics.py | 2 +- scico/loss.py | 4 +- scico/metric.py | 2 +- scico/numpy/__init__.py | 8 +- scico/numpy/_create.py | 3 +- scico/numpy/_util.py | 34 +- scico/numpy/_util_old.py | 2 +- scico/numpy/blockarray.py | 608 +++++++++++++++++++- scico/operator/biconvolve.py | 3 +- scico/optimize/_ladmm.py | 2 +- scico/optimize/_primaldual.py | 2 +- scico/optimize/admm.py | 4 +- scico/optimize/pgm.py | 2 +- scico/random.py | 8 +- scico/scipy/special.py | 8 +- scico/solver.py | 2 +- scico/test/functional/test_core.py | 9 +- scico/test/functional/test_loss.py | 4 +- scico/test/linop/test_diff.py | 2 +- scico/test/linop/test_linop.py | 2 +- scico/test/optimize/test_ladmm.py | 2 +- scico/test/optimize/test_pdhg.py | 2 +- scico/test/test_array.py | 9 +- scico/test/test_biconvolve.py | 2 +- scico/test/test_blockarray.py | 2 +- scico/test/test_blockarray_new.py | 24 - scico/test/test_numpy.py | 2 +- scico/test/test_operator.py | 2 +- scico/test/test_random.py | 4 +- scico/test/test_solver.py | 2 +- 43 files changed, 692 insertions(+), 884 deletions(-) delete mode 100644 scico/blockarray.py delete mode 100644 scico/test/test_blockarray_new.py diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index dc0bb39e2..34f29ad2e 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -40,7 +40,7 @@ import scico.random from scico import functional, linop, loss, operator, plot from scico.array import ensure_on_device -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize from scico.typing import JaxArray from scico.util import device_info diff --git a/examples/scripts/sparsecode_poisson_pgm.py b/examples/scripts/sparsecode_poisson_pgm.py index 7479d01d3..64191671c 100644 --- a/examples/scripts/sparsecode_poisson_pgm.py +++ b/examples/scripts/sparsecode_poisson_pgm.py @@ -22,7 +22,7 @@ $I(\mathbf{x}^{(0)} \geq 0)$ is the non-negative indicator. This example also demonstrates the application of -[blockarray.BlockArray](../_autosummary/scico.blockarray.rst#scico.blockarray.BlockArray), +[blockarray.BlockArray](../_autosummary/scico.numpy.rst#scico.numpy.BlockArray), [functional.SeparableFunctional](../_autosummary/scico.functional.rst#scico.functional.SeparableFunctional), and [functional.ZeroFunctional](../_autosummary/scico.functional.rst#scico.functional.ZeroFunctional) @@ -40,7 +40,7 @@ import scico.numpy as snp import scico.random from scico import functional, loss, plot -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.operator import Operator from scico.optimize.pgm import ( AcceleratedPGM, diff --git a/scico/_flax.py b/scico/_flax.py index 5a2a67a18..bda3facc7 100644 --- a/scico/_flax.py +++ b/scico/_flax.py @@ -17,7 +17,7 @@ from flax.core import Scope # noqa from flax.linen.module import _Sentinel # noqa -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray # The imports of Scope and _Sentinel (above) and the definition of Module diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index e5f484082..139b683f2 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -23,8 +23,7 @@ import scico.numpy as snp from scico._autograd import linear_adjoint -from scico.array import block_sizes, is_complex_dtype, is_nested -from scico.blockarray import BlockArray +from scico.numpy import BlockArray, is_complex_dtype, is_nested, shape_to_size from scico.typing import BlockShape, DType, JaxArray, Shape @@ -152,8 +151,8 @@ def __init__( # Determine the shape of the "vectorized" operator (as an element of ℝ^{n × m} # If the function returns a BlockArray we need to compute the size of each block, # then sum. - self.input_size = int(np.sum(block_sizes(self.input_shape))) - self.output_size = int(np.sum(block_sizes(self.output_shape))) + self.input_size = shape_to_size(self.input_shape) + self.output_size = shape_to_size(self.output_shape) self.shape = (self.output_shape, self.input_shape) self.matrix_shape = (self.output_size, self.input_size) diff --git a/scico/array.py b/scico/array.py index 451fa737f..6cf940b03 100644 --- a/scico/array.py +++ b/scico/array.py @@ -11,7 +11,7 @@ from __future__ import annotations import warnings -from typing import Any, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np @@ -19,9 +19,8 @@ from jax.interpreters.pxla import ShardedDeviceArray from jax.interpreters.xla import DeviceArray -import scico.numpy as snp -from scico.blockarray import BlockArray -from scico.typing import ArrayIndex, Axes, AxisIndex, DType, JaxArray, Shape +from scico.numpy import BlockArray +from scico.typing import ArrayIndex, Axes, AxisIndex, JaxArray, Shape def ensure_on_device( @@ -74,22 +73,6 @@ def ensure_on_device( return arrays -def no_nan_divide( - x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] -) -> Union[BlockArray, JaxArray]: - """Return `x/y`, with 0 instead of NaN where `y` is 0. - - Args: - x: Numerator. - y: Denominator. - - Returns: - `x / y` with 0 wherever `y == 0`. - """ - - return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) - - def parse_axes( axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None ) -> List[int]: @@ -191,146 +174,3 @@ def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: continue idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) return tuple(filter(lambda x: x is not None, idx_shape)) - - -def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: - r"""Compute the 'sizes' of block shapes. - - This function computes ``block_sizes(z.shape) == (_.size for _ in z)`` - - - Args: - shape: A shape tuple; possibly containing nested tuples. - - - Examples: - - .. doctest:: - >>> import scico.numpy as snp - - >>> x = BlockArray.ones( ( (4, 4), (2,))) - >>> x.size - 18 - - >>> y = snp.ones((3, 3)) - >>> y.size - 9 - - >>> z = BlockArray.array([x, y]) - >>> block_sizes(z.shape) - (18, 9) - - >>> zz = BlockArray.array([z, z]) - >>> block_sizes(zz.shape) - (27, 27) - """ - - if isinstance(shape, BlockArray): - raise TypeError( - "Expected a `shape` (possibly nested tuple of ints); got :class:`.BlockArray`." - ) - - out = [] - if is_nested(shape): - # shape is nested -> at least one element came from a blockarray - for y in shape: - if is_nested(y): - # recursively calculate the block size until we arrive at - # a tuple (shape of a non-block array) - while is_nested(y): - y = block_sizes(y) - out.append(np.sum(y)) # adjacent block sizes are added together - else: - # this is a tuple; size given by product of elements - out.append(np.prod(y)) - return tuple(out) - - # shape is a non-nested tuple; return the product - return np.prod(shape) - - -def is_nested(x: Any) -> bool: - """Check if input is a list/tuple containing at least one list/tuple. - - Args: - x: Object to be tested. - - Returns: - ``True`` if `x` is a list/tuple of list/tuples, otherwise - ``False``. - - - Example: - >>> is_nested([1, 2, 3]) - False - >>> is_nested([(1,2), (3,)]) - True - >>> is_nested([[1, 2], 3]) - True - - """ - if isinstance(x, (list, tuple)): - return any([isinstance(_, (list, tuple)) for _ in x]) - return False - - -def is_real_dtype(dtype: DType) -> bool: - """Determine whether a dtype is real. - - Args: - dtype: A numpy or scico.numpy dtype (e.g. ``np.float32``, - ``np.complex64``). - - Returns: - ``False`` if the dtype is complex, otherwise ``True``. - """ - return snp.dtype(dtype).kind != "c" - - -def is_complex_dtype(dtype: DType) -> bool: - """Determine whether a dtype is complex. - - Args: - dtype: A numpy or scico.numpy dtype (e.g. ``np.float32``, - ``np.complex64``). - - Returns: - ``True`` if the dtype is complex, otherwise ``False``. - """ - return snp.dtype(dtype).kind == "c" - - -def real_dtype(dtype: DType) -> DType: - """Construct the corresponding real dtype for a given complex dtype. - - Construct the corresponding real dtype for a given complex dtype, - e.g. the real dtype corresponding to ``np.complex64`` is - ``np.float32``. - - Args: - dtype: A complex numpy or scico.numpy dtype (e.g. ``np.complex64``, - ``np.complex128``). - - Returns: - The real dtype corresponding to the input dtype - """ - - return snp.zeros(1, dtype).real.dtype - - -def complex_dtype(dtype: DType) -> DType: - """Construct the corresponding complex dtype for a given real dtype. - - Construct the corresponding complex dtype for a given real dtype, - e.g. the complex dtype corresponding to ``np.float32`` is - ``np.complex64``. - - Args: - dtype: A real numpy or scico.numpy dtype (e.g. ``np.float32``, - ``np.float64``). - - Returns: - The complex dtype corresponding to the input dtype. - """ - - return (snp.zeros(1, dtype) + 1j).dtype diff --git a/scico/blockarray.py b/scico/blockarray.py deleted file mode 100644 index 933ee2524..000000000 --- a/scico/blockarray.py +++ /dev/null @@ -1,609 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SPORCO package. Details of the copyright -# and user license can be found in the 'LICENSE.txt' file distributed -# with the package. - -r"""Extensions of numpy ndarray class. - - .. testsetup:: - - >>> import scico - >>> import scico.numpy as snp - >>> from scico.blockarray import BlockArray - >>> import numpy as np - >>> import jax.numpy - -The class :class:`.BlockArray` provides a way to combine several arrays -of different shapes and/or data types into a single object. -A :class:`.BlockArray` consists of a list of `DeviceArray` -objects, which we refer to as blocks. -:class:`.BlockArray`s differ from tuples in that mathematical operations -on :class:`.BlockArray`s automatically map along the blocks, returning -another :class:`.BlockArray` or tuple as appropriate. For example, - - :: - - >>> x = BlockArray(( - snp.array( - [[1, 3, 7], - [2, 2, 1],] - ), - snp.array( - [2, 4, 8] - ), - )) - >>> x.shape - ((2, 3), (3,)) # tuple - - >>> x + 1 - (DeviceArray([[2, 4, 8], - [3, 3, 2]], dtype=int32), - DeviceArray([3, 5, 9], dtype=int32)) # BlockArray - - -TODO: not specifying axis to get a full reduction -TODO: using a BlockArray for axis or shape arguments - - - - -Motivating Example -================== - -Consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. - -We compute the discrete differences of :math:`\mb{x}` in the horizontal -and vertical directions, generating two new arrays: :math:`\mb{x}_h \in -\mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) -\times m}`. - -As these arrays are of different shapes, we cannot combine them into a -single `ndarray`. Instead, we might vectorize each array and concatenate -the resulting vectors, leading to :math:`\mb{\bar{x}} \in -\mathbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional -`ndarray`. Unfortunately, this makes it hard to access the individual -components :math:`\mb{x}_h` and :math:`\mb{x}_v`. - -Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B = -[\mb{x}_h, \mb{x}_v]` - - - :: - - >>> n = 32 - >>> m = 16 - >>> x_h, key = scico.random.randn((n, m-1)) - >>> x_v, _ = scico.random.randn((n-1, m), key=key) - - # Form the blockarray - >>> x_B = BlockArray.array([x_h, x_v]) - - # The blockarray shape is a tuple of tuples - >>> x_B.shape - ((32, 15), (31, 16)) - - # Each block component can be easily accessed - >>> x_B[0].shape - (32, 15) - >>> x_B[1].shape - (31, 16) - - -Constructing a BlockArray -========================= - -Construct from a tuple of arrays (either `ndarray` or `DeviceArray`) --------------------------------------------------------------------- - - .. doctest:: - - >>> from scico.blockarray import BlockArray - >>> import numpy as np - >>> x0, key = scico.random.randn((32, 32)) - >>> x1, _ = scico.random.randn((16,), key=key) - >>> X = BlockArray.array((x0, x1)) - >>> X.shape - ((32, 32), (16,)) - >>> X.size - 1040 - >>> X.num_blocks - 2 - -While :func:`.BlockArray.array` will accept either `ndarray` or -`DeviceArray` as input, the resulting :class:`.BlockArray` will be backed -by a `DeviceArray` memory buffer. - -**Note**: constructing a :class:`.BlockArray` always involves a copy to -a new `DeviceArray` memory buffer. - -**Note**: by default, the resulting :class:`.BlockArray` is cast to -single precision and will have dtype ``float32`` or ``complex64``. - - -Construct from a single vector and tuple of shapes --------------------------------------------------- - - :: - - >>> x_flat, _ = scico.random.randn((1040,)) - >>> shape_tuple = ((32, 32), (16,)) - >>> X = BlockArray.array_from_flattened(x_flat, shape_tuple=shape_tuple) - >>> X.shape - ((32, 32), (16,)) - - - -Operating on a BlockArray -========================= - -.. _blockarray_indexing: - -Indexing --------- - -The block index is required to be an integer, selecting a single block and -returning it as an array (*not* a singleton BlockArray). If the index -expression has more than one component, then the initial index indexes the -block, and the remainder of the indexing expression indexes within the -selected block, e.g. `x[2, 3:4]` is equivalent to `y[3:4]` after -setting `y = x[2]`. - - -Indexed Updating ----------------- - -BlockArrays support the JAX DeviceArray `indexed update syntax -`_ - - -The index must be of the form [ibk] or [ibk, idx], where `ibk` is the -index of the block to be updated, and `idx` is a general index of the -elements to be updated in that block. In particular, `ibk` cannot be a -`slice`. The general index `idx` can be omitted, in which case an entire -block is updated. - - -============================== ============================================== -Alternate syntax Equivalent in-place expression -============================== ============================================== -`x.at[ibk, idx].set(y)` `x[ibk, idx] = y` -`x.at[ibk, idx].add(y)` `x[ibk, idx] += y` -`x.at[ibk, idx].multiply(y)` `x[ibk, idx] *= y` -`x.at[ibk, idx].divide(y)` `x[ibk, idx] /= y` -`x.at[ibk, idx].power(y)` `x[ibk, idx] **= y` -`x.at[ibk, idx].min(y)` `x[ibk, idx] = np.minimum(x[idx], y)` -`x.at[ibk, idx].max(y)` `x[ibk, idx] = np.maximum(x[idx], y)` -============================== ============================================== - - -Arithmetic and Broadcasting ---------------------------- - -Suppose :math:`\mb{x}` is a BlockArray with shape :math:`((n, n), (m,))`. - - :: - - >>> x1, key = scico.random.randn((4, 4)) - >>> x2, _ = scico.random.randn((5,), key=key) - >>> x = BlockArray.array( (x1, x2) ) - >>> x.shape - ((4, 4), (5,)) - >>> x.num_blocks - 2 - >>> x.size # 4*4 + 5 - 21 - -Illustrated for the operation `+`, but equally valid for operators -`+, -, *, /, //, **, <, <=, >, >=, ==` - - -Operations with BlockArrays with same number of blocks -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Let :math:`\mb{y}` be a BlockArray with the same number of blocks as -:math:`\mb{x}`. - - .. math:: - \mb{x} + \mb{y} - = - \begin{bmatrix} - \mb{x}[0] + \mb{y}[0] \\ - \mb{x}[1] + \mb{y}[1] \\ - \end{bmatrix} - -This operation depends on pair of blocks from :math:`\mb{x}` and -:math:`\mb{y}` being broadcastable against each other. - - - -Operations with a scalar -^^^^^^^^^^^^^^^^^^^^^^^^ - -The scalar is added to each element of the :class:`.BlockArray`: - - .. math:: - \mb{x} + 1 - = - \begin{bmatrix} - \mb{x}[0] + 1 \\ - \mb{x}[1] + 1\\ - \end{bmatrix} - - - :: - - >>> y = x + 1 - >>> np.testing.assert_allclose(y[0], x[0] + 1) - >>> np.testing.assert_allclose(y[1], x[1] + 1) - - - -Operations with a 1D `ndarray` of size equal to `num_blocks` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The *i*\th scalar is added to the *i*\th block of the -:class:`.BlockArray`: - - .. math:: - \mb{x} - + - \begin{bmatrix} - 1 \\ - 2 - \end{bmatrix} - = - \begin{bmatrix} - \mb{x}[0] + 1 \\ - \mb{x}[1] + 2\\ - \end{bmatrix} - - - :: - - >>> y = x + np.array([1, 2]) - >>> np.testing.assert_allclose(y[0], x[0] + 1) - >>> np.testing.assert_allclose(y[1], x[1] + 2) - - -Operations with an ndarray of `size` equal to :class:`.BlockArray` size -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We first cast the `ndarray` to a BlockArray with same shape as -:math:`\mb{x}`, then apply the operation on the resulting BlockArrays. -With `y.size = x.size`, we have: - - .. math:: - \mb{x} - + - \mb{y} - = - \begin{bmatrix} - \mb{x}[0] + \mb{y}[0] \\ - \mb{x}[1] + \mb{y}[1]\\ - \end{bmatrix} - -Equivalently, the BlockArray is first flattened, then added to the -flattened `ndarray`, and the result is reformed into a BlockArray with -the same shape as :math:`\mb{x}` - - - -MatMul ------- - -Between two BlockArrays -^^^^^^^^^^^^^^^^^^^^^^^ - -The matmul is computed between each block of the two BlockArrays. - -The BlockArrays must have the same number of blocks, and each pair of -blocks must be broadcastable. - - .. math:: - \mb{x} @ \mb{y} - = - \begin{bmatrix} - \mb{x}[0] @ \mb{y}[0] \\ - \mb{x}[1] @ \mb{y}[1]\\ - \end{bmatrix} - - - -Between BlockArray and Ndarray/DeviceArray -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This operation is not defined. - - -Between BlockArray and :class:`.LinearOperator` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :class:`.Operator` and :class:`.LinearOperator` classes are designed -to work on :class:`.BlockArray`\ s in addition to `DeviceArray`\ s. -For example - - - :: - - >>> x, key = scico.random.randn((3, 4)) - >>> A_1 = scico.linop.Identity(x.shape) - >>> A_1.shape # array -> array - ((3, 4), (3, 4)) - - >>> A_2 = scico.linop.FiniteDifference(x.shape) - >>> A_2.shape # array -> BlockArray - (((2, 4), (3, 3)), (3, 4)) - - >>> diag = BlockArray.array([np.array(1.0), np.array(2.0)]) - >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape)) - >>> A_3.shape # BlockArray -> BlockArray - (((2, 4), (3, 3)), ((2, 4), (3, 3))) - - -NumPy ufuncs ------------- - -`NumPy universal functions (ufuncs) `_ -are functions that operate on an `ndarray` on an element-by-element -fashion and support array broadcasting. Examples of ufuncs are `abs`, -`sign`, `conj`, and `exp`. - -The JAX library implements most NumPy ufuncs in the :mod:`jax.numpy` -module. However, as JAX does not support subclassing of `DeviceArray`, -the JAX ufuncs cannot be used on :class:`.BlockArray`. As a workaround, -we have wrapped several JAX ufuncs for use on :class:`.BlockArray`; these -are defined in the :mod:`scico.numpy` module. - - -Reductions -^^^^^^^^^^ - -Reductions are functions that take an array-like as an input and return -an array of lower dimension. Examples include `mean`, `sum`, `norm`. -BlockArray reductions are located in the :mod:`scico.numpy` module - -:class:`.BlockArray` tries to mirror `ndarray` reduction semantics where -possible, but cannot provide a one-to-one match as the block components -may be of different size. - -Consider the example BlockArray - - .. math:: - \mb{x} = \begin{bmatrix} - \begin{bmatrix} - 1 & 1 \\ - 1 & 1 - \end{bmatrix} \\ - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} - \end{bmatrix}. - -We have - - .. doctest:: - - >>> import scico.numpy as snp - >>> x = BlockArray.array((np.ones((2,2)), 2*np.ones((2)))) - >>> x.shape - ((2, 2), (2,)) - >>> x.size - 6 - >>> x.num_blocks - 2 - - - - If no axis is specified, the reduction is applied to the flattened - array: - - .. doctest:: - - >>> snp.sum(x, axis=None).item() - 8.0 - - Reducing along the 0-th axis crushes the `BlockArray` down into a - single `DeviceArray` and requires all blocks to have the same shape - otherwise, an error is raised. - - .. doctest:: - - >>> snp.sum(x, axis=0) - Traceback (most recent call last): - ValueError: Evaluating sum of BlockArray along axis=0 requires all blocks to be same shape; got ((2, 2), (2,)) - - >>> y = BlockArray.array((np.ones((2,2)), 2*np.ones((2, 2)))) - >>> snp.sum(y, axis=0) - DeviceArray([[3., 3.], - [3., 3.]], dtype=float32) - - Reducing along axis :math:`n` is equivalent to reducing each component - along axis :math:`n-1`: - - .. math:: - \text{sum}(x, axis=1) = \begin{bmatrix} - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} \\ - \begin{bmatrix} - 4 \\ - \end{bmatrix} - \end{bmatrix} - - - If a component does not have axis :math:`n-1`, the reduction is not - applied to that component. In this example, `x[1].ndim == 1`, so no - reduction is applied to block `x[1]`. - - .. math:: - \text{sum}(x, axis=2) = \begin{bmatrix} - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} \\ - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} - \end{bmatrix} - - -Code version - - .. doctest:: - - >>> snp.sum(x, axis=1) # doctest: +SKIP - BlockArray([[2, 2], - [4,] ]) - - >>> snp.sum(x, axis=2) # doctest: +SKIP - BlockArray([ [2, 2], - [2,] ]) - - -""" -import inspect -from functools import wraps -from typing import Callable - -import jax -import jax.numpy as jnp - -from jaxlib.xla_extension import DeviceArray - - -class BlockArray(list): - """BlockArray""" - - # Ensure we use BlockArray.__radd__, __rmul__, etc for binary - # operations of the form op(np.ndarray, BlockArray) See - # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ - __array_priority__ = 1 - - def full_ravel(self) -> DeviceArray: - """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. - - Note that a copy, rather than a view, of the underlying array is - returned. This is consistent with :func:`jax.numpy.ravel`. - - Returns: - Copy of underlying flattened array. - - """ - return jnp.concatenate(tuple(x_i.ravel() for x_i in self)) - - """ backwards compatibility methods, could be removed """ - - @staticmethod - def array(iterable): - """Construct a :class:`.BlockArray` from a list or tuple of existing array-like.""" - return BlockArray(iterable) - - -# register BlockArray as a jax pytree, without this, jax autograd won't work -# taken from what is done with tuples in jax._src.tree_util -jax.tree_util.register_pytree_node( - BlockArray, - lambda xs: (xs, None), # to iter - lambda _, xs: BlockArray(xs), # from iter -) - -""" Wrap binary ops like +, @. """ -binary_ops = ( - "__add__", - "__radd__", - "__sub__", - "__rsub__", - "__mul__", - "__rmul__", - "__matmul__", - "__rmatmul__", - "__truediv__", - "__rtruediv__", - "__floordiv__", - "__rfloordiv__", - "__pow__", - "__rpow__", - "__gt__", - "__ge__", - "__lt__", - "__le__", - "__eq__", - "__ne__", -) - - -def _binary_op_wrapper(op): - @wraps(op) - def op_ba(self, other): - if isinstance(other, BlockArray): - result = BlockArray(getattr(x, op)(y) for x, y in zip(self, other)) - else: - result = BlockArray(getattr(x, op)(other) for x in self) - if NotImplemented in result: - return NotImplemented - else: - return result - - return op_ba - - -for op in binary_ops: - setattr(BlockArray, op, _binary_op_wrapper(op)) - - -""" Wrap DeviceArray properties. """ - - -def _da_prop_wrapper(prop): - @property - @wraps(prop) - def prop_ba(self): - result = tuple(getattr(x, prop) for x in self) - if isinstance(result[0], jnp.ndarray): - return BlockArray(result) - return result - - return prop_ba - - -skip_props = ("at",) -da_props = [ - k - for k, v in dict(inspect.getmembers(DeviceArray)).items() - if isinstance(v, property) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_props -] - -for prop in da_props: - setattr(BlockArray, prop, _da_prop_wrapper(prop)) - -""" Wrap DeviceArray methods. """ - - -def _da_method_wrapper(method): - @wraps(method) - def method_ba(self, *args, **kwargs): - result = tuple(getattr(x, method)(*args, **kwargs) for x in self) - if isinstance(result[0], jnp.ndarray): - return BlockArray(result) - return result - - return method_ba - - -skip_methods = () -da_methods = [ - k - for k, v in dict(inspect.getmembers(DeviceArray)).items() - if isinstance(v, Callable) - and k[0] != "_" - and k not in dir(BlockArray) - and k not in skip_methods -] - -for method in da_methods: - setattr(BlockArray, method, _da_method_wrapper(method)) diff --git a/scico/blockarray_old.py b/scico/blockarray_old.py index ba71073bc..21772f138 100644 --- a/scico/blockarray_old.py +++ b/scico/blockarray_old.py @@ -11,7 +11,7 @@ >>> import scico >>> import scico.numpy as snp - >>> from scico.blockarray import BlockArray + >>> from scico.numpy import BlockArray >>> import numpy as np >>> import jax.numpy @@ -81,7 +81,7 @@ .. doctest:: - >>> from scico.blockarray import BlockArray + >>> from scico.numpy import BlockArray >>> import numpy as np >>> x0, key = scico.random.randn((32, 32)) >>> x1, _ = scico.random.randn((16,), key=key) @@ -805,7 +805,7 @@ def __init__(self, aval: _AbstractBlockArray, data: JaxArray): self._data = data def __repr__(self): - return "scico.blockarray.BlockArray: \n" + self._data.__repr__() + return "scico.numpy.BlockArray: \n" + self._data.__repr__() def __getitem__(self, idx: Union[int, Tuple[AxisIndex, ...]]) -> JaxArray: idxblk, idxarr = _decompose_index(idx) @@ -866,11 +866,11 @@ def __rfloordiv__(a, b): @_block_array_binary_op_wrapper def __pow__(a, b): - return a ** b + return a**b @_block_array_binary_op_wrapper def __rpow__(a, b): - return b ** a + return b**a @_block_array_binary_op_wrapper def __gt__(a, b): diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 408946805..ab94244a5 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -13,7 +13,7 @@ import scico from scico import numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray diff --git a/scico/functional/_indicator.py b/scico/functional/_indicator.py index 0fcd98330..21ab0f69b 100644 --- a/scico/functional/_indicator.py +++ b/scico/functional/_indicator.py @@ -12,7 +12,7 @@ import jax from scico import numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.numpy.linalg import norm from scico.typing import JaxArray diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 20dc1eaf7..ba77a4d3e 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -12,9 +12,7 @@ from jax import jit, lax from scico import numpy as snp -from scico.array import no_nan_divide -from scico.blockarray import BlockArray -from scico.numpy import count_nonzero +from scico.numpy import BlockArray, count_nonzero, no_nan_divide from scico.numpy.linalg import norm from scico.typing import JaxArray diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index 9520ec9a9..f21219dec 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -18,7 +18,7 @@ import scico.numpy as snp from scico._generic_operators import Operator -from scico.array import is_nested +from scico.numpy import is_nested from scico.typing import DType, JaxArray, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index dc03b549c..d2f95c946 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -18,7 +18,7 @@ import scico.numpy as snp from scico import array, blockarray from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.random import randn from scico.typing import ArrayIndex, BlockShape, DType, JaxArray, PRNGKey, Shape diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 8b09642ca..3cbd243ab 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -16,7 +16,7 @@ import numpy as np import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 256e93311..67cf100fd 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -60,8 +60,8 @@ import jax import scico.numpy as snp -from scico.array import no_nan_divide from scico.linop import Diagonal, Identity, LinearOperator +from scico.numpy import no_nan_divide from scico.typing import Shape from ._dft import DFT diff --git a/scico/loss.py b/scico/loss.py index 2c4e98c40..5d001ea08 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -17,8 +17,8 @@ import scico.numpy as snp from scico import functional, linop, operator -from scico.array import ensure_on_device, no_nan_divide -from scico.blockarray import BlockArray +from scico.array import ensure_on_device +from scico.numpy import BlockArray, no_nan_divide from scico.scipy.special import gammaln from scico.solver import cg from scico.typing import JaxArray diff --git a/scico/metric.py b/scico/metric.py index 94bbe265b..6b1b28665 100644 --- a/scico/metric.py +++ b/scico/metric.py @@ -14,7 +14,7 @@ import numpy as np import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index b974072ba..042dc6e12 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -5,11 +5,11 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -""":class:`scico.blockarray.BlockArray`-compatible +""":class:`scico.numpy.BlockArray`-compatible versions of :mod:`jax.numpy` functions. This modules consists of functions from :mod:`jax.numpy` wrapped to -support compatibility with :class:`scico.blockarray.BlockArray`. This +support compatibility with :class:`scico.numpy.BlockArray`. This module is a work in progress and therefore not all functions are wrapped. Functions that have not been wrapped yet have WARNING text in their documentation, below. @@ -20,13 +20,15 @@ from . import _util from .blockarray import BlockArray +from .util import * # wrap jnp _util.wrap_attributes( to_dict=vars(), from_dict=jnp.__dict__, modules_to_recurse=("linalg", "fft"), - reductions=("sum",), + reductions=("sum", "norm"), + no_wrap=("dtype"), ) # wrap np.testing diff --git a/scico/numpy/_create.py b/scico/numpy/_create.py index 8478d71d6..9f51f14f3 100644 --- a/scico/numpy/_create.py +++ b/scico/numpy/_create.py @@ -14,8 +14,7 @@ import jax from jax import numpy as jnp -from scico.array import is_nested -from scico.blockarray import BlockArray +from scico.numpy import BlockArray, is_nested from scico.typing import BlockShape, DType, JaxArray, Shape diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index 99106462c..c61ce4d4f 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -10,9 +10,8 @@ import jax.numpy as jnp -from scico.array import is_nested - from .blockarray import BlockArray +from .util import is_nested def wrap_attributes( @@ -20,6 +19,7 @@ def wrap_attributes( from_dict: dict, modules_to_recurse: Optional[Iterable[str]] = None, reductions: Optional[Iterable[str]] = None, + no_wrap: Optional[Iterable[str]] = None, ): """Add attributes in `from_dict` to `to_dict`. @@ -28,6 +28,8 @@ def wrap_attributes( `modules_to_recurse`, which are added recursively. All others are passed through unwrapped. + no_warp: list of functions to attach unwrapped + """ if modules_to_recurse is None: @@ -36,6 +38,9 @@ def wrap_attributes( if reductions is None: reductions = () + if no_wrap is None: + no_wrap = () + for name, obj in from_dict.items(): if name[0] == "_": continue @@ -45,8 +50,10 @@ def wrap_attributes( to_dict[name].__doc__ = obj.__doc__ # enable `import scico.numpy.linalg` and `from scico.numpy.linalg import norm` sys.modules[to_dict["__name__"] + "." + name] = to_dict[name] - wrap_attributes(to_dict[name].__dict__, obj.__dict__) - elif isinstance(obj, Callable): + wrap_attributes( + to_dict[name].__dict__, obj.__dict__, reductions=reductions, no_wrap=no_wrap + ) + elif isinstance(obj, Callable) and name not in no_wrap: obj = map_func_over_ba(obj, is_reduction=name in reductions) to_dict[name] = obj else: @@ -76,6 +83,14 @@ def mapped(*args, **kwargs): if not len(ba_args): # no BlockArray arguments return func(*args, **kwargs) + if is_reduction and "axis" not in bound_args.arguments: + if len(ba_args) > 1: + raise ValueError( + "Cannot perform a full reduction with multiple BlockArray arguments." + ) + # fully ravel the ba argument + ba_args = {k: [jnp.concatenate(v.ravel())] for k, v in ba_args.items()} + result = tuple( map( # map over lambda *args: ( # lambda x_1, x_2, ..., x_N @@ -89,17 +104,6 @@ def mapped(*args, **kwargs): ) ) - if is_reduction and "axis" not in bound_args.arguments: - if len(ba_args) > 1: - raise ValueError( - "Cannot perform a full reduction with multiple BlockArray arguments." - ) - return func( - *bound_args.args, - **bound_args.kwargs, - **{list(ba_args.keys())[0]: jnp.stack(result)}, - ) - if isinstance(result[0], jnp.ndarray): # True for abstract arrays, too return BlockArray(result) diff --git a/scico/numpy/_util_old.py b/scico/numpy/_util_old.py index af38cdb1e..6e81ccd10 100644 --- a/scico/numpy/_util_old.py +++ b/scico/numpy/_util_old.py @@ -20,7 +20,7 @@ _NOT_IMPLEMENTED_DESC = """ **WARNING**: This function is not yet implemented by :mod:`scico.numpy` and -may raise an error when operating on :class:`scico.blockarray.BlockArray`. +may raise an error when operating on :class:`scico.numpy.BlockArray`. """ diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 26a0d83e3..df6c4b619 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -1,9 +1,609 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2020-2022 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SPORCO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r"""Block array class. + + .. testsetup:: + + >>> import scico + >>> import scico.numpy as snp + >>> from scico.numpy import BlockArray + >>> import numpy as np + >>> import jax.numpy + +The class :class:`.BlockArray` provides a way to combine several arrays +of different shapes and/or data types into a single object. +A :class:`.BlockArray` consists of a list of `DeviceArray` +objects, which we refer to as blocks. +:class:`.BlockArray`s differ from tuples in that mathematical operations +on :class:`.BlockArray`s automatically map along the blocks, returning +another :class:`.BlockArray` or tuple as appropriate. For example, + + :: + + >>> x = BlockArray(( + snp.array( + [[1, 3, 7], + [2, 2, 1],] + ), + snp.array( + [2, 4, 8] + ), + )) + >>> x.shape + ((2, 3), (3,)) # tuple + + >>> x + 1 + (DeviceArray([[2, 4, 8], + [3, 3, 2]], dtype=int32), + DeviceArray([3, 5, 9], dtype=int32)) # BlockArray + + +TODO: not specifying axis to get a full reduction +TODO: using a BlockArray for axis or shape arguments + + + + +Motivating Example +================== + +Consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. + +We compute the discrete differences of :math:`\mb{x}` in the horizontal +and vertical directions, generating two new arrays: :math:`\mb{x}_h \in +\mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) +\times m}`. + +As these arrays are of different shapes, we cannot combine them into a +single `ndarray`. Instead, we might vectorize each array and concatenate +the resulting vectors, leading to :math:`\mb{\bar{x}} \in +\mathbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional +`ndarray`. Unfortunately, this makes it hard to access the individual +components :math:`\mb{x}_h` and :math:`\mb{x}_v`. + +Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B = +[\mb{x}_h, \mb{x}_v]` + + + :: + + >>> n = 32 + >>> m = 16 + >>> x_h, key = scico.random.randn((n, m-1)) + >>> x_v, _ = scico.random.randn((n-1, m), key=key) + + # Form the blockarray + >>> x_B = BlockArray.array([x_h, x_v]) + + # The blockarray shape is a tuple of tuples + >>> x_B.shape + ((32, 15), (31, 16)) + + # Each block component can be easily accessed + >>> x_B[0].shape + (32, 15) + >>> x_B[1].shape + (31, 16) + + +Constructing a BlockArray +========================= + +Construct from a tuple of arrays (either `ndarray` or `DeviceArray`) +-------------------------------------------------------------------- + + .. doctest:: + + >>> from scico.numpy import BlockArray + >>> import numpy as np + >>> x0, key = scico.random.randn((32, 32)) + >>> x1, _ = scico.random.randn((16,), key=key) + >>> X = BlockArray.array((x0, x1)) + >>> X.shape + ((32, 32), (16,)) + >>> X.size + 1040 + >>> X.num_blocks + 2 + +While :func:`.BlockArray.array` will accept either `ndarray` or +`DeviceArray` as input, the resulting :class:`.BlockArray` will be backed +by a `DeviceArray` memory buffer. + +**Note**: constructing a :class:`.BlockArray` always involves a copy to +a new `DeviceArray` memory buffer. + +**Note**: by default, the resulting :class:`.BlockArray` is cast to +single precision and will have dtype ``float32`` or ``complex64``. + + +Construct from a single vector and tuple of shapes +-------------------------------------------------- + + :: + + >>> x_flat, _ = scico.random.randn((1040,)) + >>> shape_tuple = ((32, 32), (16,)) + >>> X = BlockArray.array_from_flattened(x_flat, shape_tuple=shape_tuple) + >>> X.shape + ((32, 32), (16,)) + + + +Operating on a BlockArray +========================= + +.. _blockarray_indexing: + +Indexing +-------- + +The block index is required to be an integer, selecting a single block and +returning it as an array (*not* a singleton BlockArray). If the index +expression has more than one component, then the initial index indexes the +block, and the remainder of the indexing expression indexes within the +selected block, e.g. `x[2, 3:4]` is equivalent to `y[3:4]` after +setting `y = x[2]`. + + +Indexed Updating +---------------- + +BlockArrays support the JAX DeviceArray `indexed update syntax +`_ + + +The index must be of the form [ibk] or [ibk, idx], where `ibk` is the +index of the block to be updated, and `idx` is a general index of the +elements to be updated in that block. In particular, `ibk` cannot be a +`slice`. The general index `idx` can be omitted, in which case an entire +block is updated. + + +============================== ============================================== +Alternate syntax Equivalent in-place expression +============================== ============================================== +`x.at[ibk, idx].set(y)` `x[ibk, idx] = y` +`x.at[ibk, idx].add(y)` `x[ibk, idx] += y` +`x.at[ibk, idx].multiply(y)` `x[ibk, idx] *= y` +`x.at[ibk, idx].divide(y)` `x[ibk, idx] /= y` +`x.at[ibk, idx].power(y)` `x[ibk, idx] **= y` +`x.at[ibk, idx].min(y)` `x[ibk, idx] = np.minimum(x[idx], y)` +`x.at[ibk, idx].max(y)` `x[ibk, idx] = np.maximum(x[idx], y)` +============================== ============================================== + + +Arithmetic and Broadcasting +--------------------------- + +Suppose :math:`\mb{x}` is a BlockArray with shape :math:`((n, n), (m,))`. + + :: + + >>> x1, key = scico.random.randn((4, 4)) + >>> x2, _ = scico.random.randn((5,), key=key) + >>> x = BlockArray.array( (x1, x2) ) + >>> x.shape + ((4, 4), (5,)) + >>> x.num_blocks + 2 + >>> x.size # 4*4 + 5 + 21 + +Illustrated for the operation `+`, but equally valid for operators +`+, -, *, /, //, **, <, <=, >, >=, ==` + + +Operations with BlockArrays with same number of blocks +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Let :math:`\mb{y}` be a BlockArray with the same number of blocks as +:math:`\mb{x}`. + + .. math:: + \mb{x} + \mb{y} + = + \begin{bmatrix} + \mb{x}[0] + \mb{y}[0] \\ + \mb{x}[1] + \mb{y}[1] \\ + \end{bmatrix} + +This operation depends on pair of blocks from :math:`\mb{x}` and +:math:`\mb{y}` being broadcastable against each other. + + + +Operations with a scalar +^^^^^^^^^^^^^^^^^^^^^^^^ + +The scalar is added to each element of the :class:`.BlockArray`: + + .. math:: + \mb{x} + 1 + = + \begin{bmatrix} + \mb{x}[0] + 1 \\ + \mb{x}[1] + 1\\ + \end{bmatrix} + + + :: + + >>> y = x + 1 + >>> np.testing.assert_allclose(y[0], x[0] + 1) + >>> np.testing.assert_allclose(y[1], x[1] + 1) + + + +Operations with a 1D `ndarray` of size equal to `num_blocks` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The *i*\th scalar is added to the *i*\th block of the +:class:`.BlockArray`: + + .. math:: + \mb{x} + + + \begin{bmatrix} + 1 \\ + 2 + \end{bmatrix} + = + \begin{bmatrix} + \mb{x}[0] + 1 \\ + \mb{x}[1] + 2\\ + \end{bmatrix} + + + :: + + >>> y = x + np.array([1, 2]) + >>> np.testing.assert_allclose(y[0], x[0] + 1) + >>> np.testing.assert_allclose(y[1], x[1] + 2) + + +Operations with an ndarray of `size` equal to :class:`.BlockArray` size +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We first cast the `ndarray` to a BlockArray with same shape as +:math:`\mb{x}`, then apply the operation on the resulting BlockArrays. +With `y.size = x.size`, we have: + + .. math:: + \mb{x} + + + \mb{y} + = + \begin{bmatrix} + \mb{x}[0] + \mb{y}[0] \\ + \mb{x}[1] + \mb{y}[1]\\ + \end{bmatrix} + +Equivalently, the BlockArray is first flattened, then added to the +flattened `ndarray`, and the result is reformed into a BlockArray with +the same shape as :math:`\mb{x}` + + + +MatMul +------ + +Between two BlockArrays +^^^^^^^^^^^^^^^^^^^^^^^ + +The matmul is computed between each block of the two BlockArrays. + +The BlockArrays must have the same number of blocks, and each pair of +blocks must be broadcastable. + + .. math:: + \mb{x} @ \mb{y} + = + \begin{bmatrix} + \mb{x}[0] @ \mb{y}[0] \\ + \mb{x}[1] @ \mb{y}[1]\\ + \end{bmatrix} + + + +Between BlockArray and Ndarray/DeviceArray +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This operation is not defined. + + +Between BlockArray and :class:`.LinearOperator` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :class:`.Operator` and :class:`.LinearOperator` classes are designed +to work on :class:`.BlockArray`\ s in addition to `DeviceArray`\ s. +For example + + + :: + + >>> x, key = scico.random.randn((3, 4)) + >>> A_1 = scico.linop.Identity(x.shape) + >>> A_1.shape # array -> array + ((3, 4), (3, 4)) + + >>> A_2 = scico.linop.FiniteDifference(x.shape) + >>> A_2.shape # array -> BlockArray + (((2, 4), (3, 3)), (3, 4)) + + >>> diag = BlockArray.array([np.array(1.0), np.array(2.0)]) + >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape)) + >>> A_3.shape # BlockArray -> BlockArray + (((2, 4), (3, 3)), ((2, 4), (3, 3))) + + +NumPy ufuncs +------------ + +`NumPy universal functions (ufuncs) `_ +are functions that operate on an `ndarray` on an element-by-element +fashion and support array broadcasting. Examples of ufuncs are `abs`, +`sign`, `conj`, and `exp`. + +The JAX library implements most NumPy ufuncs in the :mod:`jax.numpy` +module. However, as JAX does not support subclassing of `DeviceArray`, +the JAX ufuncs cannot be used on :class:`.BlockArray`. As a workaround, +we have wrapped several JAX ufuncs for use on :class:`.BlockArray`; these +are defined in the :mod:`scico.numpy` module. + + +Reductions +^^^^^^^^^^ + +Reductions are functions that take an array-like as an input and return +an array of lower dimension. Examples include `mean`, `sum`, `norm`. +BlockArray reductions are located in the :mod:`scico.numpy` module + +:class:`.BlockArray` tries to mirror `ndarray` reduction semantics where +possible, but cannot provide a one-to-one match as the block components +may be of different size. + +Consider the example BlockArray + + .. math:: + \mb{x} = \begin{bmatrix} + \begin{bmatrix} + 1 & 1 \\ + 1 & 1 + \end{bmatrix} \\ + \begin{bmatrix} + 2 \\ + 2 + \end{bmatrix} + \end{bmatrix}. + +We have + + .. doctest:: + + >>> import scico.numpy as snp + >>> x = BlockArray.array((np.ones((2,2)), 2*np.ones((2)))) + >>> x.shape + ((2, 2), (2,)) + >>> x.size + 6 + >>> x.num_blocks + 2 + + + + If no axis is specified, the reduction is applied to the flattened + array: + + .. doctest:: + + >>> snp.sum(x, axis=None).item() + 8.0 + + Reducing along the 0-th axis crushes the `BlockArray` down into a + single `DeviceArray` and requires all blocks to have the same shape + otherwise, an error is raised. + + .. doctest:: + + >>> snp.sum(x, axis=0) + Traceback (most recent call last): + ValueError: Evaluating sum of BlockArray along axis=0 requires all blocks to be same shape; got ((2, 2), (2,)) + + >>> y = BlockArray.array((np.ones((2,2)), 2*np.ones((2, 2)))) + >>> snp.sum(y, axis=0) + DeviceArray([[3., 3.], + [3., 3.]], dtype=float32) + + Reducing along axis :math:`n` is equivalent to reducing each component + along axis :math:`n-1`: + + .. math:: + \text{sum}(x, axis=1) = \begin{bmatrix} + \begin{bmatrix} + 2 \\ + 2 + \end{bmatrix} \\ + \begin{bmatrix} + 4 \\ + \end{bmatrix} + \end{bmatrix} + + + If a component does not have axis :math:`n-1`, the reduction is not + applied to that component. In this example, `x[1].ndim == 1`, so no + reduction is applied to block `x[1]`. + + .. math:: + \text{sum}(x, axis=2) = \begin{bmatrix} + \begin{bmatrix} + 2 \\ + 2 + \end{bmatrix} \\ + \begin{bmatrix} + 2 \\ + 2 + \end{bmatrix} + \end{bmatrix} + + +Code version + + .. doctest:: + + >>> snp.sum(x, axis=1) # doctest: +SKIP + BlockArray([[2, 2], + [4,] ]) + + >>> snp.sum(x, axis=2) # doctest: +SKIP + BlockArray([ [2, 2], + [2,] ]) + + """ -BlockArray. -""" + +import inspect +from functools import wraps +from typing import Callable + +import jax +import jax.numpy as jnp + +from jaxlib.xla_extension import DeviceArray class BlockArray(list): - """BlockArray.""" + """BlockArray""" + + # Ensure we use BlockArray.__radd__, __rmul__, etc for binary + # operations of the form op(np.ndarray, BlockArray) See + # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ + __array_priority__ = 1 + + def _full_ravel(self) -> DeviceArray: + """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. + + Note that a copy, rather than a view, of the underlying array is + returned. This is consistent with :func:`jax.numpy.ravel`. + + Returns: + Copy of underlying flattened array. + + """ + return jnp.concatenate(tuple(x_i.ravel() for x_i in self)) + + """ backwards compatibility methods, could be removed """ + + @staticmethod + def array(iterable): + """Construct a :class:`.BlockArray` from a list or tuple of existing array-like.""" + return BlockArray(iterable) + + +# Register BlockArray as a jax pytree, without this, jax autograd won't work. +# taken from what is done with tuples in jax._src.tree_util +jax.tree_util.register_pytree_node( + BlockArray, + lambda xs: (xs, None), # to iter + lambda _, xs: BlockArray(xs), # from iter +) + +""" Wrap binary ops like +, @. """ +binary_ops = ( + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__matmul__", + "__rmatmul__", + "__truediv__", + "__rtruediv__", + "__floordiv__", + "__rfloordiv__", + "__pow__", + "__rpow__", + "__gt__", + "__ge__", + "__lt__", + "__le__", + "__eq__", + "__ne__", +) + + +def _binary_op_wrapper(op): + @wraps(op) + def op_ba(self, other): + if isinstance(other, BlockArray): + result = BlockArray(getattr(x, op)(y) for x, y in zip(self, other)) + else: + result = BlockArray(getattr(x, op)(other) for x in self) + if NotImplemented in result: + return NotImplemented + return result + + return op_ba + + +for op in binary_ops: + setattr(BlockArray, op, _binary_op_wrapper(op)) + + +""" Wrap DeviceArray properties. """ + + +def _da_prop_wrapper(prop): + @property + @wraps(prop) + def prop_ba(self): + result = tuple(getattr(x, prop) for x in self) + if isinstance(result[0], jnp.ndarray): + return BlockArray(result) + return result + + return prop_ba + + +skip_props = ("at",) +da_props = [ + k + for k, v in dict(inspect.getmembers(DeviceArray)).items() + if isinstance(v, property) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_props +] + +for prop in da_props: + setattr(BlockArray, prop, _da_prop_wrapper(prop)) + +""" Wrap DeviceArray methods. """ + + +def _da_method_wrapper(method): + @wraps(method) + def method_ba(self, *args, **kwargs): + result = tuple(getattr(x, method)(*args, **kwargs) for x in self) + if isinstance(result[0], jnp.ndarray): + return BlockArray(result) + return result + + return method_ba + + +skip_methods = () +da_methods = [ + k + for k, v in dict(inspect.getmembers(DeviceArray)).items() + if isinstance(v, Callable) + and k[0] != "_" + and k not in dir(BlockArray) + and k not in skip_methods +] - ... +for method in da_methods: + setattr(BlockArray, method, _da_method_wrapper(method)) diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index f1e54565d..bd964a216 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -13,9 +13,8 @@ from jax.scipy.signal import convolve from scico._generic_operators import LinearOperator, Operator -from scico.array import is_nested -from scico.blockarray import BlockArray from scico.linop import Convolve, ConvolveByX +from scico.numpy import BlockArray, is_nested from scico.typing import BlockShape, DType, JaxArray diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index 77261d91d..e0fbd7741 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -15,10 +15,10 @@ import scico.numpy as snp from scico.array import ensure_on_device -from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import LinearOperator +from scico.numpy import BlockArray from scico.numpy.linalg import norm from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 23aaa609c..7b006e2c4 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -15,10 +15,10 @@ import scico.numpy as snp from scico.array import ensure_on_device -from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import LinearOperator +from scico.numpy import BlockArray from scico.numpy.linalg import norm from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index f9d70ae07..06cb77e2c 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -18,12 +18,12 @@ from jax.scipy.sparse.linalg import cg as jax_cg import scico.numpy as snp -from scico.array import ensure_on_device, is_real_dtype -from scico.blockarray import BlockArray +from scico.array import ensure_on_device from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import CircularConvolve, Identity, LinearOperator from scico.loss import SquaredL2Loss +from scico.numpy import BlockArray, is_real_dtype from scico.numpy.linalg import norm from scico.solver import cg as scico_cg from scico.solver import minimize diff --git a/scico/optimize/pgm.py b/scico/optimize/pgm.py index 2fe8fcc4d..3656e969b 100644 --- a/scico/optimize/pgm.py +++ b/scico/optimize/pgm.py @@ -17,10 +17,10 @@ import scico.numpy as snp from scico.array import ensure_on_device -from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.loss import Loss +from scico.numpy import BlockArray from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/random.py b/scico/random.py index c405d7304..21f48fd0f 100644 --- a/scico/random.py +++ b/scico/random.py @@ -44,7 +44,7 @@ :: x, key = scico.random.randn( ((1, 1), (2,)), key=key) - print(x) # scico.blockarray.BlockArray: + print(x) # scico.numpy.BlockArray: # DeviceArray([ 1.1378784 , -1.220955 , -0.59153646], dtype=float32) """ @@ -58,8 +58,8 @@ import jax -import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray +from scico.numpy._util import map_func_over_ba from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape @@ -121,7 +121,7 @@ def fun_alt(*args, key=None, seed=None, **kwargs): return fun_alt -_allow_block_shape = snp._map_func_over_ba +_allow_block_shape = map_func_over_ba def _wrap(fun): diff --git a/scico/scipy/special.py b/scico/scipy/special.py index f30cbcb4b..beec232f8 100644 --- a/scico/scipy/special.py +++ b/scico/scipy/special.py @@ -9,7 +9,7 @@ This modules consists of functions from :mod:`jax.scipy.special`. Some of these functions are wrapped to support compatibility with -:class:`scico.blockarray.BlockArray` and are documented here. The +:class:`scico.numpy.BlockArray` and are documented here. The remaining functions are imported directly from :mod:`jax.numpy`. While they can be imported from the :mod:`scico.numpy` namespace, they are not documented here; please consult the documentation for the source module @@ -22,7 +22,7 @@ import jax import jax.scipy.special as js -from scico.blockarray import _block_array_ufunc_wrapper +from scico.numpy import _block_array_ufunc_wrapper from scico.numpy._util import _attach_wrapped_func, _not_implemented _ufunc_functions = [ @@ -74,9 +74,9 @@ import jax.scipy.special as js -import scico.numpy as snp +from scico.numpy._util import wrap_attributes -snp._copy_attributes( +wrap_attributes( vars(), js.__dict__, ) diff --git a/scico/solver.py b/scico/solver.py index 5c261764d..b73971232 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -65,7 +65,7 @@ import jax import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import BlockShape, DType, JaxArray, Shape from scipy import optimize as spopt diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index d69087ba8..2f17fb738 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -11,6 +11,7 @@ import scico.numpy as snp from scico import functional +from scico.numpy import BlockArray from scico.random import randn NO_BLOCK_ARRAY = [functional.L21Norm, functional.NuclearNorm] @@ -127,9 +128,13 @@ def test_prox_blockarray(self, norm, alpha, test_prox_obj): nrmobj = norm() nrm = nrmobj.__call__ prx = nrmobj.prox - pf = nrmobj.prox(test_prox_obj.vb.full_ravel(), alpha) + pf = BlockArray(nrmobj.prox(x, alpha) for x in test_prox_obj.vb) pf_b = nrmobj.prox(test_prox_obj.vb, alpha) - np.testing.assert_allclose(pf, pf_b.full_ravel()) + + print(test_prox_obj.vb) + print(pf) + print(pf_b) + snp.testing.assert_allclose(pf, pf_b) @pytest.mark.parametrize("norm", normlist) def test_prox_zeros(self, norm, test_prox_obj): diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index c606fb285..8a31406cf 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -7,12 +7,10 @@ # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) - from prox import prox_test - import scico.numpy as snp from scico import functional, linop, loss -from scico.array import complex_dtype +from scico.numpy import complex_dtype from scico.random import randn, uniform diff --git a/scico/test/linop/test_diff.py b/scico/test/linop/test_diff.py index 6af05612b..09d08138f 100644 --- a/scico/test/linop/test_diff.py +++ b/scico/test/linop/test_diff.py @@ -3,8 +3,8 @@ import pytest import scico.numpy as snp -from scico.blockarray import BlockArray from scico.linop import FiniteDifference +from scico.numpy import BlockArray from scico.random import randn from scico.test.linop.test_linop import adjoint_test diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index f556bcd71..4f5b26c56 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -14,7 +14,7 @@ import scico.numpy as snp from scico import linop -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.random import randn from scico.typing import JaxArray, PRNGKey diff --git a/scico/test/optimize/test_ladmm.py b/scico/test/optimize/test_ladmm.py index b0f05328c..d5cf0370f 100644 --- a/scico/test/optimize/test_ladmm.py +++ b/scico/test/optimize/test_ladmm.py @@ -4,7 +4,7 @@ import scico.numpy as snp from scico import functional, linop, loss, random -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.optimize import LinearizedADMM diff --git a/scico/test/optimize/test_pdhg.py b/scico/test/optimize/test_pdhg.py index 4083ad5d2..5989a080c 100644 --- a/scico/test/optimize/test_pdhg.py +++ b/scico/test/optimize/test_pdhg.py @@ -4,7 +4,7 @@ import scico.numpy as snp from scico import functional, linop, loss, random -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.optimize import PDHG diff --git a/scico/test/test_array.py b/scico/test/test_array.py index 1b44b595a..c05de3fd0 100644 --- a/scico/test/test_array.py +++ b/scico/test/test_array.py @@ -7,19 +7,16 @@ import pytest import scico.numpy as snp -from scico.array import ( +from scico.array import ensure_on_device, indexed_shape, parse_axes, slice_length +from scico.numpy import ( + BlockArray, complex_dtype, - ensure_on_device, - indexed_shape, is_complex_dtype, is_nested, is_real_dtype, no_nan_divide, - parse_axes, real_dtype, - slice_length, ) -from scico.blockarray import BlockArray from scico.random import randn diff --git a/scico/test/test_biconvolve.py b/scico/test/test_biconvolve.py index 328e10f56..192dcd923 100644 --- a/scico/test/test_biconvolve.py +++ b/scico/test/test_biconvolve.py @@ -5,8 +5,8 @@ import pytest -from scico.blockarray import BlockArray from scico.linop import Convolve, ConvolveByX +from scico.numpy import BlockArray from scico.operator.biconvolve import BiConvolve from scico.random import randn diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index a856820e4..e28f8b812 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -10,7 +10,7 @@ import pytest import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.random import randn math_ops = [op.add, op.sub, op.mul, op.truediv, op.pow] # op.floordiv doesn't work on complex diff --git a/scico/test/test_blockarray_new.py b/scico/test/test_blockarray_new.py deleted file mode 100644 index d4c9afb91..000000000 --- a/scico/test/test_blockarray_new.py +++ /dev/null @@ -1,24 +0,0 @@ -import scico.numpy as snp - -# from scico.random import randn - - -# from scico.blockarray import BlockArray - -x = snp.BlockArray( - ( - snp.ones((3, 4)), - snp.arange(4), - ) -) - -y = snp.BlockArray( - ( - 2 * snp.ones((3, 4)), - snp.arange(4), - ) -) - -snp.sum(x) - -snp.testing.assert_allclose(x, y) diff --git a/scico/test/test_numpy.py b/scico/test/test_numpy.py index 495b38795..229253dd3 100644 --- a/scico/test/test_numpy.py +++ b/scico/test/test_numpy.py @@ -8,8 +8,8 @@ import scico.numpy as snp import scico.numpy._create as snc import scico.numpy.linalg as sla -from scico.blockarray import BlockArray from scico.linop import MatrixOperator +from scico.numpy import BlockArray def on_cpu(): diff --git a/scico/test/test_operator.py b/scico/test/test_operator.py index 30788d883..d02d34ac6 100644 --- a/scico/test/test_operator.py +++ b/scico/test/test_operator.py @@ -12,7 +12,7 @@ import jax import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.operator import Operator from scico.random import randn diff --git a/scico/test/test_random.py b/scico/test/test_random.py index 7e044a58a..183f0397a 100644 --- a/scico/test/test_random.py +++ b/scico/test/test_random.py @@ -85,12 +85,12 @@ def test_block_shape_adapter(): key = jax.random.PRNGKey(seed) result = fun_alt(key, shape) - assert isinstance(result, scico.blockarray.BlockArray) + assert isinstance(result, scico.numpy.BlockArray) # should work for deeply nested as well shape = ((7,), (3, (2, 1)), (2, 4, 1)) result = fun_alt(key, shape) - assert isinstance(result, scico.blockarray.BlockArray) + assert isinstance(result, scico.numpy.BlockArray) # when shape is not nested, behavior should be normal shape = (1,) diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index b7173f91d..256fceefe 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -6,7 +6,7 @@ import scico.numpy as snp from scico import random, solver -from scico.blockarray import BlockArray +from scico.numpy import BlockArray class TestSet: From ccbfd60b5e7b1e01c39582e45c5528d6902aac8d Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 8 Apr 2022 18:00:24 -0600 Subject: [PATCH 11/37] Change reductions back to ravel --- scico/numpy/_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index c61ce4d4f..98612cc42 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -89,7 +89,8 @@ def mapped(*args, **kwargs): "Cannot perform a full reduction with multiple BlockArray arguments." ) # fully ravel the ba argument - ba_args = {k: [jnp.concatenate(v.ravel())] for k, v in ba_args.items()} + ba_args = {k: jnp.concatenate(v.ravel()) for k, v in ba_args.items()} + return func(*bound_args.args, **bound_args.kwargs, **ba_args) result = tuple( map( # map over From 9af42bd4c3392c09fff39a4686dc18f2165c8682 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 12 Apr 2022 09:53:43 -0600 Subject: [PATCH 12/37] Hack away at the tests --- scico/functional/_dist.py | 2 +- scico/linop/_linop.py | 12 +-- scico/numpy/__init__.py | 9 +- scico/numpy/blockarray.py | 2 + scico/test/functional/test_core.py | 19 ++-- scico/test/functional/test_separable.py | 9 +- scico/test/test_blockarray.py | 108 ++++++++----------- scico/test/test_numpy.py | 137 +++++++++++++----------- scico/test/test_random.py | 11 +- scico/test/test_solver.py | 7 +- 10 files changed, 158 insertions(+), 158 deletions(-) diff --git a/scico/functional/_dist.py b/scico/functional/_dist.py index 5aaca6b0e..935d4c5c9 100644 --- a/scico/functional/_dist.py +++ b/scico/functional/_dist.py @@ -10,7 +10,7 @@ from typing import Callable, Union from scico import numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray from ._functional import Functional diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index d2f95c946..6a7f3c5ac 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -16,9 +16,9 @@ from typing import Any, Callable, Optional, Union import scico.numpy as snp -from scico import array, blockarray +from scico import array from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar -from scico.numpy import BlockArray +from scico.numpy import BlockArray, is_nested from scico.random import randn from scico.typing import ArrayIndex, BlockShape, DType, JaxArray, PRNGKey, Shape @@ -189,9 +189,9 @@ def __init__( if input_dtype is None: input_dtype = self.diagonal.dtype - if isinstance(diagonal, BlockArray) and array.is_nested(input_shape): + if isinstance(diagonal, BlockArray) and is_nested(input_shape): output_shape = (snp.empty(input_shape) * diagonal).shape - elif not isinstance(diagonal, BlockArray) and not array.is_nested(input_shape): + elif not isinstance(diagonal, BlockArray) and not is_nested(input_shape): output_shape = snp.broadcast_shapes(input_shape, self.diagonal.shape) elif isinstance(diagonal, BlockArray): raise ValueError(f"`diagonal` was a BlockArray but `input_shape` was not nested.") @@ -282,8 +282,8 @@ def __init__( functions of the LinearOperator. """ - if array.is_nested(input_shape): - output_shape = blockarray.indexed_shape(input_shape, idx) + if is_nested(input_shape): + output_shape = input_shape[idx] else: output_shape = array.indexed_shape(input_shape, idx) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 042dc6e12..7f43e2139 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -28,7 +28,14 @@ from_dict=jnp.__dict__, modules_to_recurse=("linalg", "fft"), reductions=("sum", "norm"), - no_wrap=("dtype"), + no_wrap=( + "dtype", + "broadcast_shapes", # nested tuples as normal input (*shapes) + "array", # no meaning mapped over blocks + "stack", # no meaning mapped over blocks + "concatenate", # no meaning mapped over blocks + "pad", + ), # nested tuples as normal input ) # wrap np.testing diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index df6c4b619..8004e6430 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -45,6 +45,8 @@ TODO: not specifying axis to get a full reduction TODO: using a BlockArray for axis or shape arguments +TODO: indexing +TODO: mention snp.testing here or in numpy diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 2f17fb738..68c06209f 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from jax.config import config @@ -66,8 +68,8 @@ def test_separable_prox(test_separable_obj): fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha) gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) - out = BlockArray.array((fv1, gv2)).full_ravel() - np.testing.assert_allclose(out, fgv.full_ravel(), rtol=5e-2) + out = BlockArray.array((fv1, gv2)) + snp.testing.assert_allclose(out, fgv, rtol=5e-2) def test_separable_grad(test_separable_obj): @@ -84,8 +86,8 @@ def test_separable_grad(test_separable_obj): fv1 = test_separable_obj.f.grad(test_separable_obj.v1) gv2 = test_separable_obj.g.grad(test_separable_obj.v2) fgv = test_separable_obj.fg.grad(test_separable_obj.vb) - out = BlockArray.array((fv1, gv2)).full_ravel() - np.testing.assert_allclose(out, fgv.full_ravel(), rtol=5e-2) + out = BlockArray.array((fv1, gv2)) + snp.testing.assert_allclose(out, fgv, rtol=5e-2) class TestNormProx: @@ -128,13 +130,10 @@ def test_prox_blockarray(self, norm, alpha, test_prox_obj): nrmobj = norm() nrm = nrmobj.__call__ prx = nrmobj.prox - pf = BlockArray(nrmobj.prox(x, alpha) for x in test_prox_obj.vb) + pf = nrmobj.prox(snp.concatenate(snp.ravel(test_prox_obj.vb)), alpha) pf_b = nrmobj.prox(test_prox_obj.vb, alpha) - print(test_prox_obj.vb) - print(pf) - print(pf_b) - snp.testing.assert_allclose(pf, pf_b) + snp.testing.assert_allclose(pf, snp.concatenate(snp.ravel(pf_b)), rtol=1e-6) @pytest.mark.parametrize("norm", normlist) def test_prox_zeros(self, norm, test_prox_obj): @@ -206,7 +205,7 @@ def test_eval(self, cls, test_prox_obj): x = func(test_prox_obj.vb) y = func(test_prox_obj.vb.ravel()) - np.testing.assert_allclose(x, y) + np.testing.assert_allclose(x, y, rtol=1e-6) # only check double precision on projections diff --git a/scico/test/functional/test_separable.py b/scico/test/functional/test_separable.py index 8160f8ee2..a59e8d487 100644 --- a/scico/test/functional/test_separable.py +++ b/scico/test/functional/test_separable.py @@ -10,7 +10,8 @@ import pytest from scico import functional -from scico.blockarray import BlockArray +from scico.numpy import BlockArray +from scico.numpy.testing import assert_allclose from scico.random import randn @@ -38,7 +39,7 @@ def test_separable_eval(test_separable_obj): fv1 = test_separable_obj.f(test_separable_obj.v1) gv2 = test_separable_obj.g(test_separable_obj.v2) fgv = test_separable_obj.fg(test_separable_obj.vb) - np.testing.assert_allclose(fv1 + gv2, fgv, rtol=5e-2) + assert_allclose(fv1 + gv2, fgv, rtol=5e-2) def test_separable_prox(test_separable_obj): @@ -47,7 +48,7 @@ def test_separable_prox(test_separable_obj): gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) out = BlockArray.array((fv1, gv2)).ravel() - np.testing.assert_allclose(out, fgv.ravel(), rtol=5e-2) + assert_allclose(out, fgv.ravel(), rtol=5e-2) def test_separable_grad(test_separable_obj): @@ -65,4 +66,4 @@ def test_separable_grad(test_separable_obj): gv2 = test_separable_obj.g.grad(test_separable_obj.v2) fgv = test_separable_obj.fg.grad(test_separable_obj.vb) out = BlockArray.array((fv1, gv2)).ravel() - np.testing.assert_allclose(out, fgv.ravel(), rtol=5e-2) + assert_allclose(out, fgv.ravel(), rtol=5e-2) diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index e28f8b812..c5d53a1c1 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -63,18 +63,18 @@ def test_operator_obj(request): def test_operator_left(test_operator_obj, operator): scalar = test_operator_obj.scalar a = test_operator_obj.a - x = operator(scalar, a).full_ravel() - y = operator(scalar, a.full_ravel()) - np.testing.assert_allclose(x, y) + x = operator(scalar, a) + y = BlockArray(operator(scalar, a_i) for a_i in a) + snp.testing.assert_allclose(x, y) @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_operator_right(test_operator_obj, operator): scalar = test_operator_obj.scalar a = test_operator_obj.a - x = operator(a, scalar).full_ravel() - y = operator(a.full_ravel(), scalar) - np.testing.assert_allclose(x, y) + x = operator(a, scalar) + y = BlockArray(operator(a_i, scalar) for a_i in a) + snp.testing.assert_allclose(x, y) # Operations between a blockarray and a flat DeviceArray @@ -84,9 +84,9 @@ def test_operator_right(test_operator_obj, operator): def test_ba_da_left(test_operator_obj, operator): flat_da = test_operator_obj.flat_da a = test_operator_obj.a - x = operator(flat_da, a).full_ravel() - y = operator(flat_da, a.full_ravel()) - np.testing.assert_allclose(x, y, rtol=5e-5) + x = operator(flat_da, a) + y = BlockArray(operator(flat_da, a_i) for a_i in a) + snp.testing.assert_allclose(x, y, rtol=5e-5) @pytest.mark.skip # see previous @@ -94,8 +94,8 @@ def test_ba_da_left(test_operator_obj, operator): def test_ba_da_right(test_operator_obj, operator): flat_da = test_operator_obj.flat_da a = test_operator_obj.a - x = operator(a, flat_da).full_ravel() - y = operator(a.full_ravel(), flat_da) + x = operator(a, flat_da) + y = BlockArray(operator(a_i, flat_da) for a_i in a) np.testing.assert_allclose(x, y) @@ -107,9 +107,9 @@ def test_ndarray_left(test_operator_obj, operator): a = test_operator_obj.a block_nd = test_operator_obj.block_nd - x = operator(a, block_nd).full_ravel() - y = BlockArray([operator(a[i], block_nd[i]) for i in range(len(a))]).full_ravel() - np.testing.assert_allclose(x, y) + x = operator(a, block_nd) + y = BlockArray([operator(a[i], block_nd[i]) for i in range(len(a))]) + snp.testing.assert_allclose(x, y) @pytest.mark.skip # see previous @@ -118,9 +118,9 @@ def test_ndarray_right(test_operator_obj, operator): a = test_operator_obj.a block_nd = test_operator_obj.block_nd - x = operator(block_nd, a).full_ravel() - y = BlockArray([operator(block_nd[i], a[i]) for i in range(len(a))]).full_ravel() - np.testing.assert_allclose(x, y) + x = operator(block_nd, a) + y = BlockArray([operator(block_nd[i], a[i]) for i in range(len(a))]) + snp.testing.assert_allclose(x, y) # Blockwise comparison between a BlockArray and DeviceArray @@ -130,9 +130,9 @@ def test_devicearray_left(test_operator_obj, operator): a = test_operator_obj.a block_da = test_operator_obj.block_da - x = operator(a, block_da).full_ravel() - y = BlockArray([operator(a[i], block_da[i]) for i in range(len(a))]).full_ravel() - np.testing.assert_allclose(x, y) + x = operator(a, block_da) + y = BlockArray([operator(a[i], block_da[i]) for i in range(len(a))]) + snp.testing.assert_allclose(x, y) @pytest.mark.skip # see previous @@ -141,9 +141,9 @@ def test_devicearray_right(test_operator_obj, operator): a = test_operator_obj.a block_da = test_operator_obj.block_da - x = operator(block_da, a).full_ravel() - y = BlockArray([operator(block_da[i], a[i]) for i in range(len(a))]).full_ravel() - np.testing.assert_allclose(x, y, atol=1e-7, rtol=0) + x = operator(block_da, a) + y = BlockArray([operator(block_da[i], a[i]) for i in range(len(a))]) + snp.testing.assert_allclose(x, y, atol=1e-7, rtol=0) # Operations between two blockarrays of same size @@ -151,9 +151,9 @@ def test_devicearray_right(test_operator_obj, operator): def test_ba_ba_operator(test_operator_obj, operator): a = test_operator_obj.a b = test_operator_obj.b - x = operator(a, b).full_ravel() - y = operator(a.full_ravel(), b.full_ravel()) - np.testing.assert_allclose(x, y) + x = operator(a, b) + y = BlockArray(operator(a_i, b_i) for a_i, b_i in zip(a, b)) + snp.testing.assert_allclose(x, y) # Testing the @ interface for blockarrays of same size, and a blockarray and flattened ndarray/devicearray @@ -171,7 +171,7 @@ def test_ba_ba_matmul(test_operator_obj): y = BlockArray([a0 @ d0, a1 @ d1]) assert x.shape == y.shape - np.testing.assert_allclose(x.full_ravel(), y.full_ravel()) + snp.testing.assert_allclose(x, y) with pytest.raises(TypeError): z = a @ c @@ -182,23 +182,21 @@ def test_conj(test_operator_obj): ac = a.conj() assert a.shape == ac.shape - np.testing.assert_allclose(a.full_ravel().conj(), ac.full_ravel()) + snp.testing.assert_allclose(BlockArray(a_i.conj() for a_i in a), ac) def test_real(test_operator_obj): a = test_operator_obj.a ac = a.real - assert a.shape == ac.shape - np.testing.assert_allclose(a.full_ravel().real, ac.full_ravel()) + snp.testing.assert_allclose(BlockArray(a_i.real for a_i in a), ac) def test_imag(test_operator_obj): a = test_operator_obj.a ac = a.imag - assert a.shape == ac.shape - np.testing.assert_allclose(a.full_ravel().imag, ac.full_ravel()) + snp.testing.assert_allclose(BlockArray(a_i.imag for a_i in a), ac) def test_ndim(test_operator_obj): @@ -271,11 +269,10 @@ def test_blockarray_from_one_array(): def test_sum_method(test_operator_obj, axis, keepdims): a = test_operator_obj.a - method_result = a.sum(axis=axis, keepdims=keepdims).full_ravel() - snp_result = snp.sum(a, axis=axis, keepdims=keepdims).full_ravel() + method_result = a.sum(axis=axis, keepdims=keepdims) + snp_result = snp.sum(a, axis=axis, keepdims=keepdims) - assert method_result.shape == snp_result.shape - np.testing.assert_allclose(method_result, snp_result) + snp.testing.assert_allclose(method_result, snp_result) @pytest.mark.skip() @@ -305,29 +302,18 @@ def test_ba_ba_dot(test_operator_obj, operator): x = operator(a, d) y = BlockArray([operator(a0, d0), operator(a1, d1)]) - np.testing.assert_allclose(x.full_ravel(), y.full_ravel()) + snp.testing.assert_allclose(x, y) ############################################################################### # Reduction tests ############################################################################### reduction_funcs = [ - snp.count_nonzero, snp.sum, snp.linalg.norm, - snp.mean, - snp.var, - snp.max, - snp.min, - snp.amin, - snp.amax, - snp.all, - snp.any, ] -real_reduction_funcs = [ - snp.median, -] +real_reduction_funcs = [] class BlockArrayReductionObj: @@ -366,7 +352,7 @@ def reduction_obj(request): def test_reduce(reduction_obj, func): x = func(reduction_obj.a) x_jit = jax.jit(func)(reduction_obj.a) - y = func(reduction_obj.a.full_ravel()) + y = func(snp.concatenate(snp.ravel(reduction_obj.a))) np.testing.assert_allclose(x, x_jit, rtol=1e-6) # test jitted function np.testing.assert_allclose(x, y, rtol=1e-6) # test for correctness @@ -400,30 +386,28 @@ def test_reduce_axis(reduction_obj, func, axis): x = f(reduction_obj.a) x_jit = jax.jit(f)(reduction_obj.a) - np.testing.assert_allclose( - x.full_ravel(), x_jit.full_ravel(), rtol=1e-4 - ) # test jitted function + snp.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function # test for correctness y0 = func(reduction_obj.a[0], axis=axis) y1 = func(reduction_obj.a[1], axis=axis) y = BlockArray((y0, y1)) - np.testing.assert_allclose(x.full_ravel(), y.full_ravel()) + snp.testing.assert_allclose(x, y) @pytest.mark.parametrize(**REDUCTION_PARAMS) def test_reduce_singleton(reduction_obj, func): # Case where one block is reduced to a singleton - f = lambda x: func(x, axis=0).full_ravel() + f = lambda x: func(x, axis=0) x = f(reduction_obj.c) x_jit = jax.jit(f)(reduction_obj.c) - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function + snp.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function y0 = func(reduction_obj.c[0], axis=0) y1 = func(reduction_obj.c[1], axis=0)[None] # Ensure size (1,) - y = BlockArray((y0, y1), dtype=reduction_obj.a[0].dtype).full_ravel() - np.testing.assert_allclose(x, y) + y = BlockArray((y0, y1)) + snp.testing.assert_allclose(x, y) class TestCreators: @@ -482,7 +466,7 @@ class NestedTestObj: def __init__(self, dtype): key = None scalar, key = randn(shape=(1,), dtype=dtype, key=key) - self.scalar = scalar.copy().full_ravel()[0] # convert to float + self.scalar = scalar.item() # convert to float self.a00, key = randn(shape=(2, 2, 2), dtype=dtype, key=key) self.a01, key = randn(shape=(3, 2, 4), dtype=dtype, key=key) @@ -511,9 +495,9 @@ def test_nested_shape(nested_obj): assert a[0].shape == ((2, 2, 2), (3, 2, 4)) assert a[1].shape == (2, 4) - np.testing.assert_allclose(a[0][0].full_ravel(), a00.full_ravel()) - np.testing.assert_allclose(a[0][1].full_ravel(), a01.full_ravel()) - np.testing.assert_allclose(a[1].full_ravel(), a1.full_ravel()) + snp.testing.assert_allclose(a[0][0], a00) + snp.testing.assert_allclose(a[0][1], a01) + snp.testing.assert_allclose(a[1], a1) # basic test for block_sizes assert a.shape == (a[0].size, a[1].size) diff --git a/scico/test/test_numpy.py b/scico/test/test_numpy.py index 229253dd3..47e62fe93 100644 --- a/scico/test/test_numpy.py +++ b/scico/test/test_numpy.py @@ -6,7 +6,6 @@ import pytest import scico.numpy as snp -import scico.numpy._create as snc import scico.numpy.linalg as sla from scico.linop import MatrixOperator from scico.numpy import BlockArray @@ -37,6 +36,7 @@ def test_reshape_array(): np.testing.assert_allclose(snp.reshape(a.ravel(), (4, 4)), a) +@pytest.mark.skip # no reshaping into a BlockArray def test_reshape_array(): a = np.random.randn(13) b = snp.reshape(a, ((3, 3), (4,))) @@ -48,6 +48,7 @@ def test_reshape_array(): np.testing.assert_allclose(b.ravel(), c.ravel()) +@pytest.mark.skip # do we care to support svd of matrix operator? @pytest.mark.parametrize("compute_uv", [True, False]) @pytest.mark.parametrize("full_matrices", [True, False]) @pytest.mark.parametrize("shape", [(8, 8), (4, 8), (8, 4)]) @@ -58,6 +59,7 @@ def test_svd(compute_uv, full_matrices, shape): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support cond of matrix operator? def test_cond(): A = jax.device_put(np.random.randn(8, 8)) Ao = MatrixOperator(A) @@ -65,6 +67,7 @@ def test_cond(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support det of matrix operator? def test_det(): A = jax.device_put(np.random.randn(8, 8)) Ao = MatrixOperator(A) @@ -72,6 +75,7 @@ def test_det(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support eig of matrix operator? @pytest.mark.skipif( on_cpu() == False, reason="nonsymmetric eigendecompositions only supported on cpu" ) @@ -82,6 +86,7 @@ def test_eig(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support eigh of matrix operator? @pytest.mark.parametrize("symmetrize", [True, False]) @pytest.mark.parametrize("UPLO", [None, "L", "U"]) def test_eigh(UPLO, symmetrize): @@ -92,6 +97,7 @@ def test_eigh(UPLO, symmetrize): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? @pytest.mark.skipif( on_cpu() == False, reason="nonsymmetric eigendecompositions only supported on cpu" ) @@ -102,6 +108,7 @@ def test_eigvals(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? @pytest.mark.parametrize("UPLO", [None, "L", "U"]) def test_eigvalsh(UPLO): A = jax.device_put(np.random.randn(8, 8)) @@ -111,6 +118,7 @@ def test_eigvalsh(UPLO): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? def test_inv(): A = jax.device_put(np.random.randn(8, 8)) Ao = MatrixOperator(A) @@ -118,6 +126,7 @@ def test_inv(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? def test_lstsq(): A = jax.device_put(np.random.randn(8, 8)) b = jax.device_put(np.random.randn(8)) @@ -126,6 +135,7 @@ def test_lstsq(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? def test_matrix_power(): A = jax.device_put(np.random.randn(8, 8)) Ao = MatrixOperator(A) @@ -133,6 +143,7 @@ def test_matrix_power(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? def test_matrixrank(): A = jax.device_put(np.random.randn(8, 8)) Ao = MatrixOperator(A) @@ -140,6 +151,7 @@ def test_matrixrank(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? @pytest.mark.parametrize("rcond", [None, 1e-3]) def test_pinv(rcond): A = jax.device_put(np.random.randn(8, 8)) @@ -148,6 +160,7 @@ def test_pinv(rcond): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? @pytest.mark.parametrize("rcond", [None, 1e-3]) def test_pinv(rcond): A = jax.device_put(np.random.randn(8, 8)) @@ -156,6 +169,7 @@ def test_pinv(rcond): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? @pytest.mark.parametrize("shape", [(8, 8), (4, 8), (8, 4)]) @pytest.mark.parametrize("mode", ["reduced", "complete", "r"]) def test_qr(shape, mode): @@ -165,6 +179,7 @@ def test_qr(shape, mode): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? def test_slogdet(): A = jax.device_put(np.random.randn(8, 8)) Ao = MatrixOperator(A) @@ -172,6 +187,7 @@ def test_slogdet(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? def test_solve(): A = jax.device_put(np.random.randn(8, 8)) b = jax.device_put(np.random.randn(8)) @@ -180,6 +196,7 @@ def test_solve(): check_results(f(A), f(Ao)) +@pytest.mark.skip # do we care to support...? def test_multi_dot(): A = jax.device_put(np.random.randn(8, 8)) B = jax.device_put(np.random.randn(8, 4)) @@ -244,8 +261,7 @@ def test_ufunc_maximum(): Bb = BlockArray.array((B, D)) res = BlockArray.array((snp.array([2, 3, 4]), snp.array([5, 7]))) Bmax = snp.maximum(Ba, Bb) - assert Bmax.shape == res.shape - np.testing.assert_allclose(Bmax.ravel(), res.ravel()) + snp.testing.assert_allclose(Bmax, res) A = snp.array([1, 6, 3]) B = snp.array([6, 3, 8]) @@ -253,8 +269,7 @@ def test_ufunc_maximum(): Ba = BlockArray.array((A, B)) res = BlockArray.array((snp.array([5, 6, 5]), snp.array([6, 5, 8]))) Bmax = snp.maximum(Ba, C) - assert Bmax.shape == res.shape - np.testing.assert_allclose(Bmax.ravel(), res.ravel()) + snp.testing.assert_allclose(Bmax, res) def test_ufunc_sign(): @@ -264,11 +279,11 @@ def test_ufunc_sign(): Ba = BlockArray.array((snp.array([10, -5, 0]),)) res = BlockArray.array((snp.array([1, -1, 0]),)) - np.testing.assert_allclose(snp.sign(Ba).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.sign(Ba), res) Ba = BlockArray.array((snp.array([10, -5, 0]), snp.array([0, 5, -6]))) res = BlockArray.array((snp.array([1, -1, 0]), snp.array([0, 1, -1]))) - np.testing.assert_allclose(snp.sign(Ba).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.sign(Ba), res) def test_ufunc_where(): @@ -309,17 +324,17 @@ def test_ufunc_true_divide(): Ba = BlockArray.array((snp.array([1, 2, 3]),)) Bb = BlockArray.array((snp.array([3, 3, 3]),)) res = BlockArray.array((snp.array([0.33333333, 0.66666667, 1.0]),)) - np.testing.assert_allclose(snp.true_divide(Ba, Bb).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res) Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) Bb = BlockArray.array((snp.array([3, 3, 3]), snp.array([2, 2]))) res = BlockArray.array((snp.array([0.33333333, 0.66666667, 1.0]), snp.array([0.5, 1.0]))) - np.testing.assert_allclose(snp.true_divide(Ba, Bb).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res) Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) A = 2 res = BlockArray.array((snp.array([0.5, 1.0, 1.5]), snp.array([0.5, 1.0]))) - np.testing.assert_allclose(snp.true_divide(Ba, A).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.true_divide(Ba, A), res) def test_ufunc_floor_divide(): @@ -336,17 +351,17 @@ def test_ufunc_floor_divide(): Ba = BlockArray.array((snp.array([1, 2, 3]),)) Bb = BlockArray.array((snp.array([3, 3, 3]),)) res = BlockArray.array((snp.array([0, 0, 1.0]),)) - np.testing.assert_allclose(snp.floor_divide(Ba, Bb).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res) Ba = BlockArray.array((snp.array([1, 7, 3]), snp.array([1, 2]))) Bb = BlockArray.array((snp.array([3, 3, 3]), snp.array([2, 2]))) res = BlockArray.array((snp.array([0, 2, 1.0]), snp.array([0, 1.0]))) - np.testing.assert_allclose(snp.floor_divide(Ba, Bb).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res) Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) A = 2 res = BlockArray.array((snp.array([0, 1.0, 1.0]), snp.array([0, 1.0]))) - np.testing.assert_allclose(snp.floor_divide(Ba, A).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.floor_divide(Ba, A), res) def test_ufunc_real(): @@ -360,11 +375,11 @@ def test_ufunc_real(): Ba = BlockArray.array((snp.array([1 + 3j]),)) res = BlockArray.array((snp.array([1]),)) - np.testing.assert_allclose(snp.real(Ba).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.real(Ba), res) Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) res = BlockArray.array((snp.array([1]), snp.array([1, 4.0]))) - np.testing.assert_allclose(snp.real(Ba).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.real(Ba), res) def test_ufunc_imag(): @@ -378,11 +393,11 @@ def test_ufunc_imag(): Ba = BlockArray.array((snp.array([1 + 3j]),)) res = BlockArray.array((snp.array([3]),)) - np.testing.assert_allclose(snp.imag(Ba).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.imag(Ba), res) Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) res = BlockArray.array((snp.array([3]), snp.array([3, 0]))) - np.testing.assert_allclose(snp.imag(Ba).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.imag(Ba), res) def test_ufunc_conj(): @@ -396,99 +411,101 @@ def test_ufunc_conj(): Ba = BlockArray.array((snp.array([1 + 3j]),)) res = BlockArray.array((snp.array([1 - 3j]),)) - np.testing.assert_allclose(snp.conj(Ba).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.conj(Ba), res) Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) res = BlockArray.array((snp.array([1 - 3j]), snp.array([1 - 3j, 4.0 - 0j]))) - np.testing.assert_allclose(snp.conj(Ba).ravel(), res.ravel()) + snp.testing.assert_allclose(snp.conj(Ba), res) def test_create_zeros(): - A = snc.zeros(2) + A = snp.zeros(2) assert np.all(A == 0) - A = snc.zeros([(2,), (2,)]) - assert np.all(A.ravel() == 0) + A = snp.zeros([(2,), (2,)]) + assert all(snp.all(A == 0)) def test_create_ones(): - A = snc.ones(2, dtype=np.float32) + A = snp.ones(2, dtype=np.float32) assert np.all(A == 1) - A = snc.ones([(2,), (2,)]) - assert np.all(A.ravel() == 1) + A = snp.ones([(2,), (2,)]) + assert all(snp.all(A == 1)) def test_create_zeros(): - A = snc.empty(2) + A = snp.empty(2) assert np.all(A == 0) - A = snc.empty([(2,), (2,)]) - assert np.all(A.ravel() == 0) + A = snp.empty([(2,), (2,)]) + assert all(snp.all(A == 0)) def test_create_full(): - A = snc.full((2,), 1) + A = snp.full((2,), 1) assert np.all(A == 1) - A = snc.full((2,), 1, dtype=np.float32) + A = snp.full((2,), 1, dtype=np.float32) assert np.all(A == 1) - A = snc.full([(2,), (2,)], 1) - assert np.all(A.ravel() == 1) + A = snp.full([(2,), (2,)], 1) + assert all(snp.all(A == 1)) def test_create_zeros_like(): - A = snc.ones(2, dtype=np.float32) - B = snc.zeros_like(A) + A = snp.ones(2, dtype=np.float32) + B = snp.zeros_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snc.ones(2, dtype=np.float32) - B = snc.zeros_like(A) + A = snp.ones(2, dtype=np.float32) + B = snp.zeros_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snc.ones([(2,), (2,)], dtype=np.float32) - B = snc.zeros_like(A) - assert np.all(B.ravel() == 0) and A.shape == B.shape and A.dtype == B.dtype + A = snp.ones([(2,), (2,)], dtype=np.float32) + B = snp.zeros_like(A) + assert all(snp.all(B == 0)) + assert A.shape == B.shape + assert A.dtype == B.dtype def test_create_empty_like(): - A = snc.ones(2, dtype=np.float32) - B = snc.empty_like(A) + A = snp.ones(2, dtype=np.float32) + B = snp.empty_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snc.ones(2, dtype=np.float32) - B = snc.empty_like(A) + A = snp.ones(2, dtype=np.float32) + B = snp.empty_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snc.ones([(2,), (2,)], dtype=np.float32) - B = snc.empty_like(A) - assert np.all(B.ravel() == 0) and A.shape == B.shape and A.dtype == B.dtype + A = snp.ones([(2,), (2,)], dtype=np.float32) + B = snp.empty_like(A) + assert all(snp.all(B == 0)) and A.shape == B.shape and A.dtype == B.dtype def test_create_ones_like(): - A = snc.zeros(2, dtype=np.float32) - B = snc.ones_like(A) + A = snp.zeros(2, dtype=np.float32) + B = snp.ones_like(A) assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype - A = snc.zeros(2, dtype=np.float32) - B = snc.ones_like(A) + A = snp.zeros(2, dtype=np.float32) + B = snp.ones_like(A) assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype - A = snc.zeros([(2,), (2,)], dtype=np.float32) - B = snc.ones_like(A) - assert np.all(B.ravel() == 1) and A.shape == B.shape and A.dtype == B.dtype + A = snp.zeros([(2,), (2,)], dtype=np.float32) + B = snp.ones_like(A) + assert all(snp.all(B == 1)) and A.shape == B.shape and A.dtype == B.dtype def test_create_full_like(): - A = snc.zeros(2, dtype=np.float32) - B = snc.full_like(A, 1.0) + A = snp.zeros(2, dtype=np.float32) + B = snp.full_like(A, 1.0) assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) - A = snc.zeros(2, dtype=np.float32) - B = snc.full_like(A, 1) + A = snp.zeros(2, dtype=np.float32) + B = snp.full_like(A, 1) assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) - A = snc.zeros([(2,), (2,)], dtype=np.float32) - B = snc.full_like(A, 1) - assert np.all(B.ravel() == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) + A = snp.zeros([(2,), (2,)], dtype=np.float32) + B = snp.full_like(A, 1) + assert all(snp.all(B == 1)) and (A.shape == B.shape) and (A.dtype == B.dtype) diff --git a/scico/test/test_random.py b/scico/test/test_random.py index 183f0397a..23d2cf678 100644 --- a/scico/test/test_random.py +++ b/scico/test/test_random.py @@ -25,11 +25,7 @@ def test_wrapped_funcs(seed): seed = 42 key = jax.random.PRNGKey(seed) - result = fun_wrapped(shape, seed=seed) - - # test nested blockarray - shape = ((7, (1, 3)), (3, 2), (2, 4, 1)) - result = fun_wrapped(shape, seed=seed) + result, _ = fun_wrapped(shape, seed=seed) def test_add_seed_adapter(): @@ -87,11 +83,6 @@ def test_block_shape_adapter(): result = fun_alt(key, shape) assert isinstance(result, scico.numpy.BlockArray) - # should work for deeply nested as well - shape = ((7,), (3, (2, 1)), (2, 4, 1)) - result = fun_alt(key, shape) - assert isinstance(result, scico.numpy.BlockArray) - # when shape is not nested, behavior should be normal shape = (1,) result_A = fun(key, shape) diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index 256fceefe..04c80c16c 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -198,12 +198,11 @@ def test_split_join_blockarray(): real_block = BlockArray.array((x_s[0][0], x_s[1][0])) imag_block = BlockArray.array((x_s[0][1], x_s[1][1])) - np.testing.assert_allclose(real_block.ravel(), snp.real(x).ravel(), rtol=1e-4) - np.testing.assert_allclose(imag_block.ravel(), snp.imag(x).ravel(), rtol=1e-4) + snp.testing.assert_allclose(real_block, snp.real(x), rtol=1e-4) + snp.testing.assert_allclose(imag_block, snp.imag(x), rtol=1e-4) x_j = solver._join_real_imag(x_s) - assert x_j.shape == x.shape - np.testing.assert_allclose(x_j.ravel(), x.ravel(), rtol=1e-4) + snp.testing.assert_allclose(x_j, x, rtol=1e-4) def test_bisect(): From 99b6127a72435524f1380890cbff840082fa485b Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 12 Apr 2022 14:49:11 -0600 Subject: [PATCH 13/37] Fix tests, add dtype temporary fix --- scico/linop/_stack.py | 2 +- scico/numpy/__init__.py | 2 +- scico/numpy/blockarray.py | 12 +++++ scico/test/linop/test_abel.py | 1 + scico/test/linop/test_diff.py | 65 +++++++++------------------ scico/test/linop/test_linop.py | 17 ++++--- scico/test/linop/test_radon_svmbir.py | 4 +- 7 files changed, 47 insertions(+), 56 deletions(-) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 3cbd243ab..2289904ef 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -85,7 +85,7 @@ def __init__( def _eval(self, x: JaxArray) -> Union[JaxArray, BlockArray]: if self.collapsable and self.collapse: return snp.stack([op @ x for op in self.ops]) - return BlockArray.array([op @ x for op in self.ops]) + return BlockArray([op @ x for op in self.ops]) def _adj(self, y: Union[JaxArray, BlockArray]) -> JaxArray: return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 7f43e2139..7ce2a93ae 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -19,7 +19,7 @@ import jax.numpy as jnp from . import _util -from .blockarray import BlockArray +from .blockarray import BlockArray, BlockDType, BlockShape from .util import * # wrap jnp diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 8004e6430..4bba465c6 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -499,6 +499,18 @@ def _full_ravel(self) -> DeviceArray: """ return jnp.concatenate(tuple(x_i.ravel() for x_i in self)) + @property + def dtype(self): + """Allow snp.zeros(x.shape, x.dtype) to work.""" + return self[0].dtype # TODO: a better solution is beyond current scope + + def __getitem__(self, key): + """Make, e.g., x[:2] return a BlockArray, not a list.""" + result = super().__getitem__(key) + if not isinstance(result, jnp.ndarray): + return BlockArray(result) + return result + """ backwards compatibility methods, could be removed """ @staticmethod diff --git a/scico/test/linop/test_abel.py b/scico/test/linop/test_abel.py index 8c63019f7..ca5323af3 100644 --- a/scico/test/linop/test_abel.py +++ b/scico/test/linop/test_abel.py @@ -27,6 +27,7 @@ def test_inverse(Nx, Ny): Ax = A @ im im_hat = A.inverse(Ax) + np.testing.assert_allclose(im_hat, im, rtol=5e-5) diff --git a/scico/test/linop/test_diff.py b/scico/test/linop/test_diff.py index 09d08138f..89aca58e0 100644 --- a/scico/test/linop/test_diff.py +++ b/scico/test/linop/test_diff.py @@ -4,61 +4,38 @@ import scico.numpy as snp from scico.linop import FiniteDifference -from scico.numpy import BlockArray from scico.random import randn from scico.test.linop.test_linop import adjoint_test +def test_eval(): + with pytest.raises(ValueError): # axis 3 does not exist + A = FiniteDifference(input_shape=(3, 4, 5), axes=(0, 3)) + + A = FiniteDifference(input_shape=(2, 3), append=0.0) + + x = snp.array([[1, 0, 1], [1, 1, 0]], dtype=snp.float32) + + Ax = A @ x + + snp.testing.assert_allclose( + Ax[0], # down columns x[1] - x[0], ..., append - x[N-1] + snp.array([[0, 1, -1], [-1, -1, 0]]), + ) + snp.testing.assert_allclose(Ax[1], snp.array([[-1, 1, -1], [0, -1, 0]])) # along rows + + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("axes", [0, 1, (0,), (1,), None]) @pytest.mark.parametrize("jit", [False, True]) -@pytest.mark.parametrize("append", [None, 0.0]) -def test_eval(input_shape, input_dtype, axes, jit, append): - +def test_adjoint(self, input_shape, input_dtype, axes, jit): ndim = len(input_shape) - x, _ = randn(input_shape, dtype=input_dtype) - if axes in [1, (1,)] and ndim == 1: - with pytest.raises(ValueError): - A = FiniteDifference( - input_shape=input_shape, input_dtype=input_dtype, axes=axes, append=append - ) + pass else: - A = FiniteDifference( - input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit, append=append - ) - Ax = A @ x - assert A.input_dtype == input_dtype - - # construct expected output - if axes is None: - if ndim == 1: - y = snp.diff(x, append=append) - else: - y = BlockArray.array( - [snp.diff(x, axis=0, append=append), snp.diff(x, axis=1, append=append)] - ) - elif np.isscalar(axes): - y = snp.diff(x, axis=axes, append=append) - elif len(axes) == 1: - y = snp.diff(x, axis=axes[0], append=append) - - np.testing.assert_allclose(Ax.ravel(), y.ravel(), rtol=1e-4) - - @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) - @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) - @pytest.mark.parametrize("axes", [0, 1, (0,), (1,), None]) - @pytest.mark.parametrize("jit", [False, True]) - def test_adjoint(self, input_shape, input_dtype, axes, jit): - ndim = len(input_shape) - if axes in [1, (1,)] and ndim == 1: - pass - else: - A = FiniteDifference( - input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit - ) - adjoint_test(A) + A = FiniteDifference(input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit) + adjoint_test(A) @pytest.mark.parametrize( diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index 4f5b26c56..2f524028a 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -328,7 +328,7 @@ def test_eval(self, input_shape, diagonal_dtype): D = linop.Diagonal(diagonal=diagonal) assert (D @ x).shape == D.output_shape - np.testing.assert_allclose((diagonal * x).ravel(), (D @ x).ravel(), rtol=1e-5) + snp.testing.assert_allclose((diagonal * x), (D @ x), rtol=1e-5) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) def test_eval_broadcasting(self, diagonal_dtype): @@ -341,10 +341,10 @@ def test_eval_broadcasting(self, diagonal_dtype): # blockarray broadcast diagonal, key = randn(((3, 1, 4), (5, 5)), dtype=diagonal_dtype, key=self.key) - x, key = randn(((5, 1), 1), dtype=diagonal_dtype, key=key) + x, key = randn(((5, 1), (1,)), dtype=diagonal_dtype, key=key) D = linop.Diagonal(diagonal, x.shape) assert (D @ x).shape == ((3, 5, 4), (5, 5)) - np.testing.assert_allclose((diagonal * x).ravel(), (D @ x).ravel(), rtol=1e-5) + snp.testing.assert_allclose((diagonal * x), (D @ x), rtol=1e-5) # blockarray x array -> error diagonal, key = randn(((3, 1, 4), (5, 5)), dtype=diagonal_dtype, key=self.key) @@ -354,7 +354,7 @@ def test_eval_broadcasting(self, diagonal_dtype): # array x blockarray -> error diagonal, key = randn((3, 1, 4), dtype=diagonal_dtype, key=self.key) - x, key = randn(((5, 1), 1), dtype=diagonal_dtype, key=key) + x, key = randn(((5, 1), (1,)), dtype=diagonal_dtype, key=key) with pytest.raises(ValueError): D = linop.Diagonal(diagonal, x.shape) @@ -386,7 +386,7 @@ def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator): a = operator(D1, D2) @ x Dnew = linop.Diagonal(operator(diagonal1, diagonal2)) b = Dnew @ x - np.testing.assert_allclose(a.ravel(), b.ravel(), rtol=1e-5) + snp.testing.assert_allclose(a, b, rtol=1e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_binary_op_mismatch(self, operator): @@ -545,15 +545,14 @@ def test_slice_adj(slicetestobj, idx): block_slice_examples = [ 1, - np.s_[1, :-3], - np.s_[1, :, :3], - np.s_[1, ..., 2:], + np.s_[0:1], + np.s_[:1], ] @pytest.mark.parametrize("idx", block_slice_examples) def test_slice_blockarray(idx): - x = BlockArray.array((snp.zeros((3, 4)), snp.ones((3, 4, 5, 6)))) + x = BlockArray((snp.zeros((3, 4)), snp.ones((3, 4, 5, 6)))) A = linop.Slice(idx=idx, input_shape=x.shape, input_dtype=x.dtype) assert (A @ x).shape == x[idx].shape diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index 17f67bfb5..624a008b2 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -274,7 +274,9 @@ def test_prox_cg( ) y = A @ im - A_colsum = A.H @ snp.ones(y.shape) # backproject ones to get sum over cols of A + A_colsum = A.H @ snp.ones( + y.shape, dtype=snp.float32 + ) # backproject ones to get sum over cols of A if is_masked: mask = np.asarray(A_colsum) > 0 # cols of A which are not all zeros else: From ee5e39a56f10f2e4584ff14439ac5eade52eed29 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 12 Apr 2022 15:53:28 -0600 Subject: [PATCH 14/37] Get tests passing --- docs/source/notes.rst | 4 ++-- scico/loss.py | 2 +- scico/numpy/__init__.py | 2 +- scico/numpy/blockarray.py | 11 +++++++++++ scico/test/linop/test_diff.py | 2 +- scico/test/test_blockarray.py | 4 ++-- scico/test/test_operator.py | 10 +++++----- 7 files changed, 23 insertions(+), 12 deletions(-) diff --git a/docs/source/notes.rst b/docs/source/notes.rst index eef3b2fc7..2431ae9e7 100644 --- a/docs/source/notes.rst +++ b/docs/source/notes.rst @@ -185,7 +185,7 @@ When evaluating the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``: >>> import scico >>> import scico.numpy as snp >>> f = lambda x: snp.linalg.norm(x)**2 - >>> scico.grad(f)(snp.zeros(2)) # + >>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # DeviceArray([nan, nan], dtype=float32) This can be fixed by defining the squared :math:`\ell_2` norm directly as @@ -194,7 +194,7 @@ This can be fixed by defining the squared :math:`\ell_2` norm directly as :: >>> g = lambda x: snp.sum(x**2) - >>> scico.grad(g)(snp.zeros(2)) + >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) DeviceArray([0., 0.], dtype=float32) An alternative is to define a `custom derivative rule `_ to enforce a particular derivative convention at a point. diff --git a/scico/loss.py b/scico/loss.py index 5d001ea08..59faed517 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -201,7 +201,7 @@ def __init__( self.has_prox = True def __call__(self, x: Union[JaxArray, BlockArray]) -> float: - return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum() + return self.scale * snp.sum(self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2) def prox( self, v: Union[JaxArray, BlockArray], lam: float, **kwargs diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 7ce2a93ae..7f43e2139 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -19,7 +19,7 @@ import jax.numpy as jnp from . import _util -from .blockarray import BlockArray, BlockDType, BlockShape +from .blockarray import BlockArray from .util import * # wrap jnp diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 4bba465c6..9dff15ca7 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -473,11 +473,15 @@ from functools import wraps from typing import Callable +import numpy as np + import jax import jax.numpy as jnp from jaxlib.xla_extension import DeviceArray +# TODO: .sum(), etc. should call snp + class BlockArray(list): """BlockArray""" @@ -487,6 +491,13 @@ class BlockArray(list): # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ __array_priority__ = 1 + def __init__(self, arrays): + arrays = list(arrays) # in case it is a generator + if any(not isinstance(x, (jnp.ndarray, np.ndarray)) for x in arrays): + raise ValueError("BlockArrays must be constructed from DeviceArrays or ndarrays") + + return super().__init__(x if isinstance(x, jnp.ndarray) else jnp.array(x) for x in arrays) + def _full_ravel(self) -> DeviceArray: """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. diff --git a/scico/test/linop/test_diff.py b/scico/test/linop/test_diff.py index 89aca58e0..159f07c30 100644 --- a/scico/test/linop/test_diff.py +++ b/scico/test/linop/test_diff.py @@ -29,7 +29,7 @@ def test_eval(): @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("axes", [0, 1, (0,), (1,), None]) @pytest.mark.parametrize("jit", [False, True]) -def test_adjoint(self, input_shape, input_dtype, axes, jit): +def test_adjoint(input_shape, input_dtype, axes, jit): ndim = len(input_shape) if axes in [1, (1,)] and ndim == 1: pass diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index c5d53a1c1..067e781d8 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -438,14 +438,14 @@ def test_full(self): fill_value = np.float32(np.random.randn()) x = snp.full(self.shape, fill_value=fill_value, dtype=np.float32) assert x.shape == self.shape - assert x.dtype == (np.float32, np.float32, np.float32) + assert x.dtype == np.float32 assert snp.all(x == fill_value) def test_full_nodtype(self): fill_value = np.float32(np.random.randn()) x = snp.full(self.shape, fill_value=fill_value, dtype=None) assert x.shape == self.shape - assert x.dtype == (fill_value.dtype, fill_value.dtype, fill_value.dtype) + assert x.dtype == fill_value.dtype assert snp.all(x == fill_value) diff --git a/scico/test/test_operator.py b/scico/test/test_operator.py index d02d34ac6..5b768525f 100644 --- a/scico/test/test_operator.py +++ b/scico/test/test_operator.py @@ -176,9 +176,9 @@ def test_freeze_3arg(): input_shape=((1, 3, 4), (2, 1, 4), (2, 3, 1)), eval_fn=lambda x: x[0] * x[1] * x[2] ) - a = np.random.randn(1, 3, 4) - b = np.random.randn(2, 1, 4) - c = np.random.randn(2, 3, 1) + a, _ = randn((1, 3, 4)) + b, _ = randn((2, 1, 4)) + c, _ = randn((2, 3, 1)) x = BlockArray.array([a, b, c]) Abc = A.freeze(0, a) # A as a function of b, c @@ -201,8 +201,8 @@ def test_freeze_2arg(): A = Operator(input_shape=((1, 3, 4), (2, 1, 4)), eval_fn=lambda x: x[0] * x[1]) - a = np.random.randn(1, 3, 4) - b = np.random.randn(2, 1, 4) + a, _ = randn((1, 3, 4)) + b, _ = randn((2, 1, 4)) x = BlockArray.array([a, b]) Ab = A.freeze(0, a) # A as a function of 'b' only From 8c6780d9e721f6e28444e646355ce6fbfd2d4caf Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 13 Apr 2022 08:52:11 -0600 Subject: [PATCH 15/37] Add missing module --- scico/numpy/blockarray.py | 73 +++++++++++++--------- scico/numpy/util.py | 124 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 27 deletions(-) create mode 100644 scico/numpy/util.py diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 9dff15ca7..1be77d52c 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -15,39 +15,58 @@ >>> import numpy as np >>> import jax.numpy -The class :class:`.BlockArray` provides a way to combine several arrays -of different shapes and/or data types into a single object. -A :class:`.BlockArray` consists of a list of `DeviceArray` -objects, which we refer to as blocks. -:class:`.BlockArray`s differ from tuples in that mathematical operations -on :class:`.BlockArray`s automatically map along the blocks, returning -another :class:`.BlockArray` or tuple as appropriate. For example, +The class :class:`.BlockArray` provides a way to combine arrays of +different shapes into a single object for use with other SCICO classes. +A :class:`.BlockArray` consists of a list of `DeviceArray` objects, +which we refer to as blocks. :class:`.BlockArray`s differ from lists in +that, whenever possible, operations involving :class:`.BlockArray`s +automatically map along the blocks, returning another +:class:`.BlockArray` or tuple as appropriate. For example, :: >>> x = BlockArray(( - snp.array( - [[1, 3, 7], - [2, 2, 1],] - ), - snp.array( - [2, 4, 8] - ), + [[1, 3, 7], + [2, 2, 1]], + [2, 4, 8] )) >>> x.shape ((2, 3), (3,)) # tuple - >>> x + 1 - (DeviceArray([[2, 4, 8], - [3, 3, 2]], dtype=int32), - DeviceArray([3, 5, 9], dtype=int32)) # BlockArray + >>> x * 2 + (DeviceArray([[2, 6, 14], + [4, 4, 2]], dtype=int32), + DeviceArray([4, 8, 16], dtype=int32)) # BlockArray + + >>> y = BlockArray(( + [[.2], + [.3]], + [.4] + )) + >>> x + y + [DeviceArray([[1.2, 3.2, 7.2], + [2.3, 2.3, 1.3]], dtype=float32), + DeviceArray([2.4, 4.4, 8.4], dtype=float32)] # BlockArray + + +NumPy Functions +=============== + +:mod:`scico.numpy` provides a wrapper around :mod:`jax.numpy` where many +of the functions have been extended to work with `BlockArray`s. In +particular, array creation + + :: + >>> import scico.numpy as snp + >>> ... +TODO: working with SCICO operators TODO: not specifying axis to get a full reduction TODO: using a BlockArray for axis or shape arguments TODO: indexing TODO: mention snp.testing here or in numpy - +TODO: -x doesn't work @@ -466,15 +485,12 @@ BlockArray([ [2, 2], [2,] ]) - """ import inspect from functools import wraps from typing import Callable -import numpy as np - import jax import jax.numpy as jnp @@ -491,12 +507,15 @@ class BlockArray(list): # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ __array_priority__ = 1 - def __init__(self, arrays): - arrays = list(arrays) # in case it is a generator - if any(not isinstance(x, (jnp.ndarray, np.ndarray)) for x in arrays): - raise ValueError("BlockArrays must be constructed from DeviceArrays or ndarrays") + def __init__(self, inputs): + # convert inputs to DeviceArrays + arrays = [x if isinstance(x, jnp.ndarray) else jnp.array(x) for x in inputs] + + # check that dtypes match + if not all(a.dtype == arrays[0].dtype for a in arrays): + raise ValueError("Heterogeneous dtypes not supported") - return super().__init__(x if isinstance(x, jnp.ndarray) else jnp.array(x) for x in arrays) + return super().__init__(arrays) def _full_ravel(self) -> DeviceArray: """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. diff --git a/scico/numpy/util.py b/scico/numpy/util.py new file mode 100644 index 000000000..cd71408dd --- /dev/null +++ b/scico/numpy/util.py @@ -0,0 +1,124 @@ +""" Utility functions for working with BlockArrays and DeviceArrays. """ + +from math import prod +from typing import Any, Union + +import scico.numpy as snp +from scico.typing import Axes, BlockShape, DType, JaxArray, Shape + +from .blockarray import BlockArray + + +def no_nan_divide( + x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] +) -> Union[BlockArray, JaxArray]: + """Return `x/y`, with 0 instead of NaN where `y` is 0. + + Args: + x: Numerator. + y: Denominator. + + Returns: + `x / y` with 0 wherever `y == 0`. + """ + + return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) + + +def shape_to_size(shape: Union[Shape, BlockShape]) -> Axes: + r"""Compute the size corresponding to a (possibly nested) shape. + + Args: + shape: A shape tuple; possibly tuples. + """ + + if is_nested(shape): + return sum(prod(s) for s in shape) + + return prod(shape) + + +def is_nested(x: Any) -> bool: + """Check if input is a list/tuple containing at least one list/tuple. + + Args: + x: Object to be tested. + + Returns: + ``True`` if `x` is a list/tuple of list/tuples, ``False`` otherwise. + + + Example: + >>> is_nested([1, 2, 3]) + False + >>> is_nested([(1,2), (3,)]) + True + >>> is_nested([[1, 2], 3]) + True + + """ + if isinstance(x, (list, tuple)): + return any([isinstance(_, (list, tuple)) for _ in x]) + return False + + +def is_real_dtype(dtype: DType) -> bool: + """Determine whether a dtype is real. + + Args: + dtype: A numpy or scico.numpy dtype (e.g. np.float32, + snp.complex64). + + Returns: + ``False`` if the dtype is complex, otherwise ``True``. + """ + return snp.dtype(dtype).kind != "c" + + +def is_complex_dtype(dtype: DType) -> bool: + """Determine whether a dtype is complex. + + Args: + dtype: A numpy or scico.numpy dtype (e.g. ``np.float32``, + ``np.complex64``). + + Returns: + ``True`` if the dtype is complex, otherwise ``False``. + """ + return snp.dtype(dtype).kind == "c" + + +def real_dtype(dtype: DType) -> DType: + """Construct the corresponding real dtype for a given complex dtype. + + Construct the corresponding real dtype for a given complex dtype, + e.g. the real dtype corresponding to `np.complex64` is + `np.float32`. + + Args: + dtype: A complex numpy or scico.numpy dtype (e.g. ``np.complex64``, + ``np.complex128``). + + Returns: + The real dtype corresponding to the input dtype + """ + + return snp.zeros(1, dtype).real.dtype + + +def complex_dtype(dtype: DType) -> DType: + """Construct the corresponding complex dtype for a given real dtype. + + Construct the corresponding complex dtype for a given real dtype, + e.g. the complex dtype corresponding to ``np.float32`` is + ``np.complex64``. + + Args: + dtype: A real numpy or scico.numpy dtype (e.g. ``np.float32``, + ``np.float64``). + + Returns: + The complex dtype corresponding to the input dtype. + """ + + return (snp.zeros(1, dtype) + 1j).dtype From b05d70dbe8f5a68bb9651387a1540c29950aae49 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 13 Apr 2022 16:20:54 -0600 Subject: [PATCH 16/37] Work on docs --- scico/numpy/__init__.py | 36 ++++----- scico/numpy/_util.py | 149 ++++++++++++++++++++++++-------------- scico/numpy/blockarray.py | 41 ++++++++--- scico/test/test_numpy.py | 4 +- 4 files changed, 145 insertions(+), 85 deletions(-) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 7f43e2139..a63c1d8c1 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -20,33 +20,35 @@ from . import _util from .blockarray import BlockArray +from .function_lists import ( + creation_routines, + mathematical_functions, + reduction_fuctions, + testing_functions, +) from .util import * -# wrap jnp -_util.wrap_attributes( +# copy most of jnp without wrapping +_util.add_attributes( to_dict=vars(), from_dict=jnp.__dict__, modules_to_recurse=("linalg", "fft"), - reductions=("sum", "norm"), - no_wrap=( - "dtype", - "broadcast_shapes", # nested tuples as normal input (*shapes) - "array", # no meaning mapped over blocks - "stack", # no meaning mapped over blocks - "concatenate", # no meaning mapped over blocks - "pad", - ), # nested tuples as normal input ) -# wrap np.testing -_util.wrap_attributes( +# wrap jnp funcs +_util.wrap_recursively(vars(), creation_routines, _util.map_func_over_tuple_of_tuples) +_util.wrap_recursively(vars(), mathematical_functions, _util.map_func_over_blocks) +_util.wrap_recursively(vars(), reduction_functions, _util.add_full_reduction) + +# copy np.testing +_util.add_attributes( to_dict=vars(), - from_dict={k: v for k, v in np.__dict__.items() if k == "testing"}, - modules_to_recurse=("testing"), + from_dict={"testing": np.testing}, + modules_to_recurse=("testing",), ) - -__all__ = ["BlockArray"] +# wrap testing funcs +_util.wrap_recursively(vars(), testing_functions, _util.map_func_over_blocks) # clean up del np, jnp, _util diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index 98612cc42..52f30e400 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -11,36 +11,23 @@ import jax.numpy as jnp from .blockarray import BlockArray -from .util import is_nested -def wrap_attributes( +def add_attributes( to_dict: dict, from_dict: dict, modules_to_recurse: Optional[Iterable[str]] = None, - reductions: Optional[Iterable[str]] = None, - no_wrap: Optional[Iterable[str]] = None, ): """Add attributes in `from_dict` to `to_dict`. - Underscore attributes are ignored. Functions are wrapped to allow for - `BlockArray` inputs. Modules are ignored, except those listed in - `modules_to_recurse`, which are added recursively. All others are - passed through unwrapped. - - no_warp: list of functions to attach unwrapped - + Underscore attributes are ignored. Modules are ignored, except those + listed in `modules_to_recurse`, which are added recursively. All + others are added. """ if modules_to_recurse is None: modules_to_recurse = () - if reductions is None: - reductions = () - - if no_wrap is None: - no_wrap = () - for name, obj in from_dict.items(): if name[0] == "_": continue @@ -50,64 +37,116 @@ def wrap_attributes( to_dict[name].__doc__ = obj.__doc__ # enable `import scico.numpy.linalg` and `from scico.numpy.linalg import norm` sys.modules[to_dict["__name__"] + "." + name] = to_dict[name] - wrap_attributes( - to_dict[name].__dict__, obj.__dict__, reductions=reductions, no_wrap=no_wrap - ) - elif isinstance(obj, Callable) and name not in no_wrap: - obj = map_func_over_ba(obj, is_reduction=name in reductions) - to_dict[name] = obj + add_attributes(to_dict[name].__dict__, obj.__dict__) else: to_dict[name] = obj -def map_func_over_ba(func, is_reduction=False): - """Create a version of `func` that maps over all of its `BlockArray` +def wrap_recursively( + target_dict: dict, + names: Iterable[str], + wrap: Callable, +): + """Call wrap functions in `target_dict`, correctly handling names like `"linalg.norm"`.""" + + for name in names: + if "." in name: + module, rest = name.split(".", maxsplit=1) + wrap_recursively(target_dict[module].__dict__, [rest], wrap) + else: + target_dict[name] = wrap(target_dict[name]) + + +def map_func_over_tuple_of_tuples(func: Callable, map_arg_name: Optional[str] = "shape"): + """Wrap a function so that it automatically maps over a tuple of tuples + argument, returning a `BlockArray`. + + """ + + @wraps(func) + def mapped(*args, **kwargs): + bound_args = signature(func).bind(*args, **kwargs) + + if map_arg_name not in bound_args.arguments: # no shape arg + return func(*args, **kwargs) # no mapping + + map_arg_val = bound_args.arguments.pop(map_arg_name) + + if not all(isinstance(x, tuple) for x in map_arg_val): # not nested tuple + return func(*args, **kwargs) # no mapping + + # map + return BlockArray( + func(*bound_args.args, **bound_args.kwargs, **{map_arg_name: x}) for x in map_arg_val + ) + + return mapped + + +def map_func_over_blocks(func, is_reduction=False): + """Wrap a function so that it maps over all of its `BlockArray` arguments. is_reduction: function is handled in a special way in order to allow full reductions of `BlockArray`s. If the axis parameter exists but - is not specified, the function is mapped over the blocks, then - called again on the stacked result. + is not specified, the function is called on a fully ravelled version + of all `BlockArray` inputs. """ + sig = signature(func) @wraps(func) def mapped(*args, **kwargs): - sig = signature(func) bound_args = sig.bind(*args, **kwargs) ba_args = {} for k, v in list(bound_args.arguments.items()): - if isinstance(v, BlockArray) or is_nested(v): + if isinstance(v, BlockArray): ba_args[k] = bound_args.arguments.pop(k) if not len(ba_args): # no BlockArray arguments - return func(*args, **kwargs) - - if is_reduction and "axis" not in bound_args.arguments: - if len(ba_args) > 1: - raise ValueError( - "Cannot perform a full reduction with multiple BlockArray arguments." - ) - # fully ravel the ba argument - ba_args = {k: jnp.concatenate(v.ravel()) for k, v in ba_args.items()} - return func(*bound_args.args, **bound_args.kwargs, **ba_args) - - result = tuple( - map( # map over - lambda *args: ( # lambda x_1, x_2, ..., x_N - func( - *bound_args.args, - **bound_args.kwargs, # ... nonBlockArray args - **dict(zip(ba_args.keys(), args)), - ) # plus dict of block args - ), - *ba_args.values(), # map(f, ba_1, ba_2, ..., ba_N) - ) - ) + return func(*args, **kwargs) # no mapping - if isinstance(result[0], jnp.ndarray): # True for abstract arrays, too - return BlockArray(result) + num_blocks = len(ba_args[k]) - return result + return BlockArray( + func(*bound_args.args, **bound_args.kwargs, **{k: v[i] for k, v in ba_args.items()}) + for i in range(num_blocks) + ) return mapped + + +def add_full_reduction(func: Callable, axis_arg_name: Optional[str] = "axis"): + """Wrap a function so that it can fully reduce a `BlockArray`. If + nothing is passed for the axis argument and the function is called + on a `BlockArray`, it is fully ravelled before the function is + called. + + Should be outside `map_func_over_blocks`. + """ + sig = signature(func) + if axis_arg_name not in sig.parameters: + raise ValueError( + f"Cannot wrap {func} as a reduction because it has no {axis_arg_name} argument" + ) + + @wraps(func) + def wrapped(*args, **kwargs): + bound_args = sig.bind(*args, **kwargs) + + ba_args = {} + for k, v in list(bound_args.arguments.items()): + if isinstance(v, BlockArray): + ba_args[k] = bound_args.arguments.pop(k) + + if "axis" in bound_args.arguments: + return func(*args, **kwargs, **ba_args) # call func as normal + + if len(ba_args) > 1: + raise ValueError("Cannot perform a full reduction with multiple BlockArray arguments.") + + # fully ravel the ba argument + ba_args = {k: jnp.concatenate(v.ravel()) for k, v in ba_args.items()} + return func(*bound_args.args, **bound_args.kwargs, **ba_args) + + return wrapped diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 1be77d52c..af80b6b85 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -18,10 +18,11 @@ The class :class:`.BlockArray` provides a way to combine arrays of different shapes into a single object for use with other SCICO classes. A :class:`.BlockArray` consists of a list of `DeviceArray` objects, -which we refer to as blocks. :class:`.BlockArray`s differ from lists in -that, whenever possible, operations involving :class:`.BlockArray`s -automatically map along the blocks, returning another -:class:`.BlockArray` or tuple as appropriate. For example, +which we refer to as blocks. :class:`.BlockArray`s differ from lists in +that, whenever possible, :class:`.BlockArray` properties and methods +(including unary and binary operators like +, -, *, ...) automatically +map along the blocks, returning another :class:`.BlockArray` or tuple as +appropriate. For example, :: @@ -54,18 +55,36 @@ :mod:`scico.numpy` provides a wrapper around :mod:`jax.numpy` where many of the functions have been extended to work with `BlockArray`s. In -particular, array creation +particular: + +* When a tuple of tuples is passed as the `shape` +argument to an array creation routine, a `BlockArray` is created. + +* When a `BlockArray` is passed to a reduction function, the blocks are +ravelled (i.e., reshaped to be 1D) and concatenated before the reduction +is applied. This behavior may be prevented by passing the `axis` +argument, in which case the function is mapped over the blocks. + +* When one or more `BlockArray`s is passed to a mathematical +function that is not a reduction, the function is mapped over +(corresponding) blocks. + +For lists of array creation routines, reduction functions, and mathematical +functions that have been wrapped in this manner, +see `scico.numpy.creation_routines`, `scico.numpy.reduction_fuctions`, +and +`scico.numpy.mathematical_functions`. + +:mod:`scico.numpy.testing` provides a wrapper around :mod:`numpy.testing` +where some functions have been extended to map over blocks, +notably `scico.numpy.testing.allclose`. +For a list of the extended functions, see `scico.numpy.testing_functions`. + - :: - >>> import scico.numpy as snp - >>> ... TODO: working with SCICO operators -TODO: not specifying axis to get a full reduction -TODO: using a BlockArray for axis or shape arguments TODO: indexing -TODO: mention snp.testing here or in numpy TODO: -x doesn't work diff --git a/scico/test/test_numpy.py b/scico/test/test_numpy.py index 47e62fe93..62531e907 100644 --- a/scico/test/test_numpy.py +++ b/scico/test/test_numpy.py @@ -377,8 +377,8 @@ def test_ufunc_real(): res = BlockArray.array((snp.array([1]),)) snp.testing.assert_allclose(snp.real(Ba), res) - Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) - res = BlockArray.array((snp.array([1]), snp.array([1, 4.0]))) + Ba = BlockArray.array((snp.array([1.0 + 3j]), snp.array([1 + 3j, 4.0]))) + res = BlockArray.array((snp.array([1.0]), snp.array([1, 4.0]))) snp.testing.assert_allclose(snp.real(Ba), res) From 4d8b60a0899165dec859da293609d9f2df3e2c50 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 07:28:52 -0600 Subject: [PATCH 17/37] Refactor out function lists --- scico/numpy/__init__.py | 11 +-- scico/numpy/_util.py | 4 +- scico/numpy/blockarray.py | 68 +++++++------- scico/numpy/function_lists.py | 145 ++++++++++++++++++++++++++++++ scico/test/test_new_blockarray.py | 15 ++++ 5 files changed, 200 insertions(+), 43 deletions(-) create mode 100644 scico/numpy/function_lists.py create mode 100644 scico/test/test_new_blockarray.py diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index a63c1d8c1..95c93fd68 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -20,12 +20,7 @@ from . import _util from .blockarray import BlockArray -from .function_lists import ( - creation_routines, - mathematical_functions, - reduction_fuctions, - testing_functions, -) +from .function_lists import * from .util import * # copy most of jnp without wrapping @@ -37,7 +32,9 @@ # wrap jnp funcs _util.wrap_recursively(vars(), creation_routines, _util.map_func_over_tuple_of_tuples) -_util.wrap_recursively(vars(), mathematical_functions, _util.map_func_over_blocks) +_util.wrap_recursively( + vars(), mathematical_functions + reduction_functions, _util.map_func_over_blocks +) _util.wrap_recursively(vars(), reduction_functions, _util.add_full_reduction) # copy np.testing diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index 52f30e400..3d737f0b4 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -106,7 +106,7 @@ def mapped(*args, **kwargs): if not len(ba_args): # no BlockArray arguments return func(*args, **kwargs) # no mapping - num_blocks = len(ba_args[k]) + num_blocks = len(list(ba_args.values())[0]) return BlockArray( func(*bound_args.args, **bound_args.kwargs, **{k: v[i] for k, v in ba_args.items()}) @@ -140,7 +140,7 @@ def wrapped(*args, **kwargs): ba_args[k] = bound_args.arguments.pop(k) if "axis" in bound_args.arguments: - return func(*args, **kwargs, **ba_args) # call func as normal + return func(*bound_args.args, **bound_args.kwargs, **ba_args) # call func as normal if len(ba_args) > 1: raise ValueError("Cannot perform a full reduction with multiple BlockArray arguments.") diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index af80b6b85..4edaab4bc 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -515,7 +515,9 @@ from jaxlib.xla_extension import DeviceArray -# TODO: .sum(), etc. should call snp +from .function_lists import binary_ops, unary_ops + +# CANCELED: .sum(), etc. should call snp class BlockArray(list): @@ -576,38 +578,34 @@ def array(iterable): lambda _, xs: BlockArray(xs), # from iter ) -""" Wrap binary ops like +, @. """ -binary_ops = ( - "__add__", - "__radd__", - "__sub__", - "__rsub__", - "__mul__", - "__rmul__", - "__matmul__", - "__rmatmul__", - "__truediv__", - "__rtruediv__", - "__floordiv__", - "__rfloordiv__", - "__pow__", - "__rpow__", - "__gt__", - "__ge__", - "__lt__", - "__le__", - "__eq__", - "__ne__", -) +""" Wrap unary ops like -x. """ + + +def _unary_op_wrapper(op): + op = getattr(DeviceArray, op_name) + + @wraps(op) + def op_ba(self): + return BlockArray(op(x) for x in self) + + return op_ba + + +for op_name in unary_ops: + setattr(BlockArray, op_name, _unary_op_wrapper(op_name)) + +""" Wrap binary ops like x+y. """ -def _binary_op_wrapper(op): +def _binary_op_wrapper(op_name): + op = getattr(DeviceArray, op_name) + @wraps(op) def op_ba(self, other): if isinstance(other, BlockArray): - result = BlockArray(getattr(x, op)(y) for x, y in zip(self, other)) - else: - result = BlockArray(getattr(x, op)(other) for x in self) + return BlockArray(op(x, y) for x, y in zip(self, other)) + + result = BlockArray(op(x, other) for x in self) if NotImplemented in result: return NotImplemented return result @@ -615,18 +613,20 @@ def op_ba(self, other): return op_ba -for op in binary_ops: - setattr(BlockArray, op, _binary_op_wrapper(op)) +for op_name in binary_ops: + setattr(BlockArray, op_name, _binary_op_wrapper(op_name)) """ Wrap DeviceArray properties. """ -def _da_prop_wrapper(prop): +def _da_prop_wrapper(prop_name): + prop = getattr(DeviceArray, prop_name) + @property @wraps(prop) def prop_ba(self): - result = tuple(getattr(x, prop) for x in self) + result = tuple(getattr(x, prop_name) for x in self) if isinstance(result[0], jnp.ndarray): return BlockArray(result) return result @@ -641,8 +641,8 @@ def prop_ba(self): if isinstance(v, property) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_props ] -for prop in da_props: - setattr(BlockArray, prop, _da_prop_wrapper(prop)) +for prop_name in da_props: + setattr(BlockArray, prop_name, _da_prop_wrapper(prop_name)) """ Wrap DeviceArray methods. """ diff --git a/scico/numpy/function_lists.py b/scico/numpy/function_lists.py new file mode 100644 index 000000000..ca21b6071 --- /dev/null +++ b/scico/numpy/function_lists.py @@ -0,0 +1,145 @@ +""" BlockArray """ +unary_ops = ( # found from dir(DeviceArray) + "__abs__", + "__neg__", +) + +binary_ops = ( # found from dir(DeviceArray) + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__matmul__", + "__rmatmul__", + "__truediv__", + "__rtruediv__", + "__floordiv__", + "__rfloordiv__", + "__pow__", + "__rpow__", + "__gt__", + "__ge__", + "__lt__", + "__le__", + "__eq__", + "__ne__", +) + +""" jax.numpy """ + +creation_routines = ( + "empty", + "ones", + "zeros", + "full", +) + +mathematical_functions = ( + "sin", # https://numpy.org/doc/stable/reference/routines.math.html# + "cos", + "tan", + "arcsin", + "arccos", + "arctan", + "hypot", + "arctan2", + "degrees", + "radians", + "unwrap", + "deg2rad", + "rad2deg", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", + "around", + "round_", + "rint", + "fix", + "floor", + "ceil", + "trunc", + "prod", + "sum", + "nanprod", + "nansum", + "cumprod", + "cumsum", + "nancumprod", + "nancumsum", + "diff", + "ediff1d", + "gradient", + "cross", + "trapz", + "exp", + "expm1", + "exp2", + "log", + "log10", + "log2", + "log1p", + "logaddexp", + "logaddexp2", + "i0", + "sinc", + "signbit", + "copysign", + "frexp", + "ldexp", + "nextafter", + "spacing", + "lcm", + "gcd", + "add", + "reciprocal", + "positive", + "negative", + "multiply", + "divide", + "power", + "subtract", + "true_divide", + "floor_divide", + "float_power", + "fmod", + "mod", + "modf", + "remainder", + "divmod", + "angle", + "real", + "imag", + "conj", + "conjugate", + "maximum", + "fmax", + "amax", + "nanmax", + "minimum", + "fmin", + "amin", + "nanmin", + "convolve", + "clip", + "sqrt", + "cbrt", + "square", + "absolute", + "fabs", + "sign", + "heaviside", + "nan_to_num", + "real_if_close", + "interp", +) + +reduction_functions = ("sum", "linalg.norm") + +""" numpy.testing """ + +testing_functions = ("testing.assert_allclose",) diff --git a/scico/test/test_new_blockarray.py b/scico/test/test_new_blockarray.py new file mode 100644 index 000000000..64b0eb5ba --- /dev/null +++ b/scico/test/test_new_blockarray.py @@ -0,0 +1,15 @@ +import pytest + +from scico.numpy import BlockArray + + +@pytest.fixture +def ba(): + return + + +def test_unary(): + x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42])) + y = -x + + # TODO FINISH From ae76e8da79f8b391cf902956aaeb203b07064787 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 08:29:17 -0600 Subject: [PATCH 18/37] Start on new BlockArray tests --- scico/numpy/function_lists.py | 4 +- scico/test/test_new_blockarray.py | 68 ++++++++++++++++++++++++++++--- 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/scico/numpy/function_lists.py b/scico/numpy/function_lists.py index ca21b6071..59441826e 100644 --- a/scico/numpy/function_lists.py +++ b/scico/numpy/function_lists.py @@ -2,6 +2,7 @@ unary_ops = ( # found from dir(DeviceArray) "__abs__", "__neg__", + "__pos__", ) binary_ops = ( # found from dir(DeviceArray) @@ -10,6 +11,7 @@ "__sub__", "__rsub__", "__mul__", + "__mod__", "__rmul__", "__matmul__", "__rmatmul__", @@ -142,4 +144,4 @@ """ numpy.testing """ -testing_functions = ("testing.assert_allclose",) +testing_functions = ("testing.assert_allclose", "testing.assert_array_equal") diff --git a/scico/test/test_new_blockarray.py b/scico/test/test_new_blockarray.py index 64b0eb5ba..27a421036 100644 --- a/scico/test/test_new_blockarray.py +++ b/scico/test/test_new_blockarray.py @@ -1,15 +1,71 @@ +import operator as op + import pytest from scico.numpy import BlockArray +from scico.numpy.testing import assert_array_equal + +for a in dir(op): + help(getattr(op, a)) + + +@pytest.fixture +def x(): + # any BlockArray, arbitrary shape, content, type + return BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0])) @pytest.fixture -def ba(): - return +def y(): + # another BlockArray, content, type, matching shape + return BlockArray(([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]], [-2.0])) + + +@pytest.mark.parametrize("op", [op.neg, op.pos, op.abs]) +def test_unary(op, x): + actual = op(x) + expected = BlockArray(op(x_i) for x_i in x) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype + + +@pytest.mark.parametrize( + "op", + [ + op.mul, + op.mod, + op.lt, + op.le, + op.gt, + op.ge, + op.floordiv, + op.eq, + op.add, + op.truediv, + op.sub, + op.ne, + ], +) +def test_elementwise_binary(op, x, y): + actual = op(x, y) + expected = BlockArray(op(x_i, y_i) for x_i, y_i in zip(x, y)) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype + + +# TODO: op.matmul + +def test_property(): + x = BlockArray(([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [0.0])) + actual = x.shape + expected = ((2, 3), (1,)) + assert actual == expected -def test_unary(): - x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42])) - y = -x - # TODO FINISH +def test_method(): + x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0])) + actual = x.max() + expected = BlockArray([[3.0], [42.0]]) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype From e0b1b03eea715be2e247b290b4a5446db36b36e8 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 10:00:20 -0600 Subject: [PATCH 19/37] Wrap additional functions --- scico/numpy/__init__.py | 4 +- scico/numpy/_util.py | 4 +- scico/numpy/function_lists.py | 129 +++++++++++++++++++++++++++++++++- scico/random.py | 7 +- scico/scipy/special.py | 88 ++++++++--------------- scico/test/test_numpy.py | 16 ++--- scico/test/test_random.py | 21 ------ 7 files changed, 173 insertions(+), 96 deletions(-) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 95c93fd68..e3cee1fd1 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -32,9 +32,7 @@ # wrap jnp funcs _util.wrap_recursively(vars(), creation_routines, _util.map_func_over_tuple_of_tuples) -_util.wrap_recursively( - vars(), mathematical_functions + reduction_functions, _util.map_func_over_blocks -) +_util.wrap_recursively(vars(), mathematical_functions, _util.map_func_over_blocks) _util.wrap_recursively(vars(), reduction_functions, _util.add_full_reduction) # copy np.testing diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index 3d737f0b4..504e4d9f5 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -72,7 +72,9 @@ def mapped(*args, **kwargs): map_arg_val = bound_args.arguments.pop(map_arg_name) - if not all(isinstance(x, tuple) for x in map_arg_val): # not nested tuple + if not isinstance(map_arg_val, tuple) or not all( + isinstance(x, tuple) for x in map_arg_val + ): # not nested tuple return func(*args, **kwargs) # no mapping # map diff --git a/scico/numpy/function_lists.py b/scico/numpy/function_lists.py index 59441826e..56471c881 100644 --- a/scico/numpy/function_lists.py +++ b/scico/numpy/function_lists.py @@ -131,6 +131,7 @@ "sqrt", "cbrt", "square", + "abs", "absolute", "fabs", "sign", @@ -138,10 +139,136 @@ "nan_to_num", "real_if_close", "interp", + "sort", # https://numpy.org/doc/stable/reference/routines.sort.html + "lexsort", + "argsort", + "msort", + "sort_complex", + "partition", + "argpartition", + "argmax", + "nanargmax", + "argmin", + "nanargmin", + "argwhere", + "nonzero", + "flatnonzero", + "where", + "searchsorted", + "extract", + "count_nonzero", + "dot", # https://numpy.org/doc/stable/reference/routines.linalg.html + "linalg.multi_dot", + "vdot", + "inner", + "outer", + "matmul", + "tensordot", + "einsum", + "einsum_path", + "linalg.matrix_power", + "kron", + "linalg.cholesky", + "linalg.qr", + "linalg.svd", + "linalg.eig", + "linalg.eigh", + "linalg.eigvals", + "linalg.eigvalsh", + "linalg.norm", + "linalg.cond", + "linalg.det", + "linalg.matrix_rank", + "linalg.slogdet", + "trace", + "linalg.solve", + "linalg.tensorsolve", + "linalg.lstsq", + "linalg.inv", + "linalg.pinv", + "linalg.tensorinv", + "shape", # https://numpy.org/doc/stable/reference/routines.array-manipulation.html + "reshape", + "ravel", + "moveaxis", + "rollaxis", + "swapaxes", + "transpose", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "expand_dims", + "squeeze", + "asarray", + "asanyarray", + "asmatrix", + "asfarray", + "asfortranarray", + "ascontiguousarray", + "asarray_chkfinite", + "asscalar", + "require", + "stack", + "block", + "vstack", + "hstack", + "dstack", + "column_stack", + "row_stack", + "split", + "array_split", + "dsplit", + "hsplit", + "vsplit", + "tile", + "repeat", + "insert", + "append", + "resize", + "trim_zeros", + "unique", + "flip", + "fliplr", + "flipud", + "reshape", + "roll", + "rot90", + "all", + "any", + "isfinite", + "isinf", + "isnan", + "isnat", + "isneginf", + "isposinf", + "iscomplex", + "iscomplexobj", + "isfortran", + "isreal", + "isrealobj", + "isscalar", + "logical_and", + "logical_or", + "logical_not", + "logical_xor", + "allclose", + "isclose", + "array_equal", + "array_equiv", + "greater", + "greater_equal", + "less", + "less_equal", + "equal", + "not_equal", + "empty_like", # https://numpy.org/doc/stable/reference/routines.array-creation.html + "ones_like", + "zeros_like", + "full_like", ) reduction_functions = ("sum", "linalg.norm") -""" numpy.testing """ +""" "testing", """ testing_functions = ("testing.assert_allclose", "testing.assert_array_equal") diff --git a/scico/random.py b/scico/random.py index 21f48fd0f..08ab8f472 100644 --- a/scico/random.py +++ b/scico/random.py @@ -59,7 +59,7 @@ import jax from scico.numpy import BlockArray -from scico.numpy._util import map_func_over_ba +from scico.numpy._util import map_func_over_tuple_of_tuples from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape @@ -121,11 +121,8 @@ def fun_alt(*args, key=None, seed=None, **kwargs): return fun_alt -_allow_block_shape = map_func_over_ba - - def _wrap(fun): - fun_wrapped = _add_seed(_allow_block_shape(fun)) + fun_wrapped = _add_seed(map_func_over_tuple_of_tuples(fun)) fun_wrapped.__module__ = __name__ # so it appears in docs return fun_wrapped diff --git a/scico/scipy/special.py b/scico/scipy/special.py index beec232f8..cc5e11749 100644 --- a/scico/scipy/special.py +++ b/scico/scipy/special.py @@ -15,68 +15,42 @@ documented here; please consult the documentation for the source module :mod:`jax.scipy.special`. """ - -""" -import sys - -import jax import jax.scipy.special as js -from scico.numpy import _block_array_ufunc_wrapper -from scico.numpy._util import _attach_wrapped_func, _not_implemented - -_ufunc_functions = [ - js.betainc, - js.entr, - js.erf, - js.erfc, - js.erfinv, - js.expit, - js.gammainc, - js.gammaincc, - js.gammaln, - js.i0, - js.i0e, - js.i1, - js.i1e, - js.log_ndtr, - js.logit, - js.logsumexp, - js.multigammaln, - js.ndtr, - js.ndtri, - js.polygamma, - js.sph_harm, - js.xlog1py, - js.xlogy, - js.zeta, -] +from scico.numpy import _util -_attach_wrapped_func( - _ufunc_functions, - _block_array_ufunc_wrapper, - module_name=sys.modules[__name__], - fix_mod_name=True, +_util.add_attributes( + vars(), + js.__dict__, ) -psi = _block_array_ufunc_wrapper(js.digamma) -digamma = _block_array_ufunc_wrapper(js.digamma) - -_not_implemented_functions = [] -for name, func in jax._src.util.get_module_functions(js).items(): - if name not in globals(): - _not_implemented_functions.append((name, func)) - -_attach_wrapped_func( - _not_implemented_functions, _not_implemented, module_name=sys.modules[__name__] +functions = ( + "betainc", + "entr", + "erf", + "erfc", + "erfinv", + "expit", + "gammainc", + "gammaincc", + "gammaln", + "i0", + "i0e", + "i1", + "i1e", + "log_ndtr", + "logit", + "logsumexp", + "multigammaln", + "ndtr", + "ndtri", + "polygamma", + "sph_harm", + "xlog1py", + "xlogy", + "zeta", + "digamma", ) -""" - -import jax.scipy.special as js -from scico.numpy._util import wrap_attributes -wrap_attributes( - vars(), - js.__dict__, -) +_util.wrap_recursively(vars(), functions, _util.map_func_over_blocks) diff --git a/scico/test/test_numpy.py b/scico/test/test_numpy.py index 62531e907..dc4b14be3 100644 --- a/scico/test/test_numpy.py +++ b/scico/test/test_numpy.py @@ -422,7 +422,7 @@ def test_create_zeros(): A = snp.zeros(2) assert np.all(A == 0) - A = snp.zeros([(2,), (2,)]) + A = snp.zeros(((2,), (2,))) assert all(snp.all(A == 0)) @@ -430,7 +430,7 @@ def test_create_ones(): A = snp.ones(2, dtype=np.float32) assert np.all(A == 1) - A = snp.ones([(2,), (2,)]) + A = snp.ones(((2,), (2,))) assert all(snp.all(A == 1)) @@ -438,7 +438,7 @@ def test_create_zeros(): A = snp.empty(2) assert np.all(A == 0) - A = snp.empty([(2,), (2,)]) + A = snp.empty(((2,), (2,))) assert all(snp.all(A == 0)) @@ -449,7 +449,7 @@ def test_create_full(): A = snp.full((2,), 1, dtype=np.float32) assert np.all(A == 1) - A = snp.full([(2,), (2,)], 1) + A = snp.full(((2,), (2,)), 1) assert all(snp.all(A == 1)) @@ -462,7 +462,7 @@ def test_create_zeros_like(): B = snp.zeros_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snp.ones([(2,), (2,)], dtype=np.float32) + A = snp.ones(((2,), (2,)), dtype=np.float32) B = snp.zeros_like(A) assert all(snp.all(B == 0)) assert A.shape == B.shape @@ -478,7 +478,7 @@ def test_create_empty_like(): B = snp.empty_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snp.ones([(2,), (2,)], dtype=np.float32) + A = snp.ones(((2,), (2,)), dtype=np.float32) B = snp.empty_like(A) assert all(snp.all(B == 0)) and A.shape == B.shape and A.dtype == B.dtype @@ -492,7 +492,7 @@ def test_create_ones_like(): B = snp.ones_like(A) assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype - A = snp.zeros([(2,), (2,)], dtype=np.float32) + A = snp.zeros(((2,), (2,)), dtype=np.float32) B = snp.ones_like(A) assert all(snp.all(B == 1)) and A.shape == B.shape and A.dtype == B.dtype @@ -506,6 +506,6 @@ def test_create_full_like(): B = snp.full_like(A, 1) assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) - A = snp.zeros([(2,), (2,)], dtype=np.float32) + A = snp.zeros(((2,), (2,)), dtype=np.float32) B = snp.full_like(A, 1) assert all(snp.all(B == 1)) and (A.shape == B.shape) and (A.dtype == B.dtype) diff --git a/scico/test/test_random.py b/scico/test/test_random.py index 23d2cf678..7b6b8f80b 100644 --- a/scico/test/test_random.py +++ b/scico/test/test_random.py @@ -69,24 +69,3 @@ def test_add_seed_adapter(): # error when key and seed are specified with pytest.raises(Exception): _ = fun_alt(key=jax.random.PRNGKey(0), seed=42)[0] - - -def test_block_shape_adapter(): - fun = jax.random.normal - fun_alt = scico.random._allow_block_shape(fun) - - # when shape is nested, result should be a BlockArray... - shape = ((7,), (3, 2), (2, 4, 1)) - seed = 42 - key = jax.random.PRNGKey(seed) - - result = fun_alt(key, shape) - assert isinstance(result, scico.numpy.BlockArray) - - # when shape is not nested, behavior should be normal - shape = (1,) - result_A = fun(key, shape) - result_B = fun_alt(key, shape) - np.testing.assert_array_equal(result_A, result_B) - - assert fun(key) == fun_alt(key) From 72f4d20388b60ad67e7ed8c88249a9303337f2f4 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 11:18:11 -0600 Subject: [PATCH 20/37] Stop tests ending on first failure --- .github/workflows/pytest.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 0122e0c70..b1e823d1a 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -82,11 +82,11 @@ jobs: - name: Run main unit tests with coverage report if: matrix.os == 'ubuntu-latest' run: | - pytest -x --cov=scico --cov-report=xml + pytest --cov=scico --cov-report=xml - name: Run main unit tests if: matrix.os != 'ubuntu-latest' run: | - pytest -x + pytest # Run doc tests - name: Run doc tests run: | From 0600cfa994c943beaf09258bf78b61e6b7e924ac Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 11:18:32 -0600 Subject: [PATCH 21/37] Make ray test less (not?) stochastic --- scico/test/test_ray_tune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/test/test_ray_tune.py b/scico/test/test_ray_tune.py index c712c3828..72036a143 100644 --- a/scico/test/test_ray_tune.py +++ b/scico/test/test_ray_tune.py @@ -22,7 +22,7 @@ def eval_params(config, reporter): tune.ray.tune.register_trainable("eval_func", eval_params) -config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} +config = {"x": tune.grid_search(-1, 1), "y": tune.grid_search(-1, 1)} resources = {"gpu": 0, "cpu": 1} From 8fe69302f4c1497da3f18646d13b76b8f0ff795c Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 11:26:24 -0600 Subject: [PATCH 22/37] Add matmul test --- scico/test/test_new_blockarray.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/scico/test/test_new_blockarray.py b/scico/test/test_new_blockarray.py index 27a421036..6f9971643 100644 --- a/scico/test/test_new_blockarray.py +++ b/scico/test/test_new_blockarray.py @@ -12,13 +12,13 @@ @pytest.fixture def x(): # any BlockArray, arbitrary shape, content, type - return BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0])) + return BlockArray([[[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0]]) @pytest.fixture def y(): # another BlockArray, content, type, matching shape - return BlockArray(([[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]], [-2.0])) + return BlockArray([[[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]], [-2.0]]) @pytest.mark.parametrize("op", [op.neg, op.pos, op.abs]) @@ -53,7 +53,14 @@ def test_elementwise_binary(op, x, y): assert actual.dtype == expected.dtype -# TODO: op.matmul +def test_matmul(x): + # x is ((2, 3), (1,)) + # y will be ((3, 1), (1, 2)) + y = BlockArray([[[1.0], [2.0], [3.0]], [[0.0, 1.0]]]) + actual = x @ y + expected = BlockArray([[[14.0], [0.0]], [0.0, 42.0]]) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype def test_property(): From fb76d0192f46dde592b4ab4523cfa18fe0a418d5 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 11:32:00 -0600 Subject: [PATCH 23/37] Handle CodeFactor --- scico/numpy/_util.py | 2 +- scico/numpy/blockarray.py | 11 +---------- scico/ray/tune.py | 2 +- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index 504e4d9f5..891683f7b 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -105,7 +105,7 @@ def mapped(*args, **kwargs): if isinstance(v, BlockArray): ba_args[k] = bound_args.arguments.pop(k) - if not len(ba_args): # no BlockArray arguments + if not ba_args: # no BlockArray arguments return func(*args, **kwargs) # no mapping num_blocks = len(list(ba_args.values())[0]) diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 4edaab4bc..aa43902e8 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -82,13 +82,6 @@ - -TODO: working with SCICO operators -TODO: indexing -TODO: -x doesn't work - - - Motivating Example ================== @@ -517,8 +510,6 @@ from .function_lists import binary_ops, unary_ops -# CANCELED: .sum(), etc. should call snp - class BlockArray(list): """BlockArray""" @@ -553,7 +544,7 @@ def _full_ravel(self) -> DeviceArray: @property def dtype(self): """Allow snp.zeros(x.shape, x.dtype) to work.""" - return self[0].dtype # TODO: a better solution is beyond current scope + return self[0].dtype def __getitem__(self, key): """Make, e.g., x[:2] return a BlockArray, not a list.""" diff --git a/scico/ray/tune.py b/scico/ray/tune.py index 726e6baac..4b174f9bc 100644 --- a/scico/ray/tune.py +++ b/scico/ray/tune.py @@ -18,7 +18,7 @@ import ray.tune except ImportError: raise ImportError("Could not import ray.tune; please install it.") -from ray.tune import loguniform, report, uniform # noqa +from ray.tune import grid_search, loguniform, report, uniform # noqa from ray.tune.progress_reporter import TuneReporterBase, _get_trials_by_state from ray.tune.schedulers import AsyncHyperBandScheduler from ray.tune.suggest.hyperopt import HyperOptSearch From ffe96ce2980748d36c02058e7c8f07d3f301d38d Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 11:49:30 -0600 Subject: [PATCH 24/37] Make ray test less (not?) stochastic --- scico/test/test_ray_tune.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/test/test_ray_tune.py b/scico/test/test_ray_tune.py index 72036a143..f2db6da6b 100644 --- a/scico/test/test_ray_tune.py +++ b/scico/test/test_ray_tune.py @@ -22,7 +22,7 @@ def eval_params(config, reporter): tune.ray.tune.register_trainable("eval_func", eval_params) -config = {"x": tune.grid_search(-1, 1), "y": tune.grid_search(-1, 1)} +config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)} resources = {"gpu": 0, "cpu": 1} @@ -32,7 +32,7 @@ def test_random(): "eval_func", metric="cost", mode="min", - num_samples=50, + num_samples=100, config=config, resources_per_trial=resources, hyperopt=False, From e017bdfade84f8799b9163b222a31ab4e3b5952f Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 14 Apr 2022 12:45:19 -0600 Subject: [PATCH 25/37] Fix doc tests --- .github/workflows/pytest.yml | 4 +- scico/__init__.py | 2 + scico/blockarray_old.py | 1410 ---------------------------------- scico/numpy/_create.py | 174 ----- scico/numpy/_util_old.py | 93 --- scico/numpy/blockarray.py | 357 +-------- scico/ray/tune.py | 2 +- 7 files changed, 28 insertions(+), 2014 deletions(-) delete mode 100644 scico/blockarray_old.py delete mode 100644 scico/numpy/_create.py delete mode 100644 scico/numpy/_util_old.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index b1e823d1a..0122e0c70 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -82,11 +82,11 @@ jobs: - name: Run main unit tests with coverage report if: matrix.os == 'ubuntu-latest' run: | - pytest --cov=scico --cov-report=xml + pytest -x --cov=scico --cov-report=xml - name: Run main unit tests if: matrix.os != 'ubuntu-latest' run: | - pytest + pytest -x # Run doc tests - name: Run doc tests run: | diff --git a/scico/__init__.py b/scico/__init__.py index 96ae86b69..b6dc10572 100644 --- a/scico/__init__.py +++ b/scico/__init__.py @@ -48,6 +48,8 @@ "custom_vjp", ] +from . import random, linop + # Imported items in __all__ appear to originate in top-level functional module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ diff --git a/scico/blockarray_old.py b/scico/blockarray_old.py deleted file mode 100644 index 21772f138..000000000 --- a/scico/blockarray_old.py +++ /dev/null @@ -1,1410 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SPORCO package. Details of the copyright -# and user license can be found in the 'LICENSE.txt' file distributed -# with the package. - -r"""Extensions of numpy ndarray class. - - .. testsetup:: - - >>> import scico - >>> import scico.numpy as snp - >>> from scico.numpy import BlockArray - >>> import numpy as np - >>> import jax.numpy - -The class :class:`.BlockArray` is a `jagged array -`_ that aims to mimic the -:class:`numpy.ndarray` interface where appropriate. - -A :class:`.BlockArray` object consists of a tuple of `DeviceArray` -objects that share their memory buffers with non-overlapping, contiguous -regions of a common one-dimensional `DeviceArray`. A :class:`.BlockArray` -contains the following size attributes: - -* `shape`: A tuple of tuples containing component dimensions. -* `size`: The sum of the size of each component block; this is the length - of the underlying one-dimensional `DeviceArray`. -* `num_blocks`: The number of components (blocks) that comprise the - :class:`.BlockArray`. - - -Motivating Example -================== - -Consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. - -We compute the discrete differences of :math:`\mb{x}` in the horizontal -and vertical directions, generating two new arrays: :math:`\mb{x}_h \in -\mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) -\times m}`. - -As these arrays are of different shapes, we cannot combine them into a -single `ndarray`. Instead, we might vectorize each array and concatenate -the resulting vectors, leading to :math:`\mb{\bar{x}} \in -\mathbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional -`ndarray`. Unfortunately, this makes it hard to access the individual -components :math:`\mb{x}_h` and :math:`\mb{x}_v`. - -Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B = -[\mb{x}_h, \mb{x}_v]` - - - :: - - >>> n = 32 - >>> m = 16 - >>> x_h, key = scico.random.randn((n, m-1)) - >>> x_v, _ = scico.random.randn((n-1, m), key=key) - - # Form the blockarray - >>> x_B = BlockArray.array([x_h, x_v]) - - # The blockarray shape is a tuple of tuples - >>> x_B.shape - ((32, 15), (31, 16)) - - # Each block component can be easily accessed - >>> x_B[0].shape - (32, 15) - >>> x_B[1].shape - (31, 16) - - -Constructing a BlockArray -========================= - -Construct from a tuple of arrays (either `ndarray` or `DeviceArray`) --------------------------------------------------------------------- - - .. doctest:: - - >>> from scico.numpy import BlockArray - >>> import numpy as np - >>> x0, key = scico.random.randn((32, 32)) - >>> x1, _ = scico.random.randn((16,), key=key) - >>> X = BlockArray.array((x0, x1)) - >>> X.shape - ((32, 32), (16,)) - >>> X.size - 1040 - >>> X.num_blocks - 2 - -While :func:`.BlockArray.array` will accept either `ndarray` or -`DeviceArray` as input, the resulting :class:`.BlockArray` will be backed -by a `DeviceArray` memory buffer. - -**Note**: constructing a :class:`.BlockArray` always involves a copy to -a new `DeviceArray` memory buffer. - -**Note**: by default, the resulting :class:`.BlockArray` is cast to -single precision and will have dtype `float32` or `complex64`. - - -Construct from a single vector and tuple of shapes --------------------------------------------------- - - :: - - >>> x_flat, _ = scico.random.randn((1040,)) - >>> shape_tuple = ((32, 32), (16,)) - >>> X = BlockArray.array_from_flattened(x_flat, shape_tuple=shape_tuple) - >>> X.shape - ((32, 32), (16,)) - - - -Operating on a BlockArray -========================= - -.. _blockarray_indexing: - -Indexing --------- - -The block index is required to be an integer, selecting a single block and -returning it as an array (*not* a singleton BlockArray). If the index -expression has more than one component, then the initial index indexes the -block, and the remainder of the indexing expression indexes within the -selected block, e.g. ``x[2, 3:4]`` is equivalent to ``y[3:4]`` after -setting ``y = x[2]``. - - -Indexed Updating ----------------- - -BlockArrays support the JAX DeviceArray `indexed update syntax -`_ - - -The index must be of the form [ibk] or [ibk, idx], where `ibk` is the -index of the block to be updated, and `idx` is a general index of the -elements to be updated in that block. In particular, `ibk` cannot be a -`slice`. The general index `idx` can be omitted, in which case an entire -block is updated. - - -============================== ============================================== -Alternate syntax Equivalent in-place expression -============================== ============================================== -``x.at[ibk, idx].set(y)`` ``x[ibk, idx] = y`` -``x.at[ibk, idx].add(y)`` ``x[ibk, idx] += y`` -``x.at[ibk, idx].multiply(y)`` ``x[ibk, idx] *= y`` -``x.at[ibk, idx].divide(y)`` ``x[ibk, idx] /= y`` -``x.at[ibk, idx].power(y)`` ``x[ibk, idx] **= y`` -``x.at[ibk, idx].min(y)`` ``x[ibk, idx] = np.minimum(x[idx], y)`` -``x.at[ibk, idx].max(y)`` ``x[ibk, idx] = np.maximum(x[idx], y)`` -============================== ============================================== - - -Arithmetic and Broadcasting ---------------------------- - -Suppose :math:`\mb{x}` is a BlockArray with shape :math:`((n, n), (m,))`. - - :: - - >>> x1, key = scico.random.randn((4, 4)) - >>> x2, _ = scico.random.randn((5,), key=key) - >>> x = BlockArray.array( (x1, x2) ) - >>> x.shape - ((4, 4), (5,)) - >>> x.num_blocks - 2 - >>> x.size # 4*4 + 5 - 21 - -Illustrated for the operation ``+``, but equally valid for operators -``+, -, *, /, //, **, <, <=, >, >=, ==`` - - -Operations with BlockArrays with same number of blocks -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Let :math:`\mb{y}` be a BlockArray with the same number of blocks as -:math:`\mb{x}`. - - .. math:: - \mb{x} + \mb{y} - = - \begin{bmatrix} - \mb{x}[0] + \mb{y}[0] \\ - \mb{x}[1] + \mb{y}[1] \\ - \end{bmatrix} - -This operation depends on pair of blocks from :math:`\mb{x}` and -:math:`\mb{y}` being broadcastable against each other. - - - -Operations with a scalar -^^^^^^^^^^^^^^^^^^^^^^^^ - -The scalar is added to each element of the :class:`.BlockArray`: - - .. math:: - \mb{x} + 1 - = - \begin{bmatrix} - \mb{x}[0] + 1 \\ - \mb{x}[1] + 1\\ - \end{bmatrix} - - - :: - - >>> y = x + 1 - >>> np.testing.assert_allclose(y[0], x[0] + 1) - >>> np.testing.assert_allclose(y[1], x[1] + 1) - - - -Operations with a 1D `ndarray` of size equal to `num_blocks` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The *i*\th scalar is added to the *i*\th block of the -:class:`.BlockArray`: - - .. math:: - \mb{x} - + - \begin{bmatrix} - 1 \\ - 2 - \end{bmatrix} - = - \begin{bmatrix} - \mb{x}[0] + 1 \\ - \mb{x}[1] + 2\\ - \end{bmatrix} - - - :: - - >>> y = x + np.array([1, 2]) - >>> np.testing.assert_allclose(y[0], x[0] + 1) - >>> np.testing.assert_allclose(y[1], x[1] + 2) - - -Operations with an ndarray of `size` equal to :class:`.BlockArray` size -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We first cast the `ndarray` to a BlockArray with same shape as -:math:`\mb{x}`, then apply the operation on the resulting BlockArrays. -With ``y.size = x.size``, we have: - - .. math:: - \mb{x} - + - \mb{y} - = - \begin{bmatrix} - \mb{x}[0] + \mb{y}[0] \\ - \mb{x}[1] + \mb{y}[1]\\ - \end{bmatrix} - -Equivalently, the BlockArray is first flattened, then added to the -flattened `ndarray`, and the result is reformed into a BlockArray with -the same shape as :math:`\mb{x}` - - - -MatMul ------- - -Between two BlockArrays -^^^^^^^^^^^^^^^^^^^^^^^ - -The matmul is computed between each block of the two BlockArrays. - -The BlockArrays must have the same number of blocks, and each pair of -blocks must be broadcastable. - - .. math:: - \mb{x} @ \mb{y} - = - \begin{bmatrix} - \mb{x}[0] @ \mb{y}[0] \\ - \mb{x}[1] @ \mb{y}[1]\\ - \end{bmatrix} - - - -Between BlockArray and Ndarray/DeviceArray -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This operation is not defined. - - -Between BlockArray and :class:`.LinearOperator` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :class:`.Operator` and :class:`.LinearOperator` classes are designed -to work on :class:`.BlockArray`\ s in addition to `DeviceArray`\ s. -For example - - - :: - - >>> x, key = scico.random.randn((3, 4)) - >>> A_1 = scico.linop.Identity(x.shape) - >>> A_1.shape # array -> array - ((3, 4), (3, 4)) - - >>> A_2 = scico.linop.FiniteDifference(x.shape) - >>> A_2.shape # array -> BlockArray - (((2, 4), (3, 3)), (3, 4)) - - >>> diag = BlockArray.array([np.array(1.0), np.array(2.0)]) - >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape)) - >>> A_3.shape # BlockArray -> BlockArray - (((2, 4), (3, 3)), ((2, 4), (3, 3))) - - -NumPy ufuncs ------------- - -`NumPy universal functions (ufuncs) `_ -are functions that operate on an `ndarray` on an element-by-element -fashion and support array broadcasting. Examples of ufuncs are ``abs``, -``sign``, ``conj``, and ``exp``. - -The JAX library implements most NumPy ufuncs in the :mod:`jax.numpy` -module. However, as JAX does not support subclassing of `DeviceArray`, -the JAX ufuncs cannot be used on :class:`.BlockArray`. As a workaround, -we have wrapped several JAX ufuncs for use on :class:`.BlockArray`; these -are defined in the :mod:`scico.numpy` module. - - -Reductions -^^^^^^^^^^ - -Reductions are functions that take an array-like as an input and return -an array of lower dimension. Examples include ``mean``, ``sum``, ``norm``. -BlockArray reductions are located in the :mod:`scico.numpy` module - -:class:`.BlockArray` tries to mirror `ndarray` reduction semantics where -possible, but cannot provide a one-to-one match as the block components -may be of different size. - -Consider the example BlockArray - - .. math:: - \mb{x} = \begin{bmatrix} - \begin{bmatrix} - 1 & 1 \\ - 1 & 1 - \end{bmatrix} \\ - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} - \end{bmatrix}. - -We have - - .. doctest:: - - >>> import scico.numpy as snp - >>> x = BlockArray.array((np.ones((2,2)), 2*np.ones((2)))) - >>> x.shape - ((2, 2), (2,)) - >>> x.size - 6 - >>> x.num_blocks - 2 - - - - If no axis is specified, the reduction is applied to the flattened - array: - - .. doctest:: - - >>> snp.sum(x, axis=None).item() - 8.0 - - Reducing along the 0-th axis crushes the `BlockArray` down into a - single `DeviceArray` and requires all blocks to have the same shape - otherwise, an error is raised. - - .. doctest:: - - >>> snp.sum(x, axis=0) - Traceback (most recent call last): - ValueError: Evaluating sum of BlockArray along axis=0 requires all blocks to be same shape; got ((2, 2), (2,)) - - >>> y = BlockArray.array((np.ones((2,2)), 2*np.ones((2, 2)))) - >>> snp.sum(y, axis=0) - DeviceArray([[3., 3.], - [3., 3.]], dtype=float32) - - Reducing along axis :math:`n` is equivalent to reducing each component - along axis :math:`n-1`: - - .. math:: - \text{sum}(x, axis=1) = \begin{bmatrix} - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} \\ - \begin{bmatrix} - 4 \\ - \end{bmatrix} - \end{bmatrix} - - - If a component does not have axis :math:`n-1`, the reduction is not - applied to that component. In this example, ``x[1].ndim == 1``, so no - reduction is applied to block ``x[1]``. - - .. math:: - \text{sum}(x, axis=2) = \begin{bmatrix} - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} \\ - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} - \end{bmatrix} - - -Code version - - .. doctest:: - - >>> snp.sum(x, axis=1) # doctest: +SKIP - BlockArray([[2, 2], - [4,] ]) - - >>> snp.sum(x, axis=2) # doctest: +SKIP - BlockArray([ [2, 2], - [2,] ]) - - -""" - -from __future__ import annotations - -from functools import wraps -from typing import Iterator, List, Optional, Tuple, Union - -import numpy as np - -import jax -import jax.numpy as jnp -from jax import core -from jax.interpreters import xla -from jax.interpreters.xla import DeviceArray -from jax.tree_util import register_pytree_node, tree_flatten - -from jaxlib.xla_extension import Buffer - -from scico import array -from scico.typing import AxisIndex, BlockShape, DType, JaxArray, Shape - -_arraylikes = (Buffer, DeviceArray, np.ndarray) - - -def atleast_1d(*arys): - """Convert inputs to arrays with at least one dimension. - - A wrapper for :func:`jax.numpy.atleast_1d` that acts as usual on - ndarrays and DeviceArrays, and returns BlockArrays unmodified. - """ - - if len(arys) == 1: - arr = arys[0] - return arr if isinstance(arr, BlockArray) else jnp.atleast_1d(arr) - - out = [] - for arr in arys: - if isinstance(arr, BlockArray): - out.append(arr) - else: - out.append(jnp.atleast_1d(arr)) - return out - - -# Append docstring from original jax.numpy function -atleast_1d.__doc__ = ( - atleast_1d.__doc__.replace("\n ", "\n") # deal with indentation differences - + "\nDocstring for :func:`jax.numpy.atleast_1d`:\n\n" - + "\n".join(jax.numpy.atleast_1d.__doc__.split("\n")[2:]) -) - - -def reshape( - a: Union[JaxArray, BlockArray], newshape: Union[Shape, BlockShape] -) -> Union[JaxArray, BlockArray]: - """Change the shape of an array without changing its data. - - Args: - a: Array to be reshaped. - newshape: The new shape should be compatible with the original - shape. If an integer, then the result will be a 1-D array of - that length. One shape dimension can be -1. In this case, - the value is inferred from the length of the array and - remaining dimensions. If a tuple of tuple of ints, a - :class:`.BlockArray` is returned. - - Returns: - The reshaped array. Unlike :func:`numpy.reshape`, a copy is - always returned. - """ - - if array.is_nested(newshape): - # x is a blockarray - return BlockArray.array_from_flattened(a, newshape) - - return jnp.reshape(a, newshape) - - -def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple: - """Decompose a BlockArray indexing expression into components. - - Decompose a BlockArray indexing expression into block and array - components. - - Args: - idx: BlockArray indexing expression. - - Returns: - A tuple (idxblk, idxarr) with entries corresponding to the - integer block index and the indexing to be applied to the - selected block, respectively. The latter is ``None`` if the - indexing expression simply selects one of the blocks (i.e. - it consists of a single integer). - - Raises: - TypeError: If the block index is not an integer. - """ - if isinstance(idx, tuple): - idxblk = idx[0] - idxarr = idx[1:] - else: - idxblk = idx - idxarr = None - if not isinstance(idxblk, int): - raise TypeError("Block index must be an integer") - return idxblk, idxarr - - -def indexed_shape(shape: Shape, idx: Union[int, Tuple[AxisIndex, ...]]) -> Tuple[int, ...]: - """Determine the shape of the result of indexing a BlockArray. - - Args: - shape: Shape of BlockArray. - idx: BlockArray indexing expression. - - Returns: - Shape of the selected block, or slice of that block if ``idx`` is a tuple - rather than an integer. - """ - idxblk, idxarr = _decompose_index(idx) - if idxblk < 0: - idxblk = len(shape) + idxblk - if idxarr is None: - return shape[idxblk] - return array.indexed_shape(shape[idxblk], idxarr) - - -def _flatten_blockarrays(inp, *args, **kwargs): - """Flatten any blockarrays present in inp, args, or kwargs.""" - - def _flatten_if_blockarray(inp): - if isinstance(inp, BlockArray): - return inp._data - return inp - - inp_ = _flatten_if_blockarray(inp) - args_ = (_flatten_if_blockarray(_) for _ in args) - kwargs_ = {key: _flatten_if_blockarray(val) for key, val in kwargs.items()} - return inp_, args_, kwargs_ - - -def _block_array_ufunc_wrapper(func): - """Wrap a "ufunc" to allow for joint operation on `DeviceArray` and `BlockArray`.""" - - @wraps(func) - def wrapper(inp, *args, **kwargs): - all_args = (inp,) + args + tuple(kwargs.items()) - if any([isinstance(_, BlockArray) for _ in all_args]): - # If 'inp' is a BlockArray, call func on inp._data - # Then return a BlockArray of the same shape as inp - - inp_, args_, kwargs_ = _flatten_blockarrays(inp, *args, **kwargs) - flat_out = func(inp_, *args_, **kwargs_) - return BlockArray.array_from_flattened(flat_out, inp.shape) - - # Otherwise call the function normally - return func(inp, *args, **kwargs) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_reduction_wrapper(func): - """Wrap a reduction (eg. sum, norm) to allow for joint operation on - `DeviceArray` and `BlockArray`.""" - - @wraps(func) - def wrapper(inp, *args, axis=None, **kwargs): - - all_args = (inp,) + args + tuple(kwargs.items()) - if any([isinstance(_, BlockArray) for _ in all_args]): - if axis is None: - # Treat as a single long vector - inp_, args_, kwargs_ = _flatten_blockarrays(inp, *args, **kwargs) - return func(inp_, *args_, **kwargs_) - - if type(axis) == tuple: - raise Exception( - f"""Evaluating {func.__name__} on a BlockArray with a tuple argument to - axis is not currently supported""" - ) - - if axis == 0: # reduction along block axis - # reduction along axis=0 only makes sense if all blocks are the same shape - # so we can convert to a standard DeviceArray of shape (inp.num_blocks, ...) - # and reduce along axis = 0 - if all([bk_shape == inp.shape[0] for bk_shape in inp.shape]): - view_shape = (inp.num_blocks,) + inp.shape[0] - return func(inp._data.reshape(view_shape), *args, axis=0, **kwargs) - - raise ValueError( - f"Evaluating {func.__name__} of BlockArray along axis=0 requires " - f"all blocks to be same shape; got {inp.shape}" - ) - - # Reduce each block individually along axis-1 - out = [] - for bk in inp: - if isinstance(bk, BlockArray): - # This block is itself a blockarray, so call this wrapped reduction - # on axis-1 - tmp = _block_array_reduction_wrapper(func)(bk, *args, axis=axis - 1, **kwargs) - else: - if axis - 1 >= bk.ndim: - # Trying to reduce along a dim that doesn't exist for this block, - # so just return the block. - # ie broadcast to shape (..., 1) and reduce along axis=-1 - tmp = bk - else: - tmp = func(bk, *args, axis=axis - 1, **kwargs) - out.append(atleast_1d(tmp)) - return BlockArray.array(out) - - if axis is None: - # 'axis' might not be a valid kwarg (eg dot, vdot), so don't pass it - return func(inp, *args, **kwargs) - - return func(inp, *args, axis=axis, **kwargs) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_matmul_wrapper(func): - @wraps(func) - def wrapper(self, other): - if isinstance(self, BlockArray): - if isinstance(other, BlockArray): - # Both blockarrays, work block by block - return BlockArray.array([func(x, y) for x, y in zip(self, other)]) - raise TypeError( - f"Operation {func.__name__} not implemented between {type(self)} and {type(other)}" - ) - return func(self, other) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_binary_op_wrapper(func): - """Return a decorator that performs type and shape checking for - :class:`.BlockArray` arithmetic. - """ - - @wraps(func) - def wrapper(self, other): - if isinstance(other, BlockArray): - if other.shape == self.shape: - # Same shape blocks, can operate on flattened arrays - return BlockArray.array_from_flattened(func(self._data, other._data), self.shape) - if other.num_blocks == self.num_blocks: - # Will work as long as the shapes are broadcastable - return BlockArray.array([func(x, y) for x, y in zip(self, other)]) - raise ValueError( - f"operation not valid on operands with shapes {self.shape} {other.shape}" - ) - if any([isinstance(other, _) for _ in _arraylikes]): - if other.size == 1: - # Same as operating on a scalar - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - if other.size == self.size: - # A little fast and loose, treat the block array as a length self.size vector - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - if other.size == self.num_blocks: - return BlockArray.array([func(blk, other_) for blk, other_ in zip(self, other)]) - raise ValueError( - f"operation not valid on operands with shapes {self.shape} {other.shape}" - ) - if jnp.isscalar(other) or isinstance(other, core.Tracer): - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - raise TypeError( - f"Operation {func.__name__} not implemented between {type(self)} and {type(other)}" - ) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -class _AbstractBlockArray(core.ShapedArray): - """Abstract BlockArray class for JAX tracing. - - See https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html - """ - - array_abstraction_level = 0 # Same as jax.core.ConcreteArray - - def __init__(self, shapes, dtype): - - sizes = block_sizes(shapes) - size = np.sum(sizes) - - super(_AbstractBlockArray, self).__init__((size,), dtype) - - #: Abstract data value - self._data_aval: core.ShapedArray = core.ShapedArray((size,), dtype) - - #: Array dtype - self.dtype: DType = dtype - - #: Shape of each block - self.shapes: BlockShape = shapes - - #: Size of each block - self.sizes: Shape = sizes - - #: Array specifying boundaries of components as indices in base array - self.bndpos: np.ndarray = np.r_[0, np.cumsum(sizes)] - - -# The Jax class is heavily inspired by SparseArray/AbstractSparseArray here: -# https://github.com/google/jax/blob/7724322d1c08c13008815bfb52759a29c2a6823b/tests/custom_object_test.py -class BlockArray: - """A tuple of :class:`jax.interpreters.xla.DeviceArray` objects. - - A tuple of `DeviceArray` objects that all share their memory buffers - with non-overlapping, contiguous regions of a common one-dimensional - `DeviceArray`. It can be used as the common one-dimensional array via - the :func:`BlockArray.ravel` method, or individual component arrays - can be accessed individually. - """ - - # Ensure we use BlockArray.__radd__, __rmul__, etc for binary operations of the form - # op(np.ndarray, BlockArray) - # See https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ - __array_priority__ = 1 - - def __init__(self, aval: _AbstractBlockArray, data: JaxArray): - """BlockArray init method. - - Args: - aval: `Abstract value`_ associated with this array (shape+dtype+weak_type) - data: The underlying contiguous, flattened `DeviceArray`. - - .. _Abstract value: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html - - """ - self._aval = aval - self._data = data - - def __repr__(self): - return "scico.numpy.BlockArray: \n" + self._data.__repr__() - - def __getitem__(self, idx: Union[int, Tuple[AxisIndex, ...]]) -> JaxArray: - idxblk, idxarr = _decompose_index(idx) - if idxblk < 0: - idxblk = self.num_blocks + idxblk - blk = reshape(self._data[self.bndpos[idxblk] : self.bndpos[idxblk + 1]], self.shape[idxblk]) - if idxarr is not None: - blk = blk[idxarr] - return blk - - @_block_array_matmul_wrapper - def __matmul__(self, other: Union[np.ndarray, BlockArray, JaxArray]) -> JaxArray: - return self @ other - - @_block_array_matmul_wrapper - def __rmatmul__(self, other: Union[np.ndarray, BlockArray, JaxArray]) -> JaxArray: - return other @ self - - @_block_array_binary_op_wrapper - def __sub__(a, b): - return a - b - - @_block_array_binary_op_wrapper - def __rsub__(a, b): - return b - a - - @_block_array_binary_op_wrapper - def __mul__(a, b): - return a * b - - @_block_array_binary_op_wrapper - def __rmul__(a, b): - return a * b - - @_block_array_binary_op_wrapper - def __add__(a, b): - return a + b - - @_block_array_binary_op_wrapper - def __radd__(a, b): - return a + b - - @_block_array_binary_op_wrapper - def __truediv__(a, b): - return a / b - - @_block_array_binary_op_wrapper - def __rtruediv__(a, b): - return b / a - - @_block_array_binary_op_wrapper - def __floordiv__(a, b): - return a // b - - @_block_array_binary_op_wrapper - def __rfloordiv__(a, b): - return b // a - - @_block_array_binary_op_wrapper - def __pow__(a, b): - return a**b - - @_block_array_binary_op_wrapper - def __rpow__(a, b): - return b**a - - @_block_array_binary_op_wrapper - def __gt__(a, b): - return a > b - - @_block_array_binary_op_wrapper - def __ge__(a, b): - return a >= b - - @_block_array_binary_op_wrapper - def __lt__(a, b): - return a < b - - @_block_array_binary_op_wrapper - def __le__(a, b): - return a <= b - - @_block_array_binary_op_wrapper - def __eq__(a, b): - return a == b - - @_block_array_binary_op_wrapper - def __ne__(a, b): - return a != b - - def __iter__(self) -> Iterator[int]: - for i in range(self.num_blocks): - yield self[i] - - @property - def blocks(self) -> Iterator[int]: - """Return an iterator yielding component blocks.""" - return self.__iter__() - - @property - def bndpos(self) -> np.ndarray: - """Array specifying boundaries of components as indices in base array.""" - return self._aval.bndpos - - @property - def dtype(self) -> DType: - """Array dtype.""" - return self._data.dtype - - @property - def device_buffer(self) -> Buffer: - """The :class:`jaxlib.xla_extension.Buffer` that backs the - underlying data array.""" - return self._data.device_buffer - - @property - def size(self) -> int: - """Total number of elements in the array.""" - return self._aval.size - - @property - def num_blocks(self) -> int: - """Number of :class:`.BlockArray` components.""" - - return len(self.shape) - - @property - def ndim(self) -> Shape: - """Tuple of component ndims.""" - - return tuple(len(c) for c in self.shape) - - @property - def shape(self) -> BlockShape: - """Tuple of component shapes.""" - - return self._aval.shapes - - @property - def split(self) -> Tuple[JaxArray, ...]: - """Tuple of component arrays.""" - - return tuple(self[k] for k in range(self.num_blocks)) - - def conj(self) -> BlockArray: - """Return a :class:`.BlockArray` with complex-conjugated elements.""" - - # Much faster than BlockArray.array([_.conj() for _ in self.blocks]) - return BlockArray.array_from_flattened(self.ravel().conj(), self.shape) - - @property - def real(self) -> BlockArray: - """Return a :class:`.BlockArray` with the real part of this array.""" - return BlockArray.array_from_flattened(self.ravel().real, self.shape) - - @property - def imag(self) -> BlockArray: - """Return a :class:`.BlockArray` with the imaginary part of this array.""" - return BlockArray.array_from_flattened(self.ravel().imag, self.shape) - - @classmethod - def array( - cls, alst: List[Union[np.ndarray, JaxArray]], dtype: Optional[np.dtype] = None - ) -> BlockArray: - """Construct a :class:`.BlockArray` from a list or tuple of existing array-like. - - Args: - alst: Initializers for array components. - Can be :class:`numpy.ndarray` or - :class:`jax.interpreters.xla.DeviceArray` - dtype: Data type of array. If none, dtype is derived from - dtype of initializers - - Returns: - :class:`.BlockArray` initialized from `alst` tuple - """ - - if isinstance(alst, (tuple, list)) is False: - raise TypeError("Input to `array` must be a list or tuple of existing arrays") - - if dtype is None: - present_types = jax.tree_flatten(jax.tree_map(lambda x: x.dtype, alst))[0] - dtype = np.find_common_type(present_types, []) - - # alst can be a list/tuple of arrays, or a list/tuple containing list/tuples of arrays - # consider alst to be a tree where leaves are arrays (possibly abstract arrays) - # use tree_map to find the shape of each leaf - # `shapes` will be a tuple of ints and tuples containing ints (possibly nested further) - - # ensure any scalar leaves are converted to (1,) arrays - def shape_atleast_1d(x): - return x.shape if x.shape != () else (1,) - - shapes = tuple( - jax.tree_map(shape_atleast_1d, alst, is_leaf=lambda x: not isinstance(x, (list, tuple))) - ) - - _aval = _AbstractBlockArray(shapes, dtype) - data_ravel = jnp.hstack(jax.tree_map(lambda x: x.ravel(), jax.tree_flatten(alst)[0])) - return cls(_aval, data_ravel) - - @classmethod - def array_from_flattened( - cls, data_ravel: Union[np.ndarray, JaxArray], shape_tuple: BlockShape - ) -> BlockArray: - """Construct a :class:`.BlockArray` from a flattened array and tuple of shapes. - - Args: - data_ravel: Flattened data array - shape_tuple: Tuple of tuples containing desired block shapes. - - Returns: - :class:`.BlockArray` initialized from `data_ravel` and `shape_tuple` - """ - - if not isinstance(data_ravel, DeviceArray): - data_ravel = jax.device_put(data_ravel) - - shape_tuple_size = np.sum(block_sizes(shape_tuple)) - - if shape_tuple_size != data_ravel.size: - raise ValueError( - f"""The specified shape_tuple is incompatible with provided data_ravel - shape_tuple = {shape_tuple} - shape_tuple_size = {shape_tuple_size} - len(data_ravel) = {len(data_ravel)} - """ - ) - - _aval = _AbstractBlockArray(shape_tuple, dtype=data_ravel.dtype) - return cls(_aval, data_ravel) - - @classmethod - def ones(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with ones. - - Args: - shape_tuple: Tuple of shapes for component blocks - dtype: Desired data-type for the :class:`.BlockArray`. - Default is `numpy.float32`. - - Returns: - :class:`.BlockArray` of ones with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.ones(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def zeros(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. - - Args: - shape_tuple: Tuple of shapes for component blocks. - dtype: Desired data-type for the :class:`.BlockArray`. - Default is `numpy.float32`. - - Returns: - :class:`.BlockArray` of zeros with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.zeros(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def empty(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. - - Note: like :func:`jax.numpy.empty`, this does not return an - uninitalized array. - - Args: - shape_tuple: Tuple of shapes for component blocks - dtype: Desired data-type for the :class:`.BlockArray`. - Default is `numpy.float32`. - - Returns: - :class:`.BlockArray` of zeros with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.empty(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def full( - cls, - shape_tuple: BlockShape, - fill_value: Union[float, complex, int], - dtype: DType = np.float32, - ) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with - `fill_value`. - - Args: - shape_tuple: Tuple of shapes for component blocks. - fill_value: Fill value - dtype: Desired data-type for the BlockArray. The default, - None, means `np.array(fill_value).dtype`. - - Returns: - :class:`.BlockArray` with the given component shapes and - dtype and all entries equal to `fill_value`. - """ - if dtype is None: - dtype = np.asarray(fill_value).dtype - - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.full(_aval.size, fill_value=fill_value, dtype=dtype) - return cls(_aval, data_ravel) - - def copy(self) -> BlockArray: - """Return a copy of this :class:`.BlockArray`. - - This method is not implemented for BlockArray. - - See Also: - :meth:`.to_numpy`: Convert a :class:`.BlockArray` into a - flattened NumPy array. - """ - # jax DeviceArray copies return a NumPy ndarray. This blockarray class must be backed - # by a DeviceArray, so cannot be converted to a NumPy-backed BlockArray. The BlockArray - # .to_numpy() method returns a flattened ndarray. - # - # This method may be implemented in the future if jax DeviceArray .copy() is modified to - # return another DeviceArray. - raise NotImplementedError - - def to_numpy(self) -> np.ndarray: - """Return a :class:`numpy.ndarray` containing the flattened form of this - :class:`.BlockArray`.""" - - if isinstance(self._data, DeviceArray): - host_arr = jax.device_get(self._data.copy()) - else: - host_arr = self._data.copy() - return host_arr - - def blockidx(self, idx: int) -> jax._src.ops.scatter._Indexable: - """Return :class:`jax.ops.index` for a given component block. - - Args: - idx: Desired block index. - - Returns: - :class:`jax.ops.index` pointing to desired block. - """ - return slice(self.bndpos[idx], self.bndpos[idx + 1]) - - def ravel(self) -> JaxArray: - """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. - - Note that a copy, rather than a view, of the underlying array is - returned. This is consistent with :func:`jax.numpy.ravel`. - - Returns: - Copy of underlying flattened array. - - """ - return self._data[:] - - def flatten(self) -> JaxArray: - """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. - - Note that a copy, rather than a view, of the underlying array is - returned. This is consistent with :func:`jax.numpy.ravel`. - - Returns: - Copy of underlying flattened array. - - """ - return self._data[:] - - def sum(self, axis=None, keepdims=False): - """Return the sum of the blockarray elements over the given axis. - - Refer to :func:`scico.numpy.sum` for full documentation. - """ - # Can't just call scico.numpy.sum due to pesky circular import... - return _block_array_reduction_wrapper(jnp.sum)(self, axis=axis, keepdims=keepdims) - - -## Register BlockArray as a Jax type -# Our BlockArray is just a single large vector with some extra sugar -class _ConcreteBlockArray(_AbstractBlockArray): - pass - - -def _block_array_result_handler(device, _aval): - def build_block_array(data_buf): - data = xla.DeviceArray(_aval._data_aval, device, None, data_buf) - return BlockArray(_aval, data) - - return build_block_array - - -def _block_array_shape_handler(a): - return (xla.xc.Shape.array_shape(a._data_aval.dtype, a._data_aval.shape),) - - -def _block_array_device_put_handler(a, device): - return (xla.xb.get_device_backend(device).buffer_from_pyval(a._data, device),) - - -core.pytype_aval_mappings[BlockArray] = lambda x: x._aval -core.raise_to_shaped_mappings[_AbstractBlockArray] = lambda _aval, _: _aval -xla.pytype_aval_mappings[BlockArray] = lambda x: x._aval -xla.canonicalize_dtype_handlers[BlockArray] = lambda x: x -jax._src.dispatch.device_put_handlers[BlockArray] = _block_array_device_put_handler -jax._src.dispatch.result_handlers[_AbstractBlockArray] = _block_array_result_handler -xla.xla_shape_handlers[_AbstractBlockArray] = _block_array_shape_handler - - -## Handlers to use jax.device_put on BlockArray -def _block_array_tree_flatten(block_arr): - """Flatten a :class:`.BlockArray` pytree. - - See :func:`jax.tree_util.tree_flatten`. - - Args: - block_arr (:class:`.BlockArray`): :class:`.BlockArray` to flatten - - Returns: - children (tuple): :class:`.BlockArray` leaves. - aux_data (tuple): Extra metadata used to reconstruct BlockArray. - """ - - data_children, data_aux_data = tree_flatten(block_arr._data) - return (data_children, block_arr._aval) - - -def _block_array_tree_unflatten(aux_data, children): - """Construct a :class:`.BlockArray` from a flattened pytree. - - See jax.tree_utils.tree_unflatten - - Args: - aux_data (tuple): Metadata needed to construct block array. - children (tuple): Contains block array elements. - - Returns: - block_arr: Constructed :class:`.BlockArray`. - """ - return BlockArray(aux_data, children[0]) - - -register_pytree_node(BlockArray, _block_array_tree_flatten, _block_array_tree_unflatten) - -# Syntactic sugar for the .at operations -# see https://github.com/google/jax/blob/56e9f7cb92e3a099adaaca161cc14153f024047c/jax/_src/numpy/lax_numpy.py#L5900 -class _BlockArrayIndexUpdateHelper: - """The helper class for the `at` property to call indexed update functions. - - The `at` property is syntactic sugar for calling the indexed update - functions as is done in jax. The index must be of the form [ibk] or - [ibk,idx], where `ibk` is the index of the block to be updated, and - `idx` is a general index of the elements to be updated in that block. - - In particular: - - ``x = x.at[ibk].set(y)`` is an equivalent of ``x[ibk] = y``. - - ``x = x.at[ibk,idx].set(y)`` is an equivalent of ``x[ibk,idx] = y``. - - The methods ``set, add, multiply, divide, power, maximum, minimum`` - are supported. - """ - - __slots__ = ("_block_array",) - - def __init__(self, block_array): - self._block_array = block_array - - def __getitem__(self, index): - if isinstance(index, tuple): - if isinstance(index[0], slice): - raise TypeError(f"Slicing not supported along block index") - return _BlockArrayIndexUpdateRef(self._block_array, index) - - def __repr__(self): - print(f"_BlockArrayIndexUpdateHelper({repr(self._block_array)})") - - -class _BlockArrayIndexUpdateRef: - """Helper object to call indexed update functions for an (advanced) index. - - This object references a source block array and a specific indexer, - with the first integer specifying the block being updated, and rest - being the indexer into the array of that block. Methods on this - object return copies of the source block array that have been - modified at the positions specified by the indexer in the given block. - """ - - __slots__ = ("_block_array", "bk_index", "index") - - def __init__(self, block_array, index): - self._block_array = block_array - if isinstance(index, int): - self.bk_index = index - self.index = Ellipsis - elif index == Ellipsis: - self.bk_index = Ellipsis - self.index = Ellipsis - else: - self.bk_index = index[0] - self.index = index[1:] - - def __repr__(self): - return f"_BlockArrayIndexUpdateRef({repr(self._block_array)}, {repr(self.bk_index)}, {repr(self.index)})" - - def _index_wrapper(self, func_str, values): - bk_index = self.bk_index - index = self.index - arr_tuple = self._block_array.split - if bk_index == Ellipsis: - # This may result in multiple copies: one per sub-blockarray, - # then one to combine into a nested BA. - retval = BlockArray.array([getattr(_.at[index], func_str)(values) for _ in arr_tuple]) - else: - retval = BlockArray.array( - arr_tuple[:bk_index] - + (getattr(arr_tuple[bk_index].at[index], func_str)(values),) - + arr_tuple[bk_index + 1 :] - ) - return retval - - def set(self, values): - """Pure equivalent of ``x[idx] = y``. - - Return the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] = y``. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("set", values) - - def add(self, values): - """Pure equivalent of ``x[idx] += y``. - - Return the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] += y``. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("add", values) - - def multiply(self, values): - """Pure equivalent of ``x[idx] *= y``. - - Return the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] *= y``. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("multiply", values) - - def divide(self, values): - """Pure equivalent of ``x[idx] /= y``. - - Return the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] /= y``. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("divide", values) - - def power(self, values): - """Pure equivalent of ``x[idx] **= y``. - - Return the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] **= y``. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("power", values) - - def min(self, values): - """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. - - Return the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] = minimum(x[idx], y)``. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("min", values) - - def max(self, values): - """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. - - Return the value of ``x`` that would result from the NumPy-style - :mod:indexed assignment ` ``x[idx] = maximum(x[idx], y)``. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("max", values) - - -setattr(BlockArray, "at", property(_BlockArrayIndexUpdateHelper)) diff --git a/scico/numpy/_create.py b/scico/numpy/_create.py deleted file mode 100644 index 9f51f14f3..000000000 --- a/scico/numpy/_create.py +++ /dev/null @@ -1,174 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2021-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Functions for creating new arrays.""" - -from typing import Union - -import numpy as np - -import jax -from jax import numpy as jnp - -from scico.numpy import BlockArray, is_nested -from scico.typing import BlockShape, DType, JaxArray, Shape - - -def zeros( - shape: Union[Shape, BlockShape], dtype: DType = np.float32 -) -> Union[JaxArray, BlockArray]: - """Return a new array of given shape and type, filled with zeros. - - If `shape` is a list of tuples, returns a BlockArray of zeros. - - Args: - shape: Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. - """ - if is_nested(shape): - return BlockArray.zeros(shape, dtype=dtype) - return jnp.zeros(shape, dtype=dtype) - - -def ones(shape: Union[Shape, BlockShape], dtype: DType = np.float32) -> Union[JaxArray, BlockArray]: - """Return a new array of given shape and type, filled with ones. - - If `shape` is a list of tuples, returns a BlockArray of ones. - - Args: - shape: Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. - """ - if is_nested(shape): - return BlockArray.ones(shape, dtype=dtype) - return jnp.ones(shape, dtype=dtype) - - -def empty( - shape: Union[Shape, BlockShape], dtype: DType = np.float32 -) -> Union[JaxArray, BlockArray]: - """Return a new array of given shape and type, filled with zeros. - - If `shape` is a list of tuples, returns a BlockArray of zeros. - - Args: - shape: Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. - """ - if is_nested(shape): - return BlockArray.empty(shape, dtype=dtype) - return jnp.empty(shape, dtype=dtype) - - -def full( - shape: Union[Shape, BlockShape], - fill_value: Union[float, complex], - dtype: DType = None, -) -> Union[JaxArray, BlockArray]: - """Return a new array of given shape and type, filled with `fill_value`. - - If `shape` is a list of tuples, returns a BlockArray filled with - `fill_value`. - - Args: - shape: Shape of the new array. - fill_value : Fill value. - dtype: Desired data-type of the array. The default, None, - means `np.array(fill_value).dtype`. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(type(fill_value)) - if is_nested(shape): - return BlockArray.full(shape, fill_value=fill_value, dtype=dtype) - return jnp.full(shape, fill_value=fill_value, dtype=dtype) - - -def zeros_like(x: Union[JaxArray, BlockArray], dtype=None): - """Return an array of zeros with same shape and type as a given array. - - If input is a BlockArray, returns a BlockArray of zeros with same - shape and type as a given array. - - Args: - x (array like): The shape and dtype of `x` define these - attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the - result. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(x.dtype) - - if isinstance(x, BlockArray): - return BlockArray.zeros(x.shape, dtype=dtype) - return jnp.zeros_like(x, dtype=dtype) - - -def empty_like(x: Union[JaxArray, BlockArray], dtype: DType = None): - """Return an array of zeros with same shape and type as a given array. - - If input is a BlockArray, returns a BlockArray of zeros with same - shape and type as a given array. - - Note: like :func:`jax.numpy.empty_like`, this does not return an - uninitalized array. - - Args: - x (array like): The shape and dtype of `x` define these - attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the - result. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(x.dtype) - - if isinstance(x, BlockArray): - return BlockArray.zeros(x.shape, dtype=dtype) - return jnp.zeros_like(x, dtype=dtype) - - -def ones_like(x: Union[JaxArray, BlockArray], dtype: DType = None): - """Return an array of ones with same shape and type as a given array. - - If input is a BlockArray, returns a BlockArray of ones with same - shape and type as a given array. - - Args: - x (array like): The shape and dtype of `x` define these - attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the - result. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(x.dtype) - - if isinstance(x, BlockArray): - return BlockArray.ones(x.shape, dtype=dtype) - return jnp.ones_like(x, dtype=dtype) - - -def full_like( - x: Union[JaxArray, BlockArray], fill_value: Union[float, complex], dtype: DType = None -): - """Return an array filled with `fill_value`. - - Return an array of with same shape and type as a given array, filled - with `fill_value`. If input is a BlockArray, returns a BlockArray of - `fill_value` with same shape and type as a given array. - - Args: - x (array like): The shape and dtype of `x` define these - attributes on the returned array. - fill_value (scalar): Fill value. - dtype (data-type, optional): Overrides the data type of the - result. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(x.dtype) - - if isinstance(x, BlockArray): - return BlockArray.full(x.shape, fill_value=fill_value, dtype=dtype) - return jnp.full_like(x, fill_value=fill_value, dtype=dtype) diff --git a/scico/numpy/_util_old.py b/scico/numpy/_util_old.py deleted file mode 100644 index 6e81ccd10..000000000 --- a/scico/numpy/_util_old.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Tools to construct wrapped versions of :mod:`jax.numpy` functions.""" - -import re -import types -from functools import wraps - -import numpy as np - -from jaxlib.xla_extension import CompiledFunction - -# wrapper for not-implemented jax.numpy functions -# stripped down version of jax._src.lax_numpy._not_implemented and jax.utils._wraps - -_NOT_IMPLEMENTED_DESC = """ -**WARNING**: This function is not yet implemented by :mod:`scico.numpy` and -may raise an error when operating on :class:`scico.numpy.BlockArray`. -""" - - -def _not_implemented(fun): - @wraps(fun) - def wrapped(*args, **kwargs): - return fun(*args, **kwargs) - - if not hasattr(fun, "__doc__") or fun.__doc__ is None: - return wrapped - - # wrapped.__doc__ = fun.__doc__ + "\n\n" + _NOT_IMPLEMENTED_DESC - wrapped.__doc__ = re.sub( - r"^\*Original docstring below\.\*", - _NOT_IMPLEMENTED_DESC + r"\n\n" + "*Original docstring below.*", - wrapped.__doc__, - flags=re.M, - ) - return wrapped - - -def _attach_wrapped_func(funclist, wrapper, module_name, fix_mod_name=False): - # funclist is either a function, or a tuple (name-in-this-module, function) - # wrapper is a function that is applied to each function in funclist, with - # the output being assigned as an attribute of the module `module_name` - for func in funclist: - # Test required because func.__name__ isn't always the name we want - # e.g. jnp.abs.__name__ resolves to 'absolute', not 'abs' - if isinstance(func, tuple): - fname = func[0] - fref = func[1] - else: - fname = func.__name__ - fref = func - # Set wrapped function as an attribute in module_name - setattr(module_name, fname, wrapper(fref)) - # Set __module__ attribute of wrapped function to - # module_name.__name__ (i.e., scico.numpy) so that it does not - # appear to autodoc be an imported function - if fix_mod_name: - getattr(module_name, fname).__module__ = module_name.__name__ - - -def _get_module_functions(module): - """Finds functions in module. - - This function is a slightly modified version of - :func:`jax._src.util.get_module_functions`. Unlike the JAX version, - this version will also return any - :class:`jaxlib.xla_extension.CompiledFunction`s that exist in the - module. - - Args: - module: A Python module. - Returns: - module_fns: A dict of names mapped to functions, builtins or - ufuncs in `module`. - """ - module_fns = {} - for key in dir(module): - # Omitting module level __getattr__, __dir__ which was added in Python 3.7 - # https://www.python.org/dev/peps/pep-0562/ - if key in ("__getattr__", "__dir__"): - continue - attr = getattr(module, key) - if isinstance( - attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc, CompiledFunction) - ): - module_fns[key] = attr - return module_fns diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index aa43902e8..1acab4a2f 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -27,27 +27,27 @@ :: >>> x = BlockArray(( - [[1, 3, 7], - [2, 2, 1]], - [2, 4, 8] - )) - >>> x.shape - ((2, 3), (3,)) # tuple - - >>> x * 2 - (DeviceArray([[2, 6, 14], - [4, 4, 2]], dtype=int32), - DeviceArray([4, 8, 16], dtype=int32)) # BlockArray + ... [[1, 3, 7], + ... [2, 2, 1]], + ... [2, 4, 8] + ... )) + + >>> x.shape # returns tuple + ((2, 3), (3,)) + + >>> x * 2 # returns BlockArray + [DeviceArray([[ 2, 6, 14], + [ 4, 4, 2]], dtype=int32), DeviceArray([ 4, 8, 16], dtype=int32)] >>> y = BlockArray(( - [[.2], - [.3]], - [.4] - )) - >>> x + y + ... [[.2], + ... [.3]], + ... [.4] + ... )) + + >>> x + y # returns BlockArray [DeviceArray([[1.2, 3.2, 7.2], - [2.3, 2.3, 1.3]], dtype=float32), - DeviceArray([2.4, 4.4, 8.4], dtype=float32)] # BlockArray + [2.3, 2.3, 1.3]], dtype=float32), DeviceArray([2.4, 4.4, 8.4], dtype=float32)] NumPy Functions @@ -81,7 +81,6 @@ For a list of the extended functions, see `scico.numpy.testing_functions`. - Motivating Example ================== @@ -127,9 +126,6 @@ Constructing a BlockArray ========================= -Construct from a tuple of arrays (either `ndarray` or `DeviceArray`) --------------------------------------------------------------------- - .. doctest:: >>> from scico.numpy import BlockArray @@ -140,8 +136,8 @@ >>> X.shape ((32, 32), (16,)) >>> X.size - 1040 - >>> X.num_blocks + (1024, 16) + >>> len(X) 2 While :func:`.BlockArray.array` will accept either `ndarray` or @@ -155,18 +151,6 @@ single precision and will have dtype ``float32`` or ``complex64``. -Construct from a single vector and tuple of shapes --------------------------------------------------- - - :: - - >>> x_flat, _ = scico.random.randn((1040,)) - >>> shape_tuple = ((32, 32), (16,)) - >>> X = BlockArray.array_from_flattened(x_flat, shape_tuple=shape_tuple) - >>> X.shape - ((32, 32), (16,)) - - Operating on a BlockArray ========================= @@ -176,182 +160,10 @@ Indexing -------- -The block index is required to be an integer, selecting a single block and -returning it as an array (*not* a singleton BlockArray). If the index -expression has more than one component, then the initial index indexes the -block, and the remainder of the indexing expression indexes within the -selected block, e.g. `x[2, 3:4]` is equivalent to `y[3:4]` after -setting `y = x[2]`. - - -Indexed Updating ----------------- - -BlockArrays support the JAX DeviceArray `indexed update syntax -`_ - - -The index must be of the form [ibk] or [ibk, idx], where `ibk` is the -index of the block to be updated, and `idx` is a general index of the -elements to be updated in that block. In particular, `ibk` cannot be a -`slice`. The general index `idx` can be omitted, in which case an entire -block is updated. - - -============================== ============================================== -Alternate syntax Equivalent in-place expression -============================== ============================================== -`x.at[ibk, idx].set(y)` `x[ibk, idx] = y` -`x.at[ibk, idx].add(y)` `x[ibk, idx] += y` -`x.at[ibk, idx].multiply(y)` `x[ibk, idx] *= y` -`x.at[ibk, idx].divide(y)` `x[ibk, idx] /= y` -`x.at[ibk, idx].power(y)` `x[ibk, idx] **= y` -`x.at[ibk, idx].min(y)` `x[ibk, idx] = np.minimum(x[idx], y)` -`x.at[ibk, idx].max(y)` `x[ibk, idx] = np.maximum(x[idx], y)` -============================== ============================================== - - -Arithmetic and Broadcasting ---------------------------- - -Suppose :math:`\mb{x}` is a BlockArray with shape :math:`((n, n), (m,))`. - - :: - - >>> x1, key = scico.random.randn((4, 4)) - >>> x2, _ = scico.random.randn((5,), key=key) - >>> x = BlockArray.array( (x1, x2) ) - >>> x.shape - ((4, 4), (5,)) - >>> x.num_blocks - 2 - >>> x.size # 4*4 + 5 - 21 - -Illustrated for the operation `+`, but equally valid for operators -`+, -, *, /, //, **, <, <=, >, >=, ==` - - -Operations with BlockArrays with same number of blocks -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Let :math:`\mb{y}` be a BlockArray with the same number of blocks as -:math:`\mb{x}`. - - .. math:: - \mb{x} + \mb{y} - = - \begin{bmatrix} - \mb{x}[0] + \mb{y}[0] \\ - \mb{x}[1] + \mb{y}[1] \\ - \end{bmatrix} - -This operation depends on pair of blocks from :math:`\mb{x}` and -:math:`\mb{y}` being broadcastable against each other. - - - -Operations with a scalar -^^^^^^^^^^^^^^^^^^^^^^^^ - -The scalar is added to each element of the :class:`.BlockArray`: - - .. math:: - \mb{x} + 1 - = - \begin{bmatrix} - \mb{x}[0] + 1 \\ - \mb{x}[1] + 1\\ - \end{bmatrix} - - - :: - - >>> y = x + 1 - >>> np.testing.assert_allclose(y[0], x[0] + 1) - >>> np.testing.assert_allclose(y[1], x[1] + 1) - - - -Operations with a 1D `ndarray` of size equal to `num_blocks` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The *i*\th scalar is added to the *i*\th block of the -:class:`.BlockArray`: - - .. math:: - \mb{x} - + - \begin{bmatrix} - 1 \\ - 2 - \end{bmatrix} - = - \begin{bmatrix} - \mb{x}[0] + 1 \\ - \mb{x}[1] + 2\\ - \end{bmatrix} - - - :: - - >>> y = x + np.array([1, 2]) - >>> np.testing.assert_allclose(y[0], x[0] + 1) - >>> np.testing.assert_allclose(y[1], x[1] + 2) - - -Operations with an ndarray of `size` equal to :class:`.BlockArray` size -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We first cast the `ndarray` to a BlockArray with same shape as -:math:`\mb{x}`, then apply the operation on the resulting BlockArrays. -With `y.size = x.size`, we have: +`BlockArray` indexing works just like indexing on a list. - .. math:: - \mb{x} - + - \mb{y} - = - \begin{bmatrix} - \mb{x}[0] + \mb{y}[0] \\ - \mb{x}[1] + \mb{y}[1]\\ - \end{bmatrix} - -Equivalently, the BlockArray is first flattened, then added to the -flattened `ndarray`, and the result is reformed into a BlockArray with -the same shape as :math:`\mb{x}` - - - -MatMul ------- - -Between two BlockArrays -^^^^^^^^^^^^^^^^^^^^^^^ - -The matmul is computed between each block of the two BlockArrays. - -The BlockArrays must have the same number of blocks, and each pair of -blocks must be broadcastable. - - .. math:: - \mb{x} @ \mb{y} - = - \begin{bmatrix} - \mb{x}[0] @ \mb{y}[0] \\ - \mb{x}[1] @ \mb{y}[1]\\ - \end{bmatrix} - - - -Between BlockArray and Ndarray/DeviceArray -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This operation is not defined. - - -Between BlockArray and :class:`.LinearOperator` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Multiplication Between BlockArray and :class:`.LinearOperator` +-------------------------------------------------------------- The :class:`.Operator` and :class:`.LinearOperator` classes are designed to work on :class:`.BlockArray`\ s in addition to `DeviceArray`\ s. @@ -374,129 +186,6 @@ >>> A_3.shape # BlockArray -> BlockArray (((2, 4), (3, 3)), ((2, 4), (3, 3))) - -NumPy ufuncs ------------- - -`NumPy universal functions (ufuncs) `_ -are functions that operate on an `ndarray` on an element-by-element -fashion and support array broadcasting. Examples of ufuncs are `abs`, -`sign`, `conj`, and `exp`. - -The JAX library implements most NumPy ufuncs in the :mod:`jax.numpy` -module. However, as JAX does not support subclassing of `DeviceArray`, -the JAX ufuncs cannot be used on :class:`.BlockArray`. As a workaround, -we have wrapped several JAX ufuncs for use on :class:`.BlockArray`; these -are defined in the :mod:`scico.numpy` module. - - -Reductions -^^^^^^^^^^ - -Reductions are functions that take an array-like as an input and return -an array of lower dimension. Examples include `mean`, `sum`, `norm`. -BlockArray reductions are located in the :mod:`scico.numpy` module - -:class:`.BlockArray` tries to mirror `ndarray` reduction semantics where -possible, but cannot provide a one-to-one match as the block components -may be of different size. - -Consider the example BlockArray - - .. math:: - \mb{x} = \begin{bmatrix} - \begin{bmatrix} - 1 & 1 \\ - 1 & 1 - \end{bmatrix} \\ - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} - \end{bmatrix}. - -We have - - .. doctest:: - - >>> import scico.numpy as snp - >>> x = BlockArray.array((np.ones((2,2)), 2*np.ones((2)))) - >>> x.shape - ((2, 2), (2,)) - >>> x.size - 6 - >>> x.num_blocks - 2 - - - - If no axis is specified, the reduction is applied to the flattened - array: - - .. doctest:: - - >>> snp.sum(x, axis=None).item() - 8.0 - - Reducing along the 0-th axis crushes the `BlockArray` down into a - single `DeviceArray` and requires all blocks to have the same shape - otherwise, an error is raised. - - .. doctest:: - - >>> snp.sum(x, axis=0) - Traceback (most recent call last): - ValueError: Evaluating sum of BlockArray along axis=0 requires all blocks to be same shape; got ((2, 2), (2,)) - - >>> y = BlockArray.array((np.ones((2,2)), 2*np.ones((2, 2)))) - >>> snp.sum(y, axis=0) - DeviceArray([[3., 3.], - [3., 3.]], dtype=float32) - - Reducing along axis :math:`n` is equivalent to reducing each component - along axis :math:`n-1`: - - .. math:: - \text{sum}(x, axis=1) = \begin{bmatrix} - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} \\ - \begin{bmatrix} - 4 \\ - \end{bmatrix} - \end{bmatrix} - - - If a component does not have axis :math:`n-1`, the reduction is not - applied to that component. In this example, `x[1].ndim == 1`, so no - reduction is applied to block `x[1]`. - - .. math:: - \text{sum}(x, axis=2) = \begin{bmatrix} - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} \\ - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} - \end{bmatrix} - - -Code version - - .. doctest:: - - >>> snp.sum(x, axis=1) # doctest: +SKIP - BlockArray([[2, 2], - [4,] ]) - - >>> snp.sum(x, axis=2) # doctest: +SKIP - BlockArray([ [2, 2], - [2,] ]) - """ import inspect diff --git a/scico/ray/tune.py b/scico/ray/tune.py index 4b174f9bc..726e6baac 100644 --- a/scico/ray/tune.py +++ b/scico/ray/tune.py @@ -18,7 +18,7 @@ import ray.tune except ImportError: raise ImportError("Could not import ray.tune; please install it.") -from ray.tune import grid_search, loguniform, report, uniform # noqa +from ray.tune import loguniform, report, uniform # noqa from ray.tune.progress_reporter import TuneReporterBase, _get_trials_by_state from ray.tune.schedulers import AsyncHyperBandScheduler from ray.tune.suggest.hyperopt import HyperOptSearch From c61cbe64a8e6dfd649f923d6297537da2d934837 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 19 Apr 2022 07:35:38 -0600 Subject: [PATCH 26/37] Update example scripts --- examples/scripts/denoise_tv_iso_pgm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index 34f29ad2e..9a6008490 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -122,9 +122,9 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: x_out = v / jnp.maximum(jnp.ones(v.shape), norm_v_ptp) out1 = v[0, :, -1] / jnp.maximum(jnp.ones(v[0, :, -1].shape), jnp.abs(v[0, :, -1])) - x_out_1 = jax.ops.index_update(x_out, jax.ops.index[0, :, -1], out1) + x_out = x_out.at[0, :, -1].set(out1) out2 = v[1, -1, :] / jnp.maximum(jnp.ones(v[1, -1, :].shape), jnp.abs(v[1, -1, :])) - x_out = jax.ops.index_update(x_out_1, jax.ops.index[1, -1, :], out2) + x_out = x_out.at[1, -1, :].set(out2) return x_out From 8d97d1f9b90f7d90c977c9ac742ac82b42ef38bf Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 19 Apr 2022 18:18:41 -0600 Subject: [PATCH 27/37] Work on docs --- scico/numpy/__init__.py | 36 +-- ...on_lists.py => _wrapped_function_lists.py} | 0 scico/numpy/{_util.py => _wrappers.py} | 0 scico/numpy/blockarray.py | 88 +++--- scico/random.py | 2 +- scico/scipy/special.py | 24 +- scico/test/test_blockarray.py | 259 +----------------- scico/test/test_numpy.py | 174 ------------ 8 files changed, 79 insertions(+), 504 deletions(-) rename scico/numpy/{function_lists.py => _wrapped_function_lists.py} (100%) rename scico/numpy/{_util.py => _wrappers.py} (100%) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index e3cee1fd1..a0cb17330 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -5,45 +5,47 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -""":class:`scico.numpy.BlockArray`-compatible -versions of :mod:`jax.numpy` functions. - -This modules consists of functions from :mod:`jax.numpy` wrapped to -support compatibility with :class:`scico.numpy.BlockArray`. This -module is a work in progress and therefore not all functions are -wrapped. Functions that have not been wrapped yet have WARNING text in -their documentation, below. +""":class:`.BlockArray` and functions for working with them alongside +:class:`DeviceArray` . + +This module consists :class:`.BlockArray` and functions for working with +it alongside class:`DeviceArray`. The latter include all the functions +from :mod:`jax.numpy` and :mod:`numpy.testing`, where many have been +extended to automatically map over block array blocks as described in +:mod:`scico.numpy.blockarray`. Also included are additional functions +unique to SCICO in :mod:`.util`. + """ import numpy as np import jax.numpy as jnp -from . import _util +from . import _wrappers +from ._wrapped_function_lists import * from .blockarray import BlockArray -from .function_lists import * from .util import * # copy most of jnp without wrapping -_util.add_attributes( +_wrappers.add_attributes( to_dict=vars(), from_dict=jnp.__dict__, modules_to_recurse=("linalg", "fft"), ) # wrap jnp funcs -_util.wrap_recursively(vars(), creation_routines, _util.map_func_over_tuple_of_tuples) -_util.wrap_recursively(vars(), mathematical_functions, _util.map_func_over_blocks) -_util.wrap_recursively(vars(), reduction_functions, _util.add_full_reduction) +_wrappers.wrap_recursively(vars(), creation_routines, _wrappers.map_func_over_tuple_of_tuples) +_wrappers.wrap_recursively(vars(), mathematical_functions, _wrappers.map_func_over_blocks) +_wrappers.wrap_recursively(vars(), reduction_functions, _wrappers.add_full_reduction) # copy np.testing -_util.add_attributes( +_wrappers.add_attributes( to_dict=vars(), from_dict={"testing": np.testing}, modules_to_recurse=("testing",), ) # wrap testing funcs -_util.wrap_recursively(vars(), testing_functions, _util.map_func_over_blocks) +_wrappers.wrap_recursively(vars(), testing_functions, _wrappers.map_func_over_blocks) # clean up -del np, jnp, _util +del np, jnp, _wrappers diff --git a/scico/numpy/function_lists.py b/scico/numpy/_wrapped_function_lists.py similarity index 100% rename from scico/numpy/function_lists.py rename to scico/numpy/_wrapped_function_lists.py diff --git a/scico/numpy/_util.py b/scico/numpy/_wrappers.py similarity index 100% rename from scico/numpy/_util.py rename to scico/numpy/_wrappers.py diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 1acab4a2f..7ceca2a8a 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -18,9 +18,9 @@ The class :class:`.BlockArray` provides a way to combine arrays of different shapes into a single object for use with other SCICO classes. A :class:`.BlockArray` consists of a list of `DeviceArray` objects, -which we refer to as blocks. :class:`.BlockArray`s differ from lists in +which we refer to as blocks. :class:`.BlockArray` s differ from lists in that, whenever possible, :class:`.BlockArray` properties and methods -(including unary and binary operators like +, -, *, ...) automatically +(including unary and binary operators like +, -, \*, ...) automatically map along the blocks, returning another :class:`.BlockArray` or tuple as appropriate. For example, @@ -50,35 +50,49 @@ [2.3, 2.3, 1.3]], dtype=float32), DeviceArray([2.4, 4.4, 8.4], dtype=float32)] -NumPy Functions -=============== +NumPy and SciPy Functions +========================= + +:mod:`scico.numpy`, :mod:`scico.numpy.testing`, and +:mod:`scico.scipy.special` provide wrappers around :mod:`jax.numpy`, +:mod:`numpy.testing` and :mod:`jax.scipy.special` where many of the +functions have been extended to work with `BlockArray` s. In particular: -:mod:`scico.numpy` provides a wrapper around :mod:`jax.numpy` where many -of the functions have been extended to work with `BlockArray`s. In -particular: + * When a tuple of tuples is passed as the `shape` + argument to an array creation routine, a `BlockArray` is created. + * When a `BlockArray` is passed to a reduction function, the blocks are + ravelled (i.e., reshaped to be 1D) and concatenated before the reduction + is applied. This behavior may be prevented by passing the `axis` + argument, in which case the function is mapped over the blocks. + * When one or more `BlockArray`s is passed to a mathematical + function that is not a reduction, the function is mapped over + (corresponding) blocks. -* When a tuple of tuples is passed as the `shape` -argument to an array creation routine, a `BlockArray` is created. +For a list of array creation routines, see -* When a `BlockArray` is passed to a reduction function, the blocks are -ravelled (i.e., reshaped to be 1D) and concatenated before the reduction -is applied. This behavior may be prevented by passing the `axis` -argument, in which case the function is mapped over the blocks. + :: -* When one or more `BlockArray`s is passed to a mathematical -function that is not a reduction, the function is mapped over -(corresponding) blocks. + >>> scico.numpy.creation_routines # doctest: +ELLIPSIS + ('empty', ...) -For lists of array creation routines, reduction functions, and mathematical -functions that have been wrapped in this manner, -see `scico.numpy.creation_routines`, `scico.numpy.reduction_fuctions`, -and -`scico.numpy.mathematical_functions`. +For a list of reduction functions, see + + :: + + >>> scico.numpy.reduction_functions # doctest: +ELLIPSIS + ('sum', ...) + +For lists of the remaining wrapped functions, see + + :: -:mod:`scico.numpy.testing` provides a wrapper around :mod:`numpy.testing` -where some functions have been extended to map over blocks, -notably `scico.numpy.testing.allclose`. -For a list of the extended functions, see `scico.numpy.testing_functions`. + >>> scico.numpy.mathematical_functions # doctest: +ELLIPSIS + ('sin', ...) + >>> scico.numpy.testing_functions # doctest: +ELLIPSIS + ('testing.assert_allclose', ...) + >>> import scico.scipy + >>> scico.scipy.special.functions # doctest: +ELLIPSIS + ('betainc', ...) Motivating Example @@ -126,7 +140,7 @@ Constructing a BlockArray ========================= - .. doctest:: + :: >>> from scico.numpy import BlockArray >>> import numpy as np @@ -147,11 +161,6 @@ **Note**: constructing a :class:`.BlockArray` always involves a copy to a new `DeviceArray` memory buffer. -**Note**: by default, the resulting :class:`.BlockArray` is cast to -single precision and will have dtype ``float32`` or ``complex64``. - - - Operating on a BlockArray ========================= @@ -160,7 +169,7 @@ Indexing -------- -`BlockArray` indexing works just like indexing on a list. +`BlockArray` indexing works just like indexing a list. Multiplication Between BlockArray and :class:`.LinearOperator` -------------------------------------------------------------- @@ -197,7 +206,7 @@ from jaxlib.xla_extension import DeviceArray -from .function_lists import binary_ops, unary_ops +from ._wrapped_function_lists import binary_ops, unary_ops class BlockArray(list): @@ -218,18 +227,6 @@ def __init__(self, inputs): return super().__init__(arrays) - def _full_ravel(self) -> DeviceArray: - """Return a copy of ``self._data`` as a contiguous, flattened `DeviceArray`. - - Note that a copy, rather than a view, of the underlying array is - returned. This is consistent with :func:`jax.numpy.ravel`. - - Returns: - Copy of underlying flattened array. - - """ - return jnp.concatenate(tuple(x_i.ravel() for x_i in self)) - @property def dtype(self): """Allow snp.zeros(x.shape, x.dtype) to work.""" @@ -258,6 +255,7 @@ def array(iterable): lambda _, xs: BlockArray(xs), # from iter ) + """ Wrap unary ops like -x. """ diff --git a/scico/random.py b/scico/random.py index 08ab8f472..700ed82f8 100644 --- a/scico/random.py +++ b/scico/random.py @@ -59,7 +59,7 @@ import jax from scico.numpy import BlockArray -from scico.numpy._util import map_func_over_tuple_of_tuples +from scico.numpy._wrappers import map_func_over_tuple_of_tuples from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape diff --git a/scico/scipy/special.py b/scico/scipy/special.py index cc5e11749..75a46da6d 100644 --- a/scico/scipy/special.py +++ b/scico/scipy/special.py @@ -5,25 +5,24 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Wrapped versions of :mod:`jax.scipy.special` functions. +""":class:`scico.numpy.BlockArray`-compatible :mod:`jax.scipy.special` +functions. -This modules consists of functions from :mod:`jax.scipy.special`. Some of -these functions are wrapped to support compatibility with -:class:`scico.numpy.BlockArray` and are documented here. The -remaining functions are imported directly from :mod:`jax.numpy`. While -they can be imported from the :mod:`scico.numpy` namespace, they are not -documented here; please consult the documentation for the source module -:mod:`jax.scipy.special`. +This modules is a wrapper for :mod:`jax.scipy.special` where some +functions have been extended to automatically map over block array +blocks as described in :class:`scico.numpy.BlockArray` """ import jax.scipy.special as js -from scico.numpy import _util +from scico.numpy import _wrappers -_util.add_attributes( +# add most everything in jax.scipy.special to this module +_wrappers.add_attributes( vars(), js.__dict__, ) +# wrap select functions functions = ( "betainc", "entr", @@ -51,6 +50,7 @@ "zeta", "digamma", ) +_wrappers.wrap_recursively(vars(), functions, _wrappers.map_func_over_blocks) - -_util.wrap_recursively(vars(), functions, _util.map_func_over_blocks) +# clean up +del js, _wrappers diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index 067e781d8..50e186f91 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -4,8 +4,6 @@ import numpy as np import jax -import jax.numpy as jnp -from jax.interpreters.xla import DeviceArray import pytest @@ -77,75 +75,6 @@ def test_operator_right(test_operator_obj, operator): snp.testing.assert_allclose(x, y) -# Operations between a blockarray and a flat DeviceArray -@pytest.mark.skip # do we want to allow ((3,4), (4, 5, 6)) + (132,) ? -# argument against: numpy doesn't allow (3, 4) + (12,) -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_ba_da_left(test_operator_obj, operator): - flat_da = test_operator_obj.flat_da - a = test_operator_obj.a - x = operator(flat_da, a) - y = BlockArray(operator(flat_da, a_i) for a_i in a) - snp.testing.assert_allclose(x, y, rtol=5e-5) - - -@pytest.mark.skip # see previous -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_ba_da_right(test_operator_obj, operator): - flat_da = test_operator_obj.flat_da - a = test_operator_obj.a - x = operator(a, flat_da) - y = BlockArray(operator(a_i, flat_da) for a_i in a) - np.testing.assert_allclose(x, y) - - -# Blockwise comparison between a BlockArray and Ndarray -@pytest.mark.skip # do we want to allow ((3,4), (4, 5, 6)) + (2,) ? -# argument against numpy doesn't allow (3, 4) + (3,), though leading dims match -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_ndarray_left(test_operator_obj, operator): - a = test_operator_obj.a - block_nd = test_operator_obj.block_nd - - x = operator(a, block_nd) - y = BlockArray([operator(a[i], block_nd[i]) for i in range(len(a))]) - snp.testing.assert_allclose(x, y) - - -@pytest.mark.skip # see previous -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_ndarray_right(test_operator_obj, operator): - a = test_operator_obj.a - block_nd = test_operator_obj.block_nd - - x = operator(block_nd, a) - y = BlockArray([operator(block_nd[i], a[i]) for i in range(len(a))]) - snp.testing.assert_allclose(x, y) - - -# Blockwise comparison between a BlockArray and DeviceArray -@pytest.mark.skip # see previous -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_devicearray_left(test_operator_obj, operator): - a = test_operator_obj.a - block_da = test_operator_obj.block_da - - x = operator(a, block_da) - y = BlockArray([operator(a[i], block_da[i]) for i in range(len(a))]) - snp.testing.assert_allclose(x, y) - - -@pytest.mark.skip # see previous -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_devicearray_right(test_operator_obj, operator): - a = test_operator_obj.a - block_da = test_operator_obj.block_da - - x = operator(block_da, a) - y = BlockArray([operator(block_da[i], a[i]) for i in range(len(a))]) - snp.testing.assert_allclose(x, y, atol=1e-7, rtol=0) - - # Operations between two blockarrays of same size @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_ba_ba_operator(test_operator_obj, operator): @@ -225,30 +154,6 @@ def test_getitem(test_operator_obj): np.testing.assert_allclose(x[-1], b1) -@pytest.mark.skip() -# this is indexing block dimension and internal dimensions simultaneously -# supporting it adds complexity, are we okay with just x[0][1:3] instead of x[0, 1:3]? -@pytest.mark.parametrize("index", (np.s_[0, 0], np.s_[0, 1:3], np.s_[0, :, 0:2], np.s_[0, ..., 2:])) -def test_getitem_tuple(test_operator_obj, index): - a = test_operator_obj.a - a0 = test_operator_obj.a0 - np.testing.assert_allclose(a[index], a0[index[1:]]) - - -@pytest.mark.skip() -# `.blockidx` was an index into the underlying 1D array that no longer exists -def test_blockidx(test_operator_obj): - a = test_operator_obj.a - a0 = test_operator_obj.a0 - a1 = test_operator_obj.a1 - - # use the blockidx to index the flattened data - x0 = a.full_ravel()[a.blockidx(0)] - x1 = a.full_ravel()[a.blockidx(1)] - np.testing.assert_allclose(x0, a0.full_ravel()) - np.testing.assert_allclose(x1, a1.full_ravel()) - - def test_split(test_operator_obj): a = test_operator_obj.a np.testing.assert_allclose(a[0], test_operator_obj.a0) @@ -257,8 +162,9 @@ def test_split(test_operator_obj): @pytest.mark.skip() # currently creation is exactly like a tuple, -# so BlockArray(np.jnp.zeros((32,32))) makes a block array -# with 32 1d blocks +# so BlockArray(np.jnp.zeros((3,6))) makes a block array +# with 3 length-6 blocks +# TODO replace with test of new behavior def test_blockarray_from_one_array(): with pytest.raises(TypeError): BlockArray(np.random.randn(32, 32)) @@ -275,22 +181,6 @@ def test_sum_method(test_operator_obj, axis, keepdims): snp.testing.assert_allclose(method_result, snp_result) -@pytest.mark.skip() -# previously vdot returned a scalar, -# in this proposal, it acts blockwize -def test_ba_ba_vdot(test_operator_obj): - a = test_operator_obj.a - d = test_operator_obj.d - a0 = test_operator_obj.a0 - a1 = test_operator_obj.a1 - d0 = test_operator_obj.d0 - d1 = test_operator_obj.d1 - - x = snp.vdot(a, d) - y = jnp.vdot(a.full_ravel(), d.full_ravel()) - np.testing.assert_allclose(x, y) - - @pytest.mark.parametrize("operator", [snp.dot, snp.matmul]) def test_ba_ba_dot(test_operator_obj, operator): a = test_operator_obj.a @@ -357,28 +247,6 @@ def test_reduce(reduction_obj, func): np.testing.assert_allclose(x, y, rtol=1e-6) # test for correctness -@pytest.mark.skip -# this is reduction along the block axis, which (in the old version) -# requires all blocks to be the same shape. If you know all blocks are the same shape, -# why use a block array? -@pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis0_old(reduction_obj, func): - f = lambda x: func(x, axis=0) - x = f(reduction_obj.b) - x_jit = jax.jit(f)(reduction_obj.b) - - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function - - # test for correctness - # stack into a (2, 3, 4) array, call func - y = func(np.stack(list(reduction_obj.b)), axis=0) - np.testing.assert_allclose(x, y) - - with pytest.raises(ValueError): - # Reduction along axis=0 only works if all blocks are same shape - func(reduction_obj.a, axis=0) - - @pytest.mark.parametrize(**REDUCTION_PARAMS) @pytest.mark.parametrize("axis", (0, 1)) def test_reduce_axis(reduction_obj, func, axis): @@ -449,129 +317,10 @@ def test_full_nodtype(self): assert snp.all(x == fill_value) -@pytest.mark.skip -# it no longer makes sense to make a BlockArray from a flattened array -def test_incompatible_shapes(): - # Verify that array_from_flattened raises exception when - # len(data_ravel) != size determined by shape_tuple - shape_tuple = ((32, 32), (16,)) # len == 1040 - data_ravel = np.ones(1030) - with pytest.raises(ValueError): - BlockArray.array_from_flattened(data_ravel=data_ravel, shape_tuple=shape_tuple) - - -class NestedTestObj: - operators = math_ops + comp_ops - - def __init__(self, dtype): - key = None - scalar, key = randn(shape=(1,), dtype=dtype, key=key) - self.scalar = scalar.item() # convert to float - - self.a00, key = randn(shape=(2, 2, 2), dtype=dtype, key=key) - self.a01, key = randn(shape=(3, 2, 4), dtype=dtype, key=key) - self.a1, key = randn(shape=(2, 4), dtype=dtype, key=key) - - self.a = BlockArray(((self.a00, self.a01), self.a1)) - - -@pytest.fixture(scope="module") -def nested_obj(request): - yield NestedTestObj(request.param) - - -@pytest.mark.skip # deeply nested shapes no longer allowed -@pytest.mark.parametrize("nested_obj", [np.float32, np.complex64], indirect=True) -def test_nested_shape(nested_obj): - a = nested_obj.a - - a00 = nested_obj.a00 - a01 = nested_obj.a01 - a1 = nested_obj.a1 - - assert a.shape == (((2, 2, 2), (3, 2, 4)), (2, 4)) - assert a.size == 2 * 2 * 2 + 3 * 2 * 4 + 2 * 4 - - assert a[0].shape == ((2, 2, 2), (3, 2, 4)) - assert a[1].shape == (2, 4) - - snp.testing.assert_allclose(a[0][0], a00) - snp.testing.assert_allclose(a[0][1], a01) - snp.testing.assert_allclose(a[1], a1) - - # basic test for block_sizes - assert a.shape == (a[0].size, a[1].size) - - -NESTED_REDUCTION_PARAMS = dict( - argnames="nested_obj, func", - argvalues=( - list(zip(itertools.repeat(np.float32), reduction_funcs)) - + list(zip(itertools.repeat(np.complex64), reduction_funcs)) - + list(zip(itertools.repeat(np.float32), real_reduction_funcs)) - ), - indirect=["nested_obj"], -) - - -@pytest.mark.skip # deeply nested shapes no longer allowed -@pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) -def test_nested_reduce_singleton(nested_obj, func): - a = nested_obj.a - x = func(a) - y = func(a.full_ravel()) - np.testing.assert_allclose(x, y, rtol=5e-5) - - -@pytest.mark.skip # deeply nested shapes no longer allowed -@pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) -def test_nested_reduce_axis1(nested_obj, func): - a = nested_obj.a - - with pytest.raises(ValueError): - # Blocks don't conform! - x = func(a, axis=1) - - -@pytest.mark.skip # deeply nested shapes no longer allowed -@pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) -def test_nested_reduce_axis2(nested_obj, func): - a = nested_obj.a - - x = func(a, axis=2) - assert x.shape == (((2, 2), (2, 4)), (2,)) - - y = BlockArray((func(a[0], axis=1), func(a[1], axis=1))) - assert x.shape == y.shape - - np.testing.assert_allclose(x.full_ravel(), y.full_ravel(), rtol=5e-5) - - -@pytest.mark.skip # deeply nested shapes no longer allowed -@pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) -def test_nested_reduce_axis3(nested_obj, func): - a = nested_obj.a - - x = func(a, axis=3) - assert x.shape == (((2, 2), (3, 4)), (2, 4)) - - y = BlockArray((func(a[0], axis=2), a[1])) - assert x.shape == y.shape - - np.testing.assert_allclose(x.full_ravel(), y.full_ravel(), rtol=5e-5) - - -@pytest.mark.skip -# no longer makes sense to make BlockArray from 1d array -def test_array_from_flattened(): - x = np.random.randn(19) - x_b = ba.BlockArray.array_from_flattened(x, shape_tuple=((4, 4), (3,))) - assert isinstance(x_b._data, DeviceArray) - - @pytest.mark.skip # indexing now works just like a list of DeviceArrays: # x[1] = x[1].at[:].set(0) +# TODO: some of these are new syntax? class TestBlockArrayIndex: def setup_method(self): key = None diff --git a/scico/test/test_numpy.py b/scico/test/test_numpy.py index dc4b14be3..fc09827f2 100644 --- a/scico/test/test_numpy.py +++ b/scico/test/test_numpy.py @@ -3,11 +3,7 @@ import jax from jax.interpreters.xla import DeviceArray -import pytest - import scico.numpy as snp -import scico.numpy.linalg as sla -from scico.linop import MatrixOperator from scico.numpy import BlockArray @@ -36,176 +32,6 @@ def test_reshape_array(): np.testing.assert_allclose(snp.reshape(a.ravel(), (4, 4)), a) -@pytest.mark.skip # no reshaping into a BlockArray -def test_reshape_array(): - a = np.random.randn(13) - b = snp.reshape(a, ((3, 3), (4,))) - - c = BlockArray.array_from_flattened(a, ((3, 3), (4,))) - - assert isinstance(b, BlockArray) - assert b.shape == c.shape - np.testing.assert_allclose(b.ravel(), c.ravel()) - - -@pytest.mark.skip # do we care to support svd of matrix operator? -@pytest.mark.parametrize("compute_uv", [True, False]) -@pytest.mark.parametrize("full_matrices", [True, False]) -@pytest.mark.parametrize("shape", [(8, 8), (4, 8), (8, 4)]) -def test_svd(compute_uv, full_matrices, shape): - A = jax.device_put(np.random.randn(*shape)) - Ao = MatrixOperator(A) - f = lambda x: sla.svd(x, compute_uv=compute_uv, full_matrices=full_matrices) - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support cond of matrix operator? -def test_cond(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.cond - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support det of matrix operator? -def test_det(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.det - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support eig of matrix operator? -@pytest.mark.skipif( - on_cpu() == False, reason="nonsymmetric eigendecompositions only supported on cpu" -) -def test_eig(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.eig - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support eigh of matrix operator? -@pytest.mark.parametrize("symmetrize", [True, False]) -@pytest.mark.parametrize("UPLO", [None, "L", "U"]) -def test_eigh(UPLO, symmetrize): - A = jax.device_put(np.random.randn(8, 8)) - A = A.T @ A - Ao = MatrixOperator(A) - f = lambda x: sla.eigh(x, UPLO=UPLO, symmetrize_input=symmetrize) - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -@pytest.mark.skipif( - on_cpu() == False, reason="nonsymmetric eigendecompositions only supported on cpu" -) -def test_eigvals(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.eigvals - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -@pytest.mark.parametrize("UPLO", [None, "L", "U"]) -def test_eigvalsh(UPLO): - A = jax.device_put(np.random.randn(8, 8)) - A = A.T @ A - Ao = MatrixOperator(A) - f = lambda x: sla.eigvalsh(x, UPLO=UPLO) - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -def test_inv(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.inv - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -def test_lstsq(): - A = jax.device_put(np.random.randn(8, 8)) - b = jax.device_put(np.random.randn(8)) - Ao = MatrixOperator(A) - f = lambda A: sla.lstsq(A, b) - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -def test_matrix_power(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = lambda A: sla.matrix_power(A, 3) - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -def test_matrixrank(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = lambda A: sla.matrix_rank(A, 3) - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -@pytest.mark.parametrize("rcond", [None, 1e-3]) -def test_pinv(rcond): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.pinv - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -@pytest.mark.parametrize("rcond", [None, 1e-3]) -def test_pinv(rcond): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.pinv - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -@pytest.mark.parametrize("shape", [(8, 8), (4, 8), (8, 4)]) -@pytest.mark.parametrize("mode", ["reduced", "complete", "r"]) -def test_qr(shape, mode): - A = jax.device_put(np.random.randn(*shape)) - Ao = MatrixOperator(A) - f = lambda A: sla.qr(A, mode) - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -def test_slogdet(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.slogdet - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -def test_solve(): - A = jax.device_put(np.random.randn(8, 8)) - b = jax.device_put(np.random.randn(8)) - Ao = MatrixOperator(A) - f = lambda A: sla.solve(A, b) - check_results(f(A), f(Ao)) - - -@pytest.mark.skip # do we care to support...? -def test_multi_dot(): - A = jax.device_put(np.random.randn(8, 8)) - B = jax.device_put(np.random.randn(8, 4)) - Ao = MatrixOperator(A) - Bo = MatrixOperator(B) - f = sla.multi_dot - check_results(f([A, B]), f([Ao, Bo])) - - def test_ufunc_abs(): A = snp.array([-1, 2, 5]) res = snp.array([1, 2, 5]) From 23cb53adff54ff881b9c3e6f8a7ca236d849afd5 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 20 Apr 2022 08:16:10 -0600 Subject: [PATCH 28/37] Move array, improve blockarray docs --- examples/scripts/denoise_tv_iso_pgm.py | 2 +- scico/_generic_operators.py | 3 +- scico/array.py | 176 ------------------------- scico/linop/_circconv.py | 2 +- scico/linop/_diff.py | 2 +- scico/linop/_linop.py | 3 +- scico/loss.py | 2 +- scico/numpy/__init__.py | 4 +- scico/numpy/blockarray.py | 65 +++++---- scico/numpy/util.py | 166 ++++++++++++++++++++++- scico/operator/biconvolve.py | 3 +- scico/optimize/_ladmm.py | 2 +- scico/optimize/_primaldual.py | 2 +- scico/optimize/admm.py | 2 +- scico/optimize/pgm.py | 2 +- scico/test/test_array.py | 2 +- scico/test/test_new_blockarray.py | 7 +- 17 files changed, 227 insertions(+), 218 deletions(-) delete mode 100644 scico/array.py diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index 9a6008490..b2fe6e483 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -39,8 +39,8 @@ import scico.numpy as snp import scico.random from scico import functional, linop, loss, operator, plot -from scico.array import ensure_on_device from scico.numpy import BlockArray +from scico.numpy.util import ensure_on_device from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize from scico.typing import JaxArray from scico.util import device_info diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index 139b683f2..a1b36f417 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -23,7 +23,8 @@ import scico.numpy as snp from scico._autograd import linear_adjoint -from scico.numpy import BlockArray, is_complex_dtype, is_nested, shape_to_size +from scico.numpy import BlockArray +from scico.numpy.util import is_complex_dtype, is_nested, shape_to_size from scico.typing import BlockShape, DType, JaxArray, Shape diff --git a/scico/array.py b/scico/array.py deleted file mode 100644 index 6cf940b03..000000000 --- a/scico/array.py +++ /dev/null @@ -1,176 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Utility functions for arrays, array shapes, array indexing, etc.""" - - -from __future__ import annotations - -import warnings -from typing import List, Optional, Tuple, Union - -import numpy as np - -import jax -from jax.interpreters.pxla import ShardedDeviceArray -from jax.interpreters.xla import DeviceArray - -from scico.numpy import BlockArray -from scico.typing import ArrayIndex, Axes, AxisIndex, JaxArray, Shape - - -def ensure_on_device( - *arrays: Union[np.ndarray, JaxArray, BlockArray] -) -> Union[JaxArray, BlockArray]: - """Cast ndarrays to DeviceArrays. - - Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays, - and ShardedDeviceArray as is. This is intended to be used when - initializing optimizers and functionals so that all arrays are either - DeviceArrays, BlockArrays, or ShardedDeviceArray. - - Args: - *arrays: One or more input arrays (ndarray, DeviceArray, - BlockArray, or ShardedDeviceArray). - - Returns: - Modified array or arrays. Modified are only those that were - necessary. - - Raises: - TypeError: If the arrays contain something that is neither - ndarray, DeviceArray, BlockArray, nor ShardedDeviceArray. - """ - arrays = list(arrays) - - for i, array in enumerate(arrays): - - if isinstance(array, np.ndarray): - warnings.warn( - f"Argument {i+1} of {len(arrays)} is an np.ndarray. " - f"Will cast it to DeviceArray. " - f"To suppress this warning cast all np.ndarrays to DeviceArray first.", - stacklevel=2, - ) - - elif not isinstance( - array, - (DeviceArray, BlockArray, ShardedDeviceArray), - ): - raise TypeError( - "Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or " - f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}." - ) - - arrays[i] = jax.device_put(arrays[i]) - - if len(arrays) == 1: - return arrays[0] - return arrays - - -def parse_axes( - axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None -) -> List[int]: - """Normalize `axes` to a list and optionally ensure correctness. - - Normalize `axes` to a list and (optionally) ensure that entries refer - to axes that exist in `shape`. - - Args: - axes: User specification of one or more axes: int, list, tuple, - or ``None``. - shape: The shape of the array of which axes are being specified. - If not ``None``, `axes` is checked to make sure its entries - refer to axes that exist in `shape`. - default: Default value to return if `axes` is ``None``. By - default, `list(range(len(shape)))`. - - Returns: - List of axes (never an int, never ``None``). - """ - - if axes is None: - if default is None: - if shape is None: - raise ValueError("`axes` cannot be `None` without a default or shape specified.") - axes = list(range(len(shape))) - else: - axes = default - elif isinstance(axes, (list, tuple)): - axes = axes - elif isinstance(axes, int): - axes = (axes,) - else: - raise ValueError(f"Could not understand axes {axes} as a list of axes") - if shape is not None and max(axes) >= len(shape): - raise ValueError( - f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}." - ) - if len(set(axes)) != len(axes): - raise ValueError(f"Duplicate value in axes {axes}; each axis must be unique.") - return axes - - -def slice_length(length: int, idx: AxisIndex) -> Optional[int]: - """Determine the length of an array axis after indexing. - - Determine the length of an array axis after slicing. An exception is - raised if the indexing expression is an integer that is out of bounds - for the specified axis length. A value of ``None`` is returned for - valid integer indexing expressions as an indication that the - corresponding axis shape is an empty tuple; this value should be - converted to a unit integer if the axis size is required. - - Args: - length: Length of axis being sliced. - idx: Indexing/slice to be applied to axis. - - Returns: - Length of indexed/sliced axis. - - Raises: - ValueError: If `idx` is an integer index that is out bounds for - the axis length. - """ - if idx is Ellipsis: - return length - if isinstance(idx, int): - if idx < -length or idx > length - 1: - raise ValueError(f"Index {idx} out of bounds for axis of length {length}.") - return None - start, stop, stride = idx.indices(length) - if start > stop: - start = stop - return (stop - start + stride - 1) // stride - - -def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: - """Determine the shape of an array after indexing/slicing. - - Args: - shape: Shape of array. - idx: Indexing expression. - - Returns: - Shape of indexed/sliced array. - - Raises: - ValueError: If `idx` is longer than `shape`. - """ - if not isinstance(idx, tuple): - idx = (idx,) - if len(idx) > len(shape): - raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.") - idx_shape = list(shape) - offset = 0 - for axis, ax_idx in enumerate(idx): - if ax_idx is Ellipsis: - offset = len(shape) - len(idx) - continue - idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) - return tuple(filter(lambda x: x is not None, idx_shape)) diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index f21219dec..05bb024aa 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -18,7 +18,7 @@ import scico.numpy as snp from scico._generic_operators import Operator -from scico.numpy import is_nested +from scico.numpy.util import is_nested from scico.typing import DType, JaxArray, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index b82013e2d..9f0e77d52 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -17,7 +17,7 @@ import numpy as np import scico.numpy as snp -from scico.array import parse_axes +from scico.numpy.util import parse_axes from scico.typing import Axes, DType, JaxArray, Shape from ._linop import LinearOperator diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 6a7f3c5ac..86d7b818c 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -18,7 +18,8 @@ import scico.numpy as snp from scico import array from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar -from scico.numpy import BlockArray, is_nested +from scico.numpy import BlockArray +from scico.numpy.util import is_nested from scico.random import randn from scico.typing import ArrayIndex, BlockShape, DType, JaxArray, PRNGKey, Shape diff --git a/scico/loss.py b/scico/loss.py index 59faed517..288bede32 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -17,8 +17,8 @@ import scico.numpy as snp from scico import functional, linop, operator -from scico.array import ensure_on_device from scico.numpy import BlockArray, no_nan_divide +from scico.numpy.util import ensure_on_device from scico.scipy.special import gammaln from scico.solver import cg from scico.typing import JaxArray diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index a0cb17330..957c1ecd3 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -23,7 +23,9 @@ from . import _wrappers from ._wrapped_function_lists import * from .blockarray import BlockArray -from .util import * + +# allow snp.blockarray(...) to create BlockArrays +blockarray = BlockArray.blockarray # copy most of jnp without wrapping _wrappers.add_attributes( diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 7ceca2a8a..2daef89ce 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -229,20 +229,26 @@ def __init__(self, inputs): @property def dtype(self): - """Allow snp.zeros(x.shape, x.dtype) to work.""" + """Return the dtype of the blocks, which must currently be homogeneous. + + This allows snp.zeros(x.shape, x.dtype) to work without a mechanism + to handle to lists of dtypes. + """ return self[0].dtype def __getitem__(self, key): - """Make, e.g., x[:2] return a BlockArray, not a list.""" + """Indexing method equivalent to x[key]. + + This is overridden to make, e.g., x[:2] return a BlockArray + rather than a list. + """ result = super().__getitem__(key) if not isinstance(result, jnp.ndarray): - return BlockArray(result) - return result - - """ backwards compatibility methods, could be removed """ + return BlockArray(result) # x[k:k+1] returns a BlockArray + return result # x[k] returns a DeviceArray @staticmethod - def array(iterable): + def blockarray(iterable): """Construct a :class:`.BlockArray` from a list or tuple of existing array-like.""" return BlockArray(iterable) @@ -256,10 +262,8 @@ def array(iterable): ) -""" Wrap unary ops like -x. """ - - -def _unary_op_wrapper(op): +# Wrap unary ops like -x. +def _unary_op_wrapper(op_name): op = getattr(DeviceArray, op_name) @wraps(op) @@ -272,21 +276,24 @@ def op_ba(self): for op_name in unary_ops: setattr(BlockArray, op_name, _unary_op_wrapper(op_name)) -""" Wrap binary ops like x+y. """ - +# Wrap binary ops like x + y. """ def _binary_op_wrapper(op_name): op = getattr(DeviceArray, op_name) @wraps(op) def op_ba(self, other): + # If other is a BA, we can assume the operation is implemented + # (because BAs must contain DeviceArrays) if isinstance(other, BlockArray): return BlockArray(op(x, y) for x, y in zip(self, other)) - result = BlockArray(op(x, other) for x in self) + # If not, need to handle possible NotImplemented + # without this, ba + 'hi' -> [NotImplemented, NotImplemented, ...] + result = list(op(x, other) for x in self) if NotImplemented in result: return NotImplemented - return result + return BlockArray(result) return op_ba @@ -295,9 +302,7 @@ def op_ba(self, other): setattr(BlockArray, op_name, _binary_op_wrapper(op_name)) -""" Wrap DeviceArray properties. """ - - +# Wrap DeviceArray properties. def _da_prop_wrapper(prop_name): prop = getattr(DeviceArray, prop_name) @@ -305,8 +310,12 @@ def _da_prop_wrapper(prop_name): @wraps(prop) def prop_ba(self): result = tuple(getattr(x, prop_name) for x in self) + + # if da.prop is a DA, return BA if isinstance(result[0], jnp.ndarray): return BlockArray(result) + + # otherwise, return tuple return result return prop_ba @@ -315,22 +324,26 @@ def prop_ba(self): skip_props = ("at",) da_props = [ k - for k, v in dict(inspect.getmembers(DeviceArray)).items() + for k, v in dict(inspect.getmembers(DeviceArray)).items() # (name, method) pairs if isinstance(v, property) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_props ] for prop_name in da_props: setattr(BlockArray, prop_name, _da_prop_wrapper(prop_name)) -""" Wrap DeviceArray methods. """ - +# Wrap DeviceArray methods. +def _da_method_wrapper(method_name): + method = getattr(DeviceArray, method_name) -def _da_method_wrapper(method): @wraps(method) def method_ba(self, *args, **kwargs): - result = tuple(getattr(x, method)(*args, **kwargs) for x in self) + result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self) + + # if da.method(...) is a DA, return a BA if isinstance(result[0], jnp.ndarray): return BlockArray(result) + + # otherwise return a tuple return result return method_ba @@ -339,12 +352,12 @@ def method_ba(self, *args, **kwargs): skip_methods = () da_methods = [ k - for k, v in dict(inspect.getmembers(DeviceArray)).items() + for k, v in dict(inspect.getmembers(DeviceArray)).items() # (name, method) pairs if isinstance(v, Callable) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_methods ] -for method in da_methods: - setattr(BlockArray, method, _da_method_wrapper(method)) +for method_name in da_methods: + setattr(BlockArray, method_name, _da_method_wrapper(method_name)) diff --git a/scico/numpy/util.py b/scico/numpy/util.py index cd71408dd..2784856c5 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -1,14 +1,176 @@ """ Utility functions for working with BlockArrays and DeviceArrays. """ +from __future__ import annotations + +import warnings from math import prod -from typing import Any, Union +from typing import Any, List, Optional, Tuple, Union + +import numpy as np + +import jax +from jax.interpreters.pxla import ShardedDeviceArray +from jax.interpreters.xla import DeviceArray import scico.numpy as snp -from scico.typing import Axes, BlockShape, DType, JaxArray, Shape +from scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, JaxArray, Shape from .blockarray import BlockArray +def ensure_on_device( + *arrays: Union[np.ndarray, JaxArray, BlockArray] +) -> Union[JaxArray, BlockArray]: + """Cast ndarrays to DeviceArrays. + + Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays, + and ShardedDeviceArray as is. This is intended to be used when + initializing optimizers and functionals so that all arrays are either + DeviceArrays, BlockArrays, or ShardedDeviceArray. + + Args: + *arrays: One or more input arrays (ndarray, DeviceArray, + BlockArray, or ShardedDeviceArray). + + Returns: + Modified array or arrays. Modified are only those that were + necessary. + + Raises: + TypeError: If the arrays contain something that is neither + ndarray, DeviceArray, BlockArray, nor ShardedDeviceArray. + """ + arrays = list(arrays) + + for i, array in enumerate(arrays): + + if isinstance(array, np.ndarray): + warnings.warn( + f"Argument {i+1} of {len(arrays)} is an np.ndarray. " + f"Will cast it to DeviceArray. " + f"To suppress this warning cast all np.ndarrays to DeviceArray first.", + stacklevel=2, + ) + + elif not isinstance( + array, + (DeviceArray, BlockArray, ShardedDeviceArray), + ): + raise TypeError( + "Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or " + f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}." + ) + + arrays[i] = jax.device_put(arrays[i]) + + if len(arrays) == 1: + return arrays[0] + return arrays + + +def parse_axes( + axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None +) -> List[int]: + """Normalize `axes` to a list and optionally ensure correctness. + + Normalize `axes` to a list and (optionally) ensure that entries refer + to axes that exist in `shape`. + + Args: + axes: User specification of one or more axes: int, list, tuple, + or ``None``. + shape: The shape of the array of which axes are being specified. + If not ``None``, `axes` is checked to make sure its entries + refer to axes that exist in `shape`. + default: Default value to return if `axes` is ``None``. By + default, `list(range(len(shape)))`. + + Returns: + List of axes (never an int, never ``None``). + """ + + if axes is None: + if default is None: + if shape is None: + raise ValueError("`axes` cannot be `None` without a default or shape specified.") + axes = list(range(len(shape))) + else: + axes = default + elif isinstance(axes, (list, tuple)): + axes = axes + elif isinstance(axes, int): + axes = (axes,) + else: + raise ValueError(f"Could not understand axes {axes} as a list of axes") + if shape is not None and max(axes) >= len(shape): + raise ValueError( + f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}." + ) + if len(set(axes)) != len(axes): + raise ValueError(f"Duplicate value in axes {axes}; each axis must be unique.") + return axes + + +def slice_length(length: int, idx: AxisIndex) -> Optional[int]: + """Determine the length of an array axis after indexing. + + Determine the length of an array axis after slicing. An exception is + raised if the indexing expression is an integer that is out of bounds + for the specified axis length. A value of ``None`` is returned for + valid integer indexing expressions as an indication that the + corresponding axis shape is an empty tuple; this value should be + converted to a unit integer if the axis size is required. + + Args: + length: Length of axis being sliced. + idx: Indexing/slice to be applied to axis. + + Returns: + Length of indexed/sliced axis. + + Raises: + ValueError: If `idx` is an integer index that is out bounds for + the axis length. + """ + if idx is Ellipsis: + return length + if isinstance(idx, int): + if idx < -length or idx > length - 1: + raise ValueError(f"Index {idx} out of bounds for axis of length {length}.") + return None + start, stop, stride = idx.indices(length) + if start > stop: + start = stop + return (stop - start + stride - 1) // stride + + +def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: + """Determine the shape of an array after indexing/slicing. + + Args: + shape: Shape of array. + idx: Indexing expression. + + Returns: + Shape of indexed/sliced array. + + Raises: + ValueError: If `idx` is longer than `shape`. + """ + if not isinstance(idx, tuple): + idx = (idx,) + if len(idx) > len(shape): + raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.") + idx_shape = list(shape) + offset = 0 + for axis, ax_idx in enumerate(idx): + if ax_idx is Ellipsis: + offset = len(shape) - len(idx) + continue + idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) + return tuple(filter(lambda x: x is not None, idx_shape)) + + def no_nan_divide( x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] ) -> Union[BlockArray, JaxArray]: diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index bd964a216..601d1bcd0 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -14,7 +14,8 @@ from scico._generic_operators import LinearOperator, Operator from scico.linop import Convolve, ConvolveByX -from scico.numpy import BlockArray, is_nested +from scico.numpy import BlockArray +from scico.numpy.util import is_nested from scico.typing import BlockShape, DType, JaxArray diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index e0fbd7741..2be8fae21 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -14,12 +14,12 @@ from typing import Callable, List, Optional, Union import scico.numpy as snp -from scico.array import ensure_on_device from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import LinearOperator from scico.numpy import BlockArray from scico.numpy.linalg import norm +from scico.numpy.util import ensure_on_device from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 7b006e2c4..eb0e27299 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -14,12 +14,12 @@ from typing import Callable, Optional, Union import scico.numpy as snp -from scico.array import ensure_on_device from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import LinearOperator from scico.numpy import BlockArray from scico.numpy.linalg import norm +from scico.numpy.util import ensure_on_device from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index 06cb77e2c..8d1a0100a 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -18,13 +18,13 @@ from jax.scipy.sparse.linalg import cg as jax_cg import scico.numpy as snp -from scico.array import ensure_on_device from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import CircularConvolve, Identity, LinearOperator from scico.loss import SquaredL2Loss from scico.numpy import BlockArray, is_real_dtype from scico.numpy.linalg import norm +from scico.numpy.util import ensure_on_device from scico.solver import cg as scico_cg from scico.solver import minimize from scico.typing import JaxArray diff --git a/scico/optimize/pgm.py b/scico/optimize/pgm.py index 3656e969b..318dce827 100644 --- a/scico/optimize/pgm.py +++ b/scico/optimize/pgm.py @@ -16,11 +16,11 @@ import jax import scico.numpy as snp -from scico.array import ensure_on_device from scico.diagnostics import IterationStats from scico.functional import Functional from scico.loss import Loss from scico.numpy import BlockArray +from scico.numpy.util import ensure_on_device from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/test/test_array.py b/scico/test/test_array.py index c05de3fd0..6b1ef3e95 100644 --- a/scico/test/test_array.py +++ b/scico/test/test_array.py @@ -7,7 +7,6 @@ import pytest import scico.numpy as snp -from scico.array import ensure_on_device, indexed_shape, parse_axes, slice_length from scico.numpy import ( BlockArray, complex_dtype, @@ -17,6 +16,7 @@ no_nan_divide, real_dtype, ) +from scico.numpy.util import ensure_on_device, indexed_shape, parse_axes, slice_length from scico.random import randn diff --git a/scico/test/test_new_blockarray.py b/scico/test/test_new_blockarray.py index 6f9971643..9e2f87547 100644 --- a/scico/test/test_new_blockarray.py +++ b/scico/test/test_new_blockarray.py @@ -53,9 +53,14 @@ def test_elementwise_binary(op, x, y): assert actual.dtype == expected.dtype +def test_not_implemented_binary(x): + with pytest.raises(TypeError, match=r"unsupported operand type\(s\)"): + y = x + "a string" + + def test_matmul(x): # x is ((2, 3), (1,)) - # y will be ((3, 1), (1, 2)) + # y is ((3, 1), (1, 2)) y = BlockArray([[[1.0], [2.0], [3.0]], [[0.0, 1.0]]]) actual = x @ y expected = BlockArray([[[14.0], [0.0]], [0.0, 42.0]]) From 83266303a02a4967bbd9b5847aff492b784b69e0 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 20 Apr 2022 08:28:45 -0600 Subject: [PATCH 29/37] Combine BlockArray tests --- scico/test/test_blockarray.py | 155 +++++++++++++++--------------- scico/test/test_new_blockarray.py | 83 ---------------- 2 files changed, 77 insertions(+), 161 deletions(-) delete mode 100644 scico/test/test_new_blockarray.py diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index 50e186f91..4b85476a4 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -9,6 +9,7 @@ import scico.numpy as snp from scico.numpy import BlockArray +from scico.numpy.testing import assert_array_equal from scico.random import randn math_ops = [op.add, op.sub, op.mul, op.truediv, op.pow] # op.floordiv doesn't work on complex @@ -160,14 +161,11 @@ def test_split(test_operator_obj): np.testing.assert_allclose(a[1], test_operator_obj.a1) -@pytest.mark.skip() -# currently creation is exactly like a tuple, -# so BlockArray(np.jnp.zeros((3,6))) makes a block array -# with 3 length-6 blocks -# TODO replace with test of new behavior def test_blockarray_from_one_array(): - with pytest.raises(TypeError): - BlockArray(np.random.randn(32, 32)) + # BlockArray(np.jnp.zeros((3,6))) makes a block array + # with 3 length-6 blocks + x = BlockArray(np.random.randn(3, 6)) + assert len(x) == 3 @pytest.mark.parametrize("axis", [None, 1]) @@ -317,75 +315,76 @@ def test_full_nodtype(self): assert snp.all(x == fill_value) -@pytest.mark.skip -# indexing now works just like a list of DeviceArrays: -# x[1] = x[1].at[:].set(0) -# TODO: some of these are new syntax? -class TestBlockArrayIndex: - def setup_method(self): - key = None +# tests added for the BlockArray refactor +@pytest.fixture +def x(): + # any BlockArray, arbitrary shape, content, type + return BlockArray([[[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0]]) + + +@pytest.fixture +def y(): + # another BlockArray, content, type, shape matches x + return BlockArray([[[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]], [-2.0]]) + + +@pytest.mark.parametrize("op", [op.neg, op.pos, op.abs]) +def test_unary(op, x): + actual = op(x) + expected = BlockArray(op(x_i) for x_i in x) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype + + +@pytest.mark.parametrize( + "op", + [ + op.mul, + op.mod, + op.lt, + op.le, + op.gt, + op.ge, + op.floordiv, + op.eq, + op.add, + op.truediv, + op.sub, + op.ne, + ], +) +def test_elementwise_binary(op, x, y): + actual = op(x, y) + expected = BlockArray(op(x_i, y_i) for x_i, y_i in zip(x, y)) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype + + +def test_not_implemented_binary(x): + with pytest.raises(TypeError, match=r"unsupported operand type\(s\)"): + y = x + "a string" + + +def test_matmul(x): + # x is ((2, 3), (1,)) + # y is ((3, 1), (1, 2)) + y = BlockArray([[[1.0], [2.0], [3.0]], [[0.0, 1.0]]]) + actual = x @ y + expected = BlockArray([[[14.0], [0.0]], [0.0, 42.0]]) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype + + +def test_property(): + x = BlockArray(([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [0.0])) + actual = x.shape + expected = ((2, 3), (1,)) + assert actual == expected + - self.A, key = randn(shape=((4, 4), (3,)), key=key) - self.B, key = randn(shape=((3, 3), (4, 2, 3)), key=key) - self.C, key = randn(shape=((3, 3), (4, 2), (4, 4)), key=key) - - def test_set_block(self): - # Test assignment of an entire block - A2 = self.A[0].at[:].set(1) - np.testing.assert_allclose(A2[0], snp.ones_like(A2[0]), rtol=5e-5) - np.testing.assert_allclose(A2[1], A2[1], rtol=5e-5) - - def test_set(self): - # Test assignment using (bkidx, idx) format - A2 = self.A[0].at[2:, :-2].set(1.45) - tmp = A2[2:, :-2] - np.testing.assert_allclose(A2[0][2:, :-2], 1.45 * snp.ones_like(tmp), rtol=5e-5) - np.testing.assert_allclose(A2[1].full_ravel(), A2[1], rtol=5e-5) - - def test_add(self): - A2 = self.A.at[0, 2:, :-2].add(1.45) - tmp = np.array(self.A[0]) - tmp[2:, :-2] += 1.45 - y = BlockArray([tmp, self.A[1]]) - np.testing.assert_allclose(A2.full_ravel(), y.full_ravel(), rtol=5e-5) - - D2 = self.D.at[1].add(1.45) - y = BlockArray([self.D[0], self.D[1] + 1.45]) - np.testing.assert_allclose(D2.full_ravel(), y.full_ravel(), rtol=5e-5) - - def test_multiply(self): - A2 = self.A.at[0, 2:, :-2].multiply(1.45) - tmp = np.array(self.A[0]) - tmp[2:, :-2] *= 1.45 - y = BlockArray([tmp, self.A[1]]) - np.testing.assert_allclose(A2.full_ravel(), y.full_ravel(), rtol=5e-5) - - D2 = self.D.at[1].multiply(1.45) - y = BlockArray([self.D[0], self.D[1] * 1.45]) - np.testing.assert_allclose(D2.full_ravel(), y.full_ravel(), rtol=5e-5) - - def test_divide(self): - A2 = self.A.at[0, 2:, :-2].divide(1.45) - tmp = np.array(self.A[0]) - tmp[2:, :-2] /= 1.45 - y = BlockArray([tmp, self.A[1]]) - np.testing.assert_allclose(A2.full_ravel(), y.full_ravel(), rtol=5e-5) - - D2 = self.D.at[1].divide(1.45) - y = BlockArray([self.D[0], self.D[1] / 1.45]) - np.testing.assert_allclose(D2.full_ravel(), y.full_ravel(), rtol=5e-5) - - def test_power(self): - A2 = self.A.at[0, 2:, :-2].power(2) - tmp = np.array(self.A[0]) - tmp[2:, :-2] **= 2 - y = BlockArray([tmp, self.A[1]]) - np.testing.assert_allclose(A2.full_ravel(), y.full_ravel(), rtol=5e-5) - - D2 = self.D.at[1].power(1.45) - y = BlockArray([self.D[0], self.D[1] ** 1.45]) - np.testing.assert_allclose(D2.full_ravel(), y.full_ravel(), rtol=5e-5) - - def test_set_slice(self): - with pytest.raises(TypeError): - C2 = self.C.at[::2, 0].set(0) +def test_method(): + x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0])) + actual = x.max() + expected = BlockArray([[3.0], [42.0]]) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype diff --git a/scico/test/test_new_blockarray.py b/scico/test/test_new_blockarray.py deleted file mode 100644 index 9e2f87547..000000000 --- a/scico/test/test_new_blockarray.py +++ /dev/null @@ -1,83 +0,0 @@ -import operator as op - -import pytest - -from scico.numpy import BlockArray -from scico.numpy.testing import assert_array_equal - -for a in dir(op): - help(getattr(op, a)) - - -@pytest.fixture -def x(): - # any BlockArray, arbitrary shape, content, type - return BlockArray([[[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0]]) - - -@pytest.fixture -def y(): - # another BlockArray, content, type, matching shape - return BlockArray([[[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]], [-2.0]]) - - -@pytest.mark.parametrize("op", [op.neg, op.pos, op.abs]) -def test_unary(op, x): - actual = op(x) - expected = BlockArray(op(x_i) for x_i in x) - assert_array_equal(actual, expected) - assert actual.dtype == expected.dtype - - -@pytest.mark.parametrize( - "op", - [ - op.mul, - op.mod, - op.lt, - op.le, - op.gt, - op.ge, - op.floordiv, - op.eq, - op.add, - op.truediv, - op.sub, - op.ne, - ], -) -def test_elementwise_binary(op, x, y): - actual = op(x, y) - expected = BlockArray(op(x_i, y_i) for x_i, y_i in zip(x, y)) - assert_array_equal(actual, expected) - assert actual.dtype == expected.dtype - - -def test_not_implemented_binary(x): - with pytest.raises(TypeError, match=r"unsupported operand type\(s\)"): - y = x + "a string" - - -def test_matmul(x): - # x is ((2, 3), (1,)) - # y is ((3, 1), (1, 2)) - y = BlockArray([[[1.0], [2.0], [3.0]], [[0.0, 1.0]]]) - actual = x @ y - expected = BlockArray([[[14.0], [0.0]], [0.0, 42.0]]) - assert_array_equal(actual, expected) - assert actual.dtype == expected.dtype - - -def test_property(): - x = BlockArray(([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [0.0])) - actual = x.shape - expected = ((2, 3), (1,)) - assert actual == expected - - -def test_method(): - x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0])) - actual = x.max() - expected = BlockArray([[3.0], [42.0]]) - assert_array_equal(actual, expected) - assert actual.dtype == expected.dtype From c5e595c33412b5a0760c545ce48a665b65b5d038 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 20 Apr 2022 12:35:55 -0600 Subject: [PATCH 30/37] Continue removing scico.array, allow snp.blockarray syntax --- examples/scripts/denoise_tv_iso_pgm.py | 2 +- scico/_generic_operators.py | 6 +- scico/functional/_functional.py | 2 +- scico/functional/_norm.py | 3 +- scico/linop/_convolve.py | 4 +- scico/linop/_linop.py | 7 +- scico/linop/optics.py | 2 +- scico/loss.py | 4 +- scico/numpy/blockarray.py | 8 +- scico/operator/biconvolve.py | 2 +- scico/optimize/admm.py | 4 +- scico/solver.py | 4 +- scico/test/functional/test_core.py | 7 +- scico/test/functional/test_loss.py | 2 +- scico/test/functional/test_separable.py | 8 +- scico/test/optimize/test_ladmm.py | 2 +- scico/test/optimize/test_pdhg.py | 2 +- scico/test/test_array.py | 11 ++- scico/test/test_biconvolve.py | 4 +- scico/test/test_numpy.py | 103 ++++++++++++------------ scico/test/test_operator.py | 11 ++- scico/test/test_solver.py | 5 +- 22 files changed, 101 insertions(+), 102 deletions(-) diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index b2fe6e483..cb8ab58ec 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/Usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index a1b36f417..0657ea108 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -320,8 +320,8 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator: def concat_args(args): # Creates a blockarray with args and the frozen value in the correct place # Eg if this operator takes a blockarray with two blocks, then - # concat_args(args) = BlockArray.array([val, args]) if argnum = 0 - # concat_args(args) = BlockArray.array([args, val]) if argnum = 1 + # concat_args(args) = snp.blockarray([val, args]) if argnum = 0 + # concat_args(args) = snp.blockarray([args, val]) if argnum = 1 if isinstance(args, (DeviceArray, np.ndarray)): # In the case that the original operator takes a blcokarray with two @@ -336,7 +336,7 @@ def concat_args(args): arg_list.append(args[i - 1]) else: arg_list.append(val) - return BlockArray.array(arg_list) + return snp.blockarray(arg_list) return Operator( input_shape=input_shape, diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index ab94244a5..df8c10046 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -252,7 +252,7 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray: """ if len(v.shape) == len(self.functional_list): - return BlockArray.array([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)]) + return snp.blockarray([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)]) raise ValueError( f"Number of blocks in v, {len(v.shape)}, and length of functional_list, " f"{len(self.functional_list)}, do not match" diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index ba77a4d3e..70d094127 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -12,8 +12,9 @@ from jax import jit, lax from scico import numpy as snp -from scico.numpy import BlockArray, count_nonzero, no_nan_divide +from scico.numpy import BlockArray, count_nonzero from scico.numpy.linalg import norm +from scico.numpy.util import no_nan_divide from scico.typing import JaxArray from ._functional import Functional diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index 1e675c094..2164b892a 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -23,8 +23,8 @@ from jax.scipy.signal import convolve import scico.numpy as snp -from scico import array from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar +from scico.numpy.util import ensure_on_device from scico.typing import DType, JaxArray, Shape @@ -66,7 +66,7 @@ def __init__( if h.ndim != len(input_shape): raise ValueError(f"h.ndim = {h.ndim} must equal len(input_shape) = {len(input_shape)}") - self.h = array.ensure_on_device(h) + self.h = ensure_on_device(h) if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'") diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 86d7b818c..3866b55bf 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -16,10 +16,9 @@ from typing import Any, Callable, Optional, Union import scico.numpy as snp -from scico import array from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar from scico.numpy import BlockArray -from scico.numpy.util import is_nested +from scico.numpy.util import ensure_on_device, indexed_shape, is_nested from scico.random import randn from scico.typing import ArrayIndex, BlockShape, DType, JaxArray, PRNGKey, Shape @@ -182,7 +181,7 @@ def __init__( """ - self.diagonal = array.ensure_on_device(diagonal) + self.diagonal = ensure_on_device(diagonal) if input_shape is None: input_shape = self.diagonal.shape @@ -286,7 +285,7 @@ def __init__( if is_nested(input_shape): output_shape = input_shape[idx] else: - output_shape = array.indexed_shape(input_shape, idx) + output_shape = indexed_shape(input_shape, idx) self.idx: ArrayIndex = idx super().__init__( diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 67cf100fd..4a785d4f5 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -61,7 +61,7 @@ import scico.numpy as snp from scico.linop import Diagonal, Identity, LinearOperator -from scico.numpy import no_nan_divide +from scico.numpy.util import no_nan_divide from scico.typing import Shape from ._dft import DFT diff --git a/scico/loss.py b/scico/loss.py index 288bede32..2e76d9566 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -17,8 +17,8 @@ import scico.numpy as snp from scico import functional, linop, operator -from scico.numpy import BlockArray, no_nan_divide -from scico.numpy.util import ensure_on_device +from scico.numpy import BlockArray +from scico.numpy.util import ensure_on_device, no_nan_divide from scico.scipy.special import gammaln from scico.solver import cg from scico.typing import JaxArray diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 2daef89ce..9b4865076 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -124,7 +124,7 @@ >>> x_v, _ = scico.random.randn((n-1, m), key=key) # Form the blockarray - >>> x_B = BlockArray.array([x_h, x_v]) + >>> x_B = snp.blockarray([x_h, x_v]) # The blockarray shape is a tuple of tuples >>> x_B.shape @@ -146,7 +146,7 @@ >>> import numpy as np >>> x0, key = scico.random.randn((32, 32)) >>> x1, _ = scico.random.randn((16,), key=key) - >>> X = BlockArray.array((x0, x1)) + >>> X = snp.blockarray((x0, x1)) >>> X.shape ((32, 32), (16,)) >>> X.size @@ -154,7 +154,7 @@ >>> len(X) 2 -While :func:`.BlockArray.array` will accept either `ndarray` or +While :func:`.snp.blockarray` will accept either `ndarray` or `DeviceArray` as input, the resulting :class:`.BlockArray` will be backed by a `DeviceArray` memory buffer. @@ -190,7 +190,7 @@ >>> A_2.shape # array -> BlockArray (((2, 4), (3, 3)), (3, 4)) - >>> diag = BlockArray.array([np.array(1.0), np.array(2.0)]) + >>> diag = snp.blockarray([np.array(1.0), np.array(2.0)]) >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape)) >>> A_3.shape # BlockArray -> BlockArray (((2, 4), (3, 3)), ((2, 4), (3, 3))) diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index 601d1bcd0..77a0be36c 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -26,7 +26,7 @@ class BiConvolve(Operator): blocks of equal ndims, and convolves the first block with the second. If `A` is a BiConvolve operator, then - `A(BlockArray.array([x, h]))` equals `jax.scipy.signal.convolve(x, h)`. + `A(snp.blockarray([x, h]))` equals `jax.scipy.signal.convolve(x, h)`. """ diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index 8d1a0100a..369450d49 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -22,9 +22,9 @@ from scico.functional import Functional from scico.linop import CircularConvolve, Identity, LinearOperator from scico.loss import SquaredL2Loss -from scico.numpy import BlockArray, is_real_dtype +from scico.numpy import BlockArray from scico.numpy.linalg import norm -from scico.numpy.util import ensure_on_device +from scico.numpy.util import ensure_on_device, is_real_dtype from scico.solver import cg as scico_cg from scico.solver import minimize from scico.typing import JaxArray diff --git a/scico/solver.py b/scico/solver.py index b73971232..177d1434b 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -145,7 +145,7 @@ def _split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArr BlockArray. """ if isinstance(x, BlockArray): - return BlockArray.array([_split_real_imag(_) for _ in x]) + return snp.blockarray([_split_real_imag(_) for _ in x]) return snp.stack((snp.real(x), snp.imag(x))) @@ -163,7 +163,7 @@ def _join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArra and `x[1]` respectively. """ if isinstance(x, BlockArray): - return BlockArray.array([_join_real_imag(_) for _ in x]) + return snp.blockarray([_join_real_imag(_) for _ in x]) return x[0] + 1j * x[1] diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 68c06209f..0c8835216 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -13,7 +13,6 @@ import scico.numpy as snp from scico import functional -from scico.numpy import BlockArray from scico.random import randn NO_BLOCK_ARRAY = [functional.L21Norm, functional.NuclearNorm] @@ -48,7 +47,7 @@ def __init__(self, dtype): self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval - self.vb = BlockArray.array([self.v1, self.v2]) + self.vb = snp.blockarray([self.v1, self.v2]) @pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128]) @@ -68,7 +67,7 @@ def test_separable_prox(test_separable_obj): fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha) gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) - out = BlockArray.array((fv1, gv2)) + out = snp.blockarray((fv1, gv2)) snp.testing.assert_allclose(out, fgv, rtol=5e-2) @@ -86,7 +85,7 @@ def test_separable_grad(test_separable_obj): fv1 = test_separable_obj.f.grad(test_separable_obj.v1) gv2 = test_separable_obj.g.grad(test_separable_obj.v2) fgv = test_separable_obj.fg.grad(test_separable_obj.vb) - out = BlockArray.array((fv1, gv2)) + out = snp.blockarray((fv1, gv2)) snp.testing.assert_allclose(out, fgv, rtol=5e-2) diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index 8a31406cf..70d78404d 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -10,7 +10,7 @@ from prox import prox_test import scico.numpy as snp from scico import functional, linop, loss -from scico.numpy import complex_dtype +from scico.numpy.util import complex_dtype from scico.random import randn, uniform diff --git a/scico/test/functional/test_separable.py b/scico/test/functional/test_separable.py index a59e8d487..0d4473ce0 100644 --- a/scico/test/functional/test_separable.py +++ b/scico/test/functional/test_separable.py @@ -10,7 +10,7 @@ import pytest from scico import functional -from scico.numpy import BlockArray +from scico.numpy import blockarray from scico.numpy.testing import assert_allclose from scico.random import randn @@ -27,7 +27,7 @@ def __init__(self, dtype): self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval - self.vb = BlockArray.array([self.v1, self.v2]) + self.vb = blockarray([self.v1, self.v2]) @pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128]) @@ -47,7 +47,7 @@ def test_separable_prox(test_separable_obj): fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha) gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) - out = BlockArray.array((fv1, gv2)).ravel() + out = blockarray((fv1, gv2)).ravel() assert_allclose(out, fgv.ravel(), rtol=5e-2) @@ -65,5 +65,5 @@ def test_separable_grad(test_separable_obj): fv1 = test_separable_obj.f.grad(test_separable_obj.v1) gv2 = test_separable_obj.g.grad(test_separable_obj.v2) fgv = test_separable_obj.fg.grad(test_separable_obj.vb) - out = BlockArray.array((fv1, gv2)).ravel() + out = blockarray((fv1, gv2)).ravel() assert_allclose(out, fgv.ravel(), rtol=5e-2) diff --git a/scico/test/optimize/test_ladmm.py b/scico/test/optimize/test_ladmm.py index d5cf0370f..5b7071b6c 100644 --- a/scico/test/optimize/test_ladmm.py +++ b/scico/test/optimize/test_ladmm.py @@ -69,7 +69,7 @@ def callback(obj): class TestBlockArray: def setup_method(self, method): np.random.seed(12345) - self.y = BlockArray.array( + self.y = snp.blockarray( ( np.random.randn(32, 33).astype(np.float32), np.random.randn( diff --git a/scico/test/optimize/test_pdhg.py b/scico/test/optimize/test_pdhg.py index 5989a080c..45746e3de 100644 --- a/scico/test/optimize/test_pdhg.py +++ b/scico/test/optimize/test_pdhg.py @@ -69,7 +69,7 @@ def callback(obj): class TestBlockArray: def setup_method(self, method): np.random.seed(12345) - self.y = BlockArray.array( + self.y = snp.blockarray( ( np.random.randn(32, 33).astype(np.float32), np.random.randn( diff --git a/scico/test/test_array.py b/scico/test/test_array.py index 6b1ef3e95..e37497f8c 100644 --- a/scico/test/test_array.py +++ b/scico/test/test_array.py @@ -7,16 +7,19 @@ import pytest import scico.numpy as snp -from scico.numpy import ( - BlockArray, +from scico.numpy import BlockArray +from scico.numpy.util import ( complex_dtype, + ensure_on_device, + indexed_shape, is_complex_dtype, is_nested, is_real_dtype, no_nan_divide, + parse_axes, real_dtype, + slice_length, ) -from scico.numpy.util import ensure_on_device, indexed_shape, parse_axes, slice_length from scico.random import randn @@ -28,7 +31,7 @@ def test_ensure_on_device(): NP = np.ones(2) SNP = snp.ones(2) - BA = BlockArray.array([NP, SNP]) + BA = snp.blockarray([NP, SNP]) NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA) assert isinstance(NP_, DeviceArray) diff --git a/scico/test/test_biconvolve.py b/scico/test/test_biconvolve.py index 192dcd923..d35761293 100644 --- a/scico/test/test_biconvolve.py +++ b/scico/test/test_biconvolve.py @@ -6,7 +6,7 @@ import pytest from scico.linop import Convolve, ConvolveByX -from scico.numpy import BlockArray +from scico.numpy import blockarray from scico.operator.biconvolve import BiConvolve from scico.random import randn @@ -22,7 +22,7 @@ def test_eval(self, input_dtype, mode, jit): x, key = randn((32, 32), dtype=input_dtype, key=self.key) h, key = randn((4, 4), dtype=input_dtype, key=self.key) - x_h = BlockArray.array([x, h]) + x_h = blockarray([x, h]) A = BiConvolve(input_shape=x_h.shape, mode=mode, jit=jit) signal_out = signal.convolve(x, h, mode=mode) diff --git a/scico/test/test_numpy.py b/scico/test/test_numpy.py index fc09827f2..2a9ae390a 100644 --- a/scico/test/test_numpy.py +++ b/scico/test/test_numpy.py @@ -4,7 +4,6 @@ from jax.interpreters.xla import DeviceArray import scico.numpy as snp -from scico.numpy import BlockArray def on_cpu(): @@ -41,16 +40,16 @@ def test_ufunc_abs(): res = snp.array([1, 1, 1]) np.testing.assert_allclose(snp.abs(A), res) - Ba = BlockArray.array((snp.array([-1, 2, 5]),)) - res = BlockArray.array((snp.array([1, 2, 5]),)) + Ba = snp.blockarray((snp.array([-1, 2, 5]),)) + res = snp.blockarray((snp.array([1, 2, 5]),)) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) - Ba = BlockArray.array((snp.array([-1, -1, -1]),)) - res = BlockArray.array((snp.array([1, 1, 1]),)) + Ba = snp.blockarray((snp.array([-1, -1, -1]),)) + res = snp.blockarray((snp.array([1, 1, 1]),)) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) - Ba = BlockArray.array((snp.array([-1, 2, -3]), snp.array([1, -2, 3]))) - res = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2, 3]))) + Ba = snp.blockarray((snp.array([-1, 2, -3]), snp.array([1, -2, 3]))) + res = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2, 3]))) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) @@ -83,17 +82,17 @@ def test_ufunc_maximum(): B = snp.array([2, 3, 4]) C = snp.array([5, 6]) D = snp.array([2, 7]) - Ba = BlockArray.array((A, C)) - Bb = BlockArray.array((B, D)) - res = BlockArray.array((snp.array([2, 3, 4]), snp.array([5, 7]))) + Ba = snp.blockarray((A, C)) + Bb = snp.blockarray((B, D)) + res = snp.blockarray((snp.array([2, 3, 4]), snp.array([5, 7]))) Bmax = snp.maximum(Ba, Bb) snp.testing.assert_allclose(Bmax, res) A = snp.array([1, 6, 3]) B = snp.array([6, 3, 8]) C = 5 - Ba = BlockArray.array((A, B)) - res = BlockArray.array((snp.array([5, 6, 5]), snp.array([6, 5, 8]))) + Ba = snp.blockarray((A, B)) + res = snp.blockarray((snp.array([5, 6, 5]), snp.array([6, 5, 8]))) Bmax = snp.maximum(Ba, C) snp.testing.assert_allclose(Bmax, res) @@ -103,12 +102,12 @@ def test_ufunc_sign(): res = snp.array([1, -1, 0]) np.testing.assert_allclose(snp.sign(A), res) - Ba = BlockArray.array((snp.array([10, -5, 0]),)) - res = BlockArray.array((snp.array([1, -1, 0]),)) + Ba = snp.blockarray((snp.array([10, -5, 0]),)) + res = snp.blockarray((snp.array([1, -1, 0]),)) snp.testing.assert_allclose(snp.sign(Ba), res) - Ba = BlockArray.array((snp.array([10, -5, 0]), snp.array([0, 5, -6]))) - res = BlockArray.array((snp.array([1, -1, 0]), snp.array([0, 1, -1]))) + Ba = snp.blockarray((snp.array([10, -5, 0]), snp.array([0, 5, -6]))) + res = snp.blockarray((snp.array([1, -1, 0]), snp.array([0, 1, -1]))) snp.testing.assert_allclose(snp.sign(Ba), res) @@ -119,19 +118,19 @@ def test_ufunc_where(): res = snp.array([-1, -1, 4, 5]) np.testing.assert_allclose(snp.where(cond, A, B), res) - Ba = BlockArray.array((snp.array([1, 2, 4, 5]),)) - Bb = BlockArray.array((snp.array([-1, -1, -1, -1]),)) - Bcond = BlockArray.array((snp.array([False, False, True, True]),)) - Bres = BlockArray.array((snp.array([-1, -1, 4, 5]),)) + Ba = snp.blockarray((snp.array([1, 2, 4, 5]),)) + Bb = snp.blockarray((snp.array([-1, -1, -1, -1]),)) + Bcond = snp.blockarray((snp.array([False, False, True, True]),)) + Bres = snp.blockarray((snp.array([-1, -1, 4, 5]),)) assert snp.where(Bcond, Ba, Bb).shape == Bres.shape np.testing.assert_allclose(snp.where(Bcond, Ba, Bb).ravel(), Bres.ravel()) - Ba = BlockArray.array((snp.array([1, 2, 4, 5]), snp.array([1, 2, 4, 5]))) - Bb = BlockArray.array((snp.array([-1, -1, -1, -1]), snp.array([-1, -1, -1, -1]))) - Bcond = BlockArray.array( + Ba = snp.blockarray((snp.array([1, 2, 4, 5]), snp.array([1, 2, 4, 5]))) + Bb = snp.blockarray((snp.array([-1, -1, -1, -1]), snp.array([-1, -1, -1, -1]))) + Bcond = snp.blockarray( (snp.array([False, False, True, True]), snp.array([True, True, False, False])) ) - Bres = BlockArray.array((snp.array([-1, -1, 4, 5]), snp.array([1, 2, -1, -1]))) + Bres = snp.blockarray((snp.array([-1, -1, 4, 5]), snp.array([1, 2, -1, -1]))) assert snp.where(Bcond, Ba, Bb).shape == Bres.shape np.testing.assert_allclose(snp.where(Bcond, Ba, Bb).ravel(), Bres.ravel()) @@ -147,19 +146,19 @@ def test_ufunc_true_divide(): res = snp.array([0.33333333, 0.66666667, 1.0]) np.testing.assert_allclose(snp.true_divide(A, B), res) - Ba = BlockArray.array((snp.array([1, 2, 3]),)) - Bb = BlockArray.array((snp.array([3, 3, 3]),)) - res = BlockArray.array((snp.array([0.33333333, 0.66666667, 1.0]),)) + Ba = snp.blockarray((snp.array([1, 2, 3]),)) + Bb = snp.blockarray((snp.array([3, 3, 3]),)) + res = snp.blockarray((snp.array([0.33333333, 0.66666667, 1.0]),)) snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res) - Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) - Bb = BlockArray.array((snp.array([3, 3, 3]), snp.array([2, 2]))) - res = BlockArray.array((snp.array([0.33333333, 0.66666667, 1.0]), snp.array([0.5, 1.0]))) + Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) + Bb = snp.blockarray((snp.array([3, 3, 3]), snp.array([2, 2]))) + res = snp.blockarray((snp.array([0.33333333, 0.66666667, 1.0]), snp.array([0.5, 1.0]))) snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res) - Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) + Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) A = 2 - res = BlockArray.array((snp.array([0.5, 1.0, 1.5]), snp.array([0.5, 1.0]))) + res = snp.blockarray((snp.array([0.5, 1.0, 1.5]), snp.array([0.5, 1.0]))) snp.testing.assert_allclose(snp.true_divide(Ba, A), res) @@ -174,19 +173,19 @@ def test_ufunc_floor_divide(): res = snp.array([1.0, 0, 1.0]) np.testing.assert_allclose(snp.floor_divide(A, B), res) - Ba = BlockArray.array((snp.array([1, 2, 3]),)) - Bb = BlockArray.array((snp.array([3, 3, 3]),)) - res = BlockArray.array((snp.array([0, 0, 1.0]),)) + Ba = snp.blockarray((snp.array([1, 2, 3]),)) + Bb = snp.blockarray((snp.array([3, 3, 3]),)) + res = snp.blockarray((snp.array([0, 0, 1.0]),)) snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res) - Ba = BlockArray.array((snp.array([1, 7, 3]), snp.array([1, 2]))) - Bb = BlockArray.array((snp.array([3, 3, 3]), snp.array([2, 2]))) - res = BlockArray.array((snp.array([0, 2, 1.0]), snp.array([0, 1.0]))) + Ba = snp.blockarray((snp.array([1, 7, 3]), snp.array([1, 2]))) + Bb = snp.blockarray((snp.array([3, 3, 3]), snp.array([2, 2]))) + res = snp.blockarray((snp.array([0, 2, 1.0]), snp.array([0, 1.0]))) snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res) - Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) + Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) A = 2 - res = BlockArray.array((snp.array([0, 1.0, 1.0]), snp.array([0, 1.0]))) + res = snp.blockarray((snp.array([0, 1.0, 1.0]), snp.array([0, 1.0]))) snp.testing.assert_allclose(snp.floor_divide(Ba, A), res) @@ -199,12 +198,12 @@ def test_ufunc_real(): res = snp.array([1, 4.0]) np.testing.assert_allclose(snp.real(A), res) - Ba = BlockArray.array((snp.array([1 + 3j]),)) - res = BlockArray.array((snp.array([1]),)) + Ba = snp.blockarray((snp.array([1 + 3j]),)) + res = snp.blockarray((snp.array([1]),)) snp.testing.assert_allclose(snp.real(Ba), res) - Ba = BlockArray.array((snp.array([1.0 + 3j]), snp.array([1 + 3j, 4.0]))) - res = BlockArray.array((snp.array([1.0]), snp.array([1, 4.0]))) + Ba = snp.blockarray((snp.array([1.0 + 3j]), snp.array([1 + 3j, 4.0]))) + res = snp.blockarray((snp.array([1.0]), snp.array([1, 4.0]))) snp.testing.assert_allclose(snp.real(Ba), res) @@ -217,12 +216,12 @@ def test_ufunc_imag(): res = snp.array([3, 2]) np.testing.assert_allclose(snp.imag(A), res) - Ba = BlockArray.array((snp.array([1 + 3j]),)) - res = BlockArray.array((snp.array([3]),)) + Ba = snp.blockarray((snp.array([1 + 3j]),)) + res = snp.blockarray((snp.array([3]),)) snp.testing.assert_allclose(snp.imag(Ba), res) - Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) - res = BlockArray.array((snp.array([3]), snp.array([3, 0]))) + Ba = snp.blockarray((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) + res = snp.blockarray((snp.array([3]), snp.array([3, 0]))) snp.testing.assert_allclose(snp.imag(Ba), res) @@ -235,12 +234,12 @@ def test_ufunc_conj(): res = snp.array([1 - 3j, 4.0 - 2j]) np.testing.assert_allclose(snp.conj(A), res) - Ba = BlockArray.array((snp.array([1 + 3j]),)) - res = BlockArray.array((snp.array([1 - 3j]),)) + Ba = snp.blockarray((snp.array([1 + 3j]),)) + res = snp.blockarray((snp.array([1 - 3j]),)) snp.testing.assert_allclose(snp.conj(Ba), res) - Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) - res = BlockArray.array((snp.array([1 - 3j]), snp.array([1 - 3j, 4.0 - 0j]))) + Ba = snp.blockarray((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) + res = snp.blockarray((snp.array([1 - 3j]), snp.array([1 - 3j, 4.0 - 0j]))) snp.testing.assert_allclose(snp.conj(Ba), res) diff --git a/scico/test/test_operator.py b/scico/test/test_operator.py index 5b768525f..9fd23860d 100644 --- a/scico/test/test_operator.py +++ b/scico/test/test_operator.py @@ -12,7 +12,6 @@ import jax import scico.numpy as snp -from scico.numpy import BlockArray from scico.operator import Operator from scico.random import randn @@ -180,7 +179,7 @@ def test_freeze_3arg(): b, _ = randn((2, 1, 4)) c, _ = randn((2, 3, 1)) - x = BlockArray.array([a, b, c]) + x = snp.blockarray([a, b, c]) Abc = A.freeze(0, a) # A as a function of b, c Aac = A.freeze(1, b) # A as a function of a, c Aab = A.freeze(2, c) # A as a function of a, b @@ -189,9 +188,9 @@ def test_freeze_3arg(): assert Aac.input_shape == ((1, 3, 4), (2, 3, 1)) assert Aab.input_shape == ((1, 3, 4), (2, 1, 4)) - bc = BlockArray.array([b, c]) - ac = BlockArray.array([a, c]) - ab = BlockArray.array([a, b]) + bc = snp.blockarray([b, c]) + ac = snp.blockarray([a, c]) + ab = snp.blockarray([a, b]) np.testing.assert_allclose(A(x), Abc(bc), rtol=5e-4) np.testing.assert_allclose(A(x), Aac(ac), rtol=5e-4) np.testing.assert_allclose(A(x), Aab(ab), rtol=5e-4) @@ -204,7 +203,7 @@ def test_freeze_2arg(): a, _ = randn((1, 3, 4)) b, _ = randn((2, 1, 4)) - x = BlockArray.array([a, b]) + x = snp.blockarray([a, b]) Ab = A.freeze(0, a) # A as a function of 'b' only Aa = A.freeze(1, b) # A as a function of 'a' only diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index 04c80c16c..d694d796d 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -6,7 +6,6 @@ import scico.numpy as snp from scico import random, solver -from scico.numpy import BlockArray class TestSet: @@ -196,8 +195,8 @@ def test_split_join_blockarray(): x_s = solver._split_real_imag(x) assert x_s.shape == ((2, 4, 4), (2, 3)) - real_block = BlockArray.array((x_s[0][0], x_s[1][0])) - imag_block = BlockArray.array((x_s[0][1], x_s[1][1])) + real_block = snp.blockarray((x_s[0][0], x_s[1][0])) + imag_block = snp.blockarray((x_s[0][1], x_s[1][1])) snp.testing.assert_allclose(real_block, snp.real(x), rtol=1e-4) snp.testing.assert_allclose(imag_block, snp.imag(x), rtol=1e-4) From 8cb7b318c88e28cf57ca519b7508ec8d9a0d2212 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 20 Apr 2022 12:39:44 -0600 Subject: [PATCH 31/37] Remove imports --- scico/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scico/__init__.py b/scico/__init__.py index b6dc10572..96ae86b69 100644 --- a/scico/__init__.py +++ b/scico/__init__.py @@ -48,8 +48,6 @@ "custom_vjp", ] -from . import random, linop - # Imported items in __all__ appear to originate in top-level functional module for name in __all__: getattr(sys.modules[__name__], name).__module__ = __name__ From f9bac794244680f1cbebf49529c6e13c7eda5c73 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 20 Apr 2022 12:52:49 -0600 Subject: [PATCH 32/37] Trigger lint --- scico/test/functional/test_loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index 70d78404d..420872aa3 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -8,6 +8,7 @@ config.update("jax_enable_x64", True) from prox import prox_test + import scico.numpy as snp from scico import functional, linop, loss from scico.numpy.util import complex_dtype From 87baffa4319be0cc9c0e69241afca6a83af1c458 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 20 Apr 2022 13:14:31 -0600 Subject: [PATCH 33/37] Clean up docs --- scico/numpy/__init__.py | 18 +++++++++--------- scico/numpy/blockarray.py | 8 +++++--- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 957c1ecd3..3cd1d1807 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -5,15 +5,15 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -""":class:`.BlockArray` and functions for working with them alongside -:class:`DeviceArray` . - -This module consists :class:`.BlockArray` and functions for working with -it alongside class:`DeviceArray`. The latter include all the functions -from :mod:`jax.numpy` and :mod:`numpy.testing`, where many have been -extended to automatically map over block array blocks as described in -:mod:`scico.numpy.blockarray`. Also included are additional functions -unique to SCICO in :mod:`.util`. +r""":class:`.BlockArray` and functions for working with them alongside +:class:`DeviceArray`\ s. + +This module consists of :class:`.BlockArray` and functions for working +:class`.BlockArray`\ s alongside class:`DeviceArray`\ s. This includes +all the functions from :mod:`jax.numpy` and :mod:`numpy.testing`, where +many have been extended to automatically map over block array blocks as +described in :mod:`scico.numpy.blockarray`. Also included are additional +functions unique to SCICO in :mod:`.util`. """ import numpy as np diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 9b4865076..69fec5f88 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -140,10 +140,12 @@ Constructing a BlockArray ========================= +The recommended way to construct a :class:`.BlockArray` is by using the +`snp.blockarray` function. + :: - >>> from scico.numpy import BlockArray - >>> import numpy as np + >>> import scico.numpy as snp >>> x0, key = scico.random.randn((32, 32)) >>> x1, _ = scico.random.randn((16,), key=key) >>> X = snp.blockarray((x0, x1)) @@ -210,7 +212,7 @@ class BlockArray(list): - """BlockArray""" + """BlockArray class""" # Ensure we use BlockArray.__radd__, __rmul__, etc for binary # operations of the form op(np.ndarray, BlockArray) See From 81c5d326bce8c18f30d4f4a2689a3aaab84c4ad9 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Thu, 21 Apr 2022 09:31:16 -0600 Subject: [PATCH 34/37] Add jits --- examples/scripts/denoise_tv_iso_pgm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index cb8ab58ec..82546f30b 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -96,6 +96,7 @@ def __init__( super().__init__(y=y, A=A, scale=1.0) self.lmbda = lmbda + @jax.jit def __call__(self, x: Union[JaxArray, BlockArray]) -> float: xint = self.y - self.lmbda * self.A(x) @@ -117,6 +118,7 @@ class IsoProjector(functional.Functional): def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return 0.0 + @jax.jit def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: norm_v_ptp = jnp.sqrt(jnp.sum(jnp.abs(v) ** 2, axis=0)) From 5a11b2135f4014c0646f3d3f7c661a2be50415b3 Mon Sep 17 00:00:00 2001 From: Thilo Balke Date: Thu, 21 Apr 2022 16:00:28 -0600 Subject: [PATCH 35/37] consistent spelling of #-dimensional --- docs/source/operator.rst | 2 +- scico/denoiser.py | 6 +++--- scico/numpy/blockarray.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/operator.rst b/docs/source/operator.rst index 549a0c796..b38704c8e 100644 --- a/docs/source/operator.rst +++ b/docs/source/operator.rst @@ -13,7 +13,7 @@ Each :class:`.Operator` object has an ``input_shape`` and ``output_shape``; thes The ``matrix_shape`` attribute describes the shape of the :class:`.LinearOperator` if it were to act on vectorized, or flattened, inputs. -For example, consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. +For example, consider a two-dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. We compute the discrete differences of :math:`\mb{x}` in the horizontal and vertical directions, generating two new arrays: :math:`\mb{x}_h \in \mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) \times m}`. We represent this linear operator by diff --git a/scico/denoiser.py b/scico/denoiser.py index b48a7b1d7..e79c0a08e 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -43,7 +43,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): Args: x: Input image. Expected to be a 2D array (gray-scale denoising) - or 3D array (color denoising). Higher dimensional arrays are + or 3D array (color denoising). Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. For color denoising, the color channel is assumed to be in the last non-singleton dimension. @@ -70,7 +70,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( - "BM3D requires two dimensional or three dimensional inputs;" f" got ndim = {x.ndim}" + "BM3D requires two-dimensional or three dimensional inputs;" f" got ndim = {x.ndim}" ) # This check is also performed inside the BM3D call, but due to the host_callback, @@ -108,7 +108,7 @@ def bm4d(x: JaxArray, sigma: float): :cite:`maggioni-2012-nonlocal`. Args: - x: Input image. Expected to be a 3D array. Higher dimensional + x: Input image. Expected to be a 3D array. Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. sigma: Noise parameter. diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 69fec5f88..84e409e46 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -98,7 +98,7 @@ Motivating Example ================== -Consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. +Consider a two-dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. We compute the discrete differences of :math:`\mb{x}` in the horizontal and vertical directions, generating two new arrays: :math:`\mb{x}_h \in From 1c0f1e42b28ddd0f07ba83f856e4664eacaa47d5 Mon Sep 17 00:00:00 2001 From: Thilo Balke Date: Thu, 21 Apr 2022 18:01:30 -0600 Subject: [PATCH 36/37] consistent spelling of #-dimensional --- scico/denoiser.py | 4 ++-- scico/linop/_matrix.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scico/denoiser.py b/scico/denoiser.py index e79c0a08e..c5e0a0f61 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -128,7 +128,7 @@ def bm4d(x: JaxArray, sigma: float): x_in_shape = x.shape if isinstance(x.ndim, tuple) or x.ndim < 3: - raise ValueError(f"BM4D requires three dimensional inputs; got ndim = {x.ndim}") + raise ValueError(f"BM4D requires three-dimensional inputs; got ndim = {x.ndim}") # This check is also performed inside the BM4D call, but due to the host_callback, # no exception is raised and the program will crash with no traceback. @@ -205,7 +205,7 @@ def __call__(self, x: JaxArray) -> JaxArray: if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( - "DnCNN requires two dimensional (M, N) or three dimensional (M, N, C)" + "DnCNN requires two-dimensional (M, N) or three-dimensional (M, N, C)" f" inputs; got ndim = {x.ndim}" ) diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index fd1bab1a9..46da33b4d 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -84,7 +84,7 @@ def __init__(self, A: JaxArray): # Can only do rank-2 arrays if A.ndim != 2: - raise TypeError(f"Expected a 2-dimensional array, got array of shape {A.shape}") + raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}") super().__init__(input_shape=A.shape[1], output_shape=A.shape[0], input_dtype=self.A.dtype) From 715711c2adfc18b289962b89ee7bcfa172212043 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Fri, 22 Apr 2022 07:55:41 -0600 Subject: [PATCH 37/37] Thilo review --- scico/numpy/util.py | 6 ++---- scico/test/linop/test_diff.py | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 2784856c5..74eacfccd 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -207,7 +207,7 @@ def is_nested(x: Any) -> bool: x: Object to be tested. Returns: - ``True`` if `x` is a list/tuple of list/tuples, ``False`` otherwise. + ``True`` if `x` is a list/tuple containing at least one list/tuple, ``False`` otherwise. Example: @@ -219,9 +219,7 @@ def is_nested(x: Any) -> bool: True """ - if isinstance(x, (list, tuple)): - return any([isinstance(_, (list, tuple)) for _ in x]) - return False + return isinstance(x, (list, tuple)) and any([isinstance(_, (list, tuple)) for _ in x]) def is_real_dtype(dtype: DType) -> bool: diff --git a/scico/test/linop/test_diff.py b/scico/test/linop/test_diff.py index 159f07c30..94ac3d605 100644 --- a/scico/test/linop/test_diff.py +++ b/scico/test/linop/test_diff.py @@ -32,10 +32,10 @@ def test_eval(): def test_adjoint(input_shape, input_dtype, axes, jit): ndim = len(input_shape) if axes in [1, (1,)] and ndim == 1: - pass - else: - A = FiniteDifference(input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit) - adjoint_test(A) + return + + A = FiniteDifference(input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit) + adjoint_test(A) @pytest.mark.parametrize(