Skip to content

Commit

Permalink
Merge branch 'main' into brendt/ray
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 27, 2023
2 parents 65e36a6 + 3973e1b commit c19783d
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest_ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
mamba install -c astra-toolbox astra-toolbox
mamba install -c conda-forge pyyaml
pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version
pip install bm4d>=4.0.0
pip install bm3d>=4.0.0
pip install bm4d>=4.2.2
pip install "ray[tune]>=2.0.0"
pip install hyperopt
Expand Down
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ SCICO Release Notes
Version 0.0.5 (unreleased)
----------------------------

• No significant changes yet.
• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19.



Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ scipy>=1.6.0
tifffile
imageio>=2.17
matplotlib
jaxlib>=0.4.3,<=0.4.16
jax>=0.4.3,<=0.4.16
jaxlib>=0.4.3,<=0.4.19
jax>=0.4.3,<=0.4.19
flax>=0.6.1,<=0.6.9
svmbir>=0.3.3
pyabel>=0.9.0
13 changes: 6 additions & 7 deletions scico/linop/_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import numpy as np

import jax
import jax.numpy as jnp
from jax.dtypes import result_type
from jax.typing import ArrayLike

import scico.numpy as snp

Expand Down Expand Up @@ -65,7 +65,7 @@ def wrapper(a, b):
class MatrixOperator(LinearOperator):
"""Linear operator implementing matrix multiplication."""

def __init__(self, A: snp.Array, input_cols: int = 0):
def __init__(self, A: ArrayLike, input_cols: int = 0):
"""
Args:
A: Dense array. The action of the created
Expand All @@ -80,17 +80,16 @@ def __init__(self, A: snp.Array, input_cols: int = 0):
self.A: snp.Array #: Dense array implementing this matrix

# if A is an ndarray, make sure it gets converted to a jax array
if isinstance(A, jnp.ndarray):
self.A = A
elif isinstance(A, np.ndarray):
self.A = jax.device_put(A) # TODO: ensure_on_device?
else:
if not snp.util.is_arraylike(A):
raise TypeError(f"Expected numpy or jax array, got {type(A)}.")
self.A = jnp.array(A)

# Can only do rank-2 arrays
if A.ndim != 2:
raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}.")

self.__array__ = A.__array__ # enables jnp.array(H)

if input_cols == 0:
input_shape = A.shape[1]
output_shape = A.shape[0]
Expand Down
30 changes: 30 additions & 0 deletions scico/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ def shape_to_size(shape: Union[Shape, BlockShape]) -> int:
return prod(shape)


def is_arraylike(x: Any) -> bool:
"""Check if input is of type :class:`jax.ArrayLike`.
`isinstance(x, jax.typing.ArrayLike)` does not work in Python < 3.10,
see https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices.
Args:
x: Object to be tested.
Returns:
``True`` if `x` is an ArrayLike, ``False`` otherwise.
"""
return isinstance(x, (np.ndarray, jax.Array)) or np.isscalar(x)


def is_nested(x: Any) -> bool:
"""Check if input is a list/tuple containing at least one list/tuple.
Expand Down Expand Up @@ -320,3 +335,18 @@ def complex_dtype(dtype: DType) -> DType:
"""

return (snp.zeros(1, dtype) + 1j).dtype


def is_scalar_equiv(s: Any) -> bool:
"""Determine whether an object is a scalar or is scalar-equivalent.
Determine whether an object is a scalar or a singleton array.
Args:
s: Object to be tested.
Returns:
``True`` if the object is a scalar or a singleton array,
otherwise ``False``.
"""
return snp.isscalar(s) or (isinstance(s, jax.Array) and s.ndim == 0)
2 changes: 1 addition & 1 deletion scico/operator/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _wrap_mul_div_scalar(func: Callable) -> Callable:

@wraps(func)
def wrapper(a, b):
if np.isscalar(b) or isinstance(b, jax.core.Tracer):
if snp.util.is_scalar_equiv(b):
return func(a, b)

raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")
Expand Down
10 changes: 5 additions & 5 deletions scico/optimize/_admmaux.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from scico.loss import SquaredL2Loss
from scico.numpy import Array, BlockArray
from scico.numpy.util import ensure_on_device, is_real_dtype
from scico.solver import ATADSolver, ConvATADSolver
from scico.solver import ConvATADSolver, MatrixATADSolver
from scico.solver import cg as scico_cg
from scico.solver import minimize

Expand Down Expand Up @@ -296,14 +296,14 @@ class MatrixSubproblemSolver(LinearSubproblemSolver):
\mb{u}^{(k)}_i) \;,
which is solved by factorization of the left hand side of the
equation, using :class:`.ATADSolver`.
equation, using :class:`.MatrixATADSolver`.
Attributes:
admm (:class:`.ADMM`): ADMM solver object to which the solver is
attached.
solve_kwargs (dict): Dictionary of arguments for solver
:class:`.ATADSolver` initialization.
:class:`.MatrixATADSolver` initialization.
"""

