Skip to content

Commit

Permalink
final debugging of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoppe committed Nov 14, 2024
1 parent 4f67a93 commit cec049c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 22 deletions.
21 changes: 0 additions & 21 deletions ausprobieren.py

This file was deleted.

2 changes: 1 addition & 1 deletion heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def qr(
k = A.shape[-1]
else:
last_row_reached = min(torch.argwhere(lshapes_cum >= A.shape[-2]))[0]
k = A.shape[0]
k = A.shape[-2]

if mode == "reduced":
Q = factories.zeros(
Expand Down
30 changes: 30 additions & 0 deletions heat/core/linalg/tests/test_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,36 @@ def test_qr_split0(self):
)
)

def test_batched_qr_splitNone(self):
# two batch dimensions, float64 data type, "split = None" (split batch axis)
x = ht.random.rand(2, 2 * ht.MPI_WORLD.size, 10, 9, dtype=ht.float32, split=1)
_, r = ht.linalg.qr(x, mode="r")
self.assertEqual(r.shape, (2, 2 * ht.MPI_WORLD.size, 9, 9))
self.assertEqual(r.split, 1)

def test_batched_qr_split1(self):
# two batch dimensions, float64 data type, "split = 1" (last dimension)
ht.random.seed(0)
x = ht.random.rand(3, 2, 50, ht.MPI_WORLD.size * 5 + 3, dtype=ht.float64, split=3)
q, r = ht.linalg.qr(x)
batched_id = ht.stack([ht.eye(q.shape[3], dtype=ht.float64) for _ in range(6)]).reshape(
3, 2, q.shape[3], q.shape[3]
)

self.assertTrue(
ht.allclose(q.transpose([0, 1, 3, 2]) @ q, batched_id, atol=1e-6, rtol=1e-6)
)
self.assertTrue(ht.allclose(q @ r, x, atol=1e-6, rtol=1e-6))

def test_batched_qr_split0(self):
# one batch dimension, float32 data type, "split = 0" (second last dimension)
x = ht.random.randn(8, ht.MPI_WORLD.size * 10 + 3, 9, dtype=ht.float32, split=1)
q, r = ht.linalg.qr(x)
batched_id = ht.stack([ht.eye(q.shape[2], dtype=ht.float32) for _ in range(q.shape[0])])

self.assertTrue(ht.allclose(q.transpose([0, 2, 1]) @ q, batched_id, atol=1e-3, rtol=1e-3))
self.assertTrue(ht.allclose(q @ r, x, atol=1e-3, rtol=1e-3))

def test_wronginputs(self):
# test wrong input type
with self.assertRaises(TypeError):
Expand Down

0 comments on commit cec049c

Please sign in to comment.