Skip to content

Commit

Permalink
Implement var and std using a numerically stable parallel algorithm
Browse files Browse the repository at this point in the history
Passes cubed/tests/test_array_api.py::test_var[False-0.0-0]

Test example of poorly conditioned case for var
  • Loading branch information
tomwhite committed Oct 27, 2024
1 parent c332431 commit 32164db
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 11 deletions.
84 changes: 78 additions & 6 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
int64,
uint64,
)
from cubed.array_api.elementwise_functions import sqrt, square, subtract
from cubed.array_api.elementwise_functions import sqrt
from cubed.backend_array_api import namespace as nxp
from cubed.core import reduction

Expand Down Expand Up @@ -149,6 +149,18 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
)


def std(x, /, *, axis=None, correction=0.0, keepdims=False, split_every=None):
return sqrt(
var(
x,
axis=axis,
correction=correction,
keepdims=keepdims,
split_every=split_every,
)
)


def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
Expand Down Expand Up @@ -178,10 +190,70 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
)


def var(x, /, *, axis=None, correction=0.0, keepdims=False):
mu = mean(x, axis=axis, keepdims=True)
return mean(square(subtract(x, mu)), axis=axis, keepdims=keepdims)
def var(
x,
/,
*,
axis=None,
correction=0.0,
keepdims=False,
split_every=None,
):
# This implementation follows https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
dtype = x.dtype
intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction)
return reduction(
x,
_var_func,
combine_func=_var_combine,
aggregate_func=_var_aggregate,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def _var_func(a, correction=None, **kwargs):
dtype = dict(kwargs.pop("dtype"))
n = _numel(a, dtype=dtype["n"], **kwargs)
mu = nxp.mean(a, dtype=dtype["mu"], **kwargs)
M2 = nxp.sum(nxp.square(a - mu), dtype=dtype["M2"], **kwargs)
return {"n": n, "mu": mu, "M2": M2}


def _var_combine(a, axis=None, correction=None, **kwargs):
# _var_combine is called by _partial_reduce which concatenates along the first axis
axis = axis[0]
if a["n"].shape[axis] == 1: # nothing to combine
return a
if a["n"].shape[axis] != 2:
raise ValueError(f"Expected two elements in {axis} axis to combine")

n_a = nxp.take(a["n"], 0, axis=axis)
n_b = nxp.take(a["n"], 1, axis=axis)
mu_a = nxp.take(a["mu"], 0, axis=axis)
mu_b = nxp.take(a["mu"], 1, axis=axis)
M2_a = nxp.take(a["M2"], 0, axis=axis)
M2_b = nxp.take(a["M2"], 1, axis=axis)

n_ab = n_a + n_b
delta = mu_b - mu_a
mu_ab = (n_a * mu_a + n_b * mu_b) / n_ab
M2_ab = M2_a + M2_b + delta**2 * n_a * n_b / n_ab

n = nxp.expand_dims(n_ab, axis=axis)
mu = nxp.expand_dims(mu_ab, axis=axis)
M2 = nxp.expand_dims(M2_ab, axis=axis)

return {"n": n, "mu": mu, "M2": M2}


def std(x, /, *, axis=None, correction=0.0, keepdims=False):
return sqrt(var(x, axis=axis, correction=correction, keepdims=keepdims))
def _var_aggregate(a, correction=None, **kwargs):
return nxp.divide(a["M2"], a["n"] - correction)
4 changes: 3 additions & 1 deletion cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,9 @@ def reduction(

# aggregate final chunks
if aggregate_func is not None:
result = map_blocks(aggregate_func, result, dtype=dtype)
result = map_blocks(
partial(aggregate_func, **(extra_func_kwargs or {})), result, dtype=dtype
)

if not keepdims:
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
Expand Down
14 changes: 11 additions & 3 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def test_sum_axis_0(spec, executor):
assert_array_equal(b.compute(executor=executor), np.array([12, 15, 18]))


@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
@pytest.mark.parametrize("correction", [0.0, 1.0])
@pytest.mark.parametrize("keepdims", [False, True])
def test_var(spec, axis, correction, keepdims):
Expand All @@ -736,14 +736,14 @@ def test_var(spec, axis, correction, keepdims):
)
b = xp.var(a, axis=axis, correction=correction, keepdims=keepdims)
assert_array_equal(
b.compute(),
b.compute(optimize_graph=False),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(
axis=axis, ddof=correction, keepdims=keepdims
),
)


@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
@pytest.mark.parametrize("correction", [0.0, 1.0])
@pytest.mark.parametrize("keepdims", [False, True])
def test_std(spec, axis, correction, keepdims):
Expand All @@ -759,6 +759,14 @@ def test_std(spec, axis, correction, keepdims):
)


def test_var__poorly_conditioned(spec):
# from https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Example
npa = np.array([4.0, 7.0, 13.0, 16.0]) + 1e9
a = xp.asarray(npa, chunks=2, spec=spec)
b = xp.var(a, axis=0)
assert_array_equal(b.compute(), npa.var(axis=0))


# Utility functions


Expand Down
2 changes: 1 addition & 1 deletion cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims):
return {"n": n, "total": total}


def _mean_groupby_aggregate(a):
def _mean_groupby_aggregate(a, **kwargs):
return nxp.divide(a["total"], a["n"])


Expand Down

0 comments on commit 32164db

Please sign in to comment.