def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None):
Expand All @@ -313,7 +313,7 @@ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, A
check_solve: If ``True``, compute solver accuracy after each
solve.
solve_kwargs: Dictionary of arguments for solver
:class:`.ATADSolver` initialization.
:class:`.MatrixATADSolver` initialization.
"""
self.check_solve = check_solve
default_solve_kwargs = {"cho_factor": False}
Expand Down Expand Up @@ -352,7 +352,7 @@ def internal_init(self, admm: soa.ADMM):
Csum = reduce(
lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)]
)
self.solver = ATADSolver(A, Csum, W, **self.solve_kwargs)
self.solver = MatrixATADSolver(A, Csum, W, **self.solve_kwargs)

def solve(self, x0: Array) -> Array:
"""Solve the ADMM step.
Expand Down
58 changes: 40 additions & 18 deletions scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

import jax
import jax.experimental.host_callback as hcb
import jax.numpy as jnp
import jax.scipy.linalg as jsl

import scico.numpy as snp
Expand Down Expand Up @@ -260,14 +261,13 @@ def fun(x0):

def minimize_scalar(
func: Callable,
bracket: Optional[Union[Sequence[float]]] = None,
bracket: Optional[Sequence[float]] = None,
bounds: Optional[Sequence[float]] = None,
args: Union[Tuple, Tuple[Any]] = (),
method: str = "brent",
tol: Optional[float] = None,
options: Optional[dict] = None,
) -> spopt.OptimizeResult:

