Skip to content

Commit

Permalink
Expose the linalg namespace and include in status page (#581)
Browse files Browse the repository at this point in the history
* Add matmul, matrix_transpose, tensordot, vecdot to linalg namespace

* Move outer to linalg namespace

* Remove flip from list of unimplemented functions since it was added in #114

* Remove unstack from list of unimplemented functions since it was added in #575

* Add link to cumulative_sum PR

* Add linalg table to status page
  • Loading branch information
tomwhite authored Sep 24, 2024
1 parent d5b40b3 commit c1391c0
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 14 deletions.
36 changes: 33 additions & 3 deletions api_status.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Array API Coverage Implementation Status

Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [linear algebra extensions](https://data-apis.org/array-api/2022.12/extensions/linear_algebra_functions.html) and [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.
Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.

Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438).

Expand Down Expand Up @@ -67,7 +67,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `squeeze` | :white_check_mark: | | |
| | `stack` | :white_check_mark: | | |
| | `tile` | :x: | 2023.12 | |
| | `unstack` | :x: | 2023.12 | |
| | `unstack` | :white_check_mark: | 2023.12 | |
| Searching Functions | `argmax` | :white_check_mark: | | |
| | `argmin` | :white_check_mark: | | |
| | `nonzero` | :x: | | Shape is data dependent |
Expand All @@ -79,7 +79,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `unique_values` | :x: | | Shape is data dependent |
| Sorting Functions | `argsort` | :x: | | Not in Dask |
| | `sort` | :x: | | Not in Dask |
| Statistical Functions | `cumulative_sum` | :x: | 2023.12 | |
| Statistical Functions | `cumulative_sum` | :x: | 2023.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) |
| | `max` | :white_check_mark: | | |
| | `mean` | :white_check_mark: | | |
| | `min` | :white_check_mark: | | |
Expand All @@ -89,3 +89,33 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `var` | :x: | | Like `mean`, [#29](https://github.com/cubed-dev/cubed/issues/29) |
| Utility Functions | `all` | :white_check_mark: | | |
| | `any` | :white_check_mark: | | |

### Linear Algebra Extension

A few of the [linear algebra extension](https://data-apis.org/array-api/2022.12/extensions/linear_algebra_functions.html) functions are supported, as indicated in this table.

| Category | Object/Function | Implemented | Version | Notes |
| ------------------------ | ------------------- | ------------------ | ---------- | ---------------------------- |
| Linear Algebra Functions | `cholesky` | :x: | | |
| | `cross` | :x: | | |
| | `det` | :x: | | |
| | `diagonal` | :x: | | |
| | `eigh` | :x: | | |
| | `eigvalsh` | :x: | | |
| | `inv` | :x: | | |
| | `matmul` | :white_check_mark: | | |
| | `matrix_norm` | :x: | | |
| | `matrix_power` | :x: | | |
| | `matrix_rank` | :x: | | |
| | `matrix_transpose` | :white_check_mark: | | |
| | `outer` | :white_check_mark: | | |
| | `pinv` | :x: | | |
| | `qr` | :white_check_mark: | | |
| | `slogdet` | :x: | | |
| | `solve` | :x: | | |
| | `svd` | :x: | | |
| | `svdvals` | :x: | | |
| | `tensordot` | :white_check_mark: | | |
| | `trace` | :x: | | |
| | `vecdot` | :white_check_mark: | | |
| | `vectornorm` | :x: | | |
3 changes: 1 addition & 2 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,11 @@
from .array_api.linear_algebra_functions import (
matmul,
matrix_transpose,
outer,
tensordot,
vecdot,
)

__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"]
__all__ += ["matmul", "matrix_transpose", "tensordot", "vecdot"]

from .array_api.manipulation_functions import (
broadcast_arrays,
Expand Down
4 changes: 2 additions & 2 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@

__all__ += ["take"]

from .linear_algebra_functions import matmul, matrix_transpose, outer, tensordot, vecdot
from .linear_algebra_functions import matmul, matrix_transpose, tensordot, vecdot

__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"]
__all__ += ["matmul", "matrix_transpose", "tensordot", "vecdot"]

from .manipulation_functions import (
broadcast_arrays,
Expand Down
14 changes: 13 additions & 1 deletion cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from typing import NamedTuple

from cubed.array_api.array_object import Array

# These functions are in both the main and linalg namespaces
from cubed.array_api.linear_algebra_functions import ( # noqa: F401
matmul,
matrix_transpose,
tensordot,
vecdot,
)
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import general_blockwise, map_direct, merge_chunks
from cubed.core.ops import blockwise, general_blockwise, map_direct, merge_chunks
from cubed.utils import array_memory, get_item


def outer(x1, x2, /):
return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype)


class QRResult(NamedTuple):
Q: Array
R: Array
Expand Down
4 changes: 0 additions & 4 deletions cubed/array_api/linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ def matrix_transpose(x, /):
return permute_dims(x, axes)


def outer(x1, x2, /):
return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype)


def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
from cubed.array_api.statistical_functions import sum

Expand Down
2 changes: 1 addition & 1 deletion cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def test_matmul_modal(modal_executor):
def test_outer(spec, executor):
a = xp.asarray([0, 1, 2], chunks=2, spec=spec)
b = xp.asarray([10, 50, 100], chunks=2, spec=spec)
c = xp.outer(a, b)
c = xp.linalg.outer(a, b)
assert_array_equal(c.compute(executor=executor), np.outer([0, 1, 2], [10, 50, 100]))


Expand Down
1 change: 0 additions & 1 deletion docs/array-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ The following parts of the standard are not implemented:
| Array object | In-place Ops |
| Creation Functions | `from_dlpack` |
| Indexing | Boolean array |
| Manipulation Functions | `flip` |
| Searching Functions | `nonzero` |
| Set Functions | `unique_all` |
| | `unique_counts` |
Expand Down

0 comments on commit c1391c0

Please sign in to comment.