Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QR decomposition #577

Merged
merged 10 commits into from
Sep 23, 2024
6 changes: 6 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,9 @@
from .array_api.utility_functions import all, any

__all__ += ["all", "any"]

# extensions

from .array_api import linalg

__all__ += ["linalg"]
172 changes: 172 additions & 0 deletions cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import NamedTuple

from cubed.array_api.array_object import Array
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import general_blockwise, map_direct, merge_chunks
from cubed.utils import array_memory, get_item


class QRResult(NamedTuple):
Q: Array
R: Array


def qr(x, /, *, mode="reduced") -> QRResult:
if x.ndim != 2:
raise ValueError("qr requires x to have 2 dimensions.")

if mode != "reduced":
raise ValueError("qr only supports mode='reduced'")

if x.numblocks[1] > 1:
raise ValueError(
"qr only supports tall-and-skinny (single column chunk) arrays. "
"Consider rechunking so there is only a single column chunk."
)

return tsqr(x)


def tsqr(x) -> QRResult:
"""Direct Tall-and-Skinny QR algorithm

From:

Direct QR factorizations for tall-and-skinny matrices in MapReduce architectures
Austin R. Benson, David F. Gleich, James Demmel
Proceedings of the IEEE International Conference on Big Data, 2013
https://arxiv.org/abs/1301.1071
"""

# follows Algorithm 2 from Benson et al
Q1, R1 = _qr_first_step(x)

if _r1_is_too_big(R1):
R1 = _rechunk_r1(R1)
Q2, R2 = tsqr(R1)
else:
Q2, R2 = _qr_second_step(R1)

Q, R = _qr_third_step(Q1, Q2), R2

return QRResult(Q, R)


def _qr_first_step(A):
m, n = A.chunksize
k, _ = A.numblocks

# Q1 has same shape and chunks as A
R1_shape = (n * k, n)
R1_chunks = ((n,) * k, (n,))
# qr implementation creates internal array buffers
extra_projected_mem = A.chunkmem * 4
Q1, R1 = map_blocks_multiple_outputs(
nxp.linalg.qr,
A,
shapes=[A.shape, R1_shape],
dtypes=[nxp.float64, nxp.float64],
chunkss=[A.chunks, R1_chunks],
extra_projected_mem=extra_projected_mem,
)
return QRResult(Q1, R1)


def _r1_is_too_big(R1):
array_mem = array_memory(R1.dtype, R1.shape)
# conservative values for max_mem (4 copies, doubled to give some slack)
max_mem = (R1.spec.allowed_mem - R1.spec.reserved_mem) // (4 * 2)
return array_mem > max_mem


def _rechunk_r1(R1, split_every=4):
# expand R1's chunk size in axis 0 so that new R1 will be smaller by factor of split_every
if R1.numblocks[0] == 1:
raise ValueError(
"Can't expand R1 chunk size further. Try increasing allowed_mem"
)
chunks = (R1.chunksize[0] * split_every, R1.chunksize[1])
return merge_chunks(R1, chunks=chunks)


def _qr_second_step(R1):
R1_single = _merge_into_single_chunk(R1)

Q2_shape = R1.shape
Q2_chunks = Q2_shape # single chunk

n = R1.shape[1]
R2_shape = (n, n)
R2_chunks = R2_shape # single chunk
# qr implementation creates internal array buffers
extra_projected_mem = R1_single.chunkmem * 4
Q2, R2 = map_blocks_multiple_outputs(
nxp.linalg.qr,
R1_single,
shapes=[Q2_shape, R2_shape],
dtypes=[nxp.float64, nxp.float64],
chunkss=[Q2_chunks, R2_chunks],
extra_projected_mem=extra_projected_mem,
)
return QRResult(Q2, R2)


def _merge_into_single_chunk(x, split_every=4):
# do a tree merge along first axis
while x.numblocks[0] > 1:
chunks = (x.chunksize[0] * split_every,) + x.chunksize[1:]
x = merge_chunks(x, chunks)
return x


def _qr_third_step(Q1, Q2):
m, n = Q1.chunksize
k, _ = Q1.numblocks

Q1_shape = Q1.shape
Q1_chunks = Q1.chunks

Q2_chunks = ((n,) * k, (n,))
extra_projected_mem = 0
Q = map_direct(
_q_matmul,
Q1,
Q2,
shape=Q1_shape,
dtype=nxp.float64,
chunks=Q1_chunks,
extra_projected_mem=extra_projected_mem,
q1_chunks=Q1_chunks,
q2_chunks=Q2_chunks,
)
return Q


