From 82590fbf9ae2ab03df529659950194a265b25e8e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 8 Sep 2023 22:36:58 -0600 Subject: [PATCH] Resolve #442 --- scico/operator/_operator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py index 2fc01fe7d..a99de4231 100644 --- a/scico/operator/_operator.py +++ b/scico/operator/_operator.py @@ -20,6 +20,8 @@ import jax.numpy as jnp from jax.dtypes import result_type +import jaxlib + import scico.numpy as snp from scico.numpy import Array, BlockArray from scico.numpy.util import is_nested, shape_to_size @@ -47,7 +49,11 @@ def _wrap_mul_div_scalar(func: Callable) -> Callable: @wraps(func) def wrapper(a, b): - if np.isscalar(b) or isinstance(b, jax.core.Tracer): + if ( + snp.isscalar(b) + or isinstance(b, jax.core.Tracer) + or (isinstance(b, jaxlib.xla_extension.ArrayImpl) and b.size == 1) + ): return func(a, b) raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")