From 37d01ee8d033789ed4cdb009b838aab3775bb559 Mon Sep 17 00:00:00 2001 From: brentyi Date: Thu, 31 Oct 2024 17:07:10 -0700 Subject: [PATCH] Add scale-invariant damping --- examples/pose_graph_g2o.py | 4 +--- src/jaxls/_preconditioning.py | 13 +++++-------- src/jaxls/_solvers.py | 22 +++++++++++++++++----- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/examples/pose_graph_g2o.py b/examples/pose_graph_g2o.py index 4b2568d..9d107d0 100755 --- a/examples/pose_graph_g2o.py +++ b/examples/pose_graph_g2o.py @@ -42,9 +42,7 @@ def main( ) with jaxls.utils.stopwatch("Running solve"): - solution_vals = graph.solve( - initial_vals, trust_region=None, linear_solver=linear_solver - ) + solution_vals = graph.solve(initial_vals, linear_solver=linear_solver) with jaxls.utils.stopwatch("Running solve (again)"): solution_vals = graph.solve( diff --git a/src/jaxls/_preconditioning.py b/src/jaxls/_preconditioning.py index 2368740..98b8ca7 100644 --- a/src/jaxls/_preconditioning.py +++ b/src/jaxls/_preconditioning.py @@ -10,11 +10,9 @@ from ._sparse_matrices import BlockRowSparseMatrix -def make_point_jacobi_precoditioner( - A_blocksparse: BlockRowSparseMatrix, -) -> Callable[[jax.Array], jax.Array]: - """Returns a point Jacobi (diagonal) preconditioner.""" - ATA_diagonals = jnp.zeros(A_blocksparse.shape[1]) +def get_ATA_diag(A_blocksparse: BlockRowSparseMatrix) -> jnp.ndarray: + """Get 1D array of diagonal elements of the Gram matrix.""" + ATA_diag = jnp.zeros(A_blocksparse.shape[1]) for block_row in A_blocksparse.block_rows: (n_blocks, rows, cols_concat) = block_row.blocks_concat.shape @@ -32,9 +30,8 @@ def make_point_jacobi_precoditioner( ], axis=1, ).flatten() - ATA_diagonals = ATA_diagonals.at[indices].add(block_l2_cols) - - return lambda vec: vec / ATA_diagonals + ATA_diag = ATA_diag.at[indices].add(block_l2_cols) + return ATA_diag def make_block_jacobi_precoditioner( diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index d53a73d..6d97eb1 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -12,8 +12,8 @@ from jax import numpy as jnp from jaxls._preconditioning import ( + get_ATA_diag, make_block_jacobi_precoditioner, - make_point_jacobi_precoditioner, ) from ._sparse_matrices import BlockRowSparseMatrix, SparseCooMatrix, SparseCsrMatrix @@ -122,6 +122,7 @@ def _solve( A_blocksparse: BlockRowSparseMatrix, ATA_multiply: Callable[[jax.Array], jax.Array], ATb: jax.Array, + ATA_diagonals: jax.Array, prev_linear_state: _ConjugateGradientState, ) -> tuple[jax.Array, _ConjugateGradientState]: assert len(ATb.shape) == 1, "ATb should be 1D!" @@ -130,7 +131,7 @@ def _solve( if self.preconditioner == "block_jacobi": preconditioner = make_block_jacobi_precoditioner(graph, A_blocksparse) elif self.preconditioner == "point_jacobi": - preconditioner = make_point_jacobi_precoditioner(A_blocksparse) + preconditioner = lambda vec: vec / ATA_diagonals elif self.preconditioner is None: preconditioner = lambda x: x else: @@ -283,6 +284,9 @@ def step( # Compute right-hand side of normal equation. ATb = -AT_multiply(state.residual_vector) + # Used for CG + dense Cholesky. + ATA_diagonals = get_ATA_diag(A_blocksparse) + linear_state = None if ( isinstance(self.linear_solver, ConjugateGradientConfig) @@ -300,8 +304,10 @@ def step( A_blocksparse, # We could also use (lambd * ATA_diagonals * vec) for # scale-invariant damping. But this is hard to match with CHOLMOD. - lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec, + lambda vec: AT_multiply(A_multiply(vec)) + + state.lambd * ATA_diagonals * vec, ATb=ATb, + ATA_diagonals=ATA_diagonals, prev_linear_state=state.cg_state, ) elif self.linear_solver == "cholmod": @@ -312,7 +318,7 @@ def step( A_dense = A_blocksparse.to_dense() ATA = A_dense.T @ A_dense diag_idx = jnp.arange(ATA.shape[0]) - ATA = ATA.at[diag_idx, diag_idx].add(state.lambd) + ATA = ATA.at[diag_idx, diag_idx].add(state.lambd * ATA_diagonals) cho_factor = jax.scipy.linalg.cho_factor(ATA) local_delta = jax.scipy.linalg.cho_solve(cho_factor, ATb) else: @@ -428,7 +434,13 @@ def step( class TrustRegionConfig: # Levenberg-Marquardt parameters. lambda_initial: float = 5e-4 - """Initial damping factor. Only used for Levenberg-Marquardt.""" + """Initial damping factor. Only used for Levenberg-Marquardt. + + *Unfortunate note:* the damping behavior of LM will currently be different + depending on which linear solver you use. + - For CHOLMOD, we apply damping naively to normal equations by solving `(ATA + lambda*I)x = ATb`. + - For all other solvers, we apply scale-invariant damping by solving `(ATA + lambda * diag(ATA)) x = ATb`. + """ lambda_factor: float = 2.0 """Factor to increase or decrease damping. Only used for Levenberg-Marquardt.""" lambda_min: float = 1e-5