Skip to content

Commit

Permalink
Implement var and std using a numerically stable parallel algorit…
Browse files Browse the repository at this point in the history
…hm (#596)

* Implementation of `std` and `var`.

Fixes #29. Here, I use existing cubed operations to implement `var` and `std`. Please let me know if I should reimplement the primitives as pure reductions.

List order follows import order.

Add correction

* Parameterize tests

* Add test for correction (failing)

* Implement `var` and `std` using a numerically stable parallel algorithm

Passes cubed/tests/test_array_api.py::test_var[False-0.0-0]

Test example of poorly conditioned case for var

* Update status page

* Exclude std and var from JAX tests

---------

Co-authored-by: Alex Merose <[email protected]>
  • Loading branch information
tomwhite and alxmrs authored Oct 28, 2024
1 parent a0aa47f commit bc44d38
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
- name: Run tests
run: |
# exclude tests that rely on structured types since JAX doesn't support these
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not object_dtype"
pytest -k "not argmax and not argmin and not mean and not std and not var and not apply_reduction and not broadcast_trick and not groupby and not object_dtype"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: True
4 changes: 2 additions & 2 deletions api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `mean` | :white_check_mark: | | |
| | `min` | :white_check_mark: | | |
| | `prod` | :white_check_mark: | | |
| | `std` | :x: | | Like `mean`, [#29](https://github.com/cubed-dev/cubed/issues/29) |
| | `std` | :white_check_mark: | | |
| | `sum` | :white_check_mark: | | |
| | `var` | :x: | | Like `mean`, [#29](https://github.com/cubed-dev/cubed/issues/29) |
| | `var` | :white_check_mark: | | |
| Utility Functions | `all` | :white_check_mark: | | |
| | `any` | :white_check_mark: | | |

Expand Down
4 changes: 2 additions & 2 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@

__all__ += ["argmax", "argmin", "where"]

from .array_api.statistical_functions import max, mean, min, prod, sum
from .array_api.statistical_functions import max, mean, min, prod, std, sum, var

__all__ += ["max", "mean", "min", "prod", "sum"]
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]

from .array_api.utility_functions import all, any

Expand Down
4 changes: 2 additions & 2 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@

__all__ += ["argmax", "argmin", "where"]

from .statistical_functions import max, mean, min, prod, sum
from .statistical_functions import max, mean, min, prod, sum, std, var

__all__ += ["max", "mean", "min", "prod", "sum"]
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]

from .utility_functions import all, any

Expand Down
82 changes: 82 additions & 0 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
int64,
uint64,
)
from cubed.array_api.elementwise_functions import sqrt
from cubed.backend_array_api import namespace as nxp
from cubed.core import reduction

Expand Down Expand Up @@ -148,6 +149,18 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
)


def std(x, /, *, axis=None, correction=0.0, keepdims=False, split_every=None):
return sqrt(
var(
x,
axis=axis,
correction=correction,
keepdims=keepdims,
split_every=split_every,
)
)


def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
Expand Down Expand Up @@ -175,3 +188,72 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def var(
x,
/,
*,
axis=None,
correction=0.0,
keepdims=False,
split_every=None,
):
# This implementation follows https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
dtype = x.dtype
intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction)
return reduction(
x,
_var_func,
combine_func=_var_combine,
aggregate_func=_var_aggregate,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def _var_func(a, correction=None, **kwargs):
dtype = dict(kwargs.pop("dtype"))
n = _numel(a, dtype=dtype["n"], **kwargs)
mu = nxp.mean(a, dtype=dtype["mu"], **kwargs)
M2 = nxp.sum(nxp.square(a - mu), dtype=dtype["M2"], **kwargs)
return {"n": n, "mu": mu, "M2": M2}


def _var_combine(a, axis=None, correction=None, **kwargs):
# _var_combine is called by _partial_reduce which concatenates along the first axis
axis = axis[0]
if a["n"].shape[axis] == 1: # nothing to combine
return a
if a["n"].shape[axis] != 2:
raise ValueError(f"Expected two elements in {axis} axis to combine")

n_a = nxp.take(a["n"], 0, axis=axis)
n_b = nxp.take(a["n"], 1, axis=axis)
mu_a = nxp.take(a["mu"], 0, axis=axis)
mu_b = nxp.take(a["mu"], 1, axis=axis)
M2_a = nxp.take(a["M2"], 0, axis=axis)
M2_b = nxp.take(a["M2"], 1, axis=axis)

n_ab = n_a + n_b
delta = mu_b - mu_a
mu_ab = (n_a * mu_a + n_b * mu_b) / n_ab
M2_ab = M2_a + M2_b + delta**2 * n_a * n_b / n_ab

n = nxp.expand_dims(n_ab, axis=axis)
mu = nxp.expand_dims(mu_ab, axis=axis)
M2 = nxp.expand_dims(M2_ab, axis=axis)

return {"n": n, "mu": mu, "M2": M2}


def _var_aggregate(a, correction=None, **kwargs):
return nxp.divide(a["M2"], a["n"] - correction)
4 changes: 3 additions & 1 deletion cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,9 @@ def reduction(

# aggregate final chunks
if aggregate_func is not None:
result = map_blocks(aggregate_func, result, dtype=dtype)
result = map_blocks(
partial(aggregate_func, **(extra_func_kwargs or {})), result, dtype=dtype
)

if not keepdims:
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
Expand Down
40 changes: 40 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,46 @@ def test_sum_axis_0(spec, executor):
assert_array_equal(b.compute(executor=executor), np.array([12, 15, 18]))


@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
@pytest.mark.parametrize("correction", [0.0, 1.0])
@pytest.mark.parametrize("keepdims", [False, True])
def test_var(spec, axis, correction, keepdims):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.var(a, axis=axis, correction=correction, keepdims=keepdims)
assert_array_equal(
b.compute(optimize_graph=False),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(
axis=axis, ddof=correction, keepdims=keepdims
),
)


@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
@pytest.mark.parametrize("correction", [0.0, 1.0])
@pytest.mark.parametrize("keepdims", [False, True])
def test_std(spec, axis, correction, keepdims):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.std(a, axis=axis, correction=correction, keepdims=keepdims)
assert_array_equal(
b.compute(),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).std(
axis=axis, ddof=correction, keepdims=keepdims
),
)


def test_var__poorly_conditioned(spec):
# from https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Example
npa = np.array([4.0, 7.0, 13.0, 16.0]) + 1e9
a = xp.asarray(npa, chunks=2, spec=spec)
b = xp.var(a, axis=0)
assert_array_equal(b.compute(), npa.var(axis=0))


# Utility functions


Expand Down
2 changes: 1 addition & 1 deletion cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims):
return {"n": n, "total": total}


def _mean_groupby_aggregate(a):
def _mean_groupby_aggregate(a, **kwargs):
return nxp.divide(a["total"], a["n"])


Expand Down

0 comments on commit bc44d38

Please sign in to comment.