Skip to content

Commit

Permalink
Although, maybe better to use their add_low_rank
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 9, 2024
1 parent 34f2f2c commit cc1623a
Showing 1 changed file with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def __init__(self, *linear_ops, preconditioner_override=None):
is_first = isinstance(linear_ops[0], LowRankRootLinearOperator)
self._root_op = linear_ops[1 - is_first]
self._other_op = linear_ops[is_first]
super().__init__(linear_ops, preconditioner_override=preconditioner_override)
super().__init__(*linear_ops, preconditioner_override=preconditioner_override)

@property
@cached(name="chol_cap_mat")
def chol_cap_mat(self):
U = self._root_op.root
V = self._root_op.root.mT
Ainv_U = self._other_op.solve(U)
Ainv_U = self._other_op.solve(U.to_dense())

C = ConstantDiagLinearOperator(
torch.ones(*V.batch_shape, 1, device=V.device, dtype=V.dtype), V.shape[-2]
Expand Down Expand Up @@ -87,15 +87,16 @@ def _solve(
],
]:
A = self._other_op
U = self._linear_op.root
V = self._linear_op.root.mT
U = self._root_op.root
V = self._root_op.root.mT
chol_cap_mat = self.chol_cap_mat
Ainv_rhs = A.solve(rhs)

res = V.matmul(A.solve(rhs))
res = V.matmul(Ainv_rhs)
res = torch.cholesky_solve(res, chol_cap_mat)
res = A.solve(U.matmul(res))

solve = A.solve(rhs) - res
# res = A.solve(U.matmul(res))
# solve = A.solve(rhs) - res
solve = Ainv_rhs.sub_(A.solve(U.matmul(res)))

return solve

Expand Down

0 comments on commit cc1623a

Please sign in to comment.