Skip to content

Commit

Permalink
Modify conditional for scalar equivalence and corresponding test func…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
bwohlberg committed Oct 23, 2023
1 parent 13eafc2 commit 7da14bf
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
4 changes: 1 addition & 3 deletions scico/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

import jax

import jaxlib

import scico.numpy as snp
from scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, Shape

Expand Down Expand Up @@ -339,5 +337,5 @@ def is_scalar_equiv(s: Any) -> bool:
return (
snp.isscalar(s)
or isinstance(s, jax.core.Tracer)
or (isinstance(s, jaxlib.xla_extension.ArrayImpl) and s.size == 1)
or (isinstance(s, jax.Array) and s.ndim == 0)
)
2 changes: 1 addition & 1 deletion scico/test/test_numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,5 @@ def test_broadcast_nested_shapes():
def test_is_scalar_equiv():
assert is_scalar_equiv(1e0)
assert is_scalar_equiv(snp.array(1e0))
assert is_scalar_equiv(snp.array([1e0]))
assert not is_scalar_equiv(snp.array([1e0]))
assert not is_scalar_equiv(snp.array([1e0, 2e0]))

0 comments on commit 7da14bf

Please sign in to comment.