diff --git a/.github/workflows/pytest_ubuntu.yml b/.github/workflows/pytest_ubuntu.yml index 4c0364cdd..db24830e5 100644 --- a/.github/workflows/pytest_ubuntu.yml +++ b/.github/workflows/pytest_ubuntu.yml @@ -65,7 +65,7 @@ jobs: mamba install -c astra-toolbox astra-toolbox mamba install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version - pip install bm4d>=4.0.0 + pip install bm3d>=4.0.0 pip install bm4d>=4.2.2 pip install "ray[tune]>=2.0.0" pip install hyperopt diff --git a/scico/numpy/util.py b/scico/numpy/util.py index a9e31a00d..50fefdd4e 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -335,3 +335,18 @@ def complex_dtype(dtype: DType) -> DType: """ return (snp.zeros(1, dtype) + 1j).dtype + + +def is_scalar_equiv(s: Any) -> bool: + """Determine whether an object is a scalar or is scalar-equivalent. + + Determine whether an object is a scalar or a singleton array. + + Args: + s: Object to be tested. + + Returns: + ``True`` if the object is a scalar or a singleton array, + otherwise ``False``. + """ + return snp.isscalar(s) or (isinstance(s, jax.Array) and s.ndim == 0) diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py index 2fc01fe7d..baef6ec91 100644 --- a/scico/operator/_operator.py +++ b/scico/operator/_operator.py @@ -47,7 +47,7 @@ def _wrap_mul_div_scalar(func: Callable) -> Callable: @wraps(func) def wrapper(a, b): - if np.isscalar(b) or isinstance(b, jax.core.Tracer): + if snp.util.is_scalar_equiv(b): return func(a, b) raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index aa7e21075..ebd0f2406 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -17,6 +17,8 @@ from scico.random import randn from scico.typing import PRNGKey +SCALARS = (2, 1e0, snp.array(1.0)) + def adjoint_test( A: linop.LinearOperator, @@ -67,8 +69,7 @@ def __init__(self, dtype): self.x, key = randn((N,), dtype=dtype, key=key) self.y, key = randn((M,), dtype=dtype, key=key) - scalar, key = randn((1,), dtype=dtype, key=key) - self.scalar = scalar.item() + self.Ao = AbsMatOp(self.A) self.Bo = AbsMatOp(self.B) self.Co = AbsMatOp(self.C) @@ -101,9 +102,10 @@ def test_binary_op(testobj, operator): @pytest.mark.parametrize("operator", [op.mul, op.truediv]) -def test_scalar_left(testobj, operator): - comp_mat = operator(testobj.A, testobj.scalar) - comp_op = operator(testobj.Ao, testobj.scalar) +@pytest.mark.parametrize("scalar", SCALARS) +def test_scalar_left(testobj, operator, scalar): + comp_mat = operator(testobj.A, scalar) + comp_op = operator(testobj.Ao, scalar) assert isinstance(comp_op, linop.LinearOperator) # Ensure we don't get a Map assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) @@ -112,11 +114,12 @@ def test_scalar_left(testobj, operator): @pytest.mark.parametrize("operator", [op.mul, op.truediv]) -def test_scalar_right(testobj, operator): +@pytest.mark.parametrize("scalar", SCALARS) +def test_scalar_right(testobj, operator, scalar): if operator == op.truediv: pytest.xfail("scalar / LinearOperator is not supported") - comp_mat = operator(testobj.scalar, testobj.A) - comp_op = operator(testobj.scalar, testobj.Ao) + comp_mat = operator(scalar, testobj.A) + comp_op = operator(scalar, testobj.Ao) assert comp_op.input_dtype == testobj.A.dtype np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5) diff --git a/scico/test/test_numpy_util.py b/scico/test/test_numpy_util.py index f274d71c4..be1c9f5ab 100644 --- a/scico/test/test_numpy_util.py +++ b/scico/test/test_numpy_util.py @@ -15,6 +15,7 @@ is_complex_dtype, is_nested, is_real_dtype, + is_scalar_equiv, no_nan_divide, parse_axes, real_dtype, @@ -192,3 +193,11 @@ def test_broadcast_nested_shapes(): (1, 2, 3), (1, 7, 4, 3), ) + + +def test_is_scalar_equiv(): + assert is_scalar_equiv(1e0) + assert is_scalar_equiv(snp.array(1e0)) + assert is_scalar_equiv(snp.sum(snp.zeros(1))) + assert not is_scalar_equiv(snp.array([1e0])) + assert not is_scalar_equiv(snp.array([1e0, 2e0])) diff --git a/scico/test/test_operator.py b/scico/test/test_operator.py index ea97a97b8..1c86063ec 100644 --- a/scico/test/test_operator.py +++ b/scico/test/test_operator.py @@ -15,6 +15,8 @@ from scico.operator import Abs, Angle, Exp, Operator, operator_from_function from scico.random import randn +SCALARS = (2, 1e0, snp.array(1.0)) + class AbsOperator(Operator): def _eval(self, x): @@ -43,8 +45,6 @@ def __init__(self, dtype): self.mat = randn(self.A.input_shape, dtype=dtype, key=key) self.x, key = randn((N,), dtype=dtype, key=key) - scalar, key = randn((1,), dtype=dtype, key=key) - self.scalar = scalar.item() # jax array -> actual scalar self.z, key = randn((2 * N,), dtype=dtype, key=key) @@ -85,21 +85,23 @@ def test_binary_op_same(testobj, operator): @pytest.mark.parametrize("operator", [op.mul, op.truediv]) -def test_scalar_left(testobj, operator): +@pytest.mark.parametrize("scalar", SCALARS) +def test_scalar_left(testobj, operator, scalar): x = testobj.x - comp_op = operator(testobj.A, testobj.scalar) - res = operator(testobj.A(x), testobj.scalar) + comp_op = operator(testobj.A, scalar) + res = operator(testobj.A(x), scalar) assert comp_op.output_dtype == res.dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5) @pytest.mark.parametrize("operator", [op.mul, op.truediv]) -def test_scalar_right(testobj, operator): +@pytest.mark.parametrize("scalar", SCALARS) +def test_scalar_right(testobj, operator, scalar): if operator == op.truediv: pytest.xfail("scalar / Operator is not supported") x = testobj.x - comp_op = operator(testobj.scalar, testobj.A) - res = operator(testobj.scalar, testobj.A(x)) + comp_op = operator(scalar, testobj.A) + res = operator(scalar, testobj.A(x)) assert comp_op.output_dtype == res.dtype np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)