diff --git a/docs/source/notes.rst b/docs/source/notes.rst index eef3b2fc7..2431ae9e7 100644 --- a/docs/source/notes.rst +++ b/docs/source/notes.rst @@ -185,7 +185,7 @@ When evaluating the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``: >>> import scico >>> import scico.numpy as snp >>> f = lambda x: snp.linalg.norm(x)**2 - >>> scico.grad(f)(snp.zeros(2)) # + >>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # DeviceArray([nan, nan], dtype=float32) This can be fixed by defining the squared :math:`\ell_2` norm directly as @@ -194,7 +194,7 @@ This can be fixed by defining the squared :math:`\ell_2` norm directly as :: >>> g = lambda x: snp.sum(x**2) - >>> scico.grad(g)(snp.zeros(2)) + >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) DeviceArray([0., 0.], dtype=float32) An alternative is to define a `custom derivative rule `_ to enforce a particular derivative convention at a point. diff --git a/docs/source/operator.rst b/docs/source/operator.rst index 549a0c796..b38704c8e 100644 --- a/docs/source/operator.rst +++ b/docs/source/operator.rst @@ -13,7 +13,7 @@ Each :class:`.Operator` object has an ``input_shape`` and ``output_shape``; thes The ``matrix_shape`` attribute describes the shape of the :class:`.LinearOperator` if it were to act on vectorized, or flattened, inputs. -For example, consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. +For example, consider a two-dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. We compute the discrete differences of :math:`\mb{x}` in the horizontal and vertical directions, generating two new arrays: :math:`\mb{x}_h \in \mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) \times m}`. We represent this linear operator by diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index dc0bb39e2..82546f30b 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/Usr/bin/env python # -*- coding: utf-8 -*- # This file is part of the SCICO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed @@ -39,8 +39,8 @@ import scico.numpy as snp import scico.random from scico import functional, linop, loss, operator, plot -from scico.array import ensure_on_device -from scico.blockarray import BlockArray +from scico.numpy import BlockArray +from scico.numpy.util import ensure_on_device from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize from scico.typing import JaxArray from scico.util import device_info @@ -96,6 +96,7 @@ def __init__( super().__init__(y=y, A=A, scale=1.0) self.lmbda = lmbda + @jax.jit def __call__(self, x: Union[JaxArray, BlockArray]) -> float: xint = self.y - self.lmbda * self.A(x) @@ -117,14 +118,15 @@ class IsoProjector(functional.Functional): def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return 0.0 + @jax.jit def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: norm_v_ptp = jnp.sqrt(jnp.sum(jnp.abs(v) ** 2, axis=0)) x_out = v / jnp.maximum(jnp.ones(v.shape), norm_v_ptp) out1 = v[0, :, -1] / jnp.maximum(jnp.ones(v[0, :, -1].shape), jnp.abs(v[0, :, -1])) - x_out_1 = jax.ops.index_update(x_out, jax.ops.index[0, :, -1], out1) + x_out = x_out.at[0, :, -1].set(out1) out2 = v[1, -1, :] / jnp.maximum(jnp.ones(v[1, -1, :].shape), jnp.abs(v[1, -1, :])) - x_out = jax.ops.index_update(x_out_1, jax.ops.index[1, -1, :], out2) + x_out = x_out.at[1, -1, :].set(out2) return x_out diff --git a/examples/scripts/sparsecode_poisson_pgm.py b/examples/scripts/sparsecode_poisson_pgm.py index 7479d01d3..64191671c 100644 --- a/examples/scripts/sparsecode_poisson_pgm.py +++ b/examples/scripts/sparsecode_poisson_pgm.py @@ -22,7 +22,7 @@ $I(\mathbf{x}^{(0)} \geq 0)$ is the non-negative indicator. This example also demonstrates the application of -[blockarray.BlockArray](../_autosummary/scico.blockarray.rst#scico.blockarray.BlockArray), +[blockarray.BlockArray](../_autosummary/scico.numpy.rst#scico.numpy.BlockArray), [functional.SeparableFunctional](../_autosummary/scico.functional.rst#scico.functional.SeparableFunctional), and [functional.ZeroFunctional](../_autosummary/scico.functional.rst#scico.functional.ZeroFunctional) @@ -40,7 +40,7 @@ import scico.numpy as snp import scico.random from scico import functional, loss, plot -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.operator import Operator from scico.optimize.pgm import ( AcceleratedPGM, diff --git a/scico/_flax.py b/scico/_flax.py index 5a2a67a18..bda3facc7 100644 --- a/scico/_flax.py +++ b/scico/_flax.py @@ -17,7 +17,7 @@ from flax.core import Scope # noqa from flax.linen.module import _Sentinel # noqa -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray # The imports of Scope and _Sentinel (above) and the definition of Module diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index 8d0c0d66b..0657ea108 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -23,8 +23,8 @@ import scico.numpy as snp from scico._autograd import linear_adjoint -from scico.array import is_complex_dtype, is_nested -from scico.blockarray import BlockArray, block_sizes +from scico.numpy import BlockArray +from scico.numpy.util import is_complex_dtype, is_nested, shape_to_size from scico.typing import BlockShape, DType, JaxArray, Shape @@ -152,8 +152,8 @@ def __init__( # Determine the shape of the "vectorized" operator (as an element of ℝ^{n × m} # If the function returns a BlockArray we need to compute the size of each block, # then sum. - self.input_size = int(np.sum(block_sizes(self.input_shape))) - self.output_size = int(np.sum(block_sizes(self.output_shape))) + self.input_size = shape_to_size(self.input_shape) + self.output_size = shape_to_size(self.output_shape) self.shape = (self.output_shape, self.input_shape) self.matrix_shape = (self.output_size, self.input_size) @@ -320,8 +320,8 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator: def concat_args(args): # Creates a blockarray with args and the frozen value in the correct place # Eg if this operator takes a blockarray with two blocks, then - # concat_args(args) = BlockArray.array([val, args]) if argnum = 0 - # concat_args(args) = BlockArray.array([args, val]) if argnum = 1 + # concat_args(args) = snp.blockarray([val, args]) if argnum = 0 + # concat_args(args) = snp.blockarray([args, val]) if argnum = 1 if isinstance(args, (DeviceArray, np.ndarray)): # In the case that the original operator takes a blcokarray with two @@ -336,7 +336,7 @@ def concat_args(args): arg_list.append(args[i - 1]) else: arg_list.append(val) - return BlockArray.array(arg_list) + return snp.blockarray(arg_list) return Operator( input_shape=input_shape, diff --git a/scico/blockarray.py b/scico/blockarray.py deleted file mode 100644 index 601d5e4a7..000000000 --- a/scico/blockarray.py +++ /dev/null @@ -1,1466 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SPORCO package. Details of the copyright -# and user license can be found in the 'LICENSE.txt' file distributed -# with the package. - -r"""Extensions of numpy ndarray class. - - .. testsetup:: - - >>> import scico - >>> import scico.numpy as snp - >>> from scico.blockarray import BlockArray - >>> import numpy as np - >>> import jax.numpy - -The class :class:`.BlockArray` is a `jagged array -`_ that aims to mimic the -:class:`numpy.ndarray` interface where appropriate. - -A :class:`.BlockArray` object consists of a tuple of `DeviceArray` -objects that share their memory buffers with non-overlapping, contiguous -regions of a common one-dimensional `DeviceArray`. A :class:`.BlockArray` -contains the following size attributes: - -* `shape`: A tuple of tuples containing component dimensions. -* `size`: The sum of the size of each component block; this is the length - of the underlying one-dimensional `DeviceArray`. -* `num_blocks`: The number of components (blocks) that comprise the - :class:`.BlockArray`. - - -Motivating Example -================== - -Consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. - -We compute the discrete differences of :math:`\mb{x}` in the horizontal -and vertical directions, generating two new arrays: :math:`\mb{x}_h \in -\mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) -\times m}`. - -As these arrays are of different shapes, we cannot combine them into a -single `ndarray`. Instead, we might vectorize each array and concatenate -the resulting vectors, leading to :math:`\mb{\bar{x}} \in -\mathbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional -`ndarray`. Unfortunately, this makes it hard to access the individual -components :math:`\mb{x}_h` and :math:`\mb{x}_v`. - -Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B = -[\mb{x}_h, \mb{x}_v]` - - - :: - - >>> n = 32 - >>> m = 16 - >>> x_h, key = scico.random.randn((n, m-1)) - >>> x_v, _ = scico.random.randn((n-1, m), key=key) - - # Form the blockarray - >>> x_B = BlockArray.array([x_h, x_v]) - - # The blockarray shape is a tuple of tuples - >>> x_B.shape - ((32, 15), (31, 16)) - - # Each block component can be easily accessed - >>> x_B[0].shape - (32, 15) - >>> x_B[1].shape - (31, 16) - - -Constructing a BlockArray -========================= - -Construct from a tuple of arrays (either `ndarray` or `DeviceArray`) --------------------------------------------------------------------- - - .. doctest:: - - >>> from scico.blockarray import BlockArray - >>> import numpy as np - >>> x0, key = scico.random.randn((32, 32)) - >>> x1, _ = scico.random.randn((16,), key=key) - >>> X = BlockArray.array((x0, x1)) - >>> X.shape - ((32, 32), (16,)) - >>> X.size - 1040 - >>> X.num_blocks - 2 - -While :func:`.BlockArray.array` will accept either `ndarray` or -`DeviceArray` as input, the resulting :class:`.BlockArray` will be backed -by a `DeviceArray` memory buffer. - -**Note**: constructing a :class:`.BlockArray` always involves a copy to -a new `DeviceArray` memory buffer. - -**Note**: by default, the resulting :class:`.BlockArray` is cast to -single precision and will have dtype ``float32`` or ``complex64``. - - -Construct from a single vector and tuple of shapes --------------------------------------------------- - - :: - - >>> x_flat, _ = scico.random.randn((1040,)) - >>> shape_tuple = ((32, 32), (16,)) - >>> X = BlockArray.array_from_flattened(x_flat, shape_tuple=shape_tuple) - >>> X.shape - ((32, 32), (16,)) - - - -Operating on a BlockArray -========================= - -.. _blockarray_indexing: - -Indexing --------- - -The block index is required to be an integer, selecting a single block and -returning it as an array (*not* a singleton BlockArray). If the index -expression has more than one component, then the initial index indexes the -block, and the remainder of the indexing expression indexes within the -selected block, e.g. `x[2, 3:4]` is equivalent to `y[3:4]` after -setting `y = x[2]`. - - -Indexed Updating ----------------- - -BlockArrays support the JAX DeviceArray `indexed update syntax -`_ - - -The index must be of the form [ibk] or [ibk, idx], where `ibk` is the -index of the block to be updated, and `idx` is a general index of the -elements to be updated in that block. In particular, `ibk` cannot be a -`slice`. The general index `idx` can be omitted, in which case an entire -block is updated. - - -============================== ============================================== -Alternate syntax Equivalent in-place expression -============================== ============================================== -`x.at[ibk, idx].set(y)` `x[ibk, idx] = y` -`x.at[ibk, idx].add(y)` `x[ibk, idx] += y` -`x.at[ibk, idx].multiply(y)` `x[ibk, idx] *= y` -`x.at[ibk, idx].divide(y)` `x[ibk, idx] /= y` -`x.at[ibk, idx].power(y)` `x[ibk, idx] **= y` -`x.at[ibk, idx].min(y)` `x[ibk, idx] = np.minimum(x[idx], y)` -`x.at[ibk, idx].max(y)` `x[ibk, idx] = np.maximum(x[idx], y)` -============================== ============================================== - - -Arithmetic and Broadcasting ---------------------------- - -Suppose :math:`\mb{x}` is a BlockArray with shape :math:`((n, n), (m,))`. - - :: - - >>> x1, key = scico.random.randn((4, 4)) - >>> x2, _ = scico.random.randn((5,), key=key) - >>> x = BlockArray.array( (x1, x2) ) - >>> x.shape - ((4, 4), (5,)) - >>> x.num_blocks - 2 - >>> x.size # 4*4 + 5 - 21 - -Illustrated for the operation `+`, but equally valid for operators -`+, -, *, /, //, **, <, <=, >, >=, ==` - - -Operations with BlockArrays with same number of blocks -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Let :math:`\mb{y}` be a BlockArray with the same number of blocks as -:math:`\mb{x}`. - - .. math:: - \mb{x} + \mb{y} - = - \begin{bmatrix} - \mb{x}[0] + \mb{y}[0] \\ - \mb{x}[1] + \mb{y}[1] \\ - \end{bmatrix} - -This operation depends on pair of blocks from :math:`\mb{x}` and -:math:`\mb{y}` being broadcastable against each other. - - - -Operations with a scalar -^^^^^^^^^^^^^^^^^^^^^^^^ - -The scalar is added to each element of the :class:`.BlockArray`: - - .. math:: - \mb{x} + 1 - = - \begin{bmatrix} - \mb{x}[0] + 1 \\ - \mb{x}[1] + 1\\ - \end{bmatrix} - - - :: - - >>> y = x + 1 - >>> np.testing.assert_allclose(y[0], x[0] + 1) - >>> np.testing.assert_allclose(y[1], x[1] + 1) - - - -Operations with a 1D `ndarray` of size equal to `num_blocks` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The *i*\th scalar is added to the *i*\th block of the -:class:`.BlockArray`: - - .. math:: - \mb{x} - + - \begin{bmatrix} - 1 \\ - 2 - \end{bmatrix} - = - \begin{bmatrix} - \mb{x}[0] + 1 \\ - \mb{x}[1] + 2\\ - \end{bmatrix} - - - :: - - >>> y = x + np.array([1, 2]) - >>> np.testing.assert_allclose(y[0], x[0] + 1) - >>> np.testing.assert_allclose(y[1], x[1] + 2) - - -Operations with an ndarray of `size` equal to :class:`.BlockArray` size -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We first cast the `ndarray` to a BlockArray with same shape as -:math:`\mb{x}`, then apply the operation on the resulting BlockArrays. -With `y.size = x.size`, we have: - - .. math:: - \mb{x} - + - \mb{y} - = - \begin{bmatrix} - \mb{x}[0] + \mb{y}[0] \\ - \mb{x}[1] + \mb{y}[1]\\ - \end{bmatrix} - -Equivalently, the BlockArray is first flattened, then added to the -flattened `ndarray`, and the result is reformed into a BlockArray with -the same shape as :math:`\mb{x}` - - - -MatMul ------- - -Between two BlockArrays -^^^^^^^^^^^^^^^^^^^^^^^ - -The matmul is computed between each block of the two BlockArrays. - -The BlockArrays must have the same number of blocks, and each pair of -blocks must be broadcastable. - - .. math:: - \mb{x} @ \mb{y} - = - \begin{bmatrix} - \mb{x}[0] @ \mb{y}[0] \\ - \mb{x}[1] @ \mb{y}[1]\\ - \end{bmatrix} - - - -Between BlockArray and Ndarray/DeviceArray -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This operation is not defined. - - -Between BlockArray and :class:`.LinearOperator` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :class:`.Operator` and :class:`.LinearOperator` classes are designed -to work on :class:`.BlockArray`\ s in addition to `DeviceArray`\ s. -For example - - - :: - - >>> x, key = scico.random.randn((3, 4)) - >>> A_1 = scico.linop.Identity(x.shape) - >>> A_1.shape # array -> array - ((3, 4), (3, 4)) - - >>> A_2 = scico.linop.FiniteDifference(x.shape) - >>> A_2.shape # array -> BlockArray - (((2, 4), (3, 3)), (3, 4)) - - >>> diag = BlockArray.array([np.array(1.0), np.array(2.0)]) - >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape)) - >>> A_3.shape # BlockArray -> BlockArray - (((2, 4), (3, 3)), ((2, 4), (3, 3))) - - -NumPy ufuncs ------------- - -`NumPy universal functions (ufuncs) `_ -are functions that operate on an `ndarray` on an element-by-element -fashion and support array broadcasting. Examples of ufuncs are `abs`, -`sign`, `conj`, and `exp`. - -The JAX library implements most NumPy ufuncs in the :mod:`jax.numpy` -module. However, as JAX does not support subclassing of `DeviceArray`, -the JAX ufuncs cannot be used on :class:`.BlockArray`. As a workaround, -we have wrapped several JAX ufuncs for use on :class:`.BlockArray`; these -are defined in the :mod:`scico.numpy` module. - - -Reductions -^^^^^^^^^^ - -Reductions are functions that take an array-like as an input and return -an array of lower dimension. Examples include `mean`, `sum`, `norm`. -BlockArray reductions are located in the :mod:`scico.numpy` module - -:class:`.BlockArray` tries to mirror `ndarray` reduction semantics where -possible, but cannot provide a one-to-one match as the block components -may be of different size. - -Consider the example BlockArray - - .. math:: - \mb{x} = \begin{bmatrix} - \begin{bmatrix} - 1 & 1 \\ - 1 & 1 - \end{bmatrix} \\ - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} - \end{bmatrix}. - -We have - - .. doctest:: - - >>> import scico.numpy as snp - >>> x = BlockArray.array((np.ones((2,2)), 2*np.ones((2)))) - >>> x.shape - ((2, 2), (2,)) - >>> x.size - 6 - >>> x.num_blocks - 2 - - - - If no axis is specified, the reduction is applied to the flattened - array: - - .. doctest:: - - >>> snp.sum(x, axis=None).item() - 8.0 - - Reducing along the 0-th axis crushes the `BlockArray` down into a - single `DeviceArray` and requires all blocks to have the same shape - otherwise, an error is raised. - - .. doctest:: - - >>> snp.sum(x, axis=0) - Traceback (most recent call last): - ValueError: Evaluating sum of BlockArray along axis=0 requires all blocks to be same shape; got ((2, 2), (2,)) - - >>> y = BlockArray.array((np.ones((2,2)), 2*np.ones((2, 2)))) - >>> snp.sum(y, axis=0) - DeviceArray([[3., 3.], - [3., 3.]], dtype=float32) - - Reducing along axis :math:`n` is equivalent to reducing each component - along axis :math:`n-1`: - - .. math:: - \text{sum}(x, axis=1) = \begin{bmatrix} - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} \\ - \begin{bmatrix} - 4 \\ - \end{bmatrix} - \end{bmatrix} - - - If a component does not have axis :math:`n-1`, the reduction is not - applied to that component. In this example, `x[1].ndim == 1`, so no - reduction is applied to block `x[1]`. - - .. math:: - \text{sum}(x, axis=2) = \begin{bmatrix} - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} \\ - \begin{bmatrix} - 2 \\ - 2 - \end{bmatrix} - \end{bmatrix} - - -Code version - - .. doctest:: - - >>> snp.sum(x, axis=1) # doctest: +SKIP - BlockArray([[2, 2], - [4,] ]) - - >>> snp.sum(x, axis=2) # doctest: +SKIP - BlockArray([ [2, 2], - [2,] ]) - - -""" - -from __future__ import annotations - -from functools import wraps -from typing import Iterator, List, Optional, Tuple, Union - -import numpy as np - -import jax -import jax.numpy as jnp -from jax import core -from jax.interpreters import xla -from jax.interpreters.xla import DeviceArray -from jax.tree_util import register_pytree_node, tree_flatten - -from jaxlib.xla_extension import Buffer - -from scico import array -from scico.typing import Axes, AxisIndex, BlockShape, DType, JaxArray, Shape - -_arraylikes = (Buffer, DeviceArray, np.ndarray) - - -def atleast_1d(*arys): - """Convert inputs to arrays with at least one dimension. - - A wrapper for :func:`jax.numpy.atleast_1d` that acts as usual on - ndarrays and DeviceArrays, and returns BlockArrays unmodified. - """ - - if len(arys) == 1: - arr = arys[0] - return arr if isinstance(arr, BlockArray) else jnp.atleast_1d(arr) - - out = [] - for arr in arys: - if isinstance(arr, BlockArray): - out.append(arr) - else: - out.append(jnp.atleast_1d(arr)) - return out - - -# Append docstring from original jax.numpy function -atleast_1d.__doc__ = ( - atleast_1d.__doc__.replace("\n ", "\n") # deal with indentation differences - + "\nDocstring for :func:`jax.numpy.atleast_1d`:\n\n" - + "\n".join(jax.numpy.atleast_1d.__doc__.split("\n")[2:]) -) - - -def reshape( - a: Union[JaxArray, BlockArray], newshape: Union[Shape, BlockShape] -) -> Union[JaxArray, BlockArray]: - """Change the shape of an array without changing its data. - - Args: - a: Array to be reshaped. - newshape: The new shape should be compatible with the original - shape. If an integer, then the result will be a 1-D array of - that length. One shape dimension can be -1. In this case, - the value is inferred from the length of the array and - remaining dimensions. If a tuple of tuple of ints, a - :class:`.BlockArray` is returned. - - Returns: - The reshaped array. Unlike :func:`numpy.reshape`, a copy is - always returned. - """ - - if array.is_nested(newshape): - # x is a blockarray - return BlockArray.array_from_flattened(a, newshape) - - return jnp.reshape(a, newshape) - - -def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: - r"""Compute the 'sizes' of (possibly nested) block shapes. - - This function computes `block_sizes(z.shape) == (_.size for _ in z)`. - - - Args: - shape: A shape tuple; possibly containing nested tuples. - - - Examples: - - .. doctest:: - >>> import scico.numpy as snp - - >>> x = BlockArray.ones( ( (4, 4), (2,))) - >>> x.size - 18 - - >>> y = snp.ones((3, 3)) - >>> y.size - 9 - - >>> z = BlockArray.array([x, y]) - >>> block_sizes(z.shape) - (18, 9) - - >>> zz = BlockArray.array([z, z]) - >>> block_sizes(zz.shape) - (27, 27) - """ - - if isinstance(shape, BlockArray): - raise TypeError( - "Expected a `shape` (possibly nested tuple of ints); got :class:`.BlockArray`." - ) - - out = [] - if array.is_nested(shape): - # shape is nested -> at least one element came from a blockarray - for y in shape: - if array.is_nested(y): - # recursively calculate the block size until we arrive at - # a tuple (shape of a non-block array) - while array.is_nested(y): - y = block_sizes(y) - out.append(np.sum(y)) # adjacent block sizes are added together - else: - # this is a tuple; size given by product of elements - out.append(np.prod(y)) - return tuple(out) - - # shape is a non-nested tuple; return the product - return np.prod(shape) - - -def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple: - """Decompose a BlockArray indexing expression into components. - - Decompose a BlockArray indexing expression into block and array - components. - - Args: - idx: BlockArray indexing expression. - - Returns: - A tuple (idxblk, idxarr) with entries corresponding to the - integer block index and the indexing to be applied to the - selected block, respectively. The latter is ``None`` if the - indexing expression simply selects one of the blocks (i.e. - it consists of a single integer). - - Raises: - TypeError: If the block index is not an integer. - """ - if isinstance(idx, tuple): - idxblk = idx[0] - idxarr = idx[1:] - else: - idxblk = idx - idxarr = None - if not isinstance(idxblk, int): - raise TypeError("Block index must be an integer") - return idxblk, idxarr - - -def indexed_shape(shape: Shape, idx: Union[int, Tuple[AxisIndex, ...]]) -> Tuple[int, ...]: - """Determine the shape of the result of indexing a BlockArray. - - Args: - shape: Shape of BlockArray. - idx: BlockArray indexing expression. - - Returns: - Shape of the selected block, or slice of that block if `idx` is a tuple - rather than an integer. - """ - idxblk, idxarr = _decompose_index(idx) - if idxblk < 0: - idxblk = len(shape) + idxblk - if idxarr is None: - return shape[idxblk] - return array.indexed_shape(shape[idxblk], idxarr) - - -def _flatten_blockarrays(inp, *args, **kwargs): - """Flatten any blockarrays present in `inp`, `args`, or `kwargs`.""" - - def _flatten_if_blockarray(inp): - if isinstance(inp, BlockArray): - return inp._data - return inp - - inp_ = _flatten_if_blockarray(inp) - args_ = (_flatten_if_blockarray(_) for _ in args) - kwargs_ = {key: _flatten_if_blockarray(val) for key, val in kwargs.items()} - return inp_, args_, kwargs_ - - -def _block_array_ufunc_wrapper(func): - """Wrap a "ufunc" to allow for joint operation on `DeviceArray` and `BlockArray`.""" - - @wraps(func) - def wrapper(inp, *args, **kwargs): - all_args = (inp,) + args + tuple(kwargs.items()) - if any([isinstance(_, BlockArray) for _ in all_args]): - # If 'inp' is a BlockArray, call func on inp._data - # Then return a BlockArray of the same shape as inp - - inp_, args_, kwargs_ = _flatten_blockarrays(inp, *args, **kwargs) - flat_out = func(inp_, *args_, **kwargs_) - return BlockArray.array_from_flattened(flat_out, inp.shape) - - # Otherwise call the function normally - return func(inp, *args, **kwargs) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_reduction_wrapper(func): - """Wrap a reduction (eg. sum, norm) to allow for joint operation on - `DeviceArray` and `BlockArray`.""" - - @wraps(func) - def wrapper(inp, *args, axis=None, **kwargs): - - all_args = (inp,) + args + tuple(kwargs.items()) - if any([isinstance(_, BlockArray) for _ in all_args]): - if axis is None: - # Treat as a single long vector - inp_, args_, kwargs_ = _flatten_blockarrays(inp, *args, **kwargs) - return func(inp_, *args_, **kwargs_) - - if type(axis) == tuple: - raise Exception( - f"""Evaluating {func.__name__} on a BlockArray with a tuple argument to - axis is not currently supported""" - ) - - if axis == 0: # reduction along block axis - # reduction along axis=0 only makes sense if all blocks are the same shape - # so we can convert to a standard DeviceArray of shape (inp.num_blocks, ...) - # and reduce along axis = 0 - if all([bk_shape == inp.shape[0] for bk_shape in inp.shape]): - view_shape = (inp.num_blocks,) + inp.shape[0] - return func(inp._data.reshape(view_shape), *args, axis=0, **kwargs) - - raise ValueError( - f"Evaluating {func.__name__} of BlockArray along axis=0 requires " - f"all blocks to be same shape; got {inp.shape}" - ) - - # Reduce each block individually along axis-1 - out = [] - for bk in inp: - if isinstance(bk, BlockArray): - # This block is itself a blockarray, so call this wrapped reduction - # on axis-1 - tmp = _block_array_reduction_wrapper(func)(bk, *args, axis=axis - 1, **kwargs) - else: - if axis - 1 >= bk.ndim: - # Trying to reduce along a dim that doesn't exist for this block, - # so just return the block. - # i.e. broadcast to shape (..., 1) and reduce along axis=-1 - tmp = bk - else: - tmp = func(bk, *args, axis=axis - 1, **kwargs) - out.append(atleast_1d(tmp)) - return BlockArray.array(out) - - if axis is None: - # 'axis' might not be a valid kwarg (eg dot, vdot), so don't pass it - return func(inp, *args, **kwargs) - - return func(inp, *args, axis=axis, **kwargs) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_matmul_wrapper(func): - @wraps(func) - def wrapper(self, other): - if isinstance(self, BlockArray): - if isinstance(other, BlockArray): - # Both blockarrays, work block by block - return BlockArray.array([func(x, y) for x, y in zip(self, other)]) - raise TypeError( - f"Operation {func.__name__} not implemented between {type(self)} and {type(other)}" - ) - return func(self, other) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -def _block_array_binary_op_wrapper(func): - """Return a decorator that performs type and shape checking for - :class:`.BlockArray` arithmetic. - """ - - @wraps(func) - def wrapper(self, other): - if isinstance(other, BlockArray): - if other.shape == self.shape: - # Same shape blocks, can operate on flattened arrays - return BlockArray.array_from_flattened(func(self._data, other._data), self.shape) - if other.num_blocks == self.num_blocks: - # Will work as long as the shapes are broadcastable - return BlockArray.array([func(x, y) for x, y in zip(self, other)]) - raise ValueError( - f"operation not valid on operands with shapes {self.shape} {other.shape}" - ) - if any([isinstance(other, _) for _ in _arraylikes]): - if other.size == 1: - # Same as operating on a scalar - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - if other.size == self.size: - # A little fast and loose, treat the block array as a length self.size vector - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - if other.size == self.num_blocks: - return BlockArray.array([func(blk, other_) for blk, other_ in zip(self, other)]) - raise ValueError( - f"operation not valid on operands with shapes {self.shape} {other.shape}" - ) - if jnp.isscalar(other) or isinstance(other, core.Tracer): - return BlockArray.array_from_flattened(func(self._data, other), self.shape) - raise TypeError( - f"Operation {func.__name__} not implemented between {type(self)} and {type(other)}" - ) - - if not hasattr(func, "__doc__") or func.__doc__ is None: - return wrapper - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`BlockArray`" + "\n\n" + func.__doc__ - ) - return wrapper - - -class _AbstractBlockArray(core.ShapedArray): - """Abstract BlockArray class for JAX tracing. - - See https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html - """ - - array_abstraction_level = 0 # Same as jax.core.ConcreteArray - - def __init__(self, shapes, dtype): - - sizes = block_sizes(shapes) - size = np.sum(sizes) - - super(_AbstractBlockArray, self).__init__((size,), dtype) - - #: Abstract data value - self._data_aval: core.ShapedArray = core.ShapedArray((size,), dtype) - - #: Array dtype - self.dtype: DType = dtype - - #: Shape of each block - self.shapes: BlockShape = shapes - - #: Size of each block - self.sizes: Shape = sizes - - #: Array specifying boundaries of components as indices in base array - self.bndpos: np.ndarray = np.r_[0, np.cumsum(sizes)] - - -# The Jax class is heavily inspired by SparseArray/AbstractSparseArray here: -# https://github.com/google/jax/blob/7724322d1c08c13008815bfb52759a29c2a6823b/tests/custom_object_test.py -class BlockArray: - """A tuple of :class:`jax.interpreters.xla.DeviceArray` objects. - - A tuple of `DeviceArray` objects that all share their memory buffers - with non-overlapping, contiguous regions of a common one-dimensional - `DeviceArray`. It can be used as the common one-dimensional array via - the :func:`BlockArray.ravel` method, or individual component arrays - can be accessed individually. - """ - - # Ensure we use BlockArray.__radd__, __rmul__, etc for binary operations of the form - # op(np.ndarray, BlockArray) - # See https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ - __array_priority__ = 1 - - def __init__(self, aval: _AbstractBlockArray, data: JaxArray): - """BlockArray init method. - - Args: - aval: `Abstract value`_ associated with this array (shape+dtype+weak_type) - data: The underlying contiguous, flattened `DeviceArray`. - - .. _Abstract value: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html - - """ - self._aval = aval - self._data = data - - def __repr__(self): - return "scico.blockarray.BlockArray: \n" + self._data.__repr__() - - def __getitem__(self, idx: Union[int, Tuple[AxisIndex, ...]]) -> JaxArray: - idxblk, idxarr = _decompose_index(idx) - if idxblk < 0: - idxblk = self.num_blocks + idxblk - blk = reshape(self._data[self.bndpos[idxblk] : self.bndpos[idxblk + 1]], self.shape[idxblk]) - if idxarr is not None: - blk = blk[idxarr] - return blk - - @_block_array_matmul_wrapper - def __matmul__(self, other: Union[np.ndarray, BlockArray, JaxArray]) -> JaxArray: - return self @ other - - @_block_array_matmul_wrapper - def __rmatmul__(self, other: Union[np.ndarray, BlockArray, JaxArray]) -> JaxArray: - return other @ self - - @_block_array_binary_op_wrapper - def __sub__(a, b): - return a - b - - @_block_array_binary_op_wrapper - def __rsub__(a, b): - return b - a - - @_block_array_binary_op_wrapper - def __mul__(a, b): - return a * b - - @_block_array_binary_op_wrapper - def __rmul__(a, b): - return a * b - - @_block_array_binary_op_wrapper - def __add__(a, b): - return a + b - - @_block_array_binary_op_wrapper - def __radd__(a, b): - return a + b - - @_block_array_binary_op_wrapper - def __truediv__(a, b): - return a / b - - @_block_array_binary_op_wrapper - def __rtruediv__(a, b): - return b / a - - @_block_array_binary_op_wrapper - def __floordiv__(a, b): - return a // b - - @_block_array_binary_op_wrapper - def __rfloordiv__(a, b): - return b // a - - @_block_array_binary_op_wrapper - def __pow__(a, b): - return a**b - - @_block_array_binary_op_wrapper - def __rpow__(a, b): - return b**a - - @_block_array_binary_op_wrapper - def __gt__(a, b): - return a > b - - @_block_array_binary_op_wrapper - def __ge__(a, b): - return a >= b - - @_block_array_binary_op_wrapper - def __lt__(a, b): - return a < b - - @_block_array_binary_op_wrapper - def __le__(a, b): - return a <= b - - @_block_array_binary_op_wrapper - def __eq__(a, b): - return a == b - - @_block_array_binary_op_wrapper - def __ne__(a, b): - return a != b - - def __iter__(self) -> Iterator[int]: - for i in range(self.num_blocks): - yield self[i] - - @property - def blocks(self) -> Iterator[int]: - """Return an iterator yielding component blocks.""" - return self.__iter__() - - @property - def bndpos(self) -> np.ndarray: - """Array specifying boundaries of components as indices in base array.""" - return self._aval.bndpos - - @property - def dtype(self) -> DType: - """Array dtype.""" - return self._data.dtype - - @property - def device_buffer(self) -> Buffer: - """The :class:`jaxlib.xla_extension.Buffer` that backs the - underlying data array.""" - return self._data.device_buffer - - @property - def size(self) -> int: - """Total number of elements in the array.""" - return self._aval.size - - @property - def num_blocks(self) -> int: - """Number of :class:`.BlockArray` components.""" - - return len(self.shape) - - @property - def ndim(self) -> Shape: - """Tuple of component ndims.""" - - return tuple(len(c) for c in self.shape) - - @property - def shape(self) -> BlockShape: - """Tuple of component shapes.""" - - return self._aval.shapes - - @property - def split(self) -> Tuple[JaxArray, ...]: - """Tuple of component arrays.""" - - return tuple(self[k] for k in range(self.num_blocks)) - - def conj(self) -> BlockArray: - """Return a :class:`.BlockArray` with complex-conjugated elements.""" - - # Much faster than BlockArray.array([_.conj() for _ in self.blocks]) - return BlockArray.array_from_flattened(self.ravel().conj(), self.shape) - - @property - def real(self) -> BlockArray: - """Return a :class:`.BlockArray` with the real part of this array.""" - return BlockArray.array_from_flattened(self.ravel().real, self.shape) - - @property - def imag(self) -> BlockArray: - """Return a :class:`.BlockArray` with the imaginary part of this array.""" - return BlockArray.array_from_flattened(self.ravel().imag, self.shape) - - @classmethod - def array( - cls, alst: List[Union[np.ndarray, JaxArray]], dtype: Optional[np.dtype] = None - ) -> BlockArray: - """Construct a :class:`.BlockArray` from a list or tuple of existing array-like. - - Args: - alst: Initializers for array components. - Can be :class:`numpy.ndarray` or - :class:`jax.interpreters.xla.DeviceArray` - dtype: Data type of array. If ``None``, dtype is derived from - dtype of initializers. - - Returns: - :class:`.BlockArray` initialized from `alst` tuple. - """ - - if isinstance(alst, (tuple, list)) is False: - raise TypeError("Input to `array` must be a list or tuple of existing arrays") - - if dtype is None: - present_types = jax.tree_flatten(jax.tree_map(lambda x: x.dtype, alst))[0] - dtype = np.find_common_type(present_types, []) - - # alst can be a list/tuple of arrays, or a list/tuple containing list/tuples of arrays - # consider alst to be a tree where leaves are arrays (possibly abstract arrays) - # use tree_map to find the shape of each leaf - # `shapes` will be a tuple of ints and tuples containing ints (possibly nested further) - - # ensure any scalar leaves are converted to (1,) arrays - def shape_atleast_1d(x): - return x.shape if x.shape != () else (1,) - - shapes = tuple( - jax.tree_map(shape_atleast_1d, alst, is_leaf=lambda x: not isinstance(x, (list, tuple))) - ) - - _aval = _AbstractBlockArray(shapes, dtype) - data_ravel = jnp.hstack(jax.tree_map(lambda x: x.ravel(), jax.tree_flatten(alst)[0])) - return cls(_aval, data_ravel) - - @classmethod - def array_from_flattened( - cls, data_ravel: Union[np.ndarray, JaxArray], shape_tuple: BlockShape - ) -> BlockArray: - """Construct a :class:`.BlockArray` from a flattened array and tuple of shapes. - - Args: - data_ravel: Flattened data array. - shape_tuple: Tuple of tuples containing desired block shapes. - - Returns: - :class:`.BlockArray` initialized from `data_ravel` and `shape_tuple`. - """ - - if not isinstance(data_ravel, DeviceArray): - data_ravel = jax.device_put(data_ravel) - - shape_tuple_size = np.sum(block_sizes(shape_tuple)) - - if shape_tuple_size != data_ravel.size: - raise ValueError( - f"""The specified shape_tuple is incompatible with provided data_ravel - shape_tuple = {shape_tuple} - shape_tuple_size = {shape_tuple_size} - len(data_ravel) = {len(data_ravel)} - """ - ) - - _aval = _AbstractBlockArray(shape_tuple, dtype=data_ravel.dtype) - return cls(_aval, data_ravel) - - @classmethod - def ones(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with ones. - - Args: - shape_tuple: Tuple of shapes for component blocks. - dtype: Desired data-type for the :class:`.BlockArray`. - Default is ``numpy.float32``. - - Returns: - :class:`.BlockArray` of ones with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.ones(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def zeros(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. - - Args: - shape_tuple: Tuple of shapes for component blocks. - dtype: Desired data-type for the :class:`.BlockArray`. - Default is ``numpy.float32``. - - Returns: - :class:`.BlockArray` of zeros with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.zeros(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def empty(cls, shape_tuple: BlockShape, dtype: DType = np.float32) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with zeros. - - Note: like :func:`jax.numpy.empty`, this does not return an - uninitalized array. - - Args: - shape_tuple: Tuple of shapes for component blocks - dtype: Desired data-type for the :class:`.BlockArray`. - Default is ``numpy.float32``. - - Returns: - :class:`.BlockArray` of zeros with the given component shapes - and dtype. - """ - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.empty(_aval.size, dtype=dtype) - return cls(_aval, data_ravel) - - @classmethod - def full( - cls, - shape_tuple: BlockShape, - fill_value: Union[float, complex, int], - dtype: DType = np.float32, - ) -> BlockArray: - """ - Return a new :class:`.BlockArray` with given block shapes and type, filled with - `fill_value`. - - Args: - shape_tuple: Tuple of shapes for component blocks. - fill_value: Fill value. - dtype: Desired data-type for the BlockArray. The default, - None, means `np.array(fill_value).dtype`. - - Returns: - :class:`.BlockArray` with the given component shapes and - dtype and all entries equal to `fill_value`. - """ - if dtype is None: - dtype = np.asarray(fill_value).dtype - - _aval = _AbstractBlockArray(shape_tuple, dtype=dtype) - data_ravel = jnp.full(_aval.size, fill_value=fill_value, dtype=dtype) - return cls(_aval, data_ravel) - - def copy(self) -> BlockArray: - """Return a copy of this :class:`.BlockArray`. - - This method is not implemented for BlockArray. - - See Also: - :meth:`.to_numpy`: Convert a :class:`.BlockArray` into a - flattened NumPy array. - """ - # jax DeviceArray copies return a NumPy ndarray. This blockarray class must be backed - # by a DeviceArray, so cannot be converted to a NumPy-backed BlockArray. The BlockArray - # .to_numpy() method returns a flattened ndarray. - # - # This method may be implemented in the future if jax DeviceArray .copy() is modified to - # return another DeviceArray. - raise NotImplementedError - - def to_numpy(self) -> np.ndarray: - """Return a :class:`numpy.ndarray` containing the flattened form of this - :class:`.BlockArray`.""" - - if isinstance(self._data, DeviceArray): - host_arr = jax.device_get(self._data.copy()) - else: - host_arr = self._data.copy() - return host_arr - - def blockidx(self, idx: int) -> jax._src.ops.scatter._Indexable: - """Return :class:`jax.ops.index` for a given component block. - - Args: - idx: Desired block index. - - Returns: - :class:`jax.ops.index` pointing to desired block. - """ - return slice(self.bndpos[idx], self.bndpos[idx + 1]) - - def ravel(self) -> JaxArray: - """Return a copy of `self._data` as a contiguous, flattened `DeviceArray`. - - Note that a copy, rather than a view, of the underlying array is - returned. This is consistent with :func:`jax.numpy.ravel`. - - Returns: - Copy of underlying flattened array. - - """ - return self._data[:] - - def flatten(self) -> JaxArray: - """Return a copy of `self._data` as a contiguous, flattened `DeviceArray`. - - Note that a copy, rather than a view, of the underlying array is - returned. This is consistent with :func:`jax.numpy.ravel`. - - Returns: - Copy of underlying flattened array. - - """ - return self._data[:] - - def sum(self, axis=None, keepdims=False): - """Return the sum of the blockarray elements over the given axis. - - Refer to :func:`scico.numpy.sum` for full documentation. - """ - # Can't just call scico.numpy.sum due to pesky circular import... - return _block_array_reduction_wrapper(jnp.sum)(self, axis=axis, keepdims=keepdims) - - -## Register BlockArray as a Jax type -# Our BlockArray is just a single large vector with some extra sugar -class _ConcreteBlockArray(_AbstractBlockArray): - pass - - -def _block_array_result_handler(device, _aval): - def build_block_array(data_buf): - data = xla.DeviceArray(_aval._data_aval, device, None, data_buf) - return BlockArray(_aval, data) - - return build_block_array - - -def _block_array_shape_handler(a): - return (xla.xc.Shape.array_shape(a._data_aval.dtype, a._data_aval.shape),) - - -def _block_array_device_put_handler(a, device): - return (xla.xb.get_device_backend(device).buffer_from_pyval(a._data, device),) - - -core.pytype_aval_mappings[BlockArray] = lambda x: x._aval -core.raise_to_shaped_mappings[_AbstractBlockArray] = lambda _aval, _: _aval -xla.pytype_aval_mappings[BlockArray] = lambda x: x._aval -xla.canonicalize_dtype_handlers[BlockArray] = lambda x: x -jax._src.dispatch.device_put_handlers[BlockArray] = _block_array_device_put_handler -jax._src.dispatch.result_handlers[_AbstractBlockArray] = _block_array_result_handler -xla.xla_shape_handlers[_AbstractBlockArray] = _block_array_shape_handler - - -## Handlers to use jax.device_put on BlockArray -def _block_array_tree_flatten(block_arr): - """Flatten a :class:`.BlockArray` pytree. - - See :func:`jax.tree_util.tree_flatten`. - - Args: - block_arr (:class:`.BlockArray`): :class:`.BlockArray` to flatten - - Returns: - children (tuple): :class:`.BlockArray` leaves. - aux_data (tuple): Extra metadata used to reconstruct BlockArray. - """ - - data_children, data_aux_data = tree_flatten(block_arr._data) - return (data_children, block_arr._aval) - - -def _block_array_tree_unflatten(aux_data, children): - """Construct a :class:`.BlockArray` from a flattened pytree. - - See jax.tree_utils.tree_unflatten - - Args: - aux_data (tuple): Metadata needed to construct block array. - children (tuple): Contains block array elements. - - Returns: - block_arr: Constructed :class:`.BlockArray`. - """ - return BlockArray(aux_data, children[0]) - - -register_pytree_node(BlockArray, _block_array_tree_flatten, _block_array_tree_unflatten) - -# Syntactic sugar for the .at operations -# see https://github.com/google/jax/blob/56e9f7cb92e3a099adaaca161cc14153f024047c/jax/_src/numpy/lax_numpy.py#L5900 -class _BlockArrayIndexUpdateHelper: - """The helper class for the `at` property to call indexed update functions. - - The `at` property is syntactic sugar for calling the indexed update - functions as is done in jax. The index must be of the form [ibk] or - [ibk,idx], where `ibk` is the index of the block to be updated, and - `idx` is a general index of the elements to be updated in that block. - - In particular: - - `x = x.at[ibk].set(y)` is an equivalent of `x[ibk] = y`. - - `x = x.at[ibk,idx].set(y)` is an equivalent of `x[ibk,idx] = y`. - - The methods `set, add, multiply, divide, power, maximum, minimum` - are supported. - """ - - __slots__ = ("_block_array",) - - def __init__(self, block_array): - self._block_array = block_array - - def __getitem__(self, index): - if isinstance(index, tuple): - if isinstance(index[0], slice): - raise TypeError(f"Slicing not supported along block index") - return _BlockArrayIndexUpdateRef(self._block_array, index) - - def __repr__(self): - print(f"_BlockArrayIndexUpdateHelper({repr(self._block_array)})") - - -class _BlockArrayIndexUpdateRef: - """Helper object to call indexed update functions for an (advanced) index. - - This object references a source block array and a specific indexer, - with the first integer specifying the block being updated, and rest - being the indexer into the array of that block. Methods on this - object return copies of the source block array that have been - modified at the positions specified by the indexer in the given block. - """ - - __slots__ = ("_block_array", "bk_index", "index") - - def __init__(self, block_array, index): - self._block_array = block_array - if isinstance(index, int): - self.bk_index = index - self.index = Ellipsis - elif index == Ellipsis: - self.bk_index = Ellipsis - self.index = Ellipsis - else: - self.bk_index = index[0] - self.index = index[1:] - - def __repr__(self): - return f"_BlockArrayIndexUpdateRef({repr(self._block_array)}, {repr(self.bk_index)}, {repr(self.index)})" - - def _index_wrapper(self, func_str, values): - bk_index = self.bk_index - index = self.index - arr_tuple = self._block_array.split - if bk_index == Ellipsis: - # This may result in multiple copies: one per sub-blockarray, - # then one to combine into a nested BA. - retval = BlockArray.array([getattr(_.at[index], func_str)(values) for _ in arr_tuple]) - else: - retval = BlockArray.array( - arr_tuple[:bk_index] - + (getattr(arr_tuple[bk_index].at[index], func_str)(values),) - + arr_tuple[bk_index + 1 :] - ) - return retval - - def set(self, values): - """Pure equivalent of `x[idx] = y`. - - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] = y`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("set", values) - - def add(self, values): - """Pure equivalent of `x[idx] += y`. - - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] += y`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("add", values) - - def multiply(self, values): - """Pure equivalent of `x[idx] *= y`. - - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] *= y`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("multiply", values) - - def divide(self, values): - """Pure equivalent of `x[idx] /= y`. - - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] /= y`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("divide", values) - - def power(self, values): - """Pure equivalent of `x[idx] **= y`. - - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] **= y`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("power", values) - - def min(self, values): - """Pure equivalent of `x[idx] = minimum(x[idx], y)`. - - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] = minimum(x[idx], y)`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("min", values) - - def max(self, values): - """Pure equivalent of `x[idx] = maximum(x[idx], y)`. - - Return the value of `x` that would result from the NumPy-style - :mod:indexed assignment ` `x[idx] = maximum(x[idx], y)`. - - See :mod:`jax.ops` for details. - """ - return self._index_wrapper("max", values) - - -setattr(BlockArray, "at", property(_BlockArrayIndexUpdateHelper)) diff --git a/scico/denoiser.py b/scico/denoiser.py index b48a7b1d7..c5e0a0f61 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -43,7 +43,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): Args: x: Input image. Expected to be a 2D array (gray-scale denoising) - or 3D array (color denoising). Higher dimensional arrays are + or 3D array (color denoising). Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. For color denoising, the color channel is assumed to be in the last non-singleton dimension. @@ -70,7 +70,7 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False): if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( - "BM3D requires two dimensional or three dimensional inputs;" f" got ndim = {x.ndim}" + "BM3D requires two-dimensional or three dimensional inputs;" f" got ndim = {x.ndim}" ) # This check is also performed inside the BM3D call, but due to the host_callback, @@ -108,7 +108,7 @@ def bm4d(x: JaxArray, sigma: float): :cite:`maggioni-2012-nonlocal`. Args: - x: Input image. Expected to be a 3D array. Higher dimensional + x: Input image. Expected to be a 3D array. Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. sigma: Noise parameter. @@ -128,7 +128,7 @@ def bm4d(x: JaxArray, sigma: float): x_in_shape = x.shape if isinstance(x.ndim, tuple) or x.ndim < 3: - raise ValueError(f"BM4D requires three dimensional inputs; got ndim = {x.ndim}") + raise ValueError(f"BM4D requires three-dimensional inputs; got ndim = {x.ndim}") # This check is also performed inside the BM4D call, but due to the host_callback, # no exception is raised and the program will crash with no traceback. @@ -205,7 +205,7 @@ def __call__(self, x: JaxArray) -> JaxArray: if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( - "DnCNN requires two dimensional (M, N) or three dimensional (M, N, C)" + "DnCNN requires two-dimensional (M, N) or three-dimensional (M, N, C)" f" inputs; got ndim = {x.ndim}" ) diff --git a/scico/functional/_dist.py b/scico/functional/_dist.py index 5aaca6b0e..935d4c5c9 100644 --- a/scico/functional/_dist.py +++ b/scico/functional/_dist.py @@ -10,7 +10,7 @@ from typing import Callable, Union from scico import numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray from ._functional import Functional diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 408946805..df8c10046 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -13,7 +13,7 @@ import scico from scico import numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray @@ -252,7 +252,7 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray: """ if len(v.shape) == len(self.functional_list): - return BlockArray.array([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)]) + return snp.blockarray([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)]) raise ValueError( f"Number of blocks in v, {len(v.shape)}, and length of functional_list, " f"{len(self.functional_list)}, do not match" diff --git a/scico/functional/_indicator.py b/scico/functional/_indicator.py index 0fcd98330..21ab0f69b 100644 --- a/scico/functional/_indicator.py +++ b/scico/functional/_indicator.py @@ -12,7 +12,7 @@ import jax from scico import numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.numpy.linalg import norm from scico.typing import JaxArray diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 20dc1eaf7..70d094127 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -12,10 +12,9 @@ from jax import jit, lax from scico import numpy as snp -from scico.array import no_nan_divide -from scico.blockarray import BlockArray -from scico.numpy import count_nonzero +from scico.numpy import BlockArray, count_nonzero from scico.numpy.linalg import norm +from scico.numpy.util import no_nan_divide from scico.typing import JaxArray from ._functional import Functional diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index 9520ec9a9..05bb024aa 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -18,7 +18,7 @@ import scico.numpy as snp from scico._generic_operators import Operator -from scico.array import is_nested +from scico.numpy.util import is_nested from scico.typing import DType, JaxArray, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index 1e675c094..2164b892a 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -23,8 +23,8 @@ from jax.scipy.signal import convolve import scico.numpy as snp -from scico import array from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar +from scico.numpy.util import ensure_on_device from scico.typing import DType, JaxArray, Shape @@ -66,7 +66,7 @@ def __init__( if h.ndim != len(input_shape): raise ValueError(f"h.ndim = {h.ndim} must equal len(input_shape) = {len(input_shape)}") - self.h = array.ensure_on_device(h) + self.h = ensure_on_device(h) if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'") diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index b82013e2d..9f0e77d52 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -17,7 +17,7 @@ import numpy as np import scico.numpy as snp -from scico.array import parse_axes +from scico.numpy.util import parse_axes from scico.typing import Axes, DType, JaxArray, Shape from ._linop import LinearOperator diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index dc03b549c..3866b55bf 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -16,9 +16,9 @@ from typing import Any, Callable, Optional, Union import scico.numpy as snp -from scico import array, blockarray from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar -from scico.blockarray import BlockArray +from scico.numpy import BlockArray +from scico.numpy.util import ensure_on_device, indexed_shape, is_nested from scico.random import randn from scico.typing import ArrayIndex, BlockShape, DType, JaxArray, PRNGKey, Shape @@ -181,7 +181,7 @@ def __init__( """ - self.diagonal = array.ensure_on_device(diagonal) + self.diagonal = ensure_on_device(diagonal) if input_shape is None: input_shape = self.diagonal.shape @@ -189,9 +189,9 @@ def __init__( if input_dtype is None: input_dtype = self.diagonal.dtype - if isinstance(diagonal, BlockArray) and array.is_nested(input_shape): + if isinstance(diagonal, BlockArray) and is_nested(input_shape): output_shape = (snp.empty(input_shape) * diagonal).shape - elif not isinstance(diagonal, BlockArray) and not array.is_nested(input_shape): + elif not isinstance(diagonal, BlockArray) and not is_nested(input_shape): output_shape = snp.broadcast_shapes(input_shape, self.diagonal.shape) elif isinstance(diagonal, BlockArray): raise ValueError(f"`diagonal` was a BlockArray but `input_shape` was not nested.") @@ -282,10 +282,10 @@ def __init__( functions of the LinearOperator. """ - if array.is_nested(input_shape): - output_shape = blockarray.indexed_shape(input_shape, idx) + if is_nested(input_shape): + output_shape = input_shape[idx] else: - output_shape = array.indexed_shape(input_shape, idx) + output_shape = indexed_shape(input_shape, idx) self.idx: ArrayIndex = idx super().__init__( diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index fd1bab1a9..46da33b4d 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -84,7 +84,7 @@ def __init__(self, A: JaxArray): # Can only do rank-2 arrays if A.ndim != 2: - raise TypeError(f"Expected a 2-dimensional array, got array of shape {A.shape}") + raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}") super().__init__(input_shape=A.shape[1], output_shape=A.shape[0], input_dtype=self.A.dtype) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 8b09642ca..2289904ef 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -16,7 +16,7 @@ import numpy as np import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -85,7 +85,7 @@ def __init__( def _eval(self, x: JaxArray) -> Union[JaxArray, BlockArray]: if self.collapsable and self.collapse: return snp.stack([op @ x for op in self.ops]) - return BlockArray.array([op @ x for op in self.ops]) + return BlockArray([op @ x for op in self.ops]) def _adj(self, y: Union[JaxArray, BlockArray]) -> JaxArray: return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 256e93311..4a785d4f5 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -60,8 +60,8 @@ import jax import scico.numpy as snp -from scico.array import no_nan_divide from scico.linop import Diagonal, Identity, LinearOperator +from scico.numpy.util import no_nan_divide from scico.typing import Shape from ._dft import DFT diff --git a/scico/loss.py b/scico/loss.py index 2c4e98c40..2e76d9566 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -17,8 +17,8 @@ import scico.numpy as snp from scico import functional, linop, operator -from scico.array import ensure_on_device, no_nan_divide -from scico.blockarray import BlockArray +from scico.numpy import BlockArray +from scico.numpy.util import ensure_on_device, no_nan_divide from scico.scipy.special import gammaln from scico.solver import cg from scico.typing import JaxArray @@ -201,7 +201,7 @@ def __init__( self.has_prox = True def __call__(self, x: Union[JaxArray, BlockArray]) -> float: - return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum() + return self.scale * snp.sum(self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2) def prox( self, v: Union[JaxArray, BlockArray], lam: float, **kwargs diff --git a/scico/metric.py b/scico/metric.py index 94bbe265b..6b1b28665 100644 --- a/scico/metric.py +++ b/scico/metric.py @@ -14,7 +14,7 @@ import numpy as np import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import JaxArray diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 01d9e435d..3cd1d1807 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -5,163 +5,49 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -""":class:`scico.blockarray.BlockArray`-compatible -versions of :mod:`jax.numpy` functions. +r""":class:`.BlockArray` and functions for working with them alongside +:class:`DeviceArray`\ s. -This modules consists of functions from :mod:`jax.numpy` wrapped to -support compatibility with :class:`scico.blockarray.BlockArray`. This -module is a work in progress and therefore not all functions are -wrapped. Functions that have not been wrapped yet have WARNING text in -their documentation, below. -""" - -import sys -from functools import wraps +This module consists of :class:`.BlockArray` and functions for working +:class`.BlockArray`\ s alongside class:`DeviceArray`\ s. This includes +all the functions from :mod:`jax.numpy` and :mod:`numpy.testing`, where +many have been extended to automatically map over block array blocks as +described in :mod:`scico.numpy.blockarray`. Also included are additional +functions unique to SCICO in :mod:`.util`. +""" import numpy as np -import jax -from jax import numpy as jnp +import jax.numpy as jnp -from scico.array import is_nested +from . import _wrappers +from ._wrapped_function_lists import * +from .blockarray import BlockArray -# These functions rely on the definition of a BlockArray and must be in -# scico.blockarray to avoid a circular import -from scico.blockarray import ( - BlockArray, - _block_array_matmul_wrapper, - _block_array_reduction_wrapper, - _block_array_ufunc_wrapper, - _flatten_blockarrays, - atleast_1d, - reshape, -) -from scico.typing import BlockShape, JaxArray, Shape +# allow snp.blockarray(...) to create BlockArrays +blockarray = BlockArray.blockarray -from ._create import ( - empty, - empty_like, - full, - full_like, - ones, - ones_like, - zeros, - zeros_like, +# copy most of jnp without wrapping +_wrappers.add_attributes( + to_dict=vars(), + from_dict=jnp.__dict__, + modules_to_recurse=("linalg", "fft"), ) -from ._util import _attach_wrapped_func, _get_module_functions, _not_implemented - -# Numpy constants -pi = np.pi -e = np.e -euler_gamma = np.euler_gamma -inf = np.inf -NINF = np.NINF -PZERO = np.PZERO -NZERO = np.NZERO -nan = np.nan - -bool_ = jnp.bool_ -uint8 = jnp.uint8 -uint16 = jnp.uint16 -uint32 = jnp.uint32 -uint64 = jnp.uint64 -int8 = jnp.int8 -int16 = jnp.int16 -int32 = jnp.int32 -int64 = jnp.int64 -bfloat16 = jnp.bfloat16 -float16 = jnp.float16 -float32 = single = jnp.float32 -float64 = double = jnp.float64 -complex64 = csingle = jnp.complex64 -complex128 = cdouble = jnp.complex128 - -dtype = jnp.dtype -newaxis = None - -# Functions to which _block_array_ufunc_wrapper is to be applied -_ufunc_functions = [ - ("abs", jnp.abs), - jnp.maximum, - jnp.sign, - jnp.where, - jnp.true_divide, - jnp.floor_divide, - jnp.real, - jnp.imag, - jnp.conjugate, - jnp.angle, - jnp.exp, - jnp.sqrt, - jnp.log, - jnp.log10, -] -# Functions to which _block_array_reduction_wrapper is to be applied -_reduction_functions = [ - jnp.count_nonzero, - jnp.sum, - jnp.mean, - jnp.median, - jnp.any, - jnp.var, - ("max", jnp.max), - ("min", jnp.min), - jnp.amin, - jnp.amax, - jnp.all, - jnp.any, -] - -dot = _block_array_matmul_wrapper(jnp.dot) -matmul = _block_array_matmul_wrapper(jnp.matmul) - - -@wraps(jnp.vdot) -def vdot(a, b): - """Dot product of `a` and `b` (with first argument complex conjugated). - Wrapped to work on `BlockArray`s.""" - if isinstance(a, BlockArray): - a = a.ravel() - if isinstance(b, BlockArray): - b = b.ravel() - return jnp.vdot(a, b) - - -vdot.__doc__ = ":func:`vdot` wrapped to operate on :class:`.BlockArray`" + "\n\n" + jnp.vdot.__doc__ - -# Attach wrapped functions to this module -_attach_wrapped_func( - _ufunc_functions, - _block_array_ufunc_wrapper, - module_name=sys.modules[__name__], - fix_mod_name=True, -) -_attach_wrapped_func( - _reduction_functions, - _block_array_reduction_wrapper, - module_name=sys.modules[__name__], - fix_mod_name=True, -) - -# divide is just an alias to true_divide -divide = true_divide -conj = conjugate -# Find functions that exist in jax.numpy but not scico.numpy -# see jax.numpy.__init__.py -_not_implemented_functions = [] -for name, func in _get_module_functions(jnp).items(): - if name not in globals(): - _not_implemented_functions.append((name, func)) +# wrap jnp funcs +_wrappers.wrap_recursively(vars(), creation_routines, _wrappers.map_func_over_tuple_of_tuples) +_wrappers.wrap_recursively(vars(), mathematical_functions, _wrappers.map_func_over_blocks) +_wrappers.wrap_recursively(vars(), reduction_functions, _wrappers.add_full_reduction) -_attach_wrapped_func( - _not_implemented_functions, - _not_implemented, - module_name=sys.modules[__name__], - fix_mod_name=False, +# copy np.testing +_wrappers.add_attributes( + to_dict=vars(), + from_dict={"testing": np.testing}, + modules_to_recurse=("testing",), ) +# wrap testing funcs +_wrappers.wrap_recursively(vars(), testing_functions, _wrappers.map_func_over_blocks) -# these must be imported towards the end to avoid a circular import with -# linalg and _matrixop -from . import fft, linalg +# clean up +del np, jnp, _wrappers diff --git a/scico/numpy/_create.py b/scico/numpy/_create.py deleted file mode 100644 index 8478d71d6..000000000 --- a/scico/numpy/_create.py +++ /dev/null @@ -1,175 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2021-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Functions for creating new arrays.""" - -from typing import Union - -import numpy as np - -import jax -from jax import numpy as jnp - -from scico.array import is_nested -from scico.blockarray import BlockArray -from scico.typing import BlockShape, DType, JaxArray, Shape - - -def zeros( - shape: Union[Shape, BlockShape], dtype: DType = np.float32 -) -> Union[JaxArray, BlockArray]: - """Return a new array of given shape and type, filled with zeros. - - If `shape` is a list of tuples, returns a BlockArray of zeros. - - Args: - shape: Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. - """ - if is_nested(shape): - return BlockArray.zeros(shape, dtype=dtype) - return jnp.zeros(shape, dtype=dtype) - - -def ones(shape: Union[Shape, BlockShape], dtype: DType = np.float32) -> Union[JaxArray, BlockArray]: - """Return a new array of given shape and type, filled with ones. - - If `shape` is a list of tuples, returns a BlockArray of ones. - - Args: - shape: Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. - """ - if is_nested(shape): - return BlockArray.ones(shape, dtype=dtype) - return jnp.ones(shape, dtype=dtype) - - -def empty( - shape: Union[Shape, BlockShape], dtype: DType = np.float32 -) -> Union[JaxArray, BlockArray]: - """Return a new array of given shape and type, filled with zeros. - - If `shape` is a list of tuples, returns a BlockArray of zeros. - - Args: - shape: Shape of the new array. - dtype: Desired data-type of the array. Default is `np.float32`. - """ - if is_nested(shape): - return BlockArray.empty(shape, dtype=dtype) - return jnp.empty(shape, dtype=dtype) - - -def full( - shape: Union[Shape, BlockShape], - fill_value: Union[float, complex], - dtype: DType = None, -) -> Union[JaxArray, BlockArray]: - """Return a new array of given shape and type, filled with `fill_value`. - - If `shape` is a list of tuples, returns a BlockArray filled with - `fill_value`. - - Args: - shape: Shape of the new array. - fill_value : Fill value. - dtype: Desired data-type of the array. The default, None, - means `np.array(fill_value).dtype`. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(type(fill_value)) - if is_nested(shape): - return BlockArray.full(shape, fill_value=fill_value, dtype=dtype) - return jnp.full(shape, fill_value=fill_value, dtype=dtype) - - -def zeros_like(x: Union[JaxArray, BlockArray], dtype=None): - """Return an array of zeros with same shape and type as a given array. - - If input is a BlockArray, returns a BlockArray of zeros with same - shape and type as a given array. - - Args: - x (array like): The shape and dtype of `x` define these - attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the - result. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(x.dtype) - - if isinstance(x, BlockArray): - return BlockArray.zeros(x.shape, dtype=dtype) - return jnp.zeros_like(x, dtype=dtype) - - -def empty_like(x: Union[JaxArray, BlockArray], dtype: DType = None): - """Return an array of zeros with same shape and type as a given array. - - If input is a BlockArray, returns a BlockArray of zeros with same - shape and type as a given array. - - Note: like :func:`jax.numpy.empty_like`, this does not return an - uninitalized array. - - Args: - x (array like): The shape and dtype of `x` define these - attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the - result. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(x.dtype) - - if isinstance(x, BlockArray): - return BlockArray.zeros(x.shape, dtype=dtype) - return jnp.zeros_like(x, dtype=dtype) - - -def ones_like(x: Union[JaxArray, BlockArray], dtype: DType = None): - """Return an array of ones with same shape and type as a given array. - - If input is a BlockArray, returns a BlockArray of ones with same - shape and type as a given array. - - Args: - x (array like): The shape and dtype of `x` define these - attributes on the returned array. - dtype (data-type, optional): Overrides the data type of the - result. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(x.dtype) - - if isinstance(x, BlockArray): - return BlockArray.ones(x.shape, dtype=dtype) - return jnp.ones_like(x, dtype=dtype) - - -def full_like( - x: Union[JaxArray, BlockArray], fill_value: Union[float, complex], dtype: DType = None -): - """Return an array filled with `fill_value`. - - Return an array of with same shape and type as a given array, filled - with `fill_value`. If input is a BlockArray, returns a BlockArray of - `fill_value` with same shape and type as a given array. - - Args: - x (array like): The shape and dtype of `x` define these - attributes on the returned array. - fill_value (scalar): Fill value. - dtype (data-type, optional): Overrides the data type of the - result. - """ - if dtype is None: - dtype = jax.dtypes.canonicalize_dtype(x.dtype) - - if isinstance(x, BlockArray): - return BlockArray.full(x.shape, fill_value=fill_value, dtype=dtype) - return jnp.full_like(x, fill_value=fill_value, dtype=dtype) diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py deleted file mode 100644 index af38cdb1e..000000000 --- a/scico/numpy/_util.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Tools to construct wrapped versions of :mod:`jax.numpy` functions.""" - -import re -import types -from functools import wraps - -import numpy as np - -from jaxlib.xla_extension import CompiledFunction - -# wrapper for not-implemented jax.numpy functions -# stripped down version of jax._src.lax_numpy._not_implemented and jax.utils._wraps - -_NOT_IMPLEMENTED_DESC = """ -**WARNING**: This function is not yet implemented by :mod:`scico.numpy` and -may raise an error when operating on :class:`scico.blockarray.BlockArray`. -""" - - -def _not_implemented(fun): - @wraps(fun) - def wrapped(*args, **kwargs): - return fun(*args, **kwargs) - - if not hasattr(fun, "__doc__") or fun.__doc__ is None: - return wrapped - - # wrapped.__doc__ = fun.__doc__ + "\n\n" + _NOT_IMPLEMENTED_DESC - wrapped.__doc__ = re.sub( - r"^\*Original docstring below\.\*", - _NOT_IMPLEMENTED_DESC + r"\n\n" + "*Original docstring below.*", - wrapped.__doc__, - flags=re.M, - ) - return wrapped - - -def _attach_wrapped_func(funclist, wrapper, module_name, fix_mod_name=False): - # funclist is either a function, or a tuple (name-in-this-module, function) - # wrapper is a function that is applied to each function in funclist, with - # the output being assigned as an attribute of the module `module_name` - for func in funclist: - # Test required because func.__name__ isn't always the name we want - # e.g. jnp.abs.__name__ resolves to 'absolute', not 'abs' - if isinstance(func, tuple): - fname = func[0] - fref = func[1] - else: - fname = func.__name__ - fref = func - # Set wrapped function as an attribute in module_name - setattr(module_name, fname, wrapper(fref)) - # Set __module__ attribute of wrapped function to - # module_name.__name__ (i.e., scico.numpy) so that it does not - # appear to autodoc be an imported function - if fix_mod_name: - getattr(module_name, fname).__module__ = module_name.__name__ - - -def _get_module_functions(module): - """Finds functions in module. - - This function is a slightly modified version of - :func:`jax._src.util.get_module_functions`. Unlike the JAX version, - this version will also return any - :class:`jaxlib.xla_extension.CompiledFunction`s that exist in the - module. - - Args: - module: A Python module. - Returns: - module_fns: A dict of names mapped to functions, builtins or - ufuncs in `module`. - """ - module_fns = {} - for key in dir(module): - # Omitting module level __getattr__, __dir__ which was added in Python 3.7 - # https://www.python.org/dev/peps/pep-0562/ - if key in ("__getattr__", "__dir__"): - continue - attr = getattr(module, key) - if isinstance( - attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc, CompiledFunction) - ): - module_fns[key] = attr - return module_fns diff --git a/scico/numpy/_wrapped_function_lists.py b/scico/numpy/_wrapped_function_lists.py new file mode 100644 index 000000000..56471c881 --- /dev/null +++ b/scico/numpy/_wrapped_function_lists.py @@ -0,0 +1,274 @@ +""" BlockArray """ +unary_ops = ( # found from dir(DeviceArray) + "__abs__", + "__neg__", + "__pos__", +) + +binary_ops = ( # found from dir(DeviceArray) + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__mod__", + "__rmul__", + "__matmul__", + "__rmatmul__", + "__truediv__", + "__rtruediv__", + "__floordiv__", + "__rfloordiv__", + "__pow__", + "__rpow__", + "__gt__", + "__ge__", + "__lt__", + "__le__", + "__eq__", + "__ne__", +) + +""" jax.numpy """ + +creation_routines = ( + "empty", + "ones", + "zeros", + "full", +) + +mathematical_functions = ( + "sin", # https://numpy.org/doc/stable/reference/routines.math.html# + "cos", + "tan", + "arcsin", + "arccos", + "arctan", + "hypot", + "arctan2", + "degrees", + "radians", + "unwrap", + "deg2rad", + "rad2deg", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", + "around", + "round_", + "rint", + "fix", + "floor", + "ceil", + "trunc", + "prod", + "sum", + "nanprod", + "nansum", + "cumprod", + "cumsum", + "nancumprod", + "nancumsum", + "diff", + "ediff1d", + "gradient", + "cross", + "trapz", + "exp", + "expm1", + "exp2", + "log", + "log10", + "log2", + "log1p", + "logaddexp", + "logaddexp2", + "i0", + "sinc", + "signbit", + "copysign", + "frexp", + "ldexp", + "nextafter", + "spacing", + "lcm", + "gcd", + "add", + "reciprocal", + "positive", + "negative", + "multiply", + "divide", + "power", + "subtract", + "true_divide", + "floor_divide", + "float_power", + "fmod", + "mod", + "modf", + "remainder", + "divmod", + "angle", + "real", + "imag", + "conj", + "conjugate", + "maximum", + "fmax", + "amax", + "nanmax", + "minimum", + "fmin", + "amin", + "nanmin", + "convolve", + "clip", + "sqrt", + "cbrt", + "square", + "abs", + "absolute", + "fabs", + "sign", + "heaviside", + "nan_to_num", + "real_if_close", + "interp", + "sort", # https://numpy.org/doc/stable/reference/routines.sort.html + "lexsort", + "argsort", + "msort", + "sort_complex", + "partition", + "argpartition", + "argmax", + "nanargmax", + "argmin", + "nanargmin", + "argwhere", + "nonzero", + "flatnonzero", + "where", + "searchsorted", + "extract", + "count_nonzero", + "dot", # https://numpy.org/doc/stable/reference/routines.linalg.html + "linalg.multi_dot", + "vdot", + "inner", + "outer", + "matmul", + "tensordot", + "einsum", + "einsum_path", + "linalg.matrix_power", + "kron", + "linalg.cholesky", + "linalg.qr", + "linalg.svd", + "linalg.eig", + "linalg.eigh", + "linalg.eigvals", + "linalg.eigvalsh", + "linalg.norm", + "linalg.cond", + "linalg.det", + "linalg.matrix_rank", + "linalg.slogdet", + "trace", + "linalg.solve", + "linalg.tensorsolve", + "linalg.lstsq", + "linalg.inv", + "linalg.pinv", + "linalg.tensorinv", + "shape", # https://numpy.org/doc/stable/reference/routines.array-manipulation.html + "reshape", + "ravel", + "moveaxis", + "rollaxis", + "swapaxes", + "transpose", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "expand_dims", + "squeeze", + "asarray", + "asanyarray", + "asmatrix", + "asfarray", + "asfortranarray", + "ascontiguousarray", + "asarray_chkfinite", + "asscalar", + "require", + "stack", + "block", + "vstack", + "hstack", + "dstack", + "column_stack", + "row_stack", + "split", + "array_split", + "dsplit", + "hsplit", + "vsplit", + "tile", + "repeat", + "insert", + "append", + "resize", + "trim_zeros", + "unique", + "flip", + "fliplr", + "flipud", + "reshape", + "roll", + "rot90", + "all", + "any", + "isfinite", + "isinf", + "isnan", + "isnat", + "isneginf", + "isposinf", + "iscomplex", + "iscomplexobj", + "isfortran", + "isreal", + "isrealobj", + "isscalar", + "logical_and", + "logical_or", + "logical_not", + "logical_xor", + "allclose", + "isclose", + "array_equal", + "array_equiv", + "greater", + "greater_equal", + "less", + "less_equal", + "equal", + "not_equal", + "empty_like", # https://numpy.org/doc/stable/reference/routines.array-creation.html + "ones_like", + "zeros_like", + "full_like", +) + +reduction_functions = ("sum", "linalg.norm") + +""" "testing", """ + +testing_functions = ("testing.assert_allclose", "testing.assert_array_equal") diff --git a/scico/numpy/_wrappers.py b/scico/numpy/_wrappers.py new file mode 100644 index 000000000..891683f7b --- /dev/null +++ b/scico/numpy/_wrappers.py @@ -0,0 +1,154 @@ +""" +Utilities for wrapping jnp functions to handle BlockArray inputs. +""" + +import sys +from functools import wraps +from inspect import signature +from types import ModuleType +from typing import Callable, Iterable, Optional + +import jax.numpy as jnp + +from .blockarray import BlockArray + + +def add_attributes( + to_dict: dict, + from_dict: dict, + modules_to_recurse: Optional[Iterable[str]] = None, +): + """Add attributes in `from_dict` to `to_dict`. + + Underscore attributes are ignored. Modules are ignored, except those + listed in `modules_to_recurse`, which are added recursively. All + others are added. + """ + + if modules_to_recurse is None: + modules_to_recurse = () + + for name, obj in from_dict.items(): + if name[0] == "_": + continue + if isinstance(obj, ModuleType) and name in modules_to_recurse: + to_dict[name] = ModuleType(name) + to_dict[name].__package__ = to_dict["__name__"] + to_dict[name].__doc__ = obj.__doc__ + # enable `import scico.numpy.linalg` and `from scico.numpy.linalg import norm` + sys.modules[to_dict["__name__"] + "." + name] = to_dict[name] + add_attributes(to_dict[name].__dict__, obj.__dict__) + else: + to_dict[name] = obj + + +def wrap_recursively( + target_dict: dict, + names: Iterable[str], + wrap: Callable, +): + """Call wrap functions in `target_dict`, correctly handling names like `"linalg.norm"`.""" + + for name in names: + if "." in name: + module, rest = name.split(".", maxsplit=1) + wrap_recursively(target_dict[module].__dict__, [rest], wrap) + else: + target_dict[name] = wrap(target_dict[name]) + + +def map_func_over_tuple_of_tuples(func: Callable, map_arg_name: Optional[str] = "shape"): + """Wrap a function so that it automatically maps over a tuple of tuples + argument, returning a `BlockArray`. + + """ + + @wraps(func) + def mapped(*args, **kwargs): + bound_args = signature(func).bind(*args, **kwargs) + + if map_arg_name not in bound_args.arguments: # no shape arg + return func(*args, **kwargs) # no mapping + + map_arg_val = bound_args.arguments.pop(map_arg_name) + + if not isinstance(map_arg_val, tuple) or not all( + isinstance(x, tuple) for x in map_arg_val + ): # not nested tuple + return func(*args, **kwargs) # no mapping + + # map + return BlockArray( + func(*bound_args.args, **bound_args.kwargs, **{map_arg_name: x}) for x in map_arg_val + ) + + return mapped + + +def map_func_over_blocks(func, is_reduction=False): + """Wrap a function so that it maps over all of its `BlockArray` + arguments. + + is_reduction: function is handled in a special way in order to allow + full reductions of `BlockArray`s. If the axis parameter exists but + is not specified, the function is called on a fully ravelled version + of all `BlockArray` inputs. + """ + sig = signature(func) + + @wraps(func) + def mapped(*args, **kwargs): + bound_args = sig.bind(*args, **kwargs) + + ba_args = {} + for k, v in list(bound_args.arguments.items()): + if isinstance(v, BlockArray): + ba_args[k] = bound_args.arguments.pop(k) + + if not ba_args: # no BlockArray arguments + return func(*args, **kwargs) # no mapping + + num_blocks = len(list(ba_args.values())[0]) + + return BlockArray( + func(*bound_args.args, **bound_args.kwargs, **{k: v[i] for k, v in ba_args.items()}) + for i in range(num_blocks) + ) + + return mapped + + +def add_full_reduction(func: Callable, axis_arg_name: Optional[str] = "axis"): + """Wrap a function so that it can fully reduce a `BlockArray`. If + nothing is passed for the axis argument and the function is called + on a `BlockArray`, it is fully ravelled before the function is + called. + + Should be outside `map_func_over_blocks`. + """ + sig = signature(func) + if axis_arg_name not in sig.parameters: + raise ValueError( + f"Cannot wrap {func} as a reduction because it has no {axis_arg_name} argument" + ) + + @wraps(func) + def wrapped(*args, **kwargs): + bound_args = sig.bind(*args, **kwargs) + + ba_args = {} + for k, v in list(bound_args.arguments.items()): + if isinstance(v, BlockArray): + ba_args[k] = bound_args.arguments.pop(k) + + if "axis" in bound_args.arguments: + return func(*bound_args.args, **bound_args.kwargs, **ba_args) # call func as normal + + if len(ba_args) > 1: + raise ValueError("Cannot perform a full reduction with multiple BlockArray arguments.") + + # fully ravel the ba argument + ba_args = {k: jnp.concatenate(v.ravel()) for k, v in ba_args.items()} + return func(*bound_args.args, **bound_args.kwargs, **ba_args) + + return wrapped diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py new file mode 100644 index 000000000..84e409e46 --- /dev/null +++ b/scico/numpy/blockarray.py @@ -0,0 +1,365 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2020-2022 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SPORCO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r"""Block array class. + + .. testsetup:: + + >>> import scico + >>> import scico.numpy as snp + >>> from scico.numpy import BlockArray + >>> import numpy as np + >>> import jax.numpy + +The class :class:`.BlockArray` provides a way to combine arrays of +different shapes into a single object for use with other SCICO classes. +A :class:`.BlockArray` consists of a list of `DeviceArray` objects, +which we refer to as blocks. :class:`.BlockArray` s differ from lists in +that, whenever possible, :class:`.BlockArray` properties and methods +(including unary and binary operators like +, -, \*, ...) automatically +map along the blocks, returning another :class:`.BlockArray` or tuple as +appropriate. For example, + + :: + + >>> x = BlockArray(( + ... [[1, 3, 7], + ... [2, 2, 1]], + ... [2, 4, 8] + ... )) + + >>> x.shape # returns tuple + ((2, 3), (3,)) + + >>> x * 2 # returns BlockArray + [DeviceArray([[ 2, 6, 14], + [ 4, 4, 2]], dtype=int32), DeviceArray([ 4, 8, 16], dtype=int32)] + + >>> y = BlockArray(( + ... [[.2], + ... [.3]], + ... [.4] + ... )) + + >>> x + y # returns BlockArray + [DeviceArray([[1.2, 3.2, 7.2], + [2.3, 2.3, 1.3]], dtype=float32), DeviceArray([2.4, 4.4, 8.4], dtype=float32)] + + +NumPy and SciPy Functions +========================= + +:mod:`scico.numpy`, :mod:`scico.numpy.testing`, and +:mod:`scico.scipy.special` provide wrappers around :mod:`jax.numpy`, +:mod:`numpy.testing` and :mod:`jax.scipy.special` where many of the +functions have been extended to work with `BlockArray` s. In particular: + + * When a tuple of tuples is passed as the `shape` + argument to an array creation routine, a `BlockArray` is created. + * When a `BlockArray` is passed to a reduction function, the blocks are + ravelled (i.e., reshaped to be 1D) and concatenated before the reduction + is applied. This behavior may be prevented by passing the `axis` + argument, in which case the function is mapped over the blocks. + * When one or more `BlockArray`s is passed to a mathematical + function that is not a reduction, the function is mapped over + (corresponding) blocks. + +For a list of array creation routines, see + + :: + + >>> scico.numpy.creation_routines # doctest: +ELLIPSIS + ('empty', ...) + +For a list of reduction functions, see + + :: + + >>> scico.numpy.reduction_functions # doctest: +ELLIPSIS + ('sum', ...) + +For lists of the remaining wrapped functions, see + + :: + + >>> scico.numpy.mathematical_functions # doctest: +ELLIPSIS + ('sin', ...) + >>> scico.numpy.testing_functions # doctest: +ELLIPSIS + ('testing.assert_allclose', ...) + >>> import scico.scipy + >>> scico.scipy.special.functions # doctest: +ELLIPSIS + ('betainc', ...) + + +Motivating Example +================== + +Consider a two-dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`. + +We compute the discrete differences of :math:`\mb{x}` in the horizontal +and vertical directions, generating two new arrays: :math:`\mb{x}_h \in +\mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mathbb{R}^{(n-1) +\times m}`. + +As these arrays are of different shapes, we cannot combine them into a +single `ndarray`. Instead, we might vectorize each array and concatenate +the resulting vectors, leading to :math:`\mb{\bar{x}} \in +\mathbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional +`ndarray`. Unfortunately, this makes it hard to access the individual +components :math:`\mb{x}_h` and :math:`\mb{x}_v`. + +Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B = +[\mb{x}_h, \mb{x}_v]` + + + :: + + >>> n = 32 + >>> m = 16 + >>> x_h, key = scico.random.randn((n, m-1)) + >>> x_v, _ = scico.random.randn((n-1, m), key=key) + + # Form the blockarray + >>> x_B = snp.blockarray([x_h, x_v]) + + # The blockarray shape is a tuple of tuples + >>> x_B.shape + ((32, 15), (31, 16)) + + # Each block component can be easily accessed + >>> x_B[0].shape + (32, 15) + >>> x_B[1].shape + (31, 16) + + +Constructing a BlockArray +========================= + +The recommended way to construct a :class:`.BlockArray` is by using the +`snp.blockarray` function. + + :: + + >>> import scico.numpy as snp + >>> x0, key = scico.random.randn((32, 32)) + >>> x1, _ = scico.random.randn((16,), key=key) + >>> X = snp.blockarray((x0, x1)) + >>> X.shape + ((32, 32), (16,)) + >>> X.size + (1024, 16) + >>> len(X) + 2 + +While :func:`.snp.blockarray` will accept either `ndarray` or +`DeviceArray` as input, the resulting :class:`.BlockArray` will be backed +by a `DeviceArray` memory buffer. + +**Note**: constructing a :class:`.BlockArray` always involves a copy to +a new `DeviceArray` memory buffer. + +Operating on a BlockArray +========================= + +.. _blockarray_indexing: + +Indexing +-------- + +`BlockArray` indexing works just like indexing a list. + +Multiplication Between BlockArray and :class:`.LinearOperator` +-------------------------------------------------------------- + +The :class:`.Operator` and :class:`.LinearOperator` classes are designed +to work on :class:`.BlockArray`\ s in addition to `DeviceArray`\ s. +For example + + + :: + + >>> x, key = scico.random.randn((3, 4)) + >>> A_1 = scico.linop.Identity(x.shape) + >>> A_1.shape # array -> array + ((3, 4), (3, 4)) + + >>> A_2 = scico.linop.FiniteDifference(x.shape) + >>> A_2.shape # array -> BlockArray + (((2, 4), (3, 3)), (3, 4)) + + >>> diag = snp.blockarray([np.array(1.0), np.array(2.0)]) + >>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape)) + >>> A_3.shape # BlockArray -> BlockArray + (((2, 4), (3, 3)), ((2, 4), (3, 3))) + +""" + +import inspect +from functools import wraps +from typing import Callable + +import jax +import jax.numpy as jnp + +from jaxlib.xla_extension import DeviceArray + +from ._wrapped_function_lists import binary_ops, unary_ops + + +class BlockArray(list): + """BlockArray class""" + + # Ensure we use BlockArray.__radd__, __rmul__, etc for binary + # operations of the form op(np.ndarray, BlockArray) See + # https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ + __array_priority__ = 1 + + def __init__(self, inputs): + # convert inputs to DeviceArrays + arrays = [x if isinstance(x, jnp.ndarray) else jnp.array(x) for x in inputs] + + # check that dtypes match + if not all(a.dtype == arrays[0].dtype for a in arrays): + raise ValueError("Heterogeneous dtypes not supported") + + return super().__init__(arrays) + + @property + def dtype(self): + """Return the dtype of the blocks, which must currently be homogeneous. + + This allows snp.zeros(x.shape, x.dtype) to work without a mechanism + to handle to lists of dtypes. + """ + return self[0].dtype + + def __getitem__(self, key): + """Indexing method equivalent to x[key]. + + This is overridden to make, e.g., x[:2] return a BlockArray + rather than a list. + """ + result = super().__getitem__(key) + if not isinstance(result, jnp.ndarray): + return BlockArray(result) # x[k:k+1] returns a BlockArray + return result # x[k] returns a DeviceArray + + @staticmethod + def blockarray(iterable): + """Construct a :class:`.BlockArray` from a list or tuple of existing array-like.""" + return BlockArray(iterable) + + +# Register BlockArray as a jax pytree, without this, jax autograd won't work. +# taken from what is done with tuples in jax._src.tree_util +jax.tree_util.register_pytree_node( + BlockArray, + lambda xs: (xs, None), # to iter + lambda _, xs: BlockArray(xs), # from iter +) + + +# Wrap unary ops like -x. +def _unary_op_wrapper(op_name): + op = getattr(DeviceArray, op_name) + + @wraps(op) + def op_ba(self): + return BlockArray(op(x) for x in self) + + return op_ba + + +for op_name in unary_ops: + setattr(BlockArray, op_name, _unary_op_wrapper(op_name)) + + +# Wrap binary ops like x + y. """ +def _binary_op_wrapper(op_name): + op = getattr(DeviceArray, op_name) + + @wraps(op) + def op_ba(self, other): + # If other is a BA, we can assume the operation is implemented + # (because BAs must contain DeviceArrays) + if isinstance(other, BlockArray): + return BlockArray(op(x, y) for x, y in zip(self, other)) + + # If not, need to handle possible NotImplemented + # without this, ba + 'hi' -> [NotImplemented, NotImplemented, ...] + result = list(op(x, other) for x in self) + if NotImplemented in result: + return NotImplemented + return BlockArray(result) + + return op_ba + + +for op_name in binary_ops: + setattr(BlockArray, op_name, _binary_op_wrapper(op_name)) + + +# Wrap DeviceArray properties. +def _da_prop_wrapper(prop_name): + prop = getattr(DeviceArray, prop_name) + + @property + @wraps(prop) + def prop_ba(self): + result = tuple(getattr(x, prop_name) for x in self) + + # if da.prop is a DA, return BA + if isinstance(result[0], jnp.ndarray): + return BlockArray(result) + + # otherwise, return tuple + return result + + return prop_ba + + +skip_props = ("at",) +da_props = [ + k + for k, v in dict(inspect.getmembers(DeviceArray)).items() # (name, method) pairs + if isinstance(v, property) and k[0] != "_" and k not in dir(BlockArray) and k not in skip_props +] + +for prop_name in da_props: + setattr(BlockArray, prop_name, _da_prop_wrapper(prop_name)) + +# Wrap DeviceArray methods. +def _da_method_wrapper(method_name): + method = getattr(DeviceArray, method_name) + + @wraps(method) + def method_ba(self, *args, **kwargs): + result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self) + + # if da.method(...) is a DA, return a BA + if isinstance(result[0], jnp.ndarray): + return BlockArray(result) + + # otherwise return a tuple + return result + + return method_ba + + +skip_methods = () +da_methods = [ + k + for k, v in dict(inspect.getmembers(DeviceArray)).items() # (name, method) pairs + if isinstance(v, Callable) + and k[0] != "_" + and k not in dir(BlockArray) + and k not in skip_methods +] + +for method_name in da_methods: + setattr(BlockArray, method_name, _da_method_wrapper(method_name)) diff --git a/scico/numpy/fft.py b/scico/numpy/fft.py deleted file mode 100644 index 0577f6acd..000000000 --- a/scico/numpy/fft.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Construct wrapped versions of :mod:`jax.numpy.fft` functions. - -This modules consists of functions from :mod:`jax.numpy.fft`. Some of -these functions are wrapped to support compatibility with -:class:`scico.blockarray.BlockArray` and are documented here. -The remaining functions are imported directly from :mod:`jax.numpy.fft`. -While they can be imported from the :mod:`scico.numpy.fft` namespace, -they are not documented here; please consult the documentation for the -source module :mod:`jax.numpy.fft`. -""" -import sys - -import jax.numpy.fft - -from ._util import _attach_wrapped_func, _not_implemented - -_not_implemented_functions = [] -for name, func in jax._src.util.get_module_functions(jax.numpy.fft).items(): - if name not in globals(): - _not_implemented_functions.append((name, func)) - -_attach_wrapped_func( - _not_implemented_functions, _not_implemented, module_name=sys.modules[__name__] -) diff --git a/scico/numpy/linalg.py b/scico/numpy/linalg.py deleted file mode 100644 index 71974465e..000000000 --- a/scico/numpy/linalg.py +++ /dev/null @@ -1,108 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Construct wrapped versions of :mod:`jax.numpy.linalg` functions. - -This modules consists of functions from :mod:`jax.numpy.linalg`. Some of -these functions are wrapped to support compatibility with -:class:`scico.blockarray.BlockArray` and are documented here. The -remaining functions are imported directly from :mod:`jax.numpy.linalg`. -While they can be imported from the :mod:`scico.numpy.linalg` namespace, -they are not documented here; please consult the documentation for the -source module :mod:`jax.numpy.linalg`. -""" - - -import sys -from functools import wraps - -import jax -import jax.numpy.linalg as jla - -from scico.blockarray import _block_array_reduction_wrapper -from scico.linop._matrix import MatrixOperator - -from ._util import _attach_wrapped_func, _not_implemented - - -def _extract_if_matrix(x): - if isinstance(x, MatrixOperator): - return x.A - return x - - -def _matrixop_linalg_wrapper(func): - """Wrap :mod:`jax.numpy.linalg` functions. - - Wrap :mod:`jax.numpy.linalg` functions for joint operation on - `MatrixOperator` and `DeviceArray`.""" - - @wraps(func) - def wrapper(*args, **kwargs): - all_args = args + tuple(kwargs.items()) - if any([isinstance(_, MatrixOperator) for _ in all_args]): - args = [_extract_if_matrix(_) for _ in args] - kwargs = {key: _extract_if_matrix(val) for key, val in kwargs.items()} - return func(*args, **kwargs) - - if hasattr(func, "__doc__"): - wrapper.__doc__ = ( - f":func:`{func.__name__}` wrapped to operate on :class:`.MatrixOperator`" - + "\n\n" - + func.__doc__ - ) - return wrapper - - -# norm is a reduction and gets both block array and matrixop wrapping -norm = _block_array_reduction_wrapper(_matrixop_linalg_wrapper(jla.norm)) - -svd = _matrixop_linalg_wrapper(jla.svd) -cond = _matrixop_linalg_wrapper(jla.cond) -det = _matrixop_linalg_wrapper(jla.det) -eig = _matrixop_linalg_wrapper(jla.eig) -eigh = _matrixop_linalg_wrapper(jla.eigh) -eigvals = _matrixop_linalg_wrapper(jla.eigvals) -eigvalsh = _matrixop_linalg_wrapper(jla.eigvalsh) -inv = _matrixop_linalg_wrapper(jla.inv) -lstsq = _matrixop_linalg_wrapper(jla.lstsq) -matrix_power = _matrixop_linalg_wrapper(jla.matrix_power) -matrix_rank = _matrixop_linalg_wrapper(jla.matrix_rank) -pinv = _matrixop_linalg_wrapper(jla.pinv) -qr = _matrixop_linalg_wrapper(jla.qr) -slogdet = _matrixop_linalg_wrapper(jla.slogdet) -solve = _matrixop_linalg_wrapper(jla.solve) - - -# multidot is somewhat unique -def multi_dot(arrays, *, precision=None): - """Compute the dot product of two or more arrays. - - Compute the dot product of two or more arrays. - Wrapped to work with `MatrixOperator`s. - """ - arrays_ = [_extract_if_matrix(_) for _ in arrays] - return jla.multi_dot(arrays_, precision=precision) - - -multi_dot.__doc__ = ( - f":func:`multi_dot` wrapped to operate on :class:`.MatrixOperator`" - + "\n\n" - + jla.multi_dot.__doc__ -) - - -# Attach unwrapped functions -# jla.tensorinv, jla.tensorsolve use n-dim arrays; not supported by MatrixOperator -_not_implemented_functions = [] -for name, func in jax._src.util.get_module_functions(jla).items(): - if name not in globals(): - _not_implemented_functions.append((name, func)) - -_attach_wrapped_func( - _not_implemented_functions, _not_implemented, module_name=sys.modules[__name__] -) diff --git a/scico/array.py b/scico/numpy/util.py similarity index 85% rename from scico/array.py rename to scico/numpy/util.py index 17d582277..74eacfccd 100644 --- a/scico/array.py +++ b/scico/numpy/util.py @@ -1,16 +1,9 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Utility functions for arrays, array shapes, array indexing, etc.""" - +""" Utility functions for working with BlockArrays and DeviceArrays. """ from __future__ import annotations import warnings +from math import prod from typing import Any, List, Optional, Tuple, Union import numpy as np @@ -19,14 +12,15 @@ from jax.interpreters.pxla import ShardedDeviceArray from jax.interpreters.xla import DeviceArray -import scico.blockarray import scico.numpy as snp -from scico.typing import ArrayIndex, Axes, AxisIndex, DType, JaxArray, Shape +from scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, JaxArray, Shape + +from .blockarray import BlockArray def ensure_on_device( - *arrays: Union[np.ndarray, JaxArray, scico.blockarray.BlockArray] -) -> Union[JaxArray, scico.blockarray.BlockArray]: + *arrays: Union[np.ndarray, JaxArray, BlockArray] +) -> Union[JaxArray, BlockArray]: """Cast ndarrays to DeviceArrays. Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays, @@ -58,37 +52,22 @@ def ensure_on_device( stacklevel=2, ) - arrays[i] = jax.device_put(arrays[i]) elif not isinstance( array, - (DeviceArray, scico.blockarray.BlockArray, ShardedDeviceArray), + (DeviceArray, BlockArray, ShardedDeviceArray), ): raise TypeError( "Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or " f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}." ) + arrays[i] = jax.device_put(arrays[i]) + if len(arrays) == 1: return arrays[0] return arrays -def no_nan_divide( - x: Union[scico.blockarray.BlockArray, JaxArray], y: Union[scico.blockarray.BlockArray, JaxArray] -) -> Union[scico.blockarray.BlockArray, JaxArray]: - """Return `x/y`, with 0 instead of NaN where `y` is 0. - - Args: - x: Numerator. - y: Denominator. - - Returns: - `x / y` with 0 wherever `y == 0`. - """ - - return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) - - def parse_axes( axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None ) -> List[int]: @@ -192,6 +171,35 @@ def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: return tuple(filter(lambda x: x is not None, idx_shape)) +def no_nan_divide( + x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] +) -> Union[BlockArray, JaxArray]: + """Return `x/y`, with 0 instead of NaN where `y` is 0. + + Args: + x: Numerator. + y: Denominator. + + Returns: + `x / y` with 0 wherever `y == 0`. + """ + + return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) + + +def shape_to_size(shape: Union[Shape, BlockShape]) -> Axes: + r"""Compute the size corresponding to a (possibly nested) shape. + + Args: + shape: A shape tuple; possibly tuples. + """ + + if is_nested(shape): + return sum(prod(s) for s in shape) + + return prod(shape) + + def is_nested(x: Any) -> bool: """Check if input is a list/tuple containing at least one list/tuple. @@ -199,8 +207,7 @@ def is_nested(x: Any) -> bool: x: Object to be tested. Returns: - ``True`` if `x` is a list/tuple of list/tuples, otherwise - ``False``. + ``True`` if `x` is a list/tuple containing at least one list/tuple, ``False`` otherwise. Example: @@ -212,17 +219,15 @@ def is_nested(x: Any) -> bool: True """ - if isinstance(x, (list, tuple)): - return any([isinstance(_, (list, tuple)) for _ in x]) - return False + return isinstance(x, (list, tuple)) and any([isinstance(_, (list, tuple)) for _ in x]) def is_real_dtype(dtype: DType) -> bool: """Determine whether a dtype is real. Args: - dtype: A numpy or scico.numpy dtype (e.g. ``np.float32``, - ``np.complex64``). + dtype: A numpy or scico.numpy dtype (e.g. np.float32, + snp.complex64). Returns: ``False`` if the dtype is complex, otherwise ``True``. @@ -247,8 +252,8 @@ def real_dtype(dtype: DType) -> DType: """Construct the corresponding real dtype for a given complex dtype. Construct the corresponding real dtype for a given complex dtype, - e.g. the real dtype corresponding to ``np.complex64`` is - ``np.float32``. + e.g. the real dtype corresponding to `np.complex64` is + `np.float32`. Args: dtype: A complex numpy or scico.numpy dtype (e.g. ``np.complex64``, diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index f1e54565d..77a0be36c 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -13,9 +13,9 @@ from jax.scipy.signal import convolve from scico._generic_operators import LinearOperator, Operator -from scico.array import is_nested -from scico.blockarray import BlockArray from scico.linop import Convolve, ConvolveByX +from scico.numpy import BlockArray +from scico.numpy.util import is_nested from scico.typing import BlockShape, DType, JaxArray @@ -26,7 +26,7 @@ class BiConvolve(Operator): blocks of equal ndims, and convolves the first block with the second. If `A` is a BiConvolve operator, then - `A(BlockArray.array([x, h]))` equals `jax.scipy.signal.convolve(x, h)`. + `A(snp.blockarray([x, h]))` equals `jax.scipy.signal.convolve(x, h)`. """ diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index 77261d91d..2be8fae21 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -14,12 +14,12 @@ from typing import Callable, List, Optional, Union import scico.numpy as snp -from scico.array import ensure_on_device -from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import LinearOperator +from scico.numpy import BlockArray from scico.numpy.linalg import norm +from scico.numpy.util import ensure_on_device from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 23aaa609c..eb0e27299 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -14,12 +14,12 @@ from typing import Callable, Optional, Union import scico.numpy as snp -from scico.array import ensure_on_device -from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import LinearOperator +from scico.numpy import BlockArray from scico.numpy.linalg import norm +from scico.numpy.util import ensure_on_device from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index f9d70ae07..369450d49 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -18,13 +18,13 @@ from jax.scipy.sparse.linalg import cg as jax_cg import scico.numpy as snp -from scico.array import ensure_on_device, is_real_dtype -from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import CircularConvolve, Identity, LinearOperator from scico.loss import SquaredL2Loss +from scico.numpy import BlockArray from scico.numpy.linalg import norm +from scico.numpy.util import ensure_on_device, is_real_dtype from scico.solver import cg as scico_cg from scico.solver import minimize from scico.typing import JaxArray diff --git a/scico/optimize/pgm.py b/scico/optimize/pgm.py index 2fe8fcc4d..318dce827 100644 --- a/scico/optimize/pgm.py +++ b/scico/optimize/pgm.py @@ -16,11 +16,11 @@ import jax import scico.numpy as snp -from scico.array import ensure_on_device -from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.loss import Loss +from scico.numpy import BlockArray +from scico.numpy.util import ensure_on_device from scico.typing import JaxArray from scico.util import Timer diff --git a/scico/random.py b/scico/random.py index ee45e46d5..700ed82f8 100644 --- a/scico/random.py +++ b/scico/random.py @@ -44,13 +44,12 @@ :: x, key = scico.random.randn( ((1, 1), (2,)), key=key) - print(x) # scico.blockarray.BlockArray: + print(x) # scico.numpy.BlockArray: # DeviceArray([ 1.1378784 , -1.220955 , -0.59153646], dtype=float32) """ -import functools import inspect import sys from typing import Optional, Tuple, Union @@ -59,8 +58,8 @@ import jax -from scico.array import is_nested -from scico.blockarray import BlockArray, block_sizes +from scico.numpy import BlockArray +from scico.numpy._wrappers import map_func_over_tuple_of_tuples from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape @@ -122,49 +121,8 @@ def fun_alt(*args, key=None, seed=None, **kwargs): return fun_alt -def _allow_block_shape(fun): - """ - Decorate a jax.random function so that the `shape` argument may be a BlockShape. - """ - - # use inspect to find which argument number is `shape` - shape_ind = list(inspect.signature(fun).parameters.keys()).index("shape") - - @functools.wraps(fun) - def fun_alt(*args, **kwargs): - - # get the shape argument if it was passed - if len(args) > shape_ind: - shape = args[shape_ind] - elif "shape" in kwargs: - shape = kwargs["shape"] - else: # shape was not passed, call fun as normal - return fun(*args, **kwargs) - - # if shape is not nested, call fun as normal - if not is_nested(shape): - return fun(*args, **kwargs) - # shape is nested, so make a BlockArray! - - # call the wrapped fun with an shape=(size,) - subargs = list(args) - subkwargs = kwargs.copy() - size = np.sum(block_sizes(shape)) - - if len(subargs) > shape_ind: - subargs[shape_ind] = (size,) - else: # shape must be a kwarg if not a positional arg - subkwargs["shape"] = (size,) - - result_flat = fun(*subargs, **subkwargs) - - return BlockArray.array_from_flattened(result_flat, shape) - - return fun_alt - - def _wrap(fun): - fun_wrapped = _add_seed(_allow_block_shape(fun)) + fun_wrapped = _add_seed(map_func_over_tuple_of_tuples(fun)) fun_wrapped.__module__ = __name__ # so it appears in docs return fun_wrapped diff --git a/scico/scipy/special.py b/scico/scipy/special.py index b65004cd7..75a46da6d 100644 --- a/scico/scipy/special.py +++ b/scico/scipy/special.py @@ -5,68 +5,52 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Wrapped versions of :mod:`jax.scipy.special` functions. +""":class:`scico.numpy.BlockArray`-compatible :mod:`jax.scipy.special` +functions. -This modules consists of functions from :mod:`jax.scipy.special`. Some of -these functions are wrapped to support compatibility with -:class:`scico.blockarray.BlockArray` and are documented here. The -remaining functions are imported directly from :mod:`jax.numpy`. While -they can be imported from the :mod:`scico.numpy` namespace, they are not -documented here; please consult the documentation for the source module -:mod:`jax.scipy.special`. +This modules is a wrapper for :mod:`jax.scipy.special` where some +functions have been extended to automatically map over block array +blocks as described in :class:`scico.numpy.BlockArray` """ - - -import sys - -import jax import jax.scipy.special as js -from scico.blockarray import _block_array_ufunc_wrapper -from scico.numpy._util import _attach_wrapped_func, _not_implemented - -_ufunc_functions = [ - js.betainc, - js.entr, - js.erf, - js.erfc, - js.erfinv, - js.expit, - js.gammainc, - js.gammaincc, - js.gammaln, - js.i0, - js.i0e, - js.i1, - js.i1e, - js.log_ndtr, - js.logit, - js.logsumexp, - js.multigammaln, - js.ndtr, - js.ndtri, - js.polygamma, - js.sph_harm, - js.xlog1py, - js.xlogy, - js.zeta, -] +from scico.numpy import _wrappers -_attach_wrapped_func( - _ufunc_functions, - _block_array_ufunc_wrapper, - module_name=sys.modules[__name__], - fix_mod_name=True, +# add most everything in jax.scipy.special to this module +_wrappers.add_attributes( + vars(), + js.__dict__, ) -psi = _block_array_ufunc_wrapper(js.digamma) -digamma = _block_array_ufunc_wrapper(js.digamma) - -_not_implemented_functions = [] -for name, func in jax._src.util.get_module_functions(js).items(): - if name not in globals(): - _not_implemented_functions.append((name, func)) - -_attach_wrapped_func( - _not_implemented_functions, _not_implemented, module_name=sys.modules[__name__] +# wrap select functions +functions = ( + "betainc", + "entr", + "erf", + "erfc", + "erfinv", + "expit", + "gammainc", + "gammaincc", + "gammaln", + "i0", + "i0e", + "i1", + "i1e", + "log_ndtr", + "logit", + "logsumexp", + "multigammaln", + "ndtr", + "ndtri", + "polygamma", + "sph_harm", + "xlog1py", + "xlogy", + "zeta", + "digamma", ) +_wrappers.wrap_recursively(vars(), functions, _wrappers.map_func_over_blocks) + +# clean up +del js, _wrappers diff --git a/scico/solver.py b/scico/solver.py index 5c261764d..177d1434b 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -65,7 +65,7 @@ import jax import scico.numpy as snp -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.typing import BlockShape, DType, JaxArray, Shape from scipy import optimize as spopt @@ -145,7 +145,7 @@ def _split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArr BlockArray. """ if isinstance(x, BlockArray): - return BlockArray.array([_split_real_imag(_) for _ in x]) + return snp.blockarray([_split_real_imag(_) for _ in x]) return snp.stack((snp.real(x), snp.imag(x))) @@ -163,7 +163,7 @@ def _join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArra and `x[1]` respectively. """ if isinstance(x, BlockArray): - return BlockArray.array([_join_real_imag(_) for _ in x]) + return snp.blockarray([_join_real_imag(_) for _ in x]) return x[0] + 1j * x[1] diff --git a/scico/test/functional/test_core.py b/scico/test/functional/test_core.py index 96e3bf7f9..0c8835216 100644 --- a/scico/test/functional/test_core.py +++ b/scico/test/functional/test_core.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from jax.config import config @@ -33,6 +35,60 @@ def test_prox_obj(request): return ProxTestObj(request.param) +class SeparableTestObject: + def __init__(self, dtype): + self.f = functional.L1Norm() + self.g = functional.SquaredL2Norm() + self.fg = functional.SeparableFunctional([self.f, self.g]) + + n = 4 + m = 6 + key = None + + self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval + self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval + self.vb = snp.blockarray([self.v1, self.v2]) + + +@pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128]) +def test_separable_obj(request): + return SeparableTestObject(request.param) + + +def test_separable_eval(test_separable_obj): + fv1 = test_separable_obj.f(test_separable_obj.v1) + gv2 = test_separable_obj.g(test_separable_obj.v2) + fgv = test_separable_obj.fg(test_separable_obj.vb) + np.testing.assert_allclose(fv1 + gv2, fgv, rtol=5e-2) + + +def test_separable_prox(test_separable_obj): + alpha = 0.1 + fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha) + gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) + fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) + out = snp.blockarray((fv1, gv2)) + snp.testing.assert_allclose(out, fgv, rtol=5e-2) + + +def test_separable_grad(test_separable_obj): + # Used to restore the warnings after the context is used + with warnings.catch_warnings(): + # Ignores warning raised by ensure_on_device + warnings.filterwarnings(action="ignore", category=UserWarning) + + # Verifies that there is a warning on f.grad and fg.grad + np.testing.assert_warns(test_separable_obj.f.grad(test_separable_obj.v1)) + np.testing.assert_warns(test_separable_obj.fg.grad(test_separable_obj.vb)) + + # Tests the separable grad with warnings being supressed + fv1 = test_separable_obj.f.grad(test_separable_obj.v1) + gv2 = test_separable_obj.g.grad(test_separable_obj.v2) + fgv = test_separable_obj.fg.grad(test_separable_obj.vb) + out = snp.blockarray((fv1, gv2)) + snp.testing.assert_allclose(out, fgv, rtol=5e-2) + + class TestNormProx: alphalist = [1e-2, 1e-1, 1e0, 1e1] @@ -73,9 +129,10 @@ def test_prox_blockarray(self, norm, alpha, test_prox_obj): nrmobj = norm() nrm = nrmobj.__call__ prx = nrmobj.prox - pf = nrmobj.prox(test_prox_obj.vb.ravel(), alpha) + pf = nrmobj.prox(snp.concatenate(snp.ravel(test_prox_obj.vb)), alpha) pf_b = nrmobj.prox(test_prox_obj.vb, alpha) - np.testing.assert_allclose(pf, pf_b.ravel()) + + snp.testing.assert_allclose(pf, snp.concatenate(snp.ravel(pf_b)), rtol=1e-6) @pytest.mark.parametrize("norm", normlist) def test_prox_zeros(self, norm, test_prox_obj): @@ -147,7 +204,7 @@ def test_eval(self, cls, test_prox_obj): x = func(test_prox_obj.vb) y = func(test_prox_obj.vb.ravel()) - np.testing.assert_allclose(x, y) + np.testing.assert_allclose(x, y, rtol=1e-6) # only check double precision on projections diff --git a/scico/test/functional/test_loss.py b/scico/test/functional/test_loss.py index c606fb285..420872aa3 100644 --- a/scico/test/functional/test_loss.py +++ b/scico/test/functional/test_loss.py @@ -7,12 +7,11 @@ # enable 64-bit mode for output dtype checks config.update("jax_enable_x64", True) - from prox import prox_test import scico.numpy as snp from scico import functional, linop, loss -from scico.array import complex_dtype +from scico.numpy.util import complex_dtype from scico.random import randn, uniform diff --git a/scico/test/functional/test_separable.py b/scico/test/functional/test_separable.py index 8160f8ee2..0d4473ce0 100644 --- a/scico/test/functional/test_separable.py +++ b/scico/test/functional/test_separable.py @@ -10,7 +10,8 @@ import pytest from scico import functional -from scico.blockarray import BlockArray +from scico.numpy import blockarray +from scico.numpy.testing import assert_allclose from scico.random import randn @@ -26,7 +27,7 @@ def __init__(self, dtype): self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval - self.vb = BlockArray.array([self.v1, self.v2]) + self.vb = blockarray([self.v1, self.v2]) @pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128]) @@ -38,7 +39,7 @@ def test_separable_eval(test_separable_obj): fv1 = test_separable_obj.f(test_separable_obj.v1) gv2 = test_separable_obj.g(test_separable_obj.v2) fgv = test_separable_obj.fg(test_separable_obj.vb) - np.testing.assert_allclose(fv1 + gv2, fgv, rtol=5e-2) + assert_allclose(fv1 + gv2, fgv, rtol=5e-2) def test_separable_prox(test_separable_obj): @@ -46,8 +47,8 @@ def test_separable_prox(test_separable_obj): fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha) gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha) fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha) - out = BlockArray.array((fv1, gv2)).ravel() - np.testing.assert_allclose(out, fgv.ravel(), rtol=5e-2) + out = blockarray((fv1, gv2)).ravel() + assert_allclose(out, fgv.ravel(), rtol=5e-2) def test_separable_grad(test_separable_obj): @@ -64,5 +65,5 @@ def test_separable_grad(test_separable_obj): fv1 = test_separable_obj.f.grad(test_separable_obj.v1) gv2 = test_separable_obj.g.grad(test_separable_obj.v2) fgv = test_separable_obj.fg.grad(test_separable_obj.vb) - out = BlockArray.array((fv1, gv2)).ravel() - np.testing.assert_allclose(out, fgv.ravel(), rtol=5e-2) + out = blockarray((fv1, gv2)).ravel() + assert_allclose(out, fgv.ravel(), rtol=5e-2) diff --git a/scico/test/linop/test_abel.py b/scico/test/linop/test_abel.py index 8c63019f7..ca5323af3 100644 --- a/scico/test/linop/test_abel.py +++ b/scico/test/linop/test_abel.py @@ -27,6 +27,7 @@ def test_inverse(Nx, Ny): Ax = A @ im im_hat = A.inverse(Ax) + np.testing.assert_allclose(im_hat, im, rtol=5e-5) diff --git a/scico/test/linop/test_diff.py b/scico/test/linop/test_diff.py index 6af05612b..94ac3d605 100644 --- a/scico/test/linop/test_diff.py +++ b/scico/test/linop/test_diff.py @@ -3,62 +3,39 @@ import pytest import scico.numpy as snp -from scico.blockarray import BlockArray from scico.linop import FiniteDifference from scico.random import randn from scico.test.linop.test_linop import adjoint_test +def test_eval(): + with pytest.raises(ValueError): # axis 3 does not exist + A = FiniteDifference(input_shape=(3, 4, 5), axes=(0, 3)) + + A = FiniteDifference(input_shape=(2, 3), append=0.0) + + x = snp.array([[1, 0, 1], [1, 1, 0]], dtype=snp.float32) + + Ax = A @ x + + snp.testing.assert_allclose( + Ax[0], # down columns x[1] - x[0], ..., append - x[N-1] + snp.array([[0, 1, -1], [-1, -1, 0]]), + ) + snp.testing.assert_allclose(Ax[1], snp.array([[-1, 1, -1], [0, -1, 0]])) # along rows + + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("axes", [0, 1, (0,), (1,), None]) @pytest.mark.parametrize("jit", [False, True]) -@pytest.mark.parametrize("append", [None, 0.0]) -def test_eval(input_shape, input_dtype, axes, jit, append): - +def test_adjoint(input_shape, input_dtype, axes, jit): ndim = len(input_shape) - x, _ = randn(input_shape, dtype=input_dtype) - if axes in [1, (1,)] and ndim == 1: - with pytest.raises(ValueError): - A = FiniteDifference( - input_shape=input_shape, input_dtype=input_dtype, axes=axes, append=append - ) - else: - A = FiniteDifference( - input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit, append=append - ) - Ax = A @ x - assert A.input_dtype == input_dtype - - # construct expected output - if axes is None: - if ndim == 1: - y = snp.diff(x, append=append) - else: - y = BlockArray.array( - [snp.diff(x, axis=0, append=append), snp.diff(x, axis=1, append=append)] - ) - elif np.isscalar(axes): - y = snp.diff(x, axis=axes, append=append) - elif len(axes) == 1: - y = snp.diff(x, axis=axes[0], append=append) - - np.testing.assert_allclose(Ax.ravel(), y.ravel(), rtol=1e-4) - - @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) - @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) - @pytest.mark.parametrize("axes", [0, 1, (0,), (1,), None]) - @pytest.mark.parametrize("jit", [False, True]) - def test_adjoint(self, input_shape, input_dtype, axes, jit): - ndim = len(input_shape) - if axes in [1, (1,)] and ndim == 1: - pass - else: - A = FiniteDifference( - input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit - ) - adjoint_test(A) + return + + A = FiniteDifference(input_shape=input_shape, input_dtype=input_dtype, axes=axes, jit=jit) + adjoint_test(A) @pytest.mark.parametrize( diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index f556bcd71..2f524028a 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -14,7 +14,7 @@ import scico.numpy as snp from scico import linop -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.random import randn from scico.typing import JaxArray, PRNGKey @@ -328,7 +328,7 @@ def test_eval(self, input_shape, diagonal_dtype): D = linop.Diagonal(diagonal=diagonal) assert (D @ x).shape == D.output_shape - np.testing.assert_allclose((diagonal * x).ravel(), (D @ x).ravel(), rtol=1e-5) + snp.testing.assert_allclose((diagonal * x), (D @ x), rtol=1e-5) @pytest.mark.parametrize("diagonal_dtype", [np.float32, np.complex64]) def test_eval_broadcasting(self, diagonal_dtype): @@ -341,10 +341,10 @@ def test_eval_broadcasting(self, diagonal_dtype): # blockarray broadcast diagonal, key = randn(((3, 1, 4), (5, 5)), dtype=diagonal_dtype, key=self.key) - x, key = randn(((5, 1), 1), dtype=diagonal_dtype, key=key) + x, key = randn(((5, 1), (1,)), dtype=diagonal_dtype, key=key) D = linop.Diagonal(diagonal, x.shape) assert (D @ x).shape == ((3, 5, 4), (5, 5)) - np.testing.assert_allclose((diagonal * x).ravel(), (D @ x).ravel(), rtol=1e-5) + snp.testing.assert_allclose((diagonal * x), (D @ x), rtol=1e-5) # blockarray x array -> error diagonal, key = randn(((3, 1, 4), (5, 5)), dtype=diagonal_dtype, key=self.key) @@ -354,7 +354,7 @@ def test_eval_broadcasting(self, diagonal_dtype): # array x blockarray -> error diagonal, key = randn((3, 1, 4), dtype=diagonal_dtype, key=self.key) - x, key = randn(((5, 1), 1), dtype=diagonal_dtype, key=key) + x, key = randn(((5, 1), (1,)), dtype=diagonal_dtype, key=key) with pytest.raises(ValueError): D = linop.Diagonal(diagonal, x.shape) @@ -386,7 +386,7 @@ def test_binary_op(self, input_shape1, input_shape2, diagonal_dtype, operator): a = operator(D1, D2) @ x Dnew = linop.Diagonal(operator(diagonal1, diagonal2)) b = Dnew @ x - np.testing.assert_allclose(a.ravel(), b.ravel(), rtol=1e-5) + snp.testing.assert_allclose(a, b, rtol=1e-5) @pytest.mark.parametrize("operator", [op.add, op.sub]) def test_binary_op_mismatch(self, operator): @@ -545,15 +545,14 @@ def test_slice_adj(slicetestobj, idx): block_slice_examples = [ 1, - np.s_[1, :-3], - np.s_[1, :, :3], - np.s_[1, ..., 2:], + np.s_[0:1], + np.s_[:1], ] @pytest.mark.parametrize("idx", block_slice_examples) def test_slice_blockarray(idx): - x = BlockArray.array((snp.zeros((3, 4)), snp.ones((3, 4, 5, 6)))) + x = BlockArray((snp.zeros((3, 4)), snp.ones((3, 4, 5, 6)))) A = linop.Slice(idx=idx, input_shape=x.shape, input_dtype=x.dtype) assert (A @ x).shape == x[idx].shape diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index 17f67bfb5..624a008b2 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -274,7 +274,9 @@ def test_prox_cg( ) y = A @ im - A_colsum = A.H @ snp.ones(y.shape) # backproject ones to get sum over cols of A + A_colsum = A.H @ snp.ones( + y.shape, dtype=snp.float32 + ) # backproject ones to get sum over cols of A if is_masked: mask = np.asarray(A_colsum) > 0 # cols of A which are not all zeros else: diff --git a/scico/test/optimize/test_ladmm.py b/scico/test/optimize/test_ladmm.py index b0f05328c..5b7071b6c 100644 --- a/scico/test/optimize/test_ladmm.py +++ b/scico/test/optimize/test_ladmm.py @@ -4,7 +4,7 @@ import scico.numpy as snp from scico import functional, linop, loss, random -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.optimize import LinearizedADMM @@ -69,7 +69,7 @@ def callback(obj): class TestBlockArray: def setup_method(self, method): np.random.seed(12345) - self.y = BlockArray.array( + self.y = snp.blockarray( ( np.random.randn(32, 33).astype(np.float32), np.random.randn( diff --git a/scico/test/optimize/test_pdhg.py b/scico/test/optimize/test_pdhg.py index 4083ad5d2..45746e3de 100644 --- a/scico/test/optimize/test_pdhg.py +++ b/scico/test/optimize/test_pdhg.py @@ -4,7 +4,7 @@ import scico.numpy as snp from scico import functional, linop, loss, random -from scico.blockarray import BlockArray +from scico.numpy import BlockArray from scico.optimize import PDHG @@ -69,7 +69,7 @@ def callback(obj): class TestBlockArray: def setup_method(self, method): np.random.seed(12345) - self.y = BlockArray.array( + self.y = snp.blockarray( ( np.random.randn(32, 33).astype(np.float32), np.random.randn( diff --git a/scico/test/test_array.py b/scico/test/test_array.py index b838ce0c0..e37497f8c 100644 --- a/scico/test/test_array.py +++ b/scico/test/test_array.py @@ -7,7 +7,8 @@ import pytest import scico.numpy as snp -from scico.array import ( +from scico.numpy import BlockArray +from scico.numpy.util import ( complex_dtype, ensure_on_device, indexed_shape, @@ -19,7 +20,6 @@ real_dtype, slice_length, ) -from scico.blockarray import BlockArray from scico.random import randn @@ -31,7 +31,7 @@ def test_ensure_on_device(): NP = np.ones(2) SNP = snp.ones(2) - BA = BlockArray.array([NP, SNP]) + BA = snp.blockarray([NP, SNP]) NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA) assert isinstance(NP_, DeviceArray) @@ -40,7 +40,9 @@ def test_ensure_on_device(): assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer() assert isinstance(BA_, BlockArray) - assert BA._data.unsafe_buffer_pointer() == BA_._data.unsafe_buffer_pointer() + assert isinstance(BA_[0], DeviceArray) + assert isinstance(BA_[1], DeviceArray) + assert BA[1].unsafe_buffer_pointer() == BA_[1].unsafe_buffer_pointer() np.testing.assert_raises(TypeError, ensure_on_device, [1, 1, 1]) @@ -64,7 +66,7 @@ def test_no_nan_divide_blockarray(): x, key = randn(((3, 3), (4,)), dtype=np.float32) y, key = randn(x.shape, dtype=np.float32, key=key) - y = y.at[1].set(0 * y[1]) + y[1] = y[1].at[:].set(0 * y[1]) res = no_nan_divide(x, y) diff --git a/scico/test/test_biconvolve.py b/scico/test/test_biconvolve.py index 328e10f56..d35761293 100644 --- a/scico/test/test_biconvolve.py +++ b/scico/test/test_biconvolve.py @@ -5,8 +5,8 @@ import pytest -from scico.blockarray import BlockArray from scico.linop import Convolve, ConvolveByX +from scico.numpy import blockarray from scico.operator.biconvolve import BiConvolve from scico.random import randn @@ -22,7 +22,7 @@ def test_eval(self, input_dtype, mode, jit): x, key = randn((32, 32), dtype=input_dtype, key=self.key) h, key = randn((4, 4), dtype=input_dtype, key=self.key) - x_h = BlockArray.array([x, h]) + x_h = blockarray([x, h]) A = BiConvolve(input_shape=x_h.shape, mode=mode, jit=jit) signal_out = signal.convolve(x, h, mode=mode) diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index d5aa30858..4b85476a4 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -4,13 +4,12 @@ import numpy as np import jax -import jax.numpy as jnp -from jax.interpreters.xla import DeviceArray import pytest -import scico.blockarray as ba import scico.numpy as snp +from scico.numpy import BlockArray +from scico.numpy.testing import assert_array_equal from scico.random import randn math_ops = [op.add, op.sub, op.mul, op.truediv, op.pow] # op.floordiv doesn't work on complex @@ -27,25 +26,25 @@ def __init__(self, dtype): self.a0, key = randn(shape=(2, 3), dtype=dtype, key=key) self.a1, key = randn(shape=(2, 3, 4), dtype=dtype, key=key) - self.a = ba.BlockArray.array((self.a0, self.a1), dtype=dtype) + self.a = BlockArray((self.a0, self.a1)) self.b0, key = randn(shape=(2, 3), dtype=dtype, key=key) self.b1, key = randn(shape=(2, 3, 4), dtype=dtype, key=key) - self.b = ba.BlockArray.array((self.b0, self.b1), dtype=dtype) + self.b = BlockArray((self.b0, self.b1)) self.d0, key = randn(shape=(3, 2), dtype=dtype, key=key) self.d1, key = randn(shape=(2, 4, 3), dtype=dtype, key=key) - self.d = ba.BlockArray.array((self.d0, self.d1), dtype=dtype) + self.d = BlockArray((self.d0, self.d1)) c0, key = randn(shape=(2, 3), dtype=dtype, key=key) - self.c = ba.BlockArray.array((c0,), dtype=dtype) + self.c = BlockArray((c0,)) # A flat device array with same size as self.a & self.b - self.flat_da, key = randn(shape=(self.a.size,), dtype=dtype, key=key) + self.flat_da, key = randn(shape=self.a.size, dtype=dtype, key=key) self.flat_nd = np.array(self.flat_da) # A device array with length == self.a.num_blocks - self.block_da, key = randn(shape=(self.a.num_blocks,), dtype=dtype, key=key) + self.block_da, key = randn(shape=(len(self.a),), dtype=dtype, key=key) # block_da but as a numpy array self.block_nd = np.array(self.block_da) @@ -63,79 +62,18 @@ def test_operator_obj(request): def test_operator_left(test_operator_obj, operator): scalar = test_operator_obj.scalar a = test_operator_obj.a - x = operator(scalar, a).ravel() - y = operator(scalar, a.ravel()) - np.testing.assert_allclose(x, y, rtol=1e-6) + x = operator(scalar, a) + y = BlockArray(operator(scalar, a_i) for a_i in a) + snp.testing.assert_allclose(x, y) @pytest.mark.parametrize("operator", math_ops + comp_ops) def test_operator_right(test_operator_obj, operator): scalar = test_operator_obj.scalar a = test_operator_obj.a - x = operator(a, scalar).ravel() - y = operator(a.ravel(), scalar) - np.testing.assert_allclose(x, y) - - -# Operations between a blockarray and a flat DeviceArray -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_ba_da_left(test_operator_obj, operator): - flat_da = test_operator_obj.flat_da - a = test_operator_obj.a - x = operator(flat_da, a).ravel() - y = operator(flat_da, a.ravel()) - np.testing.assert_allclose(x, y, rtol=5e-5) - - -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_ba_da_right(test_operator_obj, operator): - flat_da = test_operator_obj.flat_da - a = test_operator_obj.a - x = operator(a, flat_da).ravel() - y = operator(a.ravel(), flat_da) - np.testing.assert_allclose(x, y) - - -# Blockwise comparison between a BlockArray and Ndarray -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_ndarray_left(test_operator_obj, operator): - a = test_operator_obj.a - block_nd = test_operator_obj.block_nd - - x = operator(a, block_nd).ravel() - y = ba.BlockArray.array([operator(a[i], block_nd[i]) for i in range(a.num_blocks)]).ravel() - np.testing.assert_allclose(x, y) - - -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_ndarray_right(test_operator_obj, operator): - a = test_operator_obj.a - block_nd = test_operator_obj.block_nd - - x = operator(block_nd, a).ravel() - y = ba.BlockArray.array([operator(block_nd[i], a[i]) for i in range(a.num_blocks)]).ravel() - np.testing.assert_allclose(x, y, rtol=1e-6) - - -# Blockwise comparison between a BlockArray and DeviceArray -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_devicearray_left(test_operator_obj, operator): - a = test_operator_obj.a - block_da = test_operator_obj.block_da - - x = operator(a, block_da).ravel() - y = ba.BlockArray.array([operator(a[i], block_da[i]) for i in range(a.num_blocks)]).ravel() - np.testing.assert_allclose(x, y) - - -@pytest.mark.parametrize("operator", math_ops + comp_ops) -def test_devicearray_right(test_operator_obj, operator): - a = test_operator_obj.a - block_da = test_operator_obj.block_da - - x = operator(block_da, a).ravel() - y = ba.BlockArray.array([operator(block_da[i], a[i]) for i in range(a.num_blocks)]).ravel() - np.testing.assert_allclose(x, y, atol=1e-7, rtol=0) + x = operator(a, scalar) + y = BlockArray(operator(a_i, scalar) for a_i in a) + snp.testing.assert_allclose(x, y) # Operations between two blockarrays of same size @@ -143,9 +81,9 @@ def test_devicearray_right(test_operator_obj, operator): def test_ba_ba_operator(test_operator_obj, operator): a = test_operator_obj.a b = test_operator_obj.b - x = operator(a, b).ravel() - y = operator(a.ravel(), b.ravel()) - np.testing.assert_allclose(x, y) + x = operator(a, b) + y = BlockArray(operator(a_i, b_i) for a_i, b_i in zip(a, b)) + snp.testing.assert_allclose(x, y) # Testing the @ interface for blockarrays of same size, and a blockarray and flattened ndarray/devicearray @@ -161,9 +99,9 @@ def test_ba_ba_matmul(test_operator_obj): x = a @ b - y = ba.BlockArray.array([a0 @ d0, a1 @ d1]) + y = BlockArray([a0 @ d0, a1 @ d1]) assert x.shape == y.shape - np.testing.assert_allclose(x.ravel(), y.ravel()) + snp.testing.assert_allclose(x, y) with pytest.raises(TypeError): z = a @ c @@ -174,23 +112,21 @@ def test_conj(test_operator_obj): ac = a.conj() assert a.shape == ac.shape - np.testing.assert_allclose(a.ravel().conj(), ac.ravel()) + snp.testing.assert_allclose(BlockArray(a_i.conj() for a_i in a), ac) def test_real(test_operator_obj): a = test_operator_obj.a ac = a.real - assert a.shape == ac.shape - np.testing.assert_allclose(a.ravel().real, ac.ravel()) + snp.testing.assert_allclose(BlockArray(a_i.real for a_i in a), ac) def test_imag(test_operator_obj): a = test_operator_obj.a ac = a.imag - assert a.shape == ac.shape - np.testing.assert_allclose(a.ravel().imag, ac.ravel()) + snp.testing.assert_allclose(BlockArray(a_i.imag for a_i in a), ac) def test_ndim(test_operator_obj): @@ -204,7 +140,7 @@ def test_getitem(test_operator_obj): a1 = test_operator_obj.a1 b0 = test_operator_obj.b0 b1 = test_operator_obj.b1 - x = ba.BlockArray.array([a0, a1, b0, b1]) + x = BlockArray([a0, a1, b0, b1]) # Positive indexing np.testing.assert_allclose(x[0], a0) @@ -219,35 +155,17 @@ def test_getitem(test_operator_obj): np.testing.assert_allclose(x[-1], b1) -@pytest.mark.parametrize("index", (np.s_[0, 0], np.s_[0, 1:3], np.s_[0, :, 0:2], np.s_[0, ..., 2:])) -def test_getitem_tuple(test_operator_obj, index): - a = test_operator_obj.a - a0 = test_operator_obj.a0 - np.testing.assert_allclose(a[index], a0[index[1:]]) - - -def test_blockidx(test_operator_obj): - a = test_operator_obj.a - a0 = test_operator_obj.a0 - a1 = test_operator_obj.a1 - - # use the blockidx to index the flattened data - x0 = a.ravel()[a.blockidx(0)] - x1 = a.ravel()[a.blockidx(1)] - np.testing.assert_allclose(x0, a0.ravel()) - np.testing.assert_allclose(x1, a1.ravel()) - - def test_split(test_operator_obj): a = test_operator_obj.a - a_split = a.split - np.testing.assert_allclose(a_split[0], test_operator_obj.a0) - np.testing.assert_allclose(a_split[1], test_operator_obj.a1) + np.testing.assert_allclose(a[0], test_operator_obj.a0) + np.testing.assert_allclose(a[1], test_operator_obj.a1) def test_blockarray_from_one_array(): - with pytest.raises(TypeError): - ba.BlockArray.array(np.random.randn(32, 32)) + # BlockArray(np.jnp.zeros((3,6))) makes a block array + # with 3 length-6 blocks + x = BlockArray(np.random.randn(3, 6)) + assert len(x) == 3 @pytest.mark.parametrize("axis", [None, 1]) @@ -255,24 +173,10 @@ def test_blockarray_from_one_array(): def test_sum_method(test_operator_obj, axis, keepdims): a = test_operator_obj.a - method_result = a.sum(axis=axis, keepdims=keepdims).ravel() - snp_result = snp.sum(a, axis=axis, keepdims=keepdims).ravel() - - assert method_result.shape == snp_result.shape - np.testing.assert_allclose(method_result, snp_result) - + method_result = a.sum(axis=axis, keepdims=keepdims) + snp_result = snp.sum(a, axis=axis, keepdims=keepdims) -def test_ba_ba_vdot(test_operator_obj): - a = test_operator_obj.a - d = test_operator_obj.d - a0 = test_operator_obj.a0 - a1 = test_operator_obj.a1 - d0 = test_operator_obj.d0 - d1 = test_operator_obj.d1 - - x = snp.vdot(a, d) - y = jnp.vdot(a.ravel(), d.ravel()) - np.testing.assert_allclose(x, y) + snp.testing.assert_allclose(method_result, snp_result) @pytest.mark.parametrize("operator", [snp.dot, snp.matmul]) @@ -285,30 +189,19 @@ def test_ba_ba_dot(test_operator_obj, operator): d1 = test_operator_obj.d1 x = operator(a, d) - y = ba.BlockArray.array([operator(a0, d0), operator(a1, d1)]) - np.testing.assert_allclose(x.ravel(), y.ravel()) + y = BlockArray([operator(a0, d0), operator(a1, d1)]) + snp.testing.assert_allclose(x, y) ############################################################################### # Reduction tests ############################################################################### reduction_funcs = [ - snp.count_nonzero, snp.sum, snp.linalg.norm, - snp.mean, - snp.var, - snp.max, - snp.min, - snp.amin, - snp.amax, - snp.all, - snp.any, ] -real_reduction_funcs = [ - snp.median, -] +real_reduction_funcs = [] class BlockArrayReductionObj: @@ -322,9 +215,9 @@ def __init__(self, dtype): c0, key = randn(shape=(2, 3), dtype=dtype, key=key) c1, key = randn(shape=(3,), dtype=dtype, key=key) - self.a = ba.BlockArray.array((a0, a1), dtype=dtype) - self.b = ba.BlockArray.array((b0, b1), dtype=dtype) - self.c = ba.BlockArray.array((c0, c1), dtype=dtype) + self.a = BlockArray((a0, a1)) + self.b = BlockArray((b0, b1)) + self.c = BlockArray((c0, c1)) @pytest.fixture(scope="module") # so that random objects are cached @@ -347,88 +240,40 @@ def reduction_obj(request): def test_reduce(reduction_obj, func): x = func(reduction_obj.a) x_jit = jax.jit(func)(reduction_obj.a) - y = func(reduction_obj.a.ravel()) - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function - np.testing.assert_allclose(x, y) # test for correctness - - -@pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis0(reduction_obj, func): - f = lambda x: func(x, axis=0) - x = f(reduction_obj.b) - x_jit = jax.jit(f)(reduction_obj.b) - - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function - - # test for correctness - # stack into a (2, 3, 4) array, call func - y = func(np.stack(list(reduction_obj.b)), axis=0) - np.testing.assert_allclose(x, y) - - with pytest.raises(ValueError): - # Reduction along axis=0 only works if all blocks are same shape - func(reduction_obj.a, axis=0) + y = func(snp.concatenate(snp.ravel(reduction_obj.a))) + np.testing.assert_allclose(x, x_jit, rtol=1e-6) # test jitted function + np.testing.assert_allclose(x, y, rtol=1e-6) # test for correctness @pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis1(reduction_obj, func): - """this is _not_ duplicated from test_reduce_axis0""" - f = lambda x: func(x, axis=1).ravel() +@pytest.mark.parametrize("axis", (0, 1)) +def test_reduce_axis(reduction_obj, func, axis): + f = lambda x: func(x, axis=axis) x = f(reduction_obj.a) x_jit = jax.jit(f)(reduction_obj.a) - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function + snp.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function # test for correctness - y0 = func(reduction_obj.a[0], axis=0) - y1 = func(reduction_obj.a[1], axis=0) - y = ba.BlockArray.array((y0, y1), dtype=reduction_obj.a[0].dtype).ravel() - np.testing.assert_allclose(x, y) - - -@pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis2(reduction_obj, func): - """this is _not_ duplicated from test_reduce_axis0""" - f = lambda x: func(x, axis=2).ravel() - x = f(reduction_obj.a) - x_jit = jax.jit(f)(reduction_obj.a) - - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function - - y0 = func(reduction_obj.a[0], axis=1) - y1 = func(reduction_obj.a[1], axis=1) - y = ba.BlockArray.array((y0, y1), dtype=reduction_obj.a[0].dtype).ravel() - np.testing.assert_allclose(x, y) - - -@pytest.mark.parametrize(**REDUCTION_PARAMS) -def test_reduce_axis3(reduction_obj, func): - """this is _not_ duplicated from test_reduce_axis0""" - f = lambda x: func(x, axis=3).ravel() - x = f(reduction_obj.a) - x_jit = jax.jit(f)(reduction_obj.a) - - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function - - y0 = reduction_obj.a[0] - y1 = func(reduction_obj.a[1], axis=2) - y = ba.BlockArray.array((y0, y1), dtype=reduction_obj.a[0].dtype).ravel() - np.testing.assert_allclose(x.ravel(), y) + y0 = func(reduction_obj.a[0], axis=axis) + y1 = func(reduction_obj.a[1], axis=axis) + y = BlockArray((y0, y1)) + snp.testing.assert_allclose(x, y) @pytest.mark.parametrize(**REDUCTION_PARAMS) def test_reduce_singleton(reduction_obj, func): - # Case where a block is reduced to a singleton - f = lambda x: func(x, axis=1).ravel() + # Case where one block is reduced to a singleton + f = lambda x: func(x, axis=0) x = f(reduction_obj.c) x_jit = jax.jit(f)(reduction_obj.c) - np.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function + snp.testing.assert_allclose(x, x_jit, rtol=1e-4) # test jitted function y0 = func(reduction_obj.c[0], axis=0) y1 = func(reduction_obj.c[1], axis=0)[None] # Ensure size (1,) - y = ba.BlockArray.array((y0, y1), dtype=reduction_obj.a[0].dtype).ravel() - np.testing.assert_allclose(x, y) + y = BlockArray((y0, y1)) + snp.testing.assert_allclose(x, y) class TestCreators: @@ -441,220 +286,105 @@ def setup_method(self, method): self.size = np.prod(self.a_shape) + np.prod(self.b_shape) + np.prod(self.c_shape) def test_zeros(self): - x = ba.BlockArray.zeros(self.shape, dtype=np.float32) + x = snp.zeros(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 0) def test_empty(self): - x = ba.BlockArray.empty(self.shape, dtype=np.float32) + x = snp.empty(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 0) def test_ones(self): - x = ba.BlockArray.ones(self.shape, dtype=np.float32) + x = snp.ones(self.shape, dtype=np.float32) assert x.shape == self.shape assert snp.all(x == 1) def test_full(self): fill_value = np.float32(np.random.randn()) - x = ba.BlockArray.full(self.shape, fill_value=fill_value, dtype=np.float32) + x = snp.full(self.shape, fill_value=fill_value, dtype=np.float32) assert x.shape == self.shape assert x.dtype == np.float32 assert snp.all(x == fill_value) def test_full_nodtype(self): fill_value = np.float32(np.random.randn()) - x = ba.BlockArray.full(self.shape, fill_value=fill_value, dtype=None) + x = snp.full(self.shape, fill_value=fill_value, dtype=None) assert x.shape == self.shape assert x.dtype == fill_value.dtype assert snp.all(x == fill_value) -def test_incompatible_shapes(): - # Verify that array_from_flattened raises exception when - # len(data_ravel) != size determined by shape_tuple - shape_tuple = ((32, 32), (16,)) # len == 1040 - data_ravel = np.ones(1030) - with pytest.raises(ValueError): - ba.BlockArray.array_from_flattened(data_ravel=data_ravel, shape_tuple=shape_tuple) - - -class NestedTestObj: - operators = math_ops + comp_ops - - def __init__(self, dtype): - key = None - scalar, key = randn(shape=(1,), dtype=dtype, key=key) - self.scalar = scalar.copy().ravel()[0] # convert to float - - self.a00, key = randn(shape=(2, 2, 2), dtype=dtype, key=key) - self.a01, key = randn(shape=(3, 2, 4), dtype=dtype, key=key) - self.a1, key = randn(shape=(2, 4), dtype=dtype, key=key) - - self.a = ba.BlockArray.array(((self.a00, self.a01), self.a1)) - - -@pytest.fixture(scope="module") -def nested_obj(request): - yield NestedTestObj(request.param) - - -@pytest.mark.parametrize("nested_obj", [np.float32, np.complex64], indirect=True) -def test_nested_shape(nested_obj): - a = nested_obj.a - - a00 = nested_obj.a00 - a01 = nested_obj.a01 - a1 = nested_obj.a1 - - assert a.shape == (((2, 2, 2), (3, 2, 4)), (2, 4)) - assert a.size == 2 * 2 * 2 + 3 * 2 * 4 + 2 * 4 - - assert a[0].shape == ((2, 2, 2), (3, 2, 4)) - assert a[1].shape == (2, 4) - - np.testing.assert_allclose(a[0][0].ravel(), a00.ravel()) - np.testing.assert_allclose(a[0][1].ravel(), a01.ravel()) - np.testing.assert_allclose(a[1].ravel(), a1.ravel()) - - # basic test for block_sizes - assert ba.block_sizes(a.shape) == (a[0].size, a[1].size) - - -NESTED_REDUCTION_PARAMS = dict( - argnames="nested_obj, func", - argvalues=( - list(zip(itertools.repeat(np.float32), reduction_funcs)) - + list(zip(itertools.repeat(np.complex64), reduction_funcs)) - + list(zip(itertools.repeat(np.float32), real_reduction_funcs)) - ), - indirect=["nested_obj"], +# tests added for the BlockArray refactor +@pytest.fixture +def x(): + # any BlockArray, arbitrary shape, content, type + return BlockArray([[[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0]]) + + +@pytest.fixture +def y(): + # another BlockArray, content, type, shape matches x + return BlockArray([[[1.0, 4.0, 6.0], [1.0, 2.0, 3.0]], [-2.0]]) + + +@pytest.mark.parametrize("op", [op.neg, op.pos, op.abs]) +def test_unary(op, x): + actual = op(x) + expected = BlockArray(op(x_i) for x_i in x) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype + + +@pytest.mark.parametrize( + "op", + [ + op.mul, + op.mod, + op.lt, + op.le, + op.gt, + op.ge, + op.floordiv, + op.eq, + op.add, + op.truediv, + op.sub, + op.ne, + ], ) +def test_elementwise_binary(op, x, y): + actual = op(x, y) + expected = BlockArray(op(x_i, y_i) for x_i, y_i in zip(x, y)) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype -@pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) -def test_nested_reduce_singleton(nested_obj, func): - a = nested_obj.a - x = func(a) - y = func(a.ravel()) - np.testing.assert_allclose(x, y, rtol=5e-5) - +def test_not_implemented_binary(x): + with pytest.raises(TypeError, match=r"unsupported operand type\(s\)"): + y = x + "a string" -@pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) -def test_nested_reduce_axis1(nested_obj, func): - a = nested_obj.a - with pytest.raises(ValueError): - # Blocks don't conform! - x = func(a, axis=1) +def test_matmul(x): + # x is ((2, 3), (1,)) + # y is ((3, 1), (1, 2)) + y = BlockArray([[[1.0], [2.0], [3.0]], [[0.0, 1.0]]]) + actual = x @ y + expected = BlockArray([[[14.0], [0.0]], [0.0, 42.0]]) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype -@pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) -def test_nested_reduce_axis2(nested_obj, func): - a = nested_obj.a +def test_property(): + x = BlockArray(([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [0.0])) + actual = x.shape + expected = ((2, 3), (1,)) + assert actual == expected - x = func(a, axis=2) - assert x.shape == (((2, 2), (2, 4)), (2,)) - - y = ba.BlockArray.array((func(a[0], axis=1), func(a[1], axis=1))) - assert x.shape == y.shape - - np.testing.assert_allclose(x.ravel(), y.ravel(), rtol=5e-5) - - -@pytest.mark.parametrize(**NESTED_REDUCTION_PARAMS) -def test_nested_reduce_axis3(nested_obj, func): - a = nested_obj.a - - x = func(a, axis=3) - assert x.shape == (((2, 2), (3, 4)), (2, 4)) - - y = ba.BlockArray.array((func(a[0], axis=2), a[1])) - assert x.shape == y.shape - - np.testing.assert_allclose(x.ravel(), y.ravel(), rtol=5e-5) - - -def test_array_from_flattened(): - x = np.random.randn(19) - x_b = ba.BlockArray.array_from_flattened(x, shape_tuple=((4, 4), (3,))) - assert isinstance(x_b._data, DeviceArray) - - -class TestBlockArrayIndex: - def setup_method(self): - key = None - self.A, key = randn(shape=((4, 4), (3,)), key=key) - self.B, key = randn(shape=((3, 3), (4, 2, 3)), key=key) - self.C, key = randn(shape=((3, 3), (4, 2), (4, 4)), key=key) - - # nested - self.D, key = randn(shape=((self.A.shape, self.B.shape)), key=key) - - def test_set_block(self): - # Test assignment of an entire block - A2 = self.A.at[0].set(1) - np.testing.assert_allclose(A2[0], snp.ones_like(A2[0]), rtol=5e-5) - np.testing.assert_allclose(A2[1], A2[1], rtol=5e-5) - - D2 = self.D.at[1].set(1.45) - np.testing.assert_allclose(D2[0].ravel(), self.D[0].ravel(), rtol=5e-5) - np.testing.assert_allclose( - D2[1].ravel(), 1.45 * snp.ones_like(self.D[1]).ravel(), rtol=5e-5 - ) - - def test_set(self): - # Test assignment using (bkidx, idx) format - A2 = self.A.at[0, 2:, :-2].set(1.45) - tmp = A2[0][2:, :-2] - np.testing.assert_allclose(A2[0][2:, :-2], 1.45 * snp.ones_like(tmp), rtol=5e-5) - np.testing.assert_allclose(A2[1].ravel(), A2[1], rtol=5e-5) - - def test_add(self): - A2 = self.A.at[0, 2:, :-2].add(1.45) - tmp = np.array(self.A[0]) - tmp[2:, :-2] += 1.45 - y = ba.BlockArray.array([tmp, self.A[1]]) - np.testing.assert_allclose(A2.ravel(), y.ravel(), rtol=5e-5) - - D2 = self.D.at[1].add(1.45) - y = ba.BlockArray.array([self.D[0], self.D[1] + 1.45]) - np.testing.assert_allclose(D2.ravel(), y.ravel(), rtol=5e-5) - - def test_multiply(self): - A2 = self.A.at[0, 2:, :-2].multiply(1.45) - tmp = np.array(self.A[0]) - tmp[2:, :-2] *= 1.45 - y = ba.BlockArray.array([tmp, self.A[1]]) - np.testing.assert_allclose(A2.ravel(), y.ravel(), rtol=5e-5) - - D2 = self.D.at[1].multiply(1.45) - y = ba.BlockArray.array([self.D[0], self.D[1] * 1.45]) - np.testing.assert_allclose(D2.ravel(), y.ravel(), rtol=5e-5) - - def test_divide(self): - A2 = self.A.at[0, 2:, :-2].divide(1.45) - tmp = np.array(self.A[0]) - tmp[2:, :-2] /= 1.45 - y = ba.BlockArray.array([tmp, self.A[1]]) - np.testing.assert_allclose(A2.ravel(), y.ravel(), rtol=5e-5) - - D2 = self.D.at[1].divide(1.45) - y = ba.BlockArray.array([self.D[0], self.D[1] / 1.45]) - np.testing.assert_allclose(D2.ravel(), y.ravel(), rtol=5e-5) - - def test_power(self): - A2 = self.A.at[0, 2:, :-2].power(2) - tmp = np.array(self.A[0]) - tmp[2:, :-2] **= 2 - y = ba.BlockArray.array([tmp, self.A[1]]) - np.testing.assert_allclose(A2.ravel(), y.ravel(), rtol=5e-5) - - D2 = self.D.at[1].power(1.45) - y = ba.BlockArray.array([self.D[0], self.D[1] ** 1.45]) - np.testing.assert_allclose(D2.ravel(), y.ravel(), rtol=5e-5) - - def test_set_slice(self): - with pytest.raises(TypeError): - C2 = self.C.at[::2, 0].set(0) +def test_method(): + x = BlockArray(([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]], [42.0])) + actual = x.max() + expected = BlockArray([[3.0], [42.0]]) + assert_array_equal(actual, expected) + assert actual.dtype == expected.dtype diff --git a/scico/test/test_numpy.py b/scico/test/test_numpy.py index 495b38795..2a9ae390a 100644 --- a/scico/test/test_numpy.py +++ b/scico/test/test_numpy.py @@ -3,13 +3,7 @@ import jax from jax.interpreters.xla import DeviceArray -import pytest - import scico.numpy as snp -import scico.numpy._create as snc -import scico.numpy.linalg as sla -from scico.blockarray import BlockArray -from scico.linop import MatrixOperator def on_cpu(): @@ -37,158 +31,6 @@ def test_reshape_array(): np.testing.assert_allclose(snp.reshape(a.ravel(), (4, 4)), a) -def test_reshape_array(): - a = np.random.randn(13) - b = snp.reshape(a, ((3, 3), (4,))) - - c = BlockArray.array_from_flattened(a, ((3, 3), (4,))) - - assert isinstance(b, BlockArray) - assert b.shape == c.shape - np.testing.assert_allclose(b.ravel(), c.ravel()) - - -@pytest.mark.parametrize("compute_uv", [True, False]) -@pytest.mark.parametrize("full_matrices", [True, False]) -@pytest.mark.parametrize("shape", [(8, 8), (4, 8), (8, 4)]) -def test_svd(compute_uv, full_matrices, shape): - A = jax.device_put(np.random.randn(*shape)) - Ao = MatrixOperator(A) - f = lambda x: sla.svd(x, compute_uv=compute_uv, full_matrices=full_matrices) - check_results(f(A), f(Ao)) - - -def test_cond(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.cond - check_results(f(A), f(Ao)) - - -def test_det(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.det - check_results(f(A), f(Ao)) - - -@pytest.mark.skipif( - on_cpu() == False, reason="nonsymmetric eigendecompositions only supported on cpu" -) -def test_eig(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.eig - check_results(f(A), f(Ao)) - - -@pytest.mark.parametrize("symmetrize", [True, False]) -@pytest.mark.parametrize("UPLO", [None, "L", "U"]) -def test_eigh(UPLO, symmetrize): - A = jax.device_put(np.random.randn(8, 8)) - A = A.T @ A - Ao = MatrixOperator(A) - f = lambda x: sla.eigh(x, UPLO=UPLO, symmetrize_input=symmetrize) - check_results(f(A), f(Ao)) - - -@pytest.mark.skipif( - on_cpu() == False, reason="nonsymmetric eigendecompositions only supported on cpu" -) -def test_eigvals(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.eigvals - check_results(f(A), f(Ao)) - - -@pytest.mark.parametrize("UPLO", [None, "L", "U"]) -def test_eigvalsh(UPLO): - A = jax.device_put(np.random.randn(8, 8)) - A = A.T @ A - Ao = MatrixOperator(A) - f = lambda x: sla.eigvalsh(x, UPLO=UPLO) - check_results(f(A), f(Ao)) - - -def test_inv(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.inv - check_results(f(A), f(Ao)) - - -def test_lstsq(): - A = jax.device_put(np.random.randn(8, 8)) - b = jax.device_put(np.random.randn(8)) - Ao = MatrixOperator(A) - f = lambda A: sla.lstsq(A, b) - check_results(f(A), f(Ao)) - - -def test_matrix_power(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = lambda A: sla.matrix_power(A, 3) - check_results(f(A), f(Ao)) - - -def test_matrixrank(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = lambda A: sla.matrix_rank(A, 3) - check_results(f(A), f(Ao)) - - -@pytest.mark.parametrize("rcond", [None, 1e-3]) -def test_pinv(rcond): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.pinv - check_results(f(A), f(Ao)) - - -@pytest.mark.parametrize("rcond", [None, 1e-3]) -def test_pinv(rcond): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.pinv - check_results(f(A), f(Ao)) - - -@pytest.mark.parametrize("shape", [(8, 8), (4, 8), (8, 4)]) -@pytest.mark.parametrize("mode", ["reduced", "complete", "r"]) -def test_qr(shape, mode): - A = jax.device_put(np.random.randn(*shape)) - Ao = MatrixOperator(A) - f = lambda A: sla.qr(A, mode) - check_results(f(A), f(Ao)) - - -def test_slogdet(): - A = jax.device_put(np.random.randn(8, 8)) - Ao = MatrixOperator(A) - f = sla.slogdet - check_results(f(A), f(Ao)) - - -def test_solve(): - A = jax.device_put(np.random.randn(8, 8)) - b = jax.device_put(np.random.randn(8)) - Ao = MatrixOperator(A) - f = lambda A: sla.solve(A, b) - check_results(f(A), f(Ao)) - - -def test_multi_dot(): - A = jax.device_put(np.random.randn(8, 8)) - B = jax.device_put(np.random.randn(8, 4)) - Ao = MatrixOperator(A) - Bo = MatrixOperator(B) - f = sla.multi_dot - check_results(f([A, B]), f([Ao, Bo])) - - def test_ufunc_abs(): A = snp.array([-1, 2, 5]) res = snp.array([1, 2, 5]) @@ -198,16 +40,16 @@ def test_ufunc_abs(): res = snp.array([1, 1, 1]) np.testing.assert_allclose(snp.abs(A), res) - Ba = BlockArray.array((snp.array([-1, 2, 5]),)) - res = BlockArray.array((snp.array([1, 2, 5]),)) + Ba = snp.blockarray((snp.array([-1, 2, 5]),)) + res = snp.blockarray((snp.array([1, 2, 5]),)) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) - Ba = BlockArray.array((snp.array([-1, -1, -1]),)) - res = BlockArray.array((snp.array([1, 1, 1]),)) + Ba = snp.blockarray((snp.array([-1, -1, -1]),)) + res = snp.blockarray((snp.array([1, 1, 1]),)) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) - Ba = BlockArray.array((snp.array([-1, 2, -3]), snp.array([1, -2, 3]))) - res = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2, 3]))) + Ba = snp.blockarray((snp.array([-1, 2, -3]), snp.array([1, -2, 3]))) + res = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2, 3]))) np.testing.assert_allclose(snp.abs(Ba).ravel(), res.ravel()) @@ -240,21 +82,19 @@ def test_ufunc_maximum(): B = snp.array([2, 3, 4]) C = snp.array([5, 6]) D = snp.array([2, 7]) - Ba = BlockArray.array((A, C)) - Bb = BlockArray.array((B, D)) - res = BlockArray.array((snp.array([2, 3, 4]), snp.array([5, 7]))) + Ba = snp.blockarray((A, C)) + Bb = snp.blockarray((B, D)) + res = snp.blockarray((snp.array([2, 3, 4]), snp.array([5, 7]))) Bmax = snp.maximum(Ba, Bb) - assert Bmax.shape == res.shape - np.testing.assert_allclose(Bmax.ravel(), res.ravel()) + snp.testing.assert_allclose(Bmax, res) A = snp.array([1, 6, 3]) B = snp.array([6, 3, 8]) C = 5 - Ba = BlockArray.array((A, B)) - res = BlockArray.array((snp.array([5, 6, 5]), snp.array([6, 5, 8]))) + Ba = snp.blockarray((A, B)) + res = snp.blockarray((snp.array([5, 6, 5]), snp.array([6, 5, 8]))) Bmax = snp.maximum(Ba, C) - assert Bmax.shape == res.shape - np.testing.assert_allclose(Bmax.ravel(), res.ravel()) + snp.testing.assert_allclose(Bmax, res) def test_ufunc_sign(): @@ -262,13 +102,13 @@ def test_ufunc_sign(): res = snp.array([1, -1, 0]) np.testing.assert_allclose(snp.sign(A), res) - Ba = BlockArray.array((snp.array([10, -5, 0]),)) - res = BlockArray.array((snp.array([1, -1, 0]),)) - np.testing.assert_allclose(snp.sign(Ba).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([10, -5, 0]),)) + res = snp.blockarray((snp.array([1, -1, 0]),)) + snp.testing.assert_allclose(snp.sign(Ba), res) - Ba = BlockArray.array((snp.array([10, -5, 0]), snp.array([0, 5, -6]))) - res = BlockArray.array((snp.array([1, -1, 0]), snp.array([0, 1, -1]))) - np.testing.assert_allclose(snp.sign(Ba).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([10, -5, 0]), snp.array([0, 5, -6]))) + res = snp.blockarray((snp.array([1, -1, 0]), snp.array([0, 1, -1]))) + snp.testing.assert_allclose(snp.sign(Ba), res) def test_ufunc_where(): @@ -278,19 +118,19 @@ def test_ufunc_where(): res = snp.array([-1, -1, 4, 5]) np.testing.assert_allclose(snp.where(cond, A, B), res) - Ba = BlockArray.array((snp.array([1, 2, 4, 5]),)) - Bb = BlockArray.array((snp.array([-1, -1, -1, -1]),)) - Bcond = BlockArray.array((snp.array([False, False, True, True]),)) - Bres = BlockArray.array((snp.array([-1, -1, 4, 5]),)) + Ba = snp.blockarray((snp.array([1, 2, 4, 5]),)) + Bb = snp.blockarray((snp.array([-1, -1, -1, -1]),)) + Bcond = snp.blockarray((snp.array([False, False, True, True]),)) + Bres = snp.blockarray((snp.array([-1, -1, 4, 5]),)) assert snp.where(Bcond, Ba, Bb).shape == Bres.shape np.testing.assert_allclose(snp.where(Bcond, Ba, Bb).ravel(), Bres.ravel()) - Ba = BlockArray.array((snp.array([1, 2, 4, 5]), snp.array([1, 2, 4, 5]))) - Bb = BlockArray.array((snp.array([-1, -1, -1, -1]), snp.array([-1, -1, -1, -1]))) - Bcond = BlockArray.array( + Ba = snp.blockarray((snp.array([1, 2, 4, 5]), snp.array([1, 2, 4, 5]))) + Bb = snp.blockarray((snp.array([-1, -1, -1, -1]), snp.array([-1, -1, -1, -1]))) + Bcond = snp.blockarray( (snp.array([False, False, True, True]), snp.array([True, True, False, False])) ) - Bres = BlockArray.array((snp.array([-1, -1, 4, 5]), snp.array([1, 2, -1, -1]))) + Bres = snp.blockarray((snp.array([-1, -1, 4, 5]), snp.array([1, 2, -1, -1]))) assert snp.where(Bcond, Ba, Bb).shape == Bres.shape np.testing.assert_allclose(snp.where(Bcond, Ba, Bb).ravel(), Bres.ravel()) @@ -306,20 +146,20 @@ def test_ufunc_true_divide(): res = snp.array([0.33333333, 0.66666667, 1.0]) np.testing.assert_allclose(snp.true_divide(A, B), res) - Ba = BlockArray.array((snp.array([1, 2, 3]),)) - Bb = BlockArray.array((snp.array([3, 3, 3]),)) - res = BlockArray.array((snp.array([0.33333333, 0.66666667, 1.0]),)) - np.testing.assert_allclose(snp.true_divide(Ba, Bb).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1, 2, 3]),)) + Bb = snp.blockarray((snp.array([3, 3, 3]),)) + res = snp.blockarray((snp.array([0.33333333, 0.66666667, 1.0]),)) + snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res) - Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) - Bb = BlockArray.array((snp.array([3, 3, 3]), snp.array([2, 2]))) - res = BlockArray.array((snp.array([0.33333333, 0.66666667, 1.0]), snp.array([0.5, 1.0]))) - np.testing.assert_allclose(snp.true_divide(Ba, Bb).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) + Bb = snp.blockarray((snp.array([3, 3, 3]), snp.array([2, 2]))) + res = snp.blockarray((snp.array([0.33333333, 0.66666667, 1.0]), snp.array([0.5, 1.0]))) + snp.testing.assert_allclose(snp.true_divide(Ba, Bb), res) - Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) + Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) A = 2 - res = BlockArray.array((snp.array([0.5, 1.0, 1.5]), snp.array([0.5, 1.0]))) - np.testing.assert_allclose(snp.true_divide(Ba, A).ravel(), res.ravel()) + res = snp.blockarray((snp.array([0.5, 1.0, 1.5]), snp.array([0.5, 1.0]))) + snp.testing.assert_allclose(snp.true_divide(Ba, A), res) def test_ufunc_floor_divide(): @@ -333,20 +173,20 @@ def test_ufunc_floor_divide(): res = snp.array([1.0, 0, 1.0]) np.testing.assert_allclose(snp.floor_divide(A, B), res) - Ba = BlockArray.array((snp.array([1, 2, 3]),)) - Bb = BlockArray.array((snp.array([3, 3, 3]),)) - res = BlockArray.array((snp.array([0, 0, 1.0]),)) - np.testing.assert_allclose(snp.floor_divide(Ba, Bb).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1, 2, 3]),)) + Bb = snp.blockarray((snp.array([3, 3, 3]),)) + res = snp.blockarray((snp.array([0, 0, 1.0]),)) + snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res) - Ba = BlockArray.array((snp.array([1, 7, 3]), snp.array([1, 2]))) - Bb = BlockArray.array((snp.array([3, 3, 3]), snp.array([2, 2]))) - res = BlockArray.array((snp.array([0, 2, 1.0]), snp.array([0, 1.0]))) - np.testing.assert_allclose(snp.floor_divide(Ba, Bb).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1, 7, 3]), snp.array([1, 2]))) + Bb = snp.blockarray((snp.array([3, 3, 3]), snp.array([2, 2]))) + res = snp.blockarray((snp.array([0, 2, 1.0]), snp.array([0, 1.0]))) + snp.testing.assert_allclose(snp.floor_divide(Ba, Bb), res) - Ba = BlockArray.array((snp.array([1, 2, 3]), snp.array([1, 2]))) + Ba = snp.blockarray((snp.array([1, 2, 3]), snp.array([1, 2]))) A = 2 - res = BlockArray.array((snp.array([0, 1.0, 1.0]), snp.array([0, 1.0]))) - np.testing.assert_allclose(snp.floor_divide(Ba, A).ravel(), res.ravel()) + res = snp.blockarray((snp.array([0, 1.0, 1.0]), snp.array([0, 1.0]))) + snp.testing.assert_allclose(snp.floor_divide(Ba, A), res) def test_ufunc_real(): @@ -358,13 +198,13 @@ def test_ufunc_real(): res = snp.array([1, 4.0]) np.testing.assert_allclose(snp.real(A), res) - Ba = BlockArray.array((snp.array([1 + 3j]),)) - res = BlockArray.array((snp.array([1]),)) - np.testing.assert_allclose(snp.real(Ba).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1 + 3j]),)) + res = snp.blockarray((snp.array([1]),)) + snp.testing.assert_allclose(snp.real(Ba), res) - Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) - res = BlockArray.array((snp.array([1]), snp.array([1, 4.0]))) - np.testing.assert_allclose(snp.real(Ba).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1.0 + 3j]), snp.array([1 + 3j, 4.0]))) + res = snp.blockarray((snp.array([1.0]), snp.array([1, 4.0]))) + snp.testing.assert_allclose(snp.real(Ba), res) def test_ufunc_imag(): @@ -376,13 +216,13 @@ def test_ufunc_imag(): res = snp.array([3, 2]) np.testing.assert_allclose(snp.imag(A), res) - Ba = BlockArray.array((snp.array([1 + 3j]),)) - res = BlockArray.array((snp.array([3]),)) - np.testing.assert_allclose(snp.imag(Ba).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1 + 3j]),)) + res = snp.blockarray((snp.array([3]),)) + snp.testing.assert_allclose(snp.imag(Ba), res) - Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) - res = BlockArray.array((snp.array([3]), snp.array([3, 0]))) - np.testing.assert_allclose(snp.imag(Ba).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) + res = snp.blockarray((snp.array([3]), snp.array([3, 0]))) + snp.testing.assert_allclose(snp.imag(Ba), res) def test_ufunc_conj(): @@ -394,101 +234,103 @@ def test_ufunc_conj(): res = snp.array([1 - 3j, 4.0 - 2j]) np.testing.assert_allclose(snp.conj(A), res) - Ba = BlockArray.array((snp.array([1 + 3j]),)) - res = BlockArray.array((snp.array([1 - 3j]),)) - np.testing.assert_allclose(snp.conj(Ba).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1 + 3j]),)) + res = snp.blockarray((snp.array([1 - 3j]),)) + snp.testing.assert_allclose(snp.conj(Ba), res) - Ba = BlockArray.array((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) - res = BlockArray.array((snp.array([1 - 3j]), snp.array([1 - 3j, 4.0 - 0j]))) - np.testing.assert_allclose(snp.conj(Ba).ravel(), res.ravel()) + Ba = snp.blockarray((snp.array([1 + 3j]), snp.array([1 + 3j, 4.0]))) + res = snp.blockarray((snp.array([1 - 3j]), snp.array([1 - 3j, 4.0 - 0j]))) + snp.testing.assert_allclose(snp.conj(Ba), res) def test_create_zeros(): - A = snc.zeros(2) + A = snp.zeros(2) assert np.all(A == 0) - A = snc.zeros([(2,), (2,)]) - assert np.all(A.ravel() == 0) + A = snp.zeros(((2,), (2,))) + assert all(snp.all(A == 0)) def test_create_ones(): - A = snc.ones(2, dtype=np.float32) + A = snp.ones(2, dtype=np.float32) assert np.all(A == 1) - A = snc.ones([(2,), (2,)]) - assert np.all(A.ravel() == 1) + A = snp.ones(((2,), (2,))) + assert all(snp.all(A == 1)) def test_create_zeros(): - A = snc.empty(2) + A = snp.empty(2) assert np.all(A == 0) - A = snc.empty([(2,), (2,)]) - assert np.all(A.ravel() == 0) + A = snp.empty(((2,), (2,))) + assert all(snp.all(A == 0)) def test_create_full(): - A = snc.full((2,), 1) + A = snp.full((2,), 1) assert np.all(A == 1) - A = snc.full((2,), 1, dtype=np.float32) + A = snp.full((2,), 1, dtype=np.float32) assert np.all(A == 1) - A = snc.full([(2,), (2,)], 1) - assert np.all(A.ravel() == 1) + A = snp.full(((2,), (2,)), 1) + assert all(snp.all(A == 1)) def test_create_zeros_like(): - A = snc.ones(2, dtype=np.float32) - B = snc.zeros_like(A) + A = snp.ones(2, dtype=np.float32) + B = snp.zeros_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snc.ones(2, dtype=np.float32) - B = snc.zeros_like(A) + A = snp.ones(2, dtype=np.float32) + B = snp.zeros_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snc.ones([(2,), (2,)], dtype=np.float32) - B = snc.zeros_like(A) - assert np.all(B.ravel() == 0) and A.shape == B.shape and A.dtype == B.dtype + A = snp.ones(((2,), (2,)), dtype=np.float32) + B = snp.zeros_like(A) + assert all(snp.all(B == 0)) + assert A.shape == B.shape + assert A.dtype == B.dtype def test_create_empty_like(): - A = snc.ones(2, dtype=np.float32) - B = snc.empty_like(A) + A = snp.ones(2, dtype=np.float32) + B = snp.empty_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snc.ones(2, dtype=np.float32) - B = snc.empty_like(A) + A = snp.ones(2, dtype=np.float32) + B = snp.empty_like(A) assert np.all(B == 0) and A.shape == B.shape and A.dtype == B.dtype - A = snc.ones([(2,), (2,)], dtype=np.float32) - B = snc.empty_like(A) - assert np.all(B.ravel() == 0) and A.shape == B.shape and A.dtype == B.dtype + A = snp.ones(((2,), (2,)), dtype=np.float32) + B = snp.empty_like(A) + assert all(snp.all(B == 0)) and A.shape == B.shape and A.dtype == B.dtype def test_create_ones_like(): - A = snc.zeros(2, dtype=np.float32) - B = snc.ones_like(A) + A = snp.zeros(2, dtype=np.float32) + B = snp.ones_like(A) assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype - A = snc.zeros(2, dtype=np.float32) - B = snc.ones_like(A) + A = snp.zeros(2, dtype=np.float32) + B = snp.ones_like(A) assert np.all(B == 1) and A.shape == B.shape and A.dtype == B.dtype - A = snc.zeros([(2,), (2,)], dtype=np.float32) - B = snc.ones_like(A) - assert np.all(B.ravel() == 1) and A.shape == B.shape and A.dtype == B.dtype + A = snp.zeros(((2,), (2,)), dtype=np.float32) + B = snp.ones_like(A) + assert all(snp.all(B == 1)) and A.shape == B.shape and A.dtype == B.dtype def test_create_full_like(): - A = snc.zeros(2, dtype=np.float32) - B = snc.full_like(A, 1.0) + A = snp.zeros(2, dtype=np.float32) + B = snp.full_like(A, 1.0) assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) - A = snc.zeros(2, dtype=np.float32) - B = snc.full_like(A, 1) + A = snp.zeros(2, dtype=np.float32) + B = snp.full_like(A, 1) assert np.all(B == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) - A = snc.zeros([(2,), (2,)], dtype=np.float32) - B = snc.full_like(A, 1) - assert np.all(B.ravel() == 1) and (A.shape == B.shape) and (A.dtype == B.dtype) + A = snp.zeros(((2,), (2,)), dtype=np.float32) + B = snp.full_like(A, 1) + assert all(snp.all(B == 1)) and (A.shape == B.shape) and (A.dtype == B.dtype) diff --git a/scico/test/test_operator.py b/scico/test/test_operator.py index 30788d883..9fd23860d 100644 --- a/scico/test/test_operator.py +++ b/scico/test/test_operator.py @@ -12,7 +12,6 @@ import jax import scico.numpy as snp -from scico.blockarray import BlockArray from scico.operator import Operator from scico.random import randn @@ -176,11 +175,11 @@ def test_freeze_3arg(): input_shape=((1, 3, 4), (2, 1, 4), (2, 3, 1)), eval_fn=lambda x: x[0] * x[1] * x[2] ) - a = np.random.randn(1, 3, 4) - b = np.random.randn(2, 1, 4) - c = np.random.randn(2, 3, 1) + a, _ = randn((1, 3, 4)) + b, _ = randn((2, 1, 4)) + c, _ = randn((2, 3, 1)) - x = BlockArray.array([a, b, c]) + x = snp.blockarray([a, b, c]) Abc = A.freeze(0, a) # A as a function of b, c Aac = A.freeze(1, b) # A as a function of a, c Aab = A.freeze(2, c) # A as a function of a, b @@ -189,9 +188,9 @@ def test_freeze_3arg(): assert Aac.input_shape == ((1, 3, 4), (2, 3, 1)) assert Aab.input_shape == ((1, 3, 4), (2, 1, 4)) - bc = BlockArray.array([b, c]) - ac = BlockArray.array([a, c]) - ab = BlockArray.array([a, b]) + bc = snp.blockarray([b, c]) + ac = snp.blockarray([a, c]) + ab = snp.blockarray([a, b]) np.testing.assert_allclose(A(x), Abc(bc), rtol=5e-4) np.testing.assert_allclose(A(x), Aac(ac), rtol=5e-4) np.testing.assert_allclose(A(x), Aab(ab), rtol=5e-4) @@ -201,10 +200,10 @@ def test_freeze_2arg(): A = Operator(input_shape=((1, 3, 4), (2, 1, 4)), eval_fn=lambda x: x[0] * x[1]) - a = np.random.randn(1, 3, 4) - b = np.random.randn(2, 1, 4) + a, _ = randn((1, 3, 4)) + b, _ = randn((2, 1, 4)) - x = BlockArray.array([a, b]) + x = snp.blockarray([a, b]) Ab = A.freeze(0, a) # A as a function of 'b' only Aa = A.freeze(1, b) # A as a function of 'a' only diff --git a/scico/test/test_random.py b/scico/test/test_random.py index 7e044a58a..7b6b8f80b 100644 --- a/scico/test/test_random.py +++ b/scico/test/test_random.py @@ -25,11 +25,7 @@ def test_wrapped_funcs(seed): seed = 42 key = jax.random.PRNGKey(seed) - result = fun_wrapped(shape, seed=seed) - - # test nested blockarray - shape = ((7, (1, 3)), (3, 2), (2, 4, 1)) - result = fun_wrapped(shape, seed=seed) + result, _ = fun_wrapped(shape, seed=seed) def test_add_seed_adapter(): @@ -73,29 +69,3 @@ def test_add_seed_adapter(): # error when key and seed are specified with pytest.raises(Exception): _ = fun_alt(key=jax.random.PRNGKey(0), seed=42)[0] - - -def test_block_shape_adapter(): - fun = jax.random.normal - fun_alt = scico.random._allow_block_shape(fun) - - # when shape is nested, result should be a BlockArray... - shape = ((7,), (3, 2), (2, 4, 1)) - seed = 42 - key = jax.random.PRNGKey(seed) - - result = fun_alt(key, shape) - assert isinstance(result, scico.blockarray.BlockArray) - - # should work for deeply nested as well - shape = ((7,), (3, (2, 1)), (2, 4, 1)) - result = fun_alt(key, shape) - assert isinstance(result, scico.blockarray.BlockArray) - - # when shape is not nested, behavior should be normal - shape = (1,) - result_A = fun(key, shape) - result_B = fun_alt(key, shape) - np.testing.assert_array_equal(result_A, result_B) - - assert fun(key) == fun_alt(key) diff --git a/scico/test/test_ray_tune.py b/scico/test/test_ray_tune.py index c712c3828..f2db6da6b 100644 --- a/scico/test/test_ray_tune.py +++ b/scico/test/test_ray_tune.py @@ -32,7 +32,7 @@ def test_random(): "eval_func", metric="cost", mode="min", - num_samples=50, + num_samples=100, config=config, resources_per_trial=resources, hyperopt=False, diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index b7173f91d..d694d796d 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -6,7 +6,6 @@ import scico.numpy as snp from scico import random, solver -from scico.blockarray import BlockArray class TestSet: @@ -196,14 +195,13 @@ def test_split_join_blockarray(): x_s = solver._split_real_imag(x) assert x_s.shape == ((2, 4, 4), (2, 3)) - real_block = BlockArray.array((x_s[0][0], x_s[1][0])) - imag_block = BlockArray.array((x_s[0][1], x_s[1][1])) - np.testing.assert_allclose(real_block.ravel(), snp.real(x).ravel(), rtol=1e-4) - np.testing.assert_allclose(imag_block.ravel(), snp.imag(x).ravel(), rtol=1e-4) + real_block = snp.blockarray((x_s[0][0], x_s[1][0])) + imag_block = snp.blockarray((x_s[0][1], x_s[1][1])) + snp.testing.assert_allclose(real_block, snp.real(x), rtol=1e-4) + snp.testing.assert_allclose(imag_block, snp.imag(x), rtol=1e-4) x_j = solver._join_real_imag(x_s) - assert x_j.shape == x.shape - np.testing.assert_allclose(x_j.ravel(), x.ravel(), rtol=1e-4) + snp.testing.assert_allclose(x_j, x, rtol=1e-4) def test_bisect():