diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index f227816b1..a93c36ca8 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -33,6 +33,7 @@ def scale_by_muon( ] = (3.4445, -4.7750, 2.0315), newton_schulz_steps: Optional[int] = 5, mumentum: float = 0.95, + eps: float = 1e-8, mu_dtype: Optional[chex.ArrayDType] = None, *, nesterov: bool = True, @@ -56,6 +57,7 @@ def scale_by_muon( newton_schulz_steps: Number of Newton-schulz iterations. mumentum: Exponential decay rate to track the first moment of past gradients. + eps: Term added to the denominator to improve numerical stability. mu_dtype: Data type of the momentum accumulator. nesterov: Whether to use Nesterov momentum. @@ -90,15 +92,12 @@ def update_fn(updates, state, params=None): ) else: mu_hat = otu.tree_bias_correction(mu, mumentum, count_inc) - updates = jax.tree.map( - lambda x: ( - x / jnp.linalg.norm(x, ord='fro') - if len(x.shape) > 1 - else x / jnp.linalg.norm(x, ord=2) - ), - mu_hat, - ) + # Ensure that the spectral norm of the updates is at most 1. + updates = jax.tree.map(lambda x: x / (jnp.linalg.norm(x) + eps), mu_hat) + # Apply Newton-schulz orthogonalization. updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs) + # Scale the orthogonalized updates by the dual norm of the original updates. + updates = jax.tree.map(lambda x, y: jnp.linalg.norm(x.T @ y) * y, mu_hat, updates) mu = otu.tree_cast(mu, mu_dtype) return updates, MuonState(count=count_inc, mu=mu) return base.GradientTransformation(init_fn, update_fn)