From ce402f853980b0b631cc8522a0c4b2e11aedd43f Mon Sep 17 00:00:00 2001 From: peekxc Date: Mon, 2 Dec 2024 16:01:56 -0500 Subject: [PATCH] Bugfix for xdiag --- src/primate/diagonal.py | 2 +- tests/test_diagonal.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/primate/diagonal.py b/src/primate/diagonal.py index 8f2543f..c81742c 100644 --- a/src/primate/diagonal.py +++ b/src/primate/diagonal.py @@ -121,7 +121,7 @@ def xdiag( Z = A.T @ Q T = Z.T @ N R_inv = np.linalg.inv(R) - S = R_inv / np.linalg.norm(R_inv, axis=0) + S = R_inv.T / np.linalg.norm(R_inv, axis=1) QS = Q @ S ## Vector quantities diff --git a/tests/test_diagonal.py b/tests/test_diagonal.py index af8ea06..17ba084 100644 --- a/tests/test_diagonal.py +++ b/tests/test_diagonal.py @@ -22,16 +22,16 @@ def test_xdiag(): ## Ensure error is decreasing errors = [] - budget = np.linspace(2, A.shape[0], 10).astype(int) + budget = np.linspace(2, 2 * A.shape[0], 10).astype(int) for m in budget: - d = xdiag(A, m, pdf="sphere", seed=rng) + d = xdiag(A, m, pdf="signs", seed=rng) errors.append(np.linalg.norm(np.diag(A) - d)) # print(f"Error: {np.linalg.norm(np.diag(A) - d)}") y = np.array(errors) B = np.c_[budget, np.ones(len(budget))] m, c = np.linalg.lstsq(B, y)[0] - assert m < -0.50, "Error is not decreasing appreciably" + assert m < -0.10, "Error is not decreasing appreciably" # def test_diagonal():