Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scale-invariant damping #23

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/pose_graph_g2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def main(
)

with jaxls.utils.stopwatch("Running solve"):
solution_vals = graph.solve(
initial_vals, trust_region=None, linear_solver=linear_solver
)
solution_vals = graph.solve(initial_vals, linear_solver=linear_solver)

with jaxls.utils.stopwatch("Running solve (again)"):
solution_vals = graph.solve(
Expand Down
13 changes: 5 additions & 8 deletions src/jaxls/_preconditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
from ._sparse_matrices import BlockRowSparseMatrix


def make_point_jacobi_precoditioner(
A_blocksparse: BlockRowSparseMatrix,
) -> Callable[[jax.Array], jax.Array]:
"""Returns a point Jacobi (diagonal) preconditioner."""
ATA_diagonals = jnp.zeros(A_blocksparse.shape[1])
def get_ATA_diag(A_blocksparse: BlockRowSparseMatrix) -> jnp.ndarray:
"""Get 1D array of diagonal elements of the Gram matrix."""
ATA_diag = jnp.zeros(A_blocksparse.shape[1])

for block_row in A_blocksparse.block_rows:
(n_blocks, rows, cols_concat) = block_row.blocks_concat.shape
Expand All @@ -32,9 +30,8 @@ def make_point_jacobi_precoditioner(
],
axis=1,
).flatten()
ATA_diagonals = ATA_diagonals.at[indices].add(block_l2_cols)

return lambda vec: vec / ATA_diagonals
ATA_diag = ATA_diag.at[indices].add(block_l2_cols)
return ATA_diag


def make_block_jacobi_precoditioner(
Expand Down
22 changes: 17 additions & 5 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from jax import numpy as jnp

from jaxls._preconditioning import (
get_ATA_diag,
make_block_jacobi_precoditioner,
make_point_jacobi_precoditioner,
)

from ._sparse_matrices import BlockRowSparseMatrix, SparseCooMatrix, SparseCsrMatrix
Expand Down Expand Up @@ -122,6 +122,7 @@ def _solve(
A_blocksparse: BlockRowSparseMatrix,
ATA_multiply: Callable[[jax.Array], jax.Array],
ATb: jax.Array,
ATA_diagonals: jax.Array,
prev_linear_state: _ConjugateGradientState,
) -> tuple[jax.Array, _ConjugateGradientState]:
assert len(ATb.shape) == 1, "ATb should be 1D!"
Expand All @@ -130,7 +131,7 @@ def _solve(
if self.preconditioner == "block_jacobi":
preconditioner = make_block_jacobi_precoditioner(graph, A_blocksparse)
elif self.preconditioner == "point_jacobi":
preconditioner = make_point_jacobi_precoditioner(A_blocksparse)
preconditioner = lambda vec: vec / ATA_diagonals
elif self.preconditioner is None:
preconditioner = lambda x: x
else:
Expand Down Expand Up @@ -283,6 +284,9 @@ def step(
# Compute right-hand side of normal equation.
ATb = -AT_multiply(state.residual_vector)

# Used for CG + dense Cholesky.
ATA_diagonals = get_ATA_diag(A_blocksparse)

linear_state = None
if (
isinstance(self.linear_solver, ConjugateGradientConfig)
Expand All @@ -300,8 +304,10 @@ def step(
A_blocksparse,
# We could also use (lambd * ATA_diagonals * vec) for
# scale-invariant damping. But this is hard to match with CHOLMOD.
lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec,
lambda vec: AT_multiply(A_multiply(vec))
+ state.lambd * ATA_diagonals * vec,
ATb=ATb,
ATA_diagonals=ATA_diagonals,
prev_linear_state=state.cg_state,
)
elif self.linear_solver == "cholmod":
Expand All @@ -312,7 +318,7 @@ def step(
A_dense = A_blocksparse.to_dense()
ATA = A_dense.T @ A_dense
diag_idx = jnp.arange(ATA.shape[0])
ATA = ATA.at[diag_idx, diag_idx].add(state.lambd)
ATA = ATA.at[diag_idx, diag_idx].add(state.lambd * ATA_diagonals)
cho_factor = jax.scipy.linalg.cho_factor(ATA)
local_delta = jax.scipy.linalg.cho_solve(cho_factor, ATb)
else:
Expand Down Expand Up @@ -428,7 +434,13 @@ def step(
class TrustRegionConfig:
# Levenberg-Marquardt parameters.
lambda_initial: float = 5e-4
"""Initial damping factor. Only used for Levenberg-Marquardt."""
"""Initial damping factor. Only used for Levenberg-Marquardt.

*Unfortunate note:* the damping behavior of LM will currently be different
depending on which linear solver you use.
- For CHOLMOD, we apply damping naively to normal equations by solving `(ATA + lambda*I)x = ATb`.
- For all other solvers, we apply scale-invariant damping by solving `(ATA + lambda * diag(ATA)) x = ATb`.
"""
lambda_factor: float = 2.0
"""Factor to increase or decrease damping. Only used for Levenberg-Marquardt."""
lambda_min: float = 1e-5
Expand Down
Loading