def _q_matmul(x, *arrays, q1_chunks=None, q2_chunks=None, block_id=None):
q1 = arrays[0].zarray[get_item(q1_chunks, block_id)]
# this array only has a single chunk, but we need to get a slice corresponding to q2_chunks
q2 = arrays[1].zarray[get_item(q2_chunks, block_id)]
return q1 @ q2


def map_blocks_multiple_outputs(
func,
*args,
shapes,
dtypes,
chunkss,
**kwargs,
):
def key_function(out_key):
return tuple((array.name,) + out_key[1:] for array in args)

return general_blockwise(
func,
key_function,
*args,
shapes=shapes,
dtypes=dtypes,
chunkss=chunkss,
target_stores=[None] * len(dtypes),
**kwargs,
)
10 changes: 9 additions & 1 deletion cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def _finalize(
compile_function: Optional[Decorator] = None,
array_names=None,
) -> "FinalizedPlan":
dag = self.optimize(optimize_function, array_names).dag if optimize_graph else self.dag
dag = (
self.optimize(optimize_function, array_names).dag
if optimize_graph
else self.dag
)
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(compile_function):
Expand Down Expand Up @@ -501,6 +505,10 @@ def num_arrays(self) -> int:
"""Return the number of arrays in this plan."""
return sum(d.get("type") == "array" for _, d in self.dag.nodes(data=True))

def num_primitive_ops(self) -> int:
"""Return the number of primitive operations in this plan."""
return len(list(visit_nodes(self.dag)))

def num_tasks(self, resume=None):
"""Return the number of tasks needed to execute this plan."""
tasks = 0
Expand Down
59 changes: 59 additions & 0 deletions cubed/tests/test_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose

import cubed
import cubed.array_api as xp
from cubed.core.plan import arrays_to_plan


def test_qr():
A = np.reshape(np.arange(32, dtype=np.float64), (16, 2))
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(4, 2)))

plan_unopt = arrays_to_plan(Q, R)._finalize()
assert plan_unopt.num_primitive_ops() == 4

Q, R = cubed.compute(Q, R)

assert_allclose(Q @ R, A, atol=1e-08)
assert_allclose(Q.T @ Q, np.eye(2, 2), atol=1e-08) # Q must be orthonormal
assert_allclose(R, np.triu(R), atol=1e-08) # R must be upper triangular


def test_qr_recursion():
A = np.reshape(np.arange(128, dtype=np.float64), (64, 2))

# find a memory setting where recursion happens
found = False
for factor in range(4, 16):
spec = cubed.Spec(allowed_mem=128 * factor, reserved_mem=0)

try:
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(8, 2), spec=spec))

found = True
plan_unopt = arrays_to_plan(Q, R)._finalize()
assert plan_unopt.num_primitive_ops() > 4 # more than without recursion

Q, R = cubed.compute(Q, R)

assert_allclose(Q @ R, A, atol=1e-08)
assert_allclose(Q.T @ Q, np.eye(2, 2), atol=1e-08) # Q must be orthonormal
assert_allclose(R, np.triu(R), atol=1e-08) # R must be upper triangular

break

except ValueError:
pass # not enough memory

assert found


def test_qr_chunking():
A = xp.ones((32, 4), chunks=(4, 2))
with pytest.raises(
ValueError,
match=r"qr only supports tall-and-skinny \(single column chunk\) arrays.",
):
xp.linalg.qr(A)
15 changes: 14 additions & 1 deletion cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ def test_sum_partial_reduce(tmp_path, spec, executor):
run_operation(tmp_path, executor, "sum_partial_reduce", b)


# Linear algebra extension


@pytest.mark.slow
def test_qr(tmp_path, spec, executor):
a = cubed.random.random(
(40000, 1000), chunks=(5000, 1000), spec=spec
) # 40MB chunks
q, r = xp.linalg.qr(a)
# don't optimize graph so we use as much memory as possible (reading from Zarr)
run_operation(tmp_path, executor, "qr", q, r, optimize_graph=False)


# Multiple outputs


Expand Down Expand Up @@ -362,7 +375,7 @@ def run_operation(
# )
hist = HistoryCallback()
mem_warn = MemoryWarningCallback()
memray = MemrayCallback()
memray = MemrayCallback(mem_threshold=30_000_000)
# use None for each store to write to temporary zarr
cubed.store(
results,
Expand Down
Loading