Skip to content

Commit

Permalink
Fix assertion in block preconditioner, tune termination criteria
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 12, 2024
1 parent 8576a1e commit 6a5c110
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 56 deletions.
2 changes: 1 addition & 1 deletion examples/pose_graph_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
graph = jaxls.FactorGraph.make(factors, vars)

# Solve the optimization problem.
solution = graph.solve()
solution = graph.solve(linear_solver=jaxls.ConjugateGradientLinearSolver())
print("All solutions", solution)
print("Pose 0", solution[vars[0]])
print("Pose 1", solution[vars[1]])
4 changes: 2 additions & 2 deletions src/jaxls/_preconditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def make_block_jacobi_precoditioner(
assert gram_blocks.shape == (
num_factors,
num_vars,
factor.residual_dim,
factor.residual_dim,
var_type.tangent_dim,
var_type.tangent_dim,
)

start_concat_col = end_concat_col
Expand Down
123 changes: 71 additions & 52 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@ class NonlinearSolverState:
vals: VarValues
cost: float | jax.Array
residual_vector: jax.Array
done: bool | jax.Array
termination_criteria: jax.Array
termination_deltas: jax.Array
lambd: float | jax.Array

linear_state: ConjugateGradientState | None
# Conjugate gradient state. Not used for other solvers.
cg_state: ConjugateGradientState | None


@jdc.pytree_dataclass
Expand All @@ -196,11 +198,12 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues:
vals=vals,
cost=jnp.sum(residual_vector**2),
residual_vector=residual_vector,
done=False,
termination_criteria=jnp.array([False, False, False, False]),
termination_deltas=jnp.zeros(3),
lambd=self.trust_region.lambda_initial
if self.trust_region is not None
else 0.0,
linear_state=None
cg_state=None
if isinstance(self.linear_solver, CholmodLinearSolver)
else ConjugateGradientState(
ATb_norm_prev=0.0, eta=self.linear_solver.tolerance_max
Expand All @@ -209,15 +212,19 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues:

# Optimization.
state = jax.lax.while_loop(
cond_fun=lambda state: jnp.logical_not(state.done),
cond_fun=lambda state: jnp.logical_not(jnp.any(state.termination_criteria)),
body_fun=functools.partial(self.step, graph),
init_val=state,
)
if self.verbose:
jax_log(
"Terminated @ iteration #{i}: cost={cost:.4f}",
"Terminated @ iteration #{i}: cost={cost:.4f} criteria={criteria}, term_deltas={cost_delta:.1e},{grad_mag:.1e},{param_delta:.1e}",
i=state.iterations,
cost=state.cost,
criteria=state.termination_criteria.astype(jnp.int32),
cost_delta=state.termination_deltas[0],
grad_mag=state.termination_deltas[1],
param_delta=state.termination_deltas[2],
)
return state.vals

Expand Down Expand Up @@ -248,15 +255,15 @@ def step(

linear_state = None
if isinstance(self.linear_solver, ConjugateGradientLinearSolver):
assert isinstance(state.linear_state, ConjugateGradientState)
assert isinstance(state.cg_state, ConjugateGradientState)
local_delta, linear_state = self.linear_solver._solve(
graph,
A_blocksparse,
# We could also use (lambd * ATA_diagonals * vec) for
# scale-invariant damping. But this is hard to match with CHOLMOD.
lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec,
ATb=ATb,
prev_linear_state=state.linear_state,
prev_linear_state=state.cg_state,
)
elif isinstance(self.linear_solver, CholmodLinearSolver):
A_csr = SparseCsrMatrix(jac_values, graph.jac_coords_csr)
Expand All @@ -266,13 +273,29 @@ def step(

vals = state.vals._retract(local_delta, graph.tangent_ordering)
if self.verbose:
jax_log(
" step #{i}: cost={cost:.4f} lambd={lambd:.4f}",
i=state.iterations,
cost=state.cost,
lambd=state.lambd,
ordered=True,
)
if state.cg_state is None:
jax_log(
" step #{i}: cost={cost:.4f} lambd={lambd:.4f} term_deltas={cost_delta:.1e},{grad_mag:.1e},{param_delta:.1e}",
i=state.iterations,
cost=state.cost,
lambd=state.lambd,
cost_delta=state.termination_deltas[0],
grad_mag=state.termination_deltas[1],
param_delta=state.termination_deltas[2],
ordered=True,
)
else:
jax_log(
" step #{i}: cost={cost:.4f} lambd={lambd:.4f} term_deltas={cost_delta:.1e},{grad_mag:.1e},{param_delta:.1e} inexact_tol={inexact_tol:.1e}",
i=state.iterations,
cost=state.cost,
lambd=state.lambd,
cost_delta=state.termination_deltas[0],
grad_mag=state.termination_deltas[1],
param_delta=state.termination_deltas[2],
inexact_tol=state.cg_state.eta,
ordered=True,
)
residual_index = 0
for f, count in zip(graph.stacked_factors, graph.factor_counts):
stacked_dim = count * f.residual_dim
Expand All @@ -296,7 +319,7 @@ def step(

# Update ATb_norm for Eisenstat-Walker criterion.
if linear_state is not None:
state_next.linear_state = linear_state
state_next.cg_state = linear_state

# Always accept Gauss-Newton steps.
if self.trust_region is None:
Expand Down Expand Up @@ -339,12 +362,14 @@ def step(
)

state_next.iterations += 1
state_next.done = self.termination._check_convergence(
state,
cost_updated=state_next.cost,
tangent=local_delta,
tangent_ordering=graph.tangent_ordering,
ATb=ATb,
state_next.termination_criteria, state_next.termination_deltas = (
self.termination._check_convergence(
state,
cost_updated=state_next.cost,
tangent=local_delta,
tangent_ordering=graph.tangent_ordering,
ATb=ATb,
)
)
return state_next

Expand All @@ -368,14 +393,14 @@ class TrustRegionConfig:
class TerminationConfig:
# Termination criteria.
max_iterations: int = 100
cost_tolerance: float = 1e-6
cost_tolerance: float = 1e-5
"""We terminate if `|cost change| / cost < cost_tolerance`."""
gradient_tolerance: float = 1e-7
gradient_tolerance: float = 1e-4
"""We terminate if `norm_inf(x - rplus(x, linear delta)) < gradient_tolerance`."""
gradient_tolerance_start_step: int = 10
"""When to start checking the gradient tolerance condition. Helps solve precision
issues caused by inexact Newton steps."""
parameter_tolerance: float = 1e-7
parameter_tolerance: float = 1e-6
"""We terminate if `norm_2(linear delta) < (norm2(x) + parameter_tolerance) * parameter_tolerance`."""

def _check_convergence(
Expand All @@ -385,44 +410,38 @@ def _check_convergence(
tangent: jax.Array,
tangent_ordering: VarTypeOrdering,
ATb: jax.Array,
) -> jax.Array:
) -> tuple[jax.Array, jax.Array]:
"""Check for convergence!"""

# Cost tolerance
converged_cost = (
jnp.abs(cost_updated - state_prev.cost) / state_prev.cost
< self.cost_tolerance
)
cost_delta = jnp.abs(cost_updated - state_prev.cost) / state_prev.cost
converged_cost = cost_delta < self.cost_tolerance

# Gradient tolerance
flat_vals = jax.flatten_util.ravel_pytree(state_prev.vals)[0]
gradient_mag = jnp.max(
flat_vals
- jax.flatten_util.ravel_pytree(
state_prev.vals._retract(ATb, tangent_ordering)
)[0]
)
converged_gradient = jnp.where(
state_prev.iterations >= self.gradient_tolerance_start_step,
jnp.max(
flat_vals
- jax.flatten_util.ravel_pytree(
state_prev.vals._retract(ATb, tangent_ordering)
)[0]
)
< self.gradient_tolerance,
gradient_mag < self.gradient_tolerance,
False,
)

# Parameter tolerance
converged_parameters = (
jnp.linalg.norm(jnp.abs(tangent))
< (jnp.linalg.norm(flat_vals) + self.parameter_tolerance)
* self.parameter_tolerance
param_delta = jnp.linalg.norm(jnp.abs(tangent)) / (
jnp.linalg.norm(flat_vals) + self.parameter_tolerance
)
converged_parameters = param_delta < self.parameter_tolerance

return jnp.any(
jnp.array(
[
converged_cost,
converged_gradient,
converged_parameters,
state_prev.iterations >= (self.max_iterations - 1),
]
),
axis=0,
)
return jnp.array(
[
converged_cost,
converged_gradient,
converged_parameters,
state_prev.iterations >= (self.max_iterations - 1),
]
), jnp.array([cost_delta, gradient_mag, param_delta])
2 changes: 1 addition & 1 deletion src/jaxls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def stopwatch(label: str = "unlabeled block") -> Generator[None, None, None]:


def _log(fmt: str, *args, **kwargs) -> None:
logger.bind(function="jakljk").info(fmt, *args, **kwargs)
logger.bind(function="log").info(fmt, *args, **kwargs)


def jax_log(fmt: str, *args, **kwargs) -> None:
Expand Down

0 comments on commit 6a5c110

Please sign in to comment.