Skip to content

Commit

Permalink
Implement Eisenstat-Walker
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 12, 2024
1 parent 2c41aef commit ea8e2da
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 24 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ problems. We accelerate optimization by analyzing the structure of graphs:
repeated factor and variable types are vectorized, and the sparsity of adjacency
in the graph is translated into sparse matrix operations.

Features:
Currently supported:

- Automatic sparse Jacobians.
- Optimization on manifolds; SO(2), SO(3), SE(2), and SE(3) implementations
included.
- Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton.
- Linear solvers: both direct (sparse Cholesky via CHOLMOD, on CPU) and
iterative (Conjugate Gradient).
- Preconditioning: block and point Jacobi.
- Direct linear solves via sparse Cholesky / CHOLMOD, on CPU.
- Iterative linear solves via Conjugate Gradient.
- Preconditioning: block and point Jacobi.
- Inexact Newton via Eisenstat-Walker.

Use cases are primarily in least squares problems that are inherently (1)
sparse and (2) inefficient to solve with gradient-based methods. These are
Expand Down
80 changes: 60 additions & 20 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,36 @@ def _solve_on_host(
return factor.solve_A(ATb)


@jdc.pytree_dataclass
class ConjugateGradientState:
"""State used for Eisenstat-Walker criterion in ConjugateGradientLinearSolver."""

ATb_norm_prev: float | jax.Array
"""Previous norm of ATb."""
eta: float | jax.Array
"""Current tolerance."""


@jdc.pytree_dataclass
class ConjugateGradientLinearSolver:
"""Iterative solver for sparse linear systems. Can run on CPU or GPU."""
"""Iterative solver for sparse linear systems. Can run on CPU or GPU.
tolerance: float = 1e-7
inexact_step_eta: float | None = None # 1e-2
"""Forcing sequence parameter for inexact Newton steps. CG tolerance is set to
`eta / iteration #`.
For inexact steps, we use the Eisenstat-Walker criterion. For reference,
see "Choosing the Forcing Terms in an Inexact Newton Method", Eisenstat &
Walker, 1996."
"""

For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR
LEAST SQUARES, Wright & Holt 1983."""
tolerance_min: float = 1e-7
tolerance_max: float = 1e-2

eisenstat_walker_gamma: float = 0.9
"""Eisenstat-Walker criterion gamma term. Controls how quickly the tolerance
decreases. Typical values range from 0.5 to 0.9. Higher values lead to more
aggressive tolerance reduction."""
eisenstat_walker_alpha: float = 2.0
""" Eisenstat-Walker criterion alpha term. Determines rate at which the
tolerance changes based on residual reduction. Typical values are 1.5 or
2.0. Higher values make the tolerance more sensitive to residual changes."""

preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"] | None] = (
"block-jacobi"
Expand All @@ -102,8 +121,8 @@ def _solve(
A_blocksparse: BlockRowSparseMatrix,
ATA_multiply: Callable[[jax.Array], jax.Array],
ATb: jax.Array,
iterations: int | jax.Array,
) -> jax.Array:
prev_linear_state: ConjugateGradientState,
) -> tuple[jax.Array, ConjugateGradientState]:
assert len(ATb.shape) == 1, "ATb should be 1D!"

# Preconditioning setup.
Expand All @@ -116,6 +135,18 @@ def _solve(
else:
assert_never(self.preconditioner)

# Calculate tolerance using Eisenstat-Walker criterion.
ATb_norm = jnp.linalg.norm(ATb)
current_eta = jnp.minimum(
self.eisenstat_walker_gamma
* (ATb_norm / (prev_linear_state.ATb_norm_prev + 1e-7))
** self.eisenstat_walker_alpha,
self.tolerance_max,
)
current_eta = jnp.maximum(
self.tolerance_min, jnp.minimum(current_eta, prev_linear_state.eta)
)

# Solve with conjugate gradient.
initial_x = jnp.zeros(ATb.shape)
solution_values, _ = jax.scipy.sparse.linalg.cg(
Expand All @@ -124,16 +155,12 @@ def _solve(
x0=initial_x,
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties
maxiter=len(initial_x),
tol=cast(
float,
jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)),
)
if self.inexact_step_eta is not None
else self.tolerance,
tol=cast(float, current_eta),
M=preconditioner,
# M=lambda x: x / ATA_diagonals, # Jacobi preconditioner.
)
return solution_values
return solution_values, ConjugateGradientState(
ATb_norm_prev=ATb_norm, eta=current_eta
)


# Nonlinear solvers.
Expand All @@ -148,6 +175,8 @@ class NonlinearSolverState:
done: bool | jax.Array
lambd: float | jax.Array

linear_state: ConjugateGradientState | None


@jdc.pytree_dataclass
class NonlinearSolver:
Expand All @@ -171,6 +200,11 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues:
lambd=self.trust_region.lambda_initial
if self.trust_region is not None
else 0.0,
linear_state=None
if isinstance(self.linear_solver, CholmodLinearSolver)
else ConjugateGradientState(
ATb_norm_prev=0.0, eta=self.linear_solver.tolerance_max
),
)

# Optimization.
Expand Down Expand Up @@ -212,15 +246,17 @@ def step(
# Compute right-hand side of normal equation.
ATb = -AT_multiply(state.residual_vector)

linear_state = None
if isinstance(self.linear_solver, ConjugateGradientLinearSolver):
local_delta = self.linear_solver._solve(
assert isinstance(state.linear_state, ConjugateGradientState)
local_delta, linear_state = self.linear_solver._solve(
graph,
A_blocksparse,
# We could also use (lambd * ATA_diagonals * vec) for
# scale-invariant damping. But this is hard to match with CHOLMOD.
lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec,
ATb=ATb,
iterations=state.iterations,
prev_linear_state=state.linear_state,
)
elif isinstance(self.linear_solver, CholmodLinearSolver):
A_csr = SparseCsrMatrix(jac_values, graph.jac_coords_csr)
Expand Down Expand Up @@ -258,6 +294,10 @@ def step(
proposed_residual_vector = graph.compute_residual_vector(vals)
proposed_cost = jnp.sum(proposed_residual_vector**2)

# Update ATb_norm for Eisenstat-Walker criterion.
if linear_state is not None:
state_next.linear_state = linear_state

# Always accept Gauss-Newton steps.
if self.trust_region is None:
state_next.vals = vals
Expand Down Expand Up @@ -327,7 +367,7 @@ class TrustRegionConfig:
@jdc.pytree_dataclass
class TerminationConfig:
# Termination criteria.
max_iterations: int = 10 # 100
max_iterations: int = 100
cost_tolerance: float = 1e-6
"""We terminate if `|cost change| / cost < cost_tolerance`."""
gradient_tolerance: float = 1e-7
Expand Down

0 comments on commit ea8e2da

Please sign in to comment.