"""Minimization of scalar function of one variable.
Wrapper around :func:`scipy.optimize.minimize_scalar`.
Expand Down Expand Up @@ -579,8 +579,8 @@ def golden(
return r


class ATADSolver:
r"""Solver for linear system involving a symmetric product plus a diagonal.
class MatrixATADSolver:
r"""Solver for linear system involving a symmetric product.
Solve a linear system of the form
Expand All @@ -596,12 +596,18 @@ class ATADSolver:
where :math:`A \in \mbb{R}^{M \times N}`,
:math:`W \in \mbb{R}^{M \times M}` and
:math:`D \in \mbb{R}^{N \times N}`. The solution is computed by
factorization of matrix :math:`A^T W A + D` and solution via Gaussian
elimination. If :math:`D` is diagonal and :math:`N < M` (i.e.
:math:`A W A^T` is smaller than :math:`A^T W A`), then
:math:`A W A^T + D` is factorized and the original problem is solved
via the Woodbury matrix identity
:math:`D \in \mbb{R}^{N \times N}`. :math:`A` must be an instance of
:class:`.MatrixOperator` or an array; :math:`D` must be an instance
of :class:`.MatrixOperator`, :class:`.Diagonal`, or an array, and
:math:`W`, if specified, must be an instance of :class:`.Diagonal`
or an array.
The solution is computed by factorization of matrix
:math:`A^T W A + D` and solution via Gaussian elimination. If
:math:`D` is diagonal and :math:`N < M` (i.e. :math:`A W A^T` is
smaller than :math:`A^T W A`), then :math:`A W A^T + D` is factorized
and the original problem is solved via the Woodbury matrix identity
.. math::
Expand Down Expand Up @@ -698,8 +704,12 @@ def __init__(
r"""
Args:
A: Matrix :math:`A`.
D: Matrix :math:`D`.
W: Matrix :math:`W`.
D: Matrix :math:`D`. If a 2D array or :class:`MatrixOperator`,
specifies the 2D matrix :math:`D`. If 1D array or
:class:`Diagonal`, specifies the diagonal elements
of :math:`D`.
W: Matrix :math:`W`. Specifies the diagonal elements of
:math:`W`. Defaults to an array with unit entries.
cho_factor: Flag indicating whether to use Cholesky
(``True``) or LU (``False``) factorization.
lower: Flag indicating whether lower (``True``) or upper
Expand All @@ -708,16 +718,28 @@ def __init__(
check_finite: Flag indicating whether the input array should
be checked for ``Inf`` and ``NaN`` values.
"""
if isinstance(A, MatrixOperator):
A = A.to_array()
if isinstance(D, MatrixOperator):
D = D.to_array()
elif isinstance(D, Diagonal):
A = jnp.array(A)

if isinstance(D, Diagonal):
D = D.diagonal
if not D.ndim == 1:
raise ValueError("If Diagonal, D should have a 1D diagonal.")
else:
D = jnp.array(D)
if not D.ndim in [1, 2]:
raise ValueError("If array or MatrixOperator, D should be 1D or 2D.")

if W is None:
W = snp.ones(A.shape[0], dtype=A.dtype)
elif isinstance(W, Diagonal):
W = W.diagonal
if not W.ndim == 1:
raise ValueError("If Diagonal, W should have a 1D diagonal.")
elif not isinstance(W, Array):
raise TypeError(
f"Operator W is required to be None, a Diagonal, or an array; got a {type(W)}."
)

self.A = A
self.D = D
self.W = W
Expand Down Expand Up @@ -796,7 +818,7 @@ def accuracy(self, x: Array, b: Array) -> float:


class ConvATADSolver:
r"""Solver for sum of convolutions plus diagonal linear system.
r"""Solver for a linear system involving a sum of convolutions.
Solve a linear system of the form
Expand Down
19 changes: 11 additions & 8 deletions scico/test/linop/test_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from scico.random import randn
from scico.typing import PRNGKey

SCALARS = (2, 1e0, snp.array(1.0))


def adjoint_test(
A: linop.LinearOperator,
Expand Down Expand Up @@ -67,8 +69,7 @@ def __init__(self, dtype):

self.x, key = randn((N,), dtype=dtype, key=key)
self.y, key = randn((M,), dtype=dtype, key=key)
scalar, key = randn((1,), dtype=dtype, key=key)
self.scalar = scalar.item()

self.Ao = AbsMatOp(self.A)
self.Bo = AbsMatOp(self.B)
self.Co = AbsMatOp(self.C)
Expand Down Expand Up @@ -101,9 +102,10 @@ def test_binary_op(testobj, operator):


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_left(testobj, operator):
comp_mat = operator(testobj.A, testobj.scalar)
comp_op = operator(testobj.Ao, testobj.scalar)
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_left(testobj, operator, scalar):
comp_mat = operator(testobj.A, scalar)
comp_op = operator(testobj.Ao, scalar)
assert isinstance(comp_op, linop.LinearOperator) # Ensure we don't get a Map
assert comp_op.input_dtype == testobj.A.dtype
np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)
Expand All @@ -112,11 +114,12 @@ def test_scalar_left(testobj, operator):


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_right(testobj, operator):
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_right(testobj, operator, scalar):
if operator == op.truediv:
pytest.xfail("scalar / LinearOperator is not supported")
comp_mat = operator(testobj.scalar, testobj.A)
comp_op = operator(testobj.scalar, testobj.Ao)
comp_mat = operator(scalar, testobj.A)
comp_op = operator(scalar, testobj.Ao)
assert comp_op.input_dtype == testobj.A.dtype
np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)

Expand Down
6 changes: 4 additions & 2 deletions scico/test/linop/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def setup_method(self, method):
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)])
def test_eval(self, matrix_shape, input_dtype, input_cols):

A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)
Ao = MatrixOperator(A, input_cols=input_cols)

Expand All @@ -38,7 +37,6 @@ def test_eval(self, matrix_shape, input_dtype, input_cols):
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)])
def test_adjoint(self, matrix_shape, input_dtype, input_cols):

A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)
Ao = MatrixOperator(A, input_cols=input_cols)

Expand Down Expand Up @@ -262,6 +260,10 @@ def test_to_array(self):
assert isinstance(A_array, np.ndarray)
np.testing.assert_allclose(A_array, A)

A_array = jnp.array(Ao)
assert isinstance(A_array, jax.Array)
np.testing.assert_allclose(A_array, A)

@pytest.mark.parametrize("ord", ["fro", 2])
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize("keepdims", [True, False])
Expand Down
9 changes: 9 additions & 0 deletions scico/test/test_numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_complex_dtype,
is_nested,
is_real_dtype,
is_scalar_equiv,
no_nan_divide,
parse_axes,
real_dtype,
Expand Down Expand Up @@ -192,3 +193,11 @@ def test_broadcast_nested_shapes():
(1, 2, 3),
(1, 7, 4, 3),
)


def test_is_scalar_equiv():
assert is_scalar_equiv(1e0)
assert is_scalar_equiv(snp.array(1e0))
assert is_scalar_equiv(snp.sum(snp.zeros(1)))
assert not is_scalar_equiv(snp.array([1e0]))
assert not is_scalar_equiv(snp.array([1e0, 2e0]))
Loading

0 comments on commit c19783d

Please sign in to comment.