diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index eb84fd22..c753f5d9 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -28,15 +28,16 @@ class MuonState(NamedTuple): def scale_by_muon( newton_schulz_coeffs: Union[ - Tuple[float, float, float], - List[Tuple[float, float, float]], + Tuple[float, float, float], + Tuple[Tuple[float, float, float], ...], ] = (3.4445, -4.7750, 2.0315), - newton_schulz_steps: Optional[int] = 5, - mumentum: float = 0.95, + newton_schulz_steps: int = 5, + beta: float = 0.95, eps: float = 1e-8, mu_dtype: Optional[chex.ArrayDType] = None, *, nesterov: bool = True, + adaptive: bool = False, ) -> base.GradientTransformation: r"""Rescale updates according to the Muon algorithm. @@ -55,24 +56,33 @@ def scale_by_muon( iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. newton_schulz_coeffs: Coefficients for the Newton-schulz method. newton_schulz_steps: Number of Newton-schulz iterations. - mumentum: Exponential decay rate to track the first moment of past - gradients. + Ignored if `newton_schulz_coeffs` is a tuple of tuples. + beta: Decay rate for the exponentially weighted average of grads. eps: Term added to the denominator to improve numerical stability. mu_dtype: Data type of the momentum accumulator. nesterov: Whether to use Nesterov momentum. + adaptive: Whether to scale the updates by the dual norm of the + original updates. See https://arxiv.org/abs/2409.20325 Returns: A `GradientTransformation` object. """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) muon_coeffs = jnp.asarray( newton_schulz_coeffs - if isinstance(newton_schulz_coeffs, list) - else [newton_schulz_coeffs] * newton_schulz_steps - ) - muon_iterator = ( - lambda x, abc: (abc[0]*x + abc[1]*(x@x.T)@x + abc[2]*(x@x.T)@(x@x.T)@x, 0) + if isinstance(newton_schulz_coeffs[0], (tuple, list)) + else [newton_schulz_coeffs] * newton_schulz_steps, + dtype=mu_dtype, ) - mu_dtype = utils.canonicalize_dtype(mu_dtype) + if muon_coeffs.ndim > 2 or muon_coeffs.shape[-1] != 3: + raise ValueError( + f"newton_schulz_coeffs must have shape (3,) or (n, 3), got {muon_coeffs.shape}" + ) + + def muon_iterator(x: jnp.ndarray, coeffs: jnp.ndarray): + A = x @ x.T + B = coeffs[1] * A + coeffs[2] * A @ A + return coeffs[0] * x + B @ x, None def init_fn(params): mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment @@ -80,24 +90,25 @@ def init_fn(params): def update_fn(updates, state, params=None): del params - mu = otu.tree_update_moment(updates, state.mu, mumentum, 1) + mu = otu.tree_update_moment(updates, state.mu, beta, 1) count_inc = numerics.safe_int32_increment(state.count) if nesterov: mu_hat = jax.tree.map( - lambda m, g: mumentum * m + (1 - mumentum) * g, + lambda m, g: beta * m + (1 - beta) * g, otu.tree_bias_correction( - mu, mumentum, numerics.safe_int32_increment(count_inc) + mu, beta, numerics.safe_int32_increment(count_inc) ), - otu.tree_bias_correction(updates, mumentum, count_inc), + otu.tree_bias_correction(updates, beta, count_inc), ) else: - mu_hat = otu.tree_bias_correction(mu, mumentum, count_inc) + mu_hat = otu.tree_bias_correction(mu, beta, count_inc) # Ensure that the spectral norm of the updates is at most 1. updates = jax.tree.map(lambda x: x / (jnp.linalg.norm(x) + eps), mu_hat) # Apply Newton-schulz orthogonalization. updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs) - # Scale the orthogonalized updates by the dual norm of the original updates. - updates = jax.tree.map(lambda x, y: jnp.linalg.norm(x.T @ y) * y, mu_hat, updates) + if adaptive: + # Scale the orthogonalized updates by the dual norm of the original updates. + updates = jax.tree.map(lambda x, y: jnp.linalg.norm(x.T @ y) * y, mu_hat, updates) mu = otu.tree_cast(mu, mu_dtype) return updates, MuonState(count=count_inc, mu=mu) return base.GradientTransformation(init_fn, update_fn) @@ -106,15 +117,16 @@ def update_fn(updates, state, params=None): def muon( learning_rate: base.ScalarOrSchedule, newton_schulz_coeffs: Union[ - Tuple[float, float, float], - List[Tuple[float, float, float]], + Tuple[float, float, float], + Tuple[Tuple[float, float, float], ...], ] = (3.4445, -4.7750, 2.0315), - newton_schulz_steps: Optional[int] = 5, - mumentum: float = 0.95, + newton_schulz_steps: int = 5, + beta: float = 0.95, eps: float = 1e-8, mu_dtype: Optional[Any] = None, *, nesterov: bool = True, + adaptive: bool = False, ) -> base.GradientTransformation: r"""Muon: Momentum Orthogonalized by Newton-schulz @@ -133,11 +145,13 @@ def muon( iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. newton_schulz_coeffs: Coefficients for the Newton-schulz method. newton_schulz_steps: Number of Newton-schulz iterations. - mumentum: Exponential decay rate to track the first moment of past - gradients. + Ignored if `newton_schulz_coeffs` is a tuple of tuples. + beta: Decay rate for the exponentially weighted average of grads. eps: Term added to the denominator to improve numerical stability. mu_dtype: Data type of the momentum accumulator. nesterov: Whether to use Nesterov momentum. + adaptive: Whether to scale the updates by the dual norm of the + original updates. See https://arxiv.org/abs/2409.20325 Returns: The corresponding `GradientTransformation`. @@ -146,10 +160,11 @@ def muon( scale_by_muon( newton_schulz_coeffs, newton_schulz_steps, - mumentum, + beta, eps, mu_dtype, nesterov=nesterov, + adaptive=adaptive, ), transform.scale_by_learning_rate(learning_rate), )