Skip to content

Commit

Permalink
add support for adaptive muon; align interface with other optimizers;…
Browse files Browse the repository at this point in the history
… optimize muon iterator
  • Loading branch information
leloykun committed Dec 31, 2024
1 parent 86c6ba9 commit 32be237
Showing 1 changed file with 41 additions and 26 deletions.
67 changes: 41 additions & 26 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ class MuonState(NamedTuple):

def scale_by_muon(
newton_schulz_coeffs: Union[
Tuple[float, float, float],
List[Tuple[float, float, float]],
Tuple[float, float, float],
Tuple[Tuple[float, float, float], ...],
] = (3.4445, -4.7750, 2.0315),
newton_schulz_steps: Optional[int] = 5,
mumentum: float = 0.95,
newton_schulz_steps: int = 5,
beta: float = 0.95,
eps: float = 1e-8,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = True,
adaptive: bool = False,
) -> base.GradientTransformation:
r"""Rescale updates according to the Muon algorithm.
Expand All @@ -55,49 +56,59 @@ def scale_by_muon(
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
newton_schulz_coeffs: Coefficients for the Newton-schulz method.
newton_schulz_steps: Number of Newton-schulz iterations.
mumentum: Exponential decay rate to track the first moment of past
gradients.
Ignored if `newton_schulz_coeffs` is a tuple of tuples.
beta: Decay rate for the exponentially weighted average of grads.
eps: Term added to the denominator to improve numerical stability.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
adaptive: Whether to scale the updates by the dual norm of the
original updates. See https://arxiv.org/abs/2409.20325
Returns:
A `GradientTransformation` object.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
muon_coeffs = jnp.asarray(
newton_schulz_coeffs
if isinstance(newton_schulz_coeffs, list)
else [newton_schulz_coeffs] * newton_schulz_steps
)
muon_iterator = (
lambda x, abc: (abc[0]*x + abc[1]*(x@x.T)@x + abc[2]*(x@x.T)@(x@x.T)@x, 0)
if isinstance(newton_schulz_coeffs[0], (tuple, list))
else [newton_schulz_coeffs] * newton_schulz_steps,
dtype=mu_dtype,
)
mu_dtype = utils.canonicalize_dtype(mu_dtype)
if muon_coeffs.ndim > 2 or muon_coeffs.shape[-1] != 3:
raise ValueError(
f"newton_schulz_coeffs must have shape (3,) or (n, 3), got {muon_coeffs.shape}"
)

def muon_iterator(x: jnp.ndarray, coeffs: jnp.ndarray):
A = x @ x.T
B = coeffs[1] * A + coeffs[2] * A @ A
return coeffs[0] * x + B @ x, None

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
return MuonState(count=jnp.zeros([], jnp.int32), mu=mu)

def update_fn(updates, state, params=None):
del params
mu = otu.tree_update_moment(updates, state.mu, mumentum, 1)
mu = otu.tree_update_moment(updates, state.mu, beta, 1)
count_inc = numerics.safe_int32_increment(state.count)
if nesterov:
mu_hat = jax.tree.map(
lambda m, g: mumentum * m + (1 - mumentum) * g,
lambda m, g: beta * m + (1 - beta) * g,
otu.tree_bias_correction(
mu, mumentum, numerics.safe_int32_increment(count_inc)
mu, beta, numerics.safe_int32_increment(count_inc)
),
otu.tree_bias_correction(updates, mumentum, count_inc),
otu.tree_bias_correction(updates, beta, count_inc),
)
else:
mu_hat = otu.tree_bias_correction(mu, mumentum, count_inc)
mu_hat = otu.tree_bias_correction(mu, beta, count_inc)
# Ensure that the spectral norm of the updates is at most 1.
updates = jax.tree.map(lambda x: x / (jnp.linalg.norm(x) + eps), mu_hat)
# Apply Newton-schulz orthogonalization.
updates, _ = jax.lax.scan(muon_iterator, updates, muon_coeffs)
# Scale the orthogonalized updates by the dual norm of the original updates.
updates = jax.tree.map(lambda x, y: jnp.linalg.norm(x.T @ y) * y, mu_hat, updates)
if adaptive:
# Scale the orthogonalized updates by the dual norm of the original updates.
updates = jax.tree.map(lambda x, y: jnp.linalg.norm(x.T @ y) * y, mu_hat, updates)
mu = otu.tree_cast(mu, mu_dtype)
return updates, MuonState(count=count_inc, mu=mu)
return base.GradientTransformation(init_fn, update_fn)
Expand All @@ -106,15 +117,16 @@ def update_fn(updates, state, params=None):
def muon(
learning_rate: base.ScalarOrSchedule,
newton_schulz_coeffs: Union[
Tuple[float, float, float],
List[Tuple[float, float, float]],
Tuple[float, float, float],
Tuple[Tuple[float, float, float], ...],
] = (3.4445, -4.7750, 2.0315),
newton_schulz_steps: Optional[int] = 5,
mumentum: float = 0.95,
newton_schulz_steps: int = 5,
beta: float = 0.95,
eps: float = 1e-8,
mu_dtype: Optional[Any] = None,
*,
nesterov: bool = True,
adaptive: bool = False,
) -> base.GradientTransformation:
r"""Muon: Momentum Orthogonalized by Newton-schulz
Expand All @@ -133,11 +145,13 @@ def muon(
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
newton_schulz_coeffs: Coefficients for the Newton-schulz method.
newton_schulz_steps: Number of Newton-schulz iterations.
mumentum: Exponential decay rate to track the first moment of past
gradients.
Ignored if `newton_schulz_coeffs` is a tuple of tuples.
beta: Decay rate for the exponentially weighted average of grads.
eps: Term added to the denominator to improve numerical stability.
mu_dtype: Data type of the momentum accumulator.
nesterov: Whether to use Nesterov momentum.
adaptive: Whether to scale the updates by the dual norm of the
original updates. See https://arxiv.org/abs/2409.20325
Returns:
The corresponding `GradientTransformation`.
Expand All @@ -146,10 +160,11 @@ def muon(
scale_by_muon(
newton_schulz_coeffs,
newton_schulz_steps,
mumentum,
beta,
eps,
mu_dtype,
nesterov=nesterov,
adaptive=adaptive,
),
transform.scale_by_learning_rate(learning_rate),
)

0 comments on commit 32be237

Please sign in to comment.