Skip to content

Commit

Permalink
add support for the muon optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Nov 2, 2024
1 parent d4592a6 commit 338357f
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 2 deletions.
46 changes: 45 additions & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from collections.abc import Callable
import functools
from typing import Any, Optional, Union
from typing import Any, List, Optional, Tuple, Union

import jax.numpy as jnp
from optax._src import base
Expand Down Expand Up @@ -2470,3 +2470,47 @@ def lbfgs(
base_scaling,
linesearch,
)


def muon(
learning_rate: base.ScalarOrSchedule,
newton_schulz_coeffs: Tuple[float, float, float] | List[Tuple[float, float, float]] = (3.4445, -4.7750, 2.0315),
newton_schulz_steps: Optional[int] = 5,
mumentum: float = 0.95,
mu_dtype: Optional[Any] = None,
*,
nesterov: bool = True,
) -> base.GradientTransformation:
r"""Muon: Momentum Orthogonalized by Newton-schulz
Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize
the momentum accumulated by the optimizer. Mathematically, it does steepest descent
under the Schatten-p norm, for some large p. With p=infty, it is equivalent to
Shampoo without accumulation, or steepest descent under the Spectral norm.
References:
Jordan, `Overview of mini-batch gradient descent
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
newton_schulz_coeffs: Coefficients for the Newton-schulz method.
newton_schulz_steps: Number of Newton-schulz iterations.
mumentum: Exponential decay rate to track the first moment of past gradients.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
Returns:
The corresponding `GradientTransformation`.
"""
return combine.chain(
transform.scale_by_muon(
newton_schulz_coeffs,
newton_schulz_steps,
mumentum,
mu_dtype,
nesterov=nesterov,
),
transform.scale_by_learning_rate(learning_rate),
)
72 changes: 71 additions & 1 deletion optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Gradient transformations."""

import functools
from typing import NamedTuple, Optional, Union
from typing import List, NamedTuple, Optional, Tuple, Union

import chex
import jax
Expand Down Expand Up @@ -1758,6 +1758,76 @@ def update_fn(
return base.GradientTransformation(base.init_empty_state, update_fn)


class ScaleByMuonState(NamedTuple):
"""State for the Adam algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates


def scale_by_muon(
newton_schulz_coeffs: Tuple[float, float, float] | List[Tuple[float, float, float]] = (3.4445, -4.7750, 2.0315),
newton_schulz_steps: Optional[int] = 5,
mumentum: float = 0.95,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = True,
) -> base.GradientTransformation:
r"""Rescale updates according to the Muon algorithm.
Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize
the momentum accumulated by the optimizer. Mathematically, it does steepest descent
under the Schatten-p norm, for some large p. With p=infty, it is equivalent to
Shampoo without accumulation, or steepest descent under the Spectral norm.
References:
Jordan, `Overview of mini-batch gradient descent
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
newton_schulz_coeffs: Coefficients for the Newton-schulz method.
newton_schulz_steps: Number of Newton-schulz iterations.
mumentum: Exponential decay rate to track the first moment of past gradients.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
Returns:
A `GradientTransformation` object.
"""
muon_coeffs = jnp.asarray(
newton_schulz_coeffs
if isinstance(newton_schulz_coeffs, list)
else [newton_schulz_coeffs] * newton_schulz_steps
)
muon_iterator = lambda x, abc:(abc[0]*x + abc[1]*(x@x.T)@x + abc[2]*(x@x.T)@(x@x.T)@x, 0)
mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
return ScaleByMuonState(count=jnp.zeros([], jnp.int32), mu=mu)

def update_fn(updates, state, params=None):
del params
mu = otu.tree_update_moment(updates, state.mu, mumentum, 1)
count_inc = numerics.safe_int32_increment(state.count)
if nesterov:
mu_hat = jax.tree.map(
lambda m, g: mumentum * m + (1 - mumentum) * g,
otu.tree_bias_correction(
mu, mumentum, numerics.safe_int32_increment(count_inc)
),
otu.tree_bias_correction(updates, mumentum, count_inc),
)
else:
mu_hat = otu.tree_bias_correction(mu, mumentum, count_inc)
updates = jax.tree.map(lambda x: x / jnp.linalg.norm(x, ord='fro'), mu_hat)
updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs)
mu = otu.tree_cast(mu, mu_dtype)
return updates, ScaleByMuonState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)


### Legacy symbols to be removed. ###


Expand Down

0 comments on commit 338357f

Please sign in to comment.