From cc1623abf27e2ff9cb00d88094819b2ee8250cd1 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 9 Nov 2024 07:30:48 -0800 Subject: [PATCH] Although, maybe better to use their add_low_rank --- .../util/{operators.py => more_operators.py} | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) rename src/dartsort/util/{operators.py => more_operators.py} (95%) diff --git a/src/dartsort/util/operators.py b/src/dartsort/util/more_operators.py similarity index 95% rename from src/dartsort/util/operators.py rename to src/dartsort/util/more_operators.py index fb9f7a55..a67af7ef 100644 --- a/src/dartsort/util/operators.py +++ b/src/dartsort/util/more_operators.py @@ -32,14 +32,14 @@ def __init__(self, *linear_ops, preconditioner_override=None): is_first = isinstance(linear_ops[0], LowRankRootLinearOperator) self._root_op = linear_ops[1 - is_first] self._other_op = linear_ops[is_first] - super().__init__(linear_ops, preconditioner_override=preconditioner_override) + super().__init__(*linear_ops, preconditioner_override=preconditioner_override) @property @cached(name="chol_cap_mat") def chol_cap_mat(self): U = self._root_op.root V = self._root_op.root.mT - Ainv_U = self._other_op.solve(U) + Ainv_U = self._other_op.solve(U.to_dense()) C = ConstantDiagLinearOperator( torch.ones(*V.batch_shape, 1, device=V.device, dtype=V.dtype), V.shape[-2] @@ -87,15 +87,16 @@ def _solve( ], ]: A = self._other_op - U = self._linear_op.root - V = self._linear_op.root.mT + U = self._root_op.root + V = self._root_op.root.mT chol_cap_mat = self.chol_cap_mat + Ainv_rhs = A.solve(rhs) - res = V.matmul(A.solve(rhs)) + res = V.matmul(Ainv_rhs) res = torch.cholesky_solve(res, chol_cap_mat) - res = A.solve(U.matmul(res)) - - solve = A.solve(rhs) - res + # res = A.solve(U.matmul(res)) + # solve = A.solve(rhs) - res + solve = Ainv_rhs.sub_(A.solve(U.matmul(res))) return solve