Skip to content

Commit

Permalink
Allow bool in sum and prod
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Aug 12, 2024
1 parent 8c4ae55 commit 966181d
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

from cubed.array_api.dtypes import (
_boolean_dtypes,
_numeric_dtypes,
_real_floating_dtypes,
_real_numeric_dtypes,
Expand Down Expand Up @@ -124,10 +125,13 @@ def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
def prod(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in prod")
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
if dtype is None:
if x.dtype in _signed_integer_dtypes:
if x.dtype in _boolean_dtypes:
dtype = int64
elif x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
Expand All @@ -153,10 +157,13 @@ def prod(
def sum(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in sum")
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
if dtype is None:
if x.dtype in _signed_integer_dtypes:
if x.dtype in _boolean_dtypes:
dtype = int64
elif x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
Expand Down

0 comments on commit 966181d

Please sign in to comment.