Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement var and std using a numerically stable parallel algorithm #596

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
- name: Run tests
run: |
# exclude tests that rely on structured types since JAX doesn't support these
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not object_dtype"
pytest -k "not argmax and not argmin and not mean and not std and not var and not apply_reduction and not broadcast_trick and not groupby and not object_dtype"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: True
4 changes: 2 additions & 2 deletions api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `mean` | :white_check_mark: | | |
| | `min` | :white_check_mark: | | |
| | `prod` | :white_check_mark: | | |
| | `std` | :x: | | Like `mean`, [#29](https://github.com/cubed-dev/cubed/issues/29) |
| | `std` | :white_check_mark: | | |
| | `sum` | :white_check_mark: | | |
| | `var` | :x: | | Like `mean`, [#29](https://github.com/cubed-dev/cubed/issues/29) |
| | `var` | :white_check_mark: | | |
| Utility Functions | `all` | :white_check_mark: | | |
| | `any` | :white_check_mark: | | |

Expand Down
4 changes: 2 additions & 2 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@

__all__ += ["argmax", "argmin", "where"]

from .array_api.statistical_functions import max, mean, min, prod, sum
from .array_api.statistical_functions import max, mean, min, prod, std, sum, var

__all__ += ["max", "mean", "min", "prod", "sum"]
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]

from .array_api.utility_functions import all, any

Expand Down
4 changes: 2 additions & 2 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@

__all__ += ["argmax", "argmin", "where"]

from .statistical_functions import max, mean, min, prod, sum
from .statistical_functions import max, mean, min, prod, sum, std, var

__all__ += ["max", "mean", "min", "prod", "sum"]
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]

from .utility_functions import all, any

Expand Down
82 changes: 82 additions & 0 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
int64,
uint64,
)
from cubed.array_api.elementwise_functions import sqrt
from cubed.backend_array_api import namespace as nxp
from cubed.core import reduction

Expand Down Expand Up @@ -148,6 +149,18 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
)


def std(x, /, *, axis=None, correction=0.0, keepdims=False, split_every=None):
return sqrt(
var(
x,
axis=axis,
correction=correction,
keepdims=keepdims,
split_every=split_every,
)
)


def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
Expand Down Expand Up @@ -175,3 +188,72 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def var(
x,
/,
*,
axis=None,
correction=0.0,
keepdims=False,
split_every=None,
):
# This implementation follows https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
dtype = x.dtype
intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction)
return reduction(
x,
_var_func,
combine_func=_var_combine,
aggregate_func=_var_aggregate,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def _var_func(a, correction=None, **kwargs):
dtype = dict(kwargs.pop("dtype"))
n = _numel(a, dtype=dtype["n"], **kwargs)
mu = nxp.mean(a, dtype=dtype["mu"], **kwargs)
M2 = nxp.sum(nxp.square(a - mu), dtype=dtype["M2"], **kwargs)
return {"n": n, "mu": mu, "M2": M2}


def _var_combine(a, axis=None, correction=None, **kwargs):
# _var_combine is called by _partial_reduce which concatenates along the first axis
axis = axis[0]
if a["n"].shape[axis] == 1: # nothing to combine
return a
if a["n"].shape[axis] != 2:
raise ValueError(f"Expected two elements in {axis} axis to combine")

n_a = nxp.take(a["n"], 0, axis=axis)
n_b = nxp.take(a["n"], 1, axis=axis)
mu_a = nxp.take(a["mu"], 0, axis=axis)
mu_b = nxp.take(a["mu"], 1, axis=axis)
M2_a = nxp.take(a["M2"], 0, axis=axis)
M2_b = nxp.take(a["M2"], 1, axis=axis)

n_ab = n_a + n_b
delta = mu_b - mu_a
mu_ab = (n_a * mu_a + n_b * mu_b) / n_ab
M2_ab = M2_a + M2_b + delta**2 * n_a * n_b / n_ab

n = nxp.expand_dims(n_ab, axis=axis)
mu = nxp.expand_dims(mu_ab, axis=axis)
M2 = nxp.expand_dims(M2_ab, axis=axis)

return {"n": n, "mu": mu, "M2": M2}


def _var_aggregate(a, correction=None, **kwargs):
return nxp.divide(a["M2"], a["n"] - correction)
4 changes: 3 additions & 1 deletion cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,9 @@ def reduction(

# aggregate final chunks
if aggregate_func is not None:
result = map_blocks(aggregate_func, result, dtype=dtype)
result = map_blocks(
partial(aggregate_func, **(extra_func_kwargs or {})), result, dtype=dtype
)

if not keepdims:
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
Expand Down
40 changes: 40 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,46 @@ def test_sum_axis_0(spec, executor):
assert_array_equal(b.compute(executor=executor), np.array([12, 15, 18]))


@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
@pytest.mark.parametrize("correction", [0.0, 1.0])
@pytest.mark.parametrize("keepdims", [False, True])
def test_var(spec, axis, correction, keepdims):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.var(a, axis=axis, correction=correction, keepdims=keepdims)
assert_array_equal(
b.compute(optimize_graph=False),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(
axis=axis, ddof=correction, keepdims=keepdims
),
)


@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
@pytest.mark.parametrize("correction", [0.0, 1.0])
@pytest.mark.parametrize("keepdims", [False, True])
def test_std(spec, axis, correction, keepdims):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.std(a, axis=axis, correction=correction, keepdims=keepdims)
assert_array_equal(
b.compute(),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).std(
axis=axis, ddof=correction, keepdims=keepdims
),
)


def test_var__poorly_conditioned(spec):
# from https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Example
npa = np.array([4.0, 7.0, 13.0, 16.0]) + 1e9
a = xp.asarray(npa, chunks=2, spec=spec)
b = xp.var(a, axis=0)
assert_array_equal(b.compute(), npa.var(axis=0))


# Utility functions


Expand Down
2 changes: 1 addition & 1 deletion cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims):
return {"n": n, "total": total}


def _mean_groupby_aggregate(a):
def _mean_groupby_aggregate(a, **kwargs):
return nxp.divide(a["total"], a["n"])


Expand Down
Loading