Skip to content

Commit

Permalink
Add tests for operator mult/div by singleton arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 23, 2023
1 parent 08dbf5c commit 13eafc2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
19 changes: 11 additions & 8 deletions scico/test/linop/test_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from scico.random import randn
from scico.typing import PRNGKey

SCALARS = (2, 1e0, snp.array(1.0))


def adjoint_test(
A: linop.LinearOperator,
Expand Down Expand Up @@ -67,8 +69,7 @@ def __init__(self, dtype):

self.x, key = randn((N,), dtype=dtype, key=key)
self.y, key = randn((M,), dtype=dtype, key=key)
scalar, key = randn((1,), dtype=dtype, key=key)
self.scalar = scalar.item()

self.Ao = AbsMatOp(self.A)
self.Bo = AbsMatOp(self.B)
self.Co = AbsMatOp(self.C)
Expand Down Expand Up @@ -101,9 +102,10 @@ def test_binary_op(testobj, operator):


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_left(testobj, operator):
comp_mat = operator(testobj.A, testobj.scalar)
comp_op = operator(testobj.Ao, testobj.scalar)
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_left(testobj, operator, scalar):
comp_mat = operator(testobj.A, scalar)
comp_op = operator(testobj.Ao, scalar)
assert isinstance(comp_op, linop.LinearOperator) # Ensure we don't get a Map
assert comp_op.input_dtype == testobj.A.dtype
np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)
Expand All @@ -112,11 +114,12 @@ def test_scalar_left(testobj, operator):


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_right(testobj, operator):
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_right(testobj, operator, scalar):
if operator == op.truediv:
pytest.xfail("scalar / LinearOperator is not supported")
comp_mat = operator(testobj.scalar, testobj.A)
comp_op = operator(testobj.scalar, testobj.Ao)
comp_mat = operator(scalar, testobj.A)
comp_op = operator(scalar, testobj.Ao)
assert comp_op.input_dtype == testobj.A.dtype
np.testing.assert_allclose(comp_mat @ testobj.x, comp_op @ testobj.x, rtol=5e-5)

Expand Down
18 changes: 10 additions & 8 deletions scico/test/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from scico.operator import Abs, Angle, Exp, Operator, operator_from_function
from scico.random import randn

SCALARS = (2, 1e0, snp.array(1.0))


class AbsOperator(Operator):
def _eval(self, x):
Expand Down Expand Up @@ -43,8 +45,6 @@ def __init__(self, dtype):

self.mat = randn(self.A.input_shape, dtype=dtype, key=key)
self.x, key = randn((N,), dtype=dtype, key=key)
scalar, key = randn((1,), dtype=dtype, key=key)
self.scalar = scalar.item() # jax array -> actual scalar

self.z, key = randn((2 * N,), dtype=dtype, key=key)

Expand Down Expand Up @@ -85,21 +85,23 @@ def test_binary_op_same(testobj, operator):


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_left(testobj, operator):
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_left(testobj, operator, scalar):
x = testobj.x
comp_op = operator(testobj.A, testobj.scalar)
res = operator(testobj.A(x), testobj.scalar)
comp_op = operator(testobj.A, scalar)
res = operator(testobj.A(x), scalar)
assert comp_op.output_dtype == res.dtype
np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)


@pytest.mark.parametrize("operator", [op.mul, op.truediv])
def test_scalar_right(testobj, operator):
@pytest.mark.parametrize("scalar", SCALARS)
def test_scalar_right(testobj, operator, scalar):
if operator == op.truediv:
pytest.xfail("scalar / Operator is not supported")
x = testobj.x
comp_op = operator(testobj.scalar, testobj.A)
res = operator(testobj.scalar, testobj.A(x))
comp_op = operator(scalar, testobj.A)
res = operator(scalar, testobj.A(x))
assert comp_op.output_dtype == res.dtype
np.testing.assert_allclose(comp_op(x), res, rtol=5e-5)

Expand Down

0 comments on commit 13eafc2

Please sign in to comment.