Skip to content

Commit

Permalink
Bugfix for xdiag
Browse files Browse the repository at this point in the history
  • Loading branch information
peekxc committed Dec 2, 2024
1 parent ef7687e commit ce402f8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/primate/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def xdiag(
Z = A.T @ Q
T = Z.T @ N
R_inv = np.linalg.inv(R)
S = R_inv / np.linalg.norm(R_inv, axis=0)
S = R_inv.T / np.linalg.norm(R_inv, axis=1)
QS = Q @ S

## Vector quantities
Expand Down
6 changes: 3 additions & 3 deletions tests/test_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ def test_xdiag():

## Ensure error is decreasing
errors = []
budget = np.linspace(2, A.shape[0], 10).astype(int)
budget = np.linspace(2, 2 * A.shape[0], 10).astype(int)
for m in budget:
d = xdiag(A, m, pdf="sphere", seed=rng)
d = xdiag(A, m, pdf="signs", seed=rng)
errors.append(np.linalg.norm(np.diag(A) - d))
# print(f"Error: {np.linalg.norm(np.diag(A) - d)}")

y = np.array(errors)
B = np.c_[budget, np.ones(len(budget))]
m, c = np.linalg.lstsq(B, y)[0]
assert m < -0.50, "Error is not decreasing appreciably"
assert m < -0.10, "Error is not decreasing appreciably"


# def test_diagonal():
Expand Down

0 comments on commit ce402f8

Please sign in to comment.