diff --git a/.github/workflows/pytest_ubuntu.yml b/.github/workflows/pytest_ubuntu.yml index b4ace2c86..db24830e5 100644 --- a/.github/workflows/pytest_ubuntu.yml +++ b/.github/workflows/pytest_ubuntu.yml @@ -65,7 +65,8 @@ jobs: mamba install -c astra-toolbox astra-toolbox mamba install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version - pip install bm4d>=4.0.0 + pip install bm3d>=4.0.0 + pip install bm4d>=4.2.2 pip install "ray[tune]>=2.0.0" pip install hyperopt # Install package to be tested diff --git a/CHANGES.rst b/CHANGES.rst index ebc467893..a413f53cb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,7 +6,8 @@ SCICO Release Notes Version 0.0.5 (unreleased) ---------------------------- -• No significant changes yet. +• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19. diff --git a/data b/data index 0d9f1fef8..1f1e9f83b 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 0d9f1fef8df6eebb98d154e1e6d1ab8357914a88 +Subproject commit 1f1e9f83bb52bf9a08115ab71d8bb32a05c4ff0c diff --git a/examples/examples_requirements.txt b/examples/examples_requirements.txt index afb79a816..60b77e7e7 100644 --- a/examples/examples_requirements.txt +++ b/examples/examples_requirements.txt @@ -3,3 +3,5 @@ colour_demosaicing xdesign>=0.5.5 ray[tune]>=2.0.0 hyperopt +bm3d>=4.0.0 +bm4d>=4.2.2 diff --git a/requirements.txt b/requirements.txt index e62fcb7cc..68ab1ebac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,8 @@ scipy>=1.6.0 tifffile imageio>=2.17 matplotlib -jaxlib>=0.4.3,<=0.4.16 -jax>=0.4.3,<=0.4.16 +jaxlib>=0.4.3,<=0.4.19 +jax>=0.4.3,<=0.4.19 flax>=0.6.1,<=0.6.9 -bm3d>=4.0.0 -bm4d>=4.2.2 svmbir>=0.3.3 pyabel>=0.9.0 diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index 0f6f21daa..951c6957e 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -17,9 +17,9 @@ import numpy as np -import jax import jax.numpy as jnp from jax.dtypes import result_type +from jax.typing import ArrayLike import scico.numpy as snp @@ -65,7 +65,7 @@ def wrapper(a, b): class MatrixOperator(LinearOperator): """Linear operator implementing matrix multiplication.""" - def __init__(self, A: snp.Array, input_cols: int = 0): + def __init__(self, A: ArrayLike, input_cols: int = 0): """ Args: A: Dense array. The action of the created @@ -80,17 +80,16 @@ def __init__(self, A: snp.Array, input_cols: int = 0): self.A: snp.Array #: Dense array implementing this matrix # if A is an ndarray, make sure it gets converted to a jax array - if isinstance(A, jnp.ndarray): - self.A = A - elif isinstance(A, np.ndarray): - self.A = jax.device_put(A) # TODO: ensure_on_device? - else: + if not snp.util.is_arraylike(A): raise TypeError(f"Expected numpy or jax array, got {type(A)}.") + self.A = jnp.array(A) # Can only do rank-2 arrays if A.ndim != 2: raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}.") + self.__array__ = A.__array__ # enables jnp.array(H) + if input_cols == 0: input_shape = A.shape[1] output_shape = A.shape[0] diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 0e776038b..50fefdd4e 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -204,6 +204,21 @@ def shape_to_size(shape: Union[Shape, BlockShape]) -> int: return prod(shape) +def is_arraylike(x: Any) -> bool: + """Check if input is of type :class:`jax.ArrayLike`. + + `isinstance(x, jax.typing.ArrayLike)` does not work in Python < 3.10, + see https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices. + + Args: + x: Object to be tested. + + Returns: + ``True`` if `x` is an ArrayLike, ``False`` otherwise. + """ + return isinstance(x, (np.ndarray, jax.Array)) or np.isscalar(x) + + def is_nested(x: Any) -> bool: """Check if input is a list/tuple containing at least one list/tuple. @@ -320,3 +335,18 @@ def complex_dtype(dtype: DType) -> DType: """ return (snp.zeros(1, dtype) + 1j).dtype + + +def is_scalar_equiv(s: Any) -> bool: + """Determine whether an object is a scalar or is scalar-equivalent. + + Determine whether an object is a scalar or a singleton array. + + Args: + s: Object to be tested. + + Returns: + ``True`` if the object is a scalar or a singleton array, + otherwise ``False``. + """ + return snp.isscalar(s) or (isinstance(s, jax.Array) and s.ndim == 0) diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py index 2fc01fe7d..baef6ec91 100644 --- a/scico/operator/_operator.py +++ b/scico/operator/_operator.py @@ -47,7 +47,7 @@ def _wrap_mul_div_scalar(func: Callable) -> Callable: @wraps(func) def wrapper(a, b): - if np.isscalar(b) or isinstance(b, jax.core.Tracer): + if snp.util.is_scalar_equiv(b): return func(a, b) raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 7a0ceb0b2..7a1c8710c 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -31,7 +31,7 @@ from scico.loss import SquaredL2Loss from scico.numpy import Array, BlockArray from scico.numpy.util import ensure_on_device, is_real_dtype -from scico.solver import ATADSolver, ConvATADSolver +from scico.solver import ConvATADSolver, MatrixATADSolver from scico.solver import cg as scico_cg from scico.solver import minimize @@ -296,14 +296,14 @@ class MatrixSubproblemSolver(LinearSubproblemSolver): \mb{u}^{(k)}_i) \;, which is solved by factorization of the left hand side of the - equation, using :class:`.ATADSolver`. + equation, using :class:`.MatrixATADSolver`. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. solve_kwargs (dict): Dictionary of arguments for solver - :class:`.ATADSolver` initialization. + :class:`.MatrixATADSolver` initialization. """ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None): @@ -313,7 +313,7 @@ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, A check_solve: If ``True``, compute solver accuracy after each solve. solve_kwargs: Dictionary of arguments for solver - :class:`.ATADSolver` initialization. + :class:`.MatrixATADSolver` initialization. """ self.check_solve = check_solve default_solve_kwargs = {"cho_factor": False} @@ -352,7 +352,7 @@ def internal_init(self, admm: soa.ADMM): Csum = reduce( lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)] ) - self.solver = ATADSolver(A, Csum, W, **self.solve_kwargs) + self.solver = MatrixATADSolver(A, Csum, W, **self.solve_kwargs) def solve(self, x0: Array) -> Array: """Solve the ADMM step. diff --git a/scico/solver.py b/scico/solver.py index 5f0994246..f93cd710e 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -54,6 +54,7 @@ import jax import jax.experimental.host_callback as hcb +import jax.numpy as jnp import jax.scipy.linalg as jsl import scico.numpy as snp @@ -260,14 +261,13 @@ def fun(x0): def minimize_scalar( func: Callable, - bracket: Optional[Union[Sequence[float]]] = None, + bracket: Optional[Sequence[float]] = None, bounds: Optional[Sequence[float]] = None, args: Union[Tuple, Tuple[Any]] = (), method: str = "brent", tol: Optional[float] = None, options: Optional[dict] = None, ) -> spopt.OptimizeResult: - """Minimization of scalar function of one variable. Wrapper around :func:`scipy.optimize.minimize_scalar`. @@ -579,8 +579,8 @@ def golden( return r -class ATADSolver: - r"""Solver for linear system involving a symmetric product plus a diagonal. +class MatrixATADSolver: + r"""Solver for linear system involving a symmetric product. Solve a linear system of the form @@ -596,12 +596,18 @@ class ATADSolver: where :math:`A \in \mbb{R}^{M \times N}`, :math:`W \in \mbb{R}^{M \times M}` and - :math:`D \in \mbb{R}^{N \times N}`. The solution is computed by - factorization of matrix :math:`A^T W A + D` and solution via Gaussian - elimination. If :math:`D` is diagonal and :math:`N < M` (i.e. - :math:`A W A^T` is smaller than :math:`A^T W A`), then - :math:`A W A^T + D` is factorized and the original problem is solved - via the Woodbury matrix identity + :math:`D \in \mbb{R}^{N \times N}`. :math:`A` must be an instance of + :class:`.MatrixOperator` or an array; :math:`D` must be an instance + of :class:`.MatrixOperator`, :class:`.Diagonal`, or an array, and + :math:`W`, if specified, must be an instance of :class:`.Diagonal` + or an array. + + + The solution is computed by factorization of matrix + :math:`A^T W A + D` and solution via Gaussian elimination. If + :math:`D` is diagonal and :math:`N < M` (i.e. :math:`A W A^T` is + smaller than :math:`A^T W A`), then :math:`A W A^T + D` is factorized + and the original problem is solved via the Woodbury matrix identity .. math:: @@ -698,8 +704,12 @@ def __init__( r""" Args: A: Matrix :math:`A`. - D: Matrix :math:`D`. - W: Matrix :math:`W`. + D: Matrix :math:`D`. If a 2D array or :class:`MatrixOperator`, + specifies the 2D matrix :math:`D`. If 1D array or + :class:`Diagonal`, specifies the diagonal elements + of :math:`D`. + W: Matrix :math:`W`. Specifies the diagonal elements of + :math:`W`. Defaults to an array with unit entries. cho_factor: Flag indicating whether to use Cholesky (``True``) or LU (``False``) factorization. lower: Flag indicating whether lower (``True``) or upper @@ -708,16 +718,28 @@ def __init__( check_finite: Flag indicating whether the input array should be checked for ``Inf`` and ``NaN`` values. """ - if isinstance(A, MatrixOperator): - A = A.to_array() - if isinstance(D, MatrixOperator): - D = D.to_array() - elif isinstance(D, Diagonal): + A = jnp.array(A) + + if isinstance(D, Diagonal): D = D.diagonal + if not D.ndim == 1: + raise ValueError("If Diagonal, D should have a 1D diagonal.") + else: + D = jnp.array(D) + if not D.ndim in [1, 2]: + raise ValueError("If array or MatrixOperator, D should be 1D or 2D.") + if W is None: W = snp.ones(A.shape[0], dtype=A.dtype) elif isinstance(W, Diagonal): W = W.diagonal + if not W.ndim == 1: + raise ValueError("If Diagonal, W should have a 1D diagonal.") + elif not isinstance(W, Array): + raise TypeError( + f"Operator W is required to be None, a Diagonal, or an array; got a {type(W)}." + ) + self.A = A self.D = D self.W = W @@ -796,7 +818,7 @@ def accuracy(self, x: Array, b: Array) -> float: class ConvATADSolver: - r"""Solver for sum of convolutions plus diagonal linear system. + r"""Solver for a linear system involving a sum of convolutions. Solve a linear system of the form diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index aa7e21075..ebd0f2406 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -17,6 +17,8 @@ from scico.random import randn from scico.typing import PRNGKey +SCALARS = (2, 1e0, snp.array(1.0)) + def adjoint_test( A: linop.LinearOperator, @@ -67,8 +69,7 @@ def __init__(self, dtype): self.x, key = randn((N,), dtype=dtype, key=key) self.y, key = randn((M,), dtype=dtype, key=key) - scalar, key = randn((1,), dtype=dtype, key=key) - self.scalar = scalar.item() + self.Ao = AbsMatOp(self.A) self.Bo = AbsMatOp(self.B) self.Co = AbsMatOp(self.C) @@ -101,9 +102,10 @@ def test_binary_op(testobj, operator): @pytest.mark.parametrize("operator", [op.mul, op.truediv]) -def test_scalar_left(testobj, operator): - comp_mat = operator(testobj.A, testobj.scalar) - comp_op = operator(testobj.Ao, testobj.scalar) +@pytest.mark.parametrize("scalar", SCALARS) +def test_scalar_left(testobj, operator, scalar): + comp_mat = operator(testobj.A, scalar) + comp_op = operator(testobj.Ao, scalar) assert isinstance(comp_op, linop.LinearOperator) # Ensure we don't get a Map assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) @@ -112,11 +114,12 @@ def test_scalar_left(testobj, operator): @pytest.mark.parametrize("operator", [op.mul, op.truediv]) -def test_scalar_right(testobj, operator): +@pytest.mark.parametrize("scalar", SCALARS) +def test_scalar_right(testobj, operator, scalar): if operator == op.truediv: pytest.xfail("scalar / LinearOperator is not supported") - comp_mat = operator(testobj.scalar, testobj.A) - comp_op = operator(testobj.scalar, testobj.Ao) + comp_mat = operator(scalar, testobj.A) + comp_op = operator(scalar, testobj.Ao) assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) diff --git a/scico/test/linop/test_matrix.py b/scico/test/linop/test_matrix.py index 3f00b2c7a..178c1fce5 100644 --- a/scico/test/linop/test_matrix.py +++ b/scico/test/linop/test_matrix.py @@ -22,7 +22,6 @@ def setup_method(self, method): @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_eval(self, matrix_shape, input_dtype, input_cols): - A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) @@ -38,7 +37,6 @@ def test_eval(self, matrix_shape, input_dtype, input_cols): @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_adjoint(self, matrix_shape, input_dtype, input_cols): - A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) @@ -262,6 +260,10 @@ def test_to_array(self): assert isinstance(A_array, np.ndarray) np.testing.assert_allclose(A_array, A) + A_array = jnp.array(Ao) + assert isinstance(A_array, jax.Array) + np.testing.assert_allclose(A_array, A) + @pytest.mark.parametrize("ord", ["fro", 2]) @pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("keepdims", [True, False]) diff --git a/scico/test/test_numpy_util.py b/scico/test/test_numpy_util.py index f274d71c4..be1c9f5ab 100644 --- a/scico/test/test_numpy_util.py +++ b/scico/test/test_numpy_util.py @@ -15,6 +15,7 @@ is_complex_dtype, is_nested, is_real_dtype, + is_scalar_equiv, no_nan_divide, parse_axes, real_dtype, @@ -192,3 +193,11 @@ def test_broadcast_nested_shapes(): (1, 2, 3), (1, 7, 4, 3), ) + + +def test_is_scalar_equiv(): + assert is_scalar_equiv(1e0) + assert is_scalar_equiv(snp.array(1e0)) + assert is_scalar_equiv(snp.sum(snp.zeros(1))) + assert not is_scalar_equiv(snp.array([1e0])) + assert not is_scalar_equiv(snp.array([1e0, 2e0])) diff --git a/scico/test/test_operator.py b/scico/test/test_operator.py index ea97a97b8..1c86063ec 100644 --- a/scico/test/test_operator.py +++ b/scico/test/test_operator.py @@ -15,6 +15,8 @@ from scico.operator import Abs, Angle, Exp, Operator, operator_from_function from scico.random import randn +SCALARS = (2, 1e0, snp.array(1.0)) + class AbsOperator(Operator): def _eval(self, x): @@ -43,8 +45,6 @@ def __init__(self, dtype): self.mat = randn(self.A.input_shape, dtype=dtype, key=key) self.x, key = randn((N,), dtype=dtype, key=key) - scalar, key = randn((1,), dtype=dtype, key=key) - self.scalar = scalar.item() # jax array -> actual scalar self.z, key = randn((2 * N,), dtype=dtype, key=key) @@ -85,21 +85,23 @@ def test_binary_op_same(testobj, operator): @pytest.mark.parametrize("operator", [op.mul, op.truediv]) -def test_scalar_left(testobj, operator): +@pytest.mark.parametrize("scalar", SCALARS) +def test_scalar_left(testobj, operator, scalar): x = testobj.x - comp_op = operator(testobj.A, testobj.scalar) - res = operator(testobj.A(x), testobj.scalar) + comp_op = operator(testobj.A, scalar) + res = operator(testobj.A(x), scalar) assert comp_op.output_dtype == res.dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) -def test_scalar_right(testobj, operator): +@pytest.mark.parametrize("scalar", SCALARS) +def test_scalar_right(testobj, operator, scalar): if operator == op.truediv: pytest.xfail("scalar / Operator is not supported") x = testobj.x - comp_op = operator(testobj.scalar, testobj.A) - res = operator(testobj.scalar, testobj.A(x)) + comp_op = operator(scalar, testobj.A) + res = operator(scalar, testobj.A(x)) assert comp_op.output_dtype == res.dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5) diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index d5b179b62..f220482df 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -319,7 +319,7 @@ def test_solve_atai(cho_factor, wide, weighted, alpha): D = alpha * snp.ones((A.shape[1],)) ATAD = A.T @ (Wa * A) + alpha * snp.identity(A.shape[1]) b = ATAD @ x0 - slv = solver.ATADSolver(A, D, W=W, cho_factor=cho_factor) + slv = solver.MatrixATADSolver(A, D, W=W, cho_factor=cho_factor) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 @@ -338,7 +338,7 @@ def test_solve_aati(cho_factor, wide, alpha): D = alpha * snp.ones((A.shape[0],)) AATD = A @ A.T + alpha * snp.identity(A.shape[0]) b = AATD @ x0 - slv = solver.ATADSolver(A.T, D) + slv = solver.MatrixATADSolver(A.T, D) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 @@ -365,7 +365,7 @@ def test_solve_atad(cho_factor, wide, vector): D = snp.abs(D) # only required for Cholesky, but improved accuracy for LU ATAD = A.T @ A + snp.diag(D) b = ATAD @ x0 - slv = solver.ATADSolver(A, D, cho_factor=cho_factor) + slv = solver.MatrixATADSolver(A, D, cho_factor=cho_factor) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 assert slv.accuracy(x1, b) < 5e-5