Skip to content

Commit

Permalink
Adding step rejection feature
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695090781
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 11, 2024
1 parent 8a610fc commit 486c2cd
Showing 1 changed file with 45 additions and 4 deletions.
49 changes: 45 additions & 4 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(
damping_upper_threshold: Numeric = 0.75,
always_use_exact_qmodel_for_damping_adjustment: bool = False,
precon_damping_mult: Numeric = 1.0,
use_step_rejection: bool = False,
reject_damping_increase_factor: float = 1.0,
norm_constraint: Numeric | None = None,
num_burnin_steps: int = 10,
estimation_mode: str | None = None,
Expand Down Expand Up @@ -283,6 +285,11 @@ def __init__(
precon_damping_mult: Scalar. Multiplies the damping used in the
preconditioner (vs the exact quadratic model) by this value.
(Default: 1.0)
use_step_rejection: Whether or not to reject the step whenever the loss
on the current batch goes up after the update. This option offers
robustness at the cost of doing more work per step (unless adaptive
damping with Levenberg-Marquardt is used). (Default: ``False``)
reject_damping_increase_factor: TODO
norm_constraint: Scalar. If specified, the update is scaled down so that
its approximate squared Fisher norm ``v^T F v`` is at most the specified
value. (Note that here ``F`` is the approximate curvature matrix, not
Expand Down Expand Up @@ -448,6 +455,9 @@ def __init__(
always_use_exact_qmodel_for_damping_adjustment)
self._precon_damping_mult = precon_damping_mult

self._use_step_rejection = use_step_rejection
self._reject_damping_increase_factor = reject_damping_increase_factor

self._norm_constraint = norm_constraint
self._num_burnin_steps = num_burnin_steps
self._curvature_ema = curvature_ema
Expand Down Expand Up @@ -1163,12 +1173,11 @@ def _step(
damping=damping,
func_args=func_args)

# Compute delta and update velocities
# Compute the parameter update (delta)
delta = self.weighted_sum_of_objects(vectors, coefficients)
state.velocities = delta

# Update parameters
params = jax.tree_util.tree_map(jnp.add, params, delta)
new_params = jax.tree_util.tree_map(jnp.add, params, delta)

# Optionally compute the reduction ratio and update the damping
if self._use_adaptive_damping:
Expand All @@ -1179,12 +1188,41 @@ def _step(
lambda args: (args[0], self._invalid_metric_value,
self._invalid_metric_value),
operand=(state.damping, loss, quad_model_change,
(params,) + func_args[1:])
(new_params,) + func_args[1:])
)

new_loss_is_valid = self.should_update_damping(state)

else:
# If not adjusting the damping we don't compute these here and just set
# them to self._invalid_metric_value.
new_loss, rho = self._invalid_metric_value, self._invalid_metric_value
new_loss_is_valid = False

if self._use_step_rejection:

new_loss = lax.cond(
new_loss_is_valid, # static eval when possible?
lambda: self.compute_loss_value((new_params,) + func_args[1:],
state=state),
lambda: new_loss,
)

# Sync (possibly redundant)
new_loss = utils.pmean_if_pmap(new_loss, self.pmap_axis_name)

reject_step = jnp.logical_or(jnp.isnan(new_loss), new_loss > loss)

params, state.velocities, state.damping = lax.cond(
reject_step,
lambda: (params, state.velocities, state.damping),
lambda: (new_params, delta,
self._reject_damping_increase_factor * state.damping))

else:
# stop the linter from complaining about uninitialized variable
reject_step = False
params, state.velocities = new_params, delta

# Compute per-device and total batch size
batch_size = self._batch_size_extractor(func_args[-1])
Expand Down Expand Up @@ -1217,6 +1255,9 @@ def _step(
scaled_grad_norm_sq=scaled_grad_norm_sq,
)

if self._use_step_rejection:
stats["step_rejected"] = reject_step

if aux is not None:
aux = utils.pmean_if_pmap(aux, self.pmap_axis_name)
stats["aux"] = aux
Expand Down

0 comments on commit 486c2cd

Please sign in to comment.