Skip to content

Commit

Permalink
add grad scaling by dual norm
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Dec 31, 2024
1 parent ca6eeda commit ff3279e
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def scale_by_muon(
] = (3.4445, -4.7750, 2.0315),
newton_schulz_steps: Optional[int] = 5,
mumentum: float = 0.95,
eps: float = 1e-8,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = True,
Expand All @@ -56,6 +57,7 @@ def scale_by_muon(
newton_schulz_steps: Number of Newton-schulz iterations.
mumentum: Exponential decay rate to track the first moment of past
gradients.
eps: Term added to the denominator to improve numerical stability.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
Expand Down Expand Up @@ -90,15 +92,12 @@ def update_fn(updates, state, params=None):
)
else:
mu_hat = otu.tree_bias_correction(mu, mumentum, count_inc)
updates = jax.tree.map(
lambda x: (
x / jnp.linalg.norm(x, ord='fro')
if len(x.shape) > 1
else x / jnp.linalg.norm(x, ord=2)
),
mu_hat,
)
# Ensure that the spectral norm of the updates is at most 1.
updates = jax.tree.map(lambda x: x / (jnp.linalg.norm(x) + eps), mu_hat)
# Apply Newton-schulz orthogonalization.
updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs)
# Scale the orthogonalized updates by the dual norm of the original updates.
updates = jax.tree.map(lambda x, y: jnp.linalg.norm(x.T @ y) * y, mu_hat, updates)
mu = otu.tree_cast(mu, mu_dtype)
return updates, MuonState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)
Expand Down

0 comments on commit ff3279e

Please sign in to comment.