Skip to content

Commit

Permalink
QR recursion improvements and test
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Sep 19, 2024
1 parent e6849a9 commit 76cbe5c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
10 changes: 7 additions & 3 deletions cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def _r1_is_too_big(R1):
return array_mem > max_mem


def _rechunk_r1(R1, split_every=10):
def _rechunk_r1(R1, split_every=4):
# expand R1's chunk size in axis 0 so that new R1 will be smaller by factor of split_every
if R1.numblocks[0] == 1:
raise ValueError(
"Can't expand R1 chunk size further. Try increasing allowed_mem"
)
chunks = (R1.chunksize[0] * split_every, R1.chunksize[1])
return merge_chunks(R1, chunks=chunks)

Expand All @@ -107,10 +111,10 @@ def _qr_second_step(R1):
return QRResult(Q2, R2)


def _merge_into_single_chunk(x, split_every=10):
def _merge_into_single_chunk(x, split_every=4):
# do a tree merge along first axis
while x.numblocks[0] > 1:
chunks = (min(x.chunksize[0] * split_every, x.shape[0]),) + x.chunksize[1:]
chunks = (x.chunksize[0] * split_every,) + x.chunksize[1:]
x = merge_chunks(x, chunks)
return x

Expand Down
10 changes: 9 additions & 1 deletion cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def _finalize(
compile_function: Optional[Decorator] = None,
array_names=None,
) -> "FinalizedPlan":
dag = self.optimize(optimize_function, array_names).dag if optimize_graph else self.dag
dag = (
self.optimize(optimize_function, array_names).dag
if optimize_graph
else self.dag
)
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(compile_function):
Expand Down Expand Up @@ -501,6 +505,10 @@ def num_arrays(self) -> int:
"""Return the number of arrays in this plan."""
return sum(d.get("type") == "array" for _, d in self.dag.nodes(data=True))

def num_primitive_ops(self) -> int:
"""Return the number of primitive operations in this plan."""
return len(list(visit_nodes(self.dag)))

def num_tasks(self, resume=None):
"""Return the number of tasks needed to execute this plan."""
tasks = 0
Expand Down
37 changes: 28 additions & 9 deletions cubed/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

import cubed
import cubed.array_api as xp
from cubed.core.plan import arrays_to_plan


def test_qr():
A = np.reshape(np.arange(32, dtype=np.float64), (16, 2))
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(4, 2)))

cubed.visualize(Q, R, optimize_graph=False)
plan_unopt = arrays_to_plan(Q, R)._finalize()
assert plan_unopt.num_primitive_ops() == 4

Q, R = cubed.compute(Q, R)

assert_allclose(Q @ R, A, atol=1e-08)
Expand All @@ -19,16 +22,32 @@ def test_qr():


def test_qr_recursion():
spec = cubed.Spec(allowed_mem=128 * 4 * 1.5, reserved_mem=0)
A = np.reshape(np.arange(64, dtype=np.float64), (32, 2))
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(8, 2), spec=spec))
A = np.reshape(np.arange(128, dtype=np.float64), (64, 2))

cubed.visualize(Q, R, optimize_graph=False)
Q, R = cubed.compute(Q, R)
# find a memory setting where recursion happens
found = False
for factor in range(4, 16):
spec = cubed.Spec(allowed_mem=128 * factor, reserved_mem=0)

assert_allclose(Q @ R, A, atol=1e-08)
assert_allclose(Q.T @ Q, np.eye(2, 2), atol=1e-08) # Q must be orthonormal
assert_allclose(R, np.triu(R), atol=1e-08) # R must be upper triangular
try:
Q, R = xp.linalg.qr(xp.asarray(A, chunks=(8, 2), spec=spec))

found = True
plan_unopt = arrays_to_plan(Q, R)._finalize()
assert plan_unopt.num_primitive_ops() > 4 # more than without recursion

Q, R = cubed.compute(Q, R)

assert_allclose(Q @ R, A, atol=1e-08)
assert_allclose(Q.T @ Q, np.eye(2, 2), atol=1e-08) # Q must be orthonormal
assert_allclose(R, np.triu(R), atol=1e-08) # R must be upper triangular

break

except ValueError:
pass # not enough memory

assert found


def test_qr_chunking():
Expand Down

0 comments on commit 76cbe5c

Please sign in to comment.