diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index b3e8abe9..d7363241 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -399,12 +399,18 @@ def __int__(self, /): # Utility methods def _check_allowed_dtypes(self, other, dtype_category, op): - if self.dtype not in _dtype_categories[dtype_category]: + if ( + dtype_category != "all" + and self.dtype not in _dtype_categories[dtype_category] + ): raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") if isinstance(other, (int, complex, float, bool)): other = self._promote_scalar(other) elif isinstance(other, CoreArray): - if other.dtype not in _dtype_categories[dtype_category]: + if ( + dtype_category != "all" + and other.dtype not in _dtype_categories[dtype_category] + ): raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") else: return NotImplemented diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 78ff6ae2..7ee6525e 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -1,6 +1,7 @@ import math from cubed.array_api.dtypes import ( + _boolean_dtypes, _numeric_dtypes, _real_floating_dtypes, _real_numeric_dtypes, @@ -124,10 +125,13 @@ def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None) def prod( x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None ): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in prod") + # boolean is allowed by numpy + if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: + raise TypeError("Only numeric or boolean dtypes are allowed in prod") if dtype is None: - if x.dtype in _signed_integer_dtypes: + if x.dtype in _boolean_dtypes: + dtype = int64 + elif x.dtype in _signed_integer_dtypes: dtype = int64 elif x.dtype in _unsigned_integer_dtypes: dtype = uint64 @@ -153,10 +157,13 @@ def prod( def sum( x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None ): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in sum") + # boolean is allowed by numpy + if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: + raise TypeError("Only numeric or boolean dtypes are allowed in sum") if dtype is None: - if x.dtype in _signed_integer_dtypes: + if x.dtype in _boolean_dtypes: + dtype = int64 + elif x.dtype in _signed_integer_dtypes: dtype = int64 elif x.dtype in _unsigned_integer_dtypes: dtype = uint64 diff --git a/cubed/tests/test_types.py b/cubed/tests/test_types.py new file mode 100644 index 00000000..b21eca6e --- /dev/null +++ b/cubed/tests/test_types.py @@ -0,0 +1,10 @@ +from numpy.testing import assert_array_equal + +import cubed.array_api as xp + + +# This is less strict than the spec, but is supported by implementations like NumPy +def test_prod_sum_bool(): + a = xp.ones((2,), dtype=xp.bool) + assert_array_equal(xp.prod(a).compute(), xp.asarray([1], dtype=xp.int64)) + assert_array_equal(xp.sum(a).compute(), xp.asarray([2], dtype=xp.int64))