Skip to content

Commit

Permalink
Fix Levenberg-Marquardt termination edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 14, 2024
1 parent d2b7b7b commit 0ed83aa
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,23 +345,12 @@ def step(
if linear_state is not None:
state_next.cg_state = linear_state

# Compute termination criteria.
state_next.termination_criteria, state_next.termination_deltas = (
self.termination._check_convergence(
state,
cost_updated=proposed_cost,
tangent=local_delta,
tangent_ordering=graph.tangent_ordering,
ATb=ATb,
)
)

# Always accept Gauss-Newton steps.
if self.trust_region is None:
state_next.vals = vals
state_next.residual_vector = proposed_residual_vector
state_next.cost = proposed_cost

accept_flag = None
# For Levenberg-Marquardt, we need to evaluate the step quality.
else:
step_quality = (proposed_cost - state.cost) / (
Expand All @@ -373,11 +362,6 @@ def step(
)
accept_flag = step_quality >= self.trust_region.step_quality_min

# Should not terminate if we're rejecting step.
state_next.termination_criteria = jnp.logical_and(
accept_flag, state_next.termination_criteria
)

state_next.vals = jax.tree_map(
lambda proposed, current: jnp.where(accept_flag, proposed, current),
vals,
Expand All @@ -401,6 +385,18 @@ def step(
),
)

# Compute termination criteria.
state_next.termination_criteria, state_next.termination_deltas = (
self.termination._check_convergence(
state,
cost_updated=proposed_cost,
tangent=local_delta,
tangent_ordering=graph.tangent_ordering,
ATb=ATb,
accept_flag=accept_flag,
)
)

state_next.iterations += 1
return state_next

Expand Down Expand Up @@ -441,12 +437,14 @@ def _check_convergence(
tangent: jax.Array,
tangent_ordering: VarTypeOrdering,
ATb: jax.Array,
accept_flag: jax.Array | None = None,
) -> tuple[jax.Array, jax.Array]:
"""Check for convergence!"""

# Cost tolerance
cost_delta = jnp.abs(cost_updated - state_prev.cost) / state_prev.cost
converged_cost = cost_delta < self.cost_tolerance
cost_absdelta = jnp.abs(cost_updated - state_prev.cost)
cost_reldelta = cost_absdelta / state_prev.cost
converged_cost = cost_reldelta < self.cost_tolerance

# Gradient tolerance
flat_vals = jax.flatten_util.ravel_pytree(state_prev.vals)[0]
Expand All @@ -468,11 +466,24 @@ def _check_convergence(
)
converged_parameters = param_delta < self.parameter_tolerance

return jnp.array(
# Check termination flags. We'll terminate if any of the conditions are met.
term_flags = jnp.array(
[
converged_cost,
converged_gradient,
converged_parameters,
state_prev.iterations >= (self.max_iterations - 1),
]
), jnp.array([cost_delta, gradient_mag, param_delta])
)

# Only consider the first three conditions if steps are accepted.
if accept_flag is not None:
term_flags = term_flags.at[:3].set(
jnp.logical_and(
term_flags[:3],
# We ignore accept_flag if the cost _actually_ didn't change at all.
jnp.logical_or(accept_flag, cost_absdelta == 0.0),
)
)

return term_flags, jnp.array([cost_reldelta, gradient_mag, param_delta])

0 comments on commit 0ed83aa

Please sign in to comment.