Skip to content

Commit

Permalink
Resolve #445 (#459)
Browse files Browse the repository at this point in the history
* Resolve #442

* Add utility function for checking if an object is a scalar of an array of unit size

* Add tests for operator mult/div by singleton arrays

* Modify conditional for scalar equivalence and corresponding test function

* Typo fix

* Simplify conditional

* Add an assertion
  • Loading branch information
bwohlberg authored Oct 24, 2023
1 parent b8c8bab commit ea2aaf0
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest_ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
mamba install -c astra-toolbox astra-toolbox
mamba install -c conda-forge pyyaml
pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version
pip install bm4d>=4.0.0
pip install bm3d>=4.0.0
pip install bm4d>=4.2.2
pip install "ray[tune]>=2.0.0"
pip install hyperopt
Expand Down
15 changes: 15 additions & 0 deletions scico/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,18 @@ def complex_dtype(dtype: DType) -> DType:
"""

return (snp.zeros(1, dtype) + 1j).dtype


def is_scalar_equiv(s: Any) -> bool:
"""Determine whether an object is a scalar or is scalar-equivalent.
Determine whether an object is a scalar or a singleton array.
Args:
s: Object to be tested.
Returns:
``True`` if the object is a scalar or a singleton array,
otherwise ``False``.
"""
return snp.isscalar(s) or (isinstance(s, jax.Array) and s.ndim == 0)
2 changes: 1 addition & 1 deletion scico/operator/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _wrap_mul_div_scalar(func: Callable) -> Callable:

@wraps(func)
def wrapper(a, b):
if np.isscalar(b) or isinstance(b, jax.core.Tracer):
if snp.util.is_scalar_equiv(b):
return func(a, b)

raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")
Expand Down
19 changes: 11 additions & 8 deletions scico/test/linop/test_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from scico.random import randn
from scico.typing import PRNGKey

SCALARS = (2, 1e0, snp.array(1.0))


def adjoint_test(
A: linop.LinearOperator,
Expand Down Expand Up @@ -67,8 +69,7 @@ def __init__(self, dtype):

self.x, key = randn((N,), dtype=dtype, key=key)
self.y, key = randn((M,), dtype=dtype, key=key)
scalar, key = randn((1,), dtype=dtype, key=key)
self.scalar = scalar.item()

self.Ao = AbsMatOp(self.A)
self.Bo = AbsMatOp(self.B)
self.Co = AbsMatOp(self.C)
Expand Down Expand Up @@ -101,9 +102,10 @@ def test_binary_op(testobj, operator):


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_left(testobj, operator):
comp_mat = operator(testobj.A, testobj.scalar)
comp_op = operator(testobj.Ao, testobj.scalar)
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_left(testobj, operator, scalar):
comp_mat = operator(testobj.A, scalar)
comp_op = operator(testobj.Ao, scalar)
assert isinstance(comp_op, linop.LinearOperator) # Ensure we don't get a Map
assert comp_op.input_dtype == testobj.A.dtype
np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)
Expand All @@ -112,11 +114,12 @@ def test_scalar_left(testobj, operator):


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_right(testobj, operator):
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_right(testobj, operator, scalar):
if operator == op.truediv:
pytest.xfail("scalar / LinearOperator is not supported")
comp_mat = operator(testobj.scalar, testobj.A)
comp_op = operator(testobj.scalar, testobj.Ao)
comp_mat = operator(scalar, testobj.A)
comp_op = operator(scalar, testobj.Ao)
assert comp_op.input_dtype == testobj.A.dtype
np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)

Expand Down
9 changes: 9 additions & 0 deletions scico/test/test_numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_complex_dtype,
is_nested,
is_real_dtype,
is_scalar_equiv,
no_nan_divide,
parse_axes,
real_dtype,
Expand Down Expand Up @@ -192,3 +193,11 @@ def test_broadcast_nested_shapes():
(1, 2, 3),
(1, 7, 4, 3),
)


def test_is_scalar_equiv():
assert is_scalar_equiv(1e0)
assert is_scalar_equiv(snp.array(1e0))
assert is_scalar_equiv(snp.sum(snp.zeros(1)))
assert not is_scalar_equiv(snp.array([1e0]))
assert not is_scalar_equiv(snp.array([1e0, 2e0]))
18 changes: 10 additions & 8 deletions scico/test/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from scico.operator import Abs, Angle, Exp, Operator, operator_from_function
from scico.random import randn

SCALARS = (2, 1e0, snp.array(1.0))


class AbsOperator(Operator):
def _eval(self, x):
Expand Down Expand Up @@ -43,8 +45,6 @@ def __init__(self, dtype):

self.mat = randn(self.A.input_shape, dtype=dtype, key=key)
self.x, key = randn((N,), dtype=dtype, key=key)
scalar, key = randn((1,), dtype=dtype, key=key)
self.scalar = scalar.item() # jax array -> actual scalar

self.z, key = randn((2 * N,), dtype=dtype, key=key)

Expand Down Expand Up @@ -85,21 +85,23 @@ def test_binary_op_same(testobj, operator):


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_left(testobj, operator):
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_left(testobj, operator, scalar):
x = testobj.x
comp_op = operator(testobj.A, testobj.scalar)
res = operator(testobj.A(x), testobj.scalar)
comp_op = operator(testobj.A, scalar)
res = operator(testobj.A(x), scalar)
assert comp_op.output_dtype == res.dtype
np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_right(testobj, operator):
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_right(testobj, operator, scalar):
if operator == op.truediv:
pytest.xfail("scalar / Operator is not supported")
x = testobj.x
comp_op = operator(testobj.scalar, testobj.A)
res = operator(testobj.scalar, testobj.A(x))
comp_op = operator(scalar, testobj.A)
res = operator(scalar, testobj.A(x))
assert comp_op.output_dtype == res.dtype
np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)

Expand Down

0 comments on commit ea2aaf0

Please sign in to comment.