From ea8e2da204521d96aca1258a2cb164a91fb709de Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sat, 12 Oct 2024 00:05:56 -0700 Subject: [PATCH] Implement Eisenstat-Walker --- README.md | 9 ++--- src/jaxls/_solvers.py | 80 ++++++++++++++++++++++++++++++++----------- 2 files changed, 65 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index f912162..4c5eef3 100644 --- a/README.md +++ b/README.md @@ -11,15 +11,16 @@ problems. We accelerate optimization by analyzing the structure of graphs: repeated factor and variable types are vectorized, and the sparsity of adjacency in the graph is translated into sparse matrix operations. -Features: +Currently supported: - Automatic sparse Jacobians. - Optimization on manifolds; SO(2), SO(3), SE(2), and SE(3) implementations included. - Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton. -- Linear solvers: both direct (sparse Cholesky via CHOLMOD, on CPU) and - iterative (Conjugate Gradient). -- Preconditioning: block and point Jacobi. +- Direct linear solves via sparse Cholesky / CHOLMOD, on CPU. +- Iterative linear solves via Conjugate Gradient. + - Preconditioning: block and point Jacobi. + - Inexact Newton via Eisenstat-Walker. Use cases are primarily in least squares problems that are inherently (1) sparse and (2) inefficient to solve with gradient-based methods. These are diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index b29a952..8716531 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -79,17 +79,36 @@ def _solve_on_host( return factor.solve_A(ATb) +@jdc.pytree_dataclass +class ConjugateGradientState: + """State used for Eisenstat-Walker criterion in ConjugateGradientLinearSolver.""" + + ATb_norm_prev: float | jax.Array + """Previous norm of ATb.""" + eta: float | jax.Array + """Current tolerance.""" + + @jdc.pytree_dataclass class ConjugateGradientLinearSolver: - """Iterative solver for sparse linear systems. Can run on CPU or GPU.""" + """Iterative solver for sparse linear systems. Can run on CPU or GPU. - tolerance: float = 1e-7 - inexact_step_eta: float | None = None # 1e-2 - """Forcing sequence parameter for inexact Newton steps. CG tolerance is set to - `eta / iteration #`. + For inexact steps, we use the Eisenstat-Walker criterion. For reference, + see "Choosing the Forcing Terms in an Inexact Newton Method", Eisenstat & + Walker, 1996." + """ - For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR - LEAST SQUARES, Wright & Holt 1983.""" + tolerance_min: float = 1e-7 + tolerance_max: float = 1e-2 + + eisenstat_walker_gamma: float = 0.9 + """Eisenstat-Walker criterion gamma term. Controls how quickly the tolerance + decreases. Typical values range from 0.5 to 0.9. Higher values lead to more + aggressive tolerance reduction.""" + eisenstat_walker_alpha: float = 2.0 + """ Eisenstat-Walker criterion alpha term. Determines rate at which the + tolerance changes based on residual reduction. Typical values are 1.5 or + 2.0. Higher values make the tolerance more sensitive to residual changes.""" preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"] | None] = ( "block-jacobi" @@ -102,8 +121,8 @@ def _solve( A_blocksparse: BlockRowSparseMatrix, ATA_multiply: Callable[[jax.Array], jax.Array], ATb: jax.Array, - iterations: int | jax.Array, - ) -> jax.Array: + prev_linear_state: ConjugateGradientState, + ) -> tuple[jax.Array, ConjugateGradientState]: assert len(ATb.shape) == 1, "ATb should be 1D!" # Preconditioning setup. @@ -116,6 +135,18 @@ def _solve( else: assert_never(self.preconditioner) + # Calculate tolerance using Eisenstat-Walker criterion. + ATb_norm = jnp.linalg.norm(ATb) + current_eta = jnp.minimum( + self.eisenstat_walker_gamma + * (ATb_norm / (prev_linear_state.ATb_norm_prev + 1e-7)) + ** self.eisenstat_walker_alpha, + self.tolerance_max, + ) + current_eta = jnp.maximum( + self.tolerance_min, jnp.minimum(current_eta, prev_linear_state.eta) + ) + # Solve with conjugate gradient. initial_x = jnp.zeros(ATb.shape) solution_values, _ = jax.scipy.sparse.linalg.cg( @@ -124,16 +155,12 @@ def _solve( x0=initial_x, # https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties maxiter=len(initial_x), - tol=cast( - float, - jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)), - ) - if self.inexact_step_eta is not None - else self.tolerance, + tol=cast(float, current_eta), M=preconditioner, - # M=lambda x: x / ATA_diagonals, # Jacobi preconditioner. ) - return solution_values + return solution_values, ConjugateGradientState( + ATb_norm_prev=ATb_norm, eta=current_eta + ) # Nonlinear solvers. @@ -148,6 +175,8 @@ class NonlinearSolverState: done: bool | jax.Array lambd: float | jax.Array + linear_state: ConjugateGradientState | None + @jdc.pytree_dataclass class NonlinearSolver: @@ -171,6 +200,11 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues: lambd=self.trust_region.lambda_initial if self.trust_region is not None else 0.0, + linear_state=None + if isinstance(self.linear_solver, CholmodLinearSolver) + else ConjugateGradientState( + ATb_norm_prev=0.0, eta=self.linear_solver.tolerance_max + ), ) # Optimization. @@ -212,15 +246,17 @@ def step( # Compute right-hand side of normal equation. ATb = -AT_multiply(state.residual_vector) + linear_state = None if isinstance(self.linear_solver, ConjugateGradientLinearSolver): - local_delta = self.linear_solver._solve( + assert isinstance(state.linear_state, ConjugateGradientState) + local_delta, linear_state = self.linear_solver._solve( graph, A_blocksparse, # We could also use (lambd * ATA_diagonals * vec) for # scale-invariant damping. But this is hard to match with CHOLMOD. lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec, ATb=ATb, - iterations=state.iterations, + prev_linear_state=state.linear_state, ) elif isinstance(self.linear_solver, CholmodLinearSolver): A_csr = SparseCsrMatrix(jac_values, graph.jac_coords_csr) @@ -258,6 +294,10 @@ def step( proposed_residual_vector = graph.compute_residual_vector(vals) proposed_cost = jnp.sum(proposed_residual_vector**2) + # Update ATb_norm for Eisenstat-Walker criterion. + if linear_state is not None: + state_next.linear_state = linear_state + # Always accept Gauss-Newton steps. if self.trust_region is None: state_next.vals = vals @@ -327,7 +367,7 @@ class TrustRegionConfig: @jdc.pytree_dataclass class TerminationConfig: # Termination criteria. - max_iterations: int = 10 # 100 + max_iterations: int = 100 cost_tolerance: float = 1e-6 """We terminate if `|cost change| / cost < cost_tolerance`.""" gradient_tolerance: float = 1e-7