Skip to content

Commit

Permalink
bugfix arg scheduling-related issue
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Jan 2, 2025
1 parent 2541e45 commit 0b42fba
Showing 1 changed file with 18 additions and 26 deletions.
44 changes: 18 additions & 26 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

def orthogonalize_via_newton_schulz(
x: jnp.ndarray,
coeffs: jnp.ndarray,
ns_dtype: Optional[chex.ArrayDType] = None,
ns_coeffs: jnp.ndarray,
ns_steps: int = 5,
eps: float = 1e-8,
) -> jnp.ndarray:
r"""Orthogonalize via Newton-Schulz iteration.
Expand All @@ -40,23 +40,22 @@ def orthogonalize_via_newton_schulz(
Args:
x: A matrix to orthogonalize.
coeffs: Coefficients for the Newton-schulz iterators.
ns_coeffs: Coefficients for the Newton-schulz iterators.
Must have shape (n, 3) where n is the number of iterations.
ns_dtype: Data type to do the Newton-Schulz iteration in.
If None, the data type of x is used.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a 2D array.
eps: Term added to denominators to improve numerical stability.
Returns:
The orthogonalized matrix.
"""
if x.ndim != 2:
raise ValueError(f'Input must have shape (m, n), got {x.shape}')
if coeffs.ndim != 2 or coeffs.shape[-1] != 3:
if ns_coeffs.ndim > 2 or ns_coeffs.shape[-1] != 3:
raise ValueError(
f'Newton-Schulz coefficients must have shape (n, 3), got {coeffs.shape}'
'Newton-Schulz coefficients must have shape (3,) or (n, 3), '
f'got {ns_coeffs.shape}'
)
if ns_dtype is None:
ns_dtype = x.dtype
def newton_schulz_iterator(x: jnp.ndarray, coeffs: jnp.ndarray):
a = x @ x.T
b = coeffs[1] * a + coeffs[2] * a @ a
Expand All @@ -66,11 +65,15 @@ def newton_schulz_iterator(x: jnp.ndarray, coeffs: jnp.ndarray):
x = x.T
transposed = True
x /= jnp.linalg.norm(x) + eps # Ensure spectral norm is at most 1
x, _ = jax.lax.scan(
lambda x, abc: (newton_schulz_iterator(x, abc), None),
x.astype(ns_dtype),
coeffs.astype(ns_dtype),
)
ns_coeffs = ns_coeffs.astype(x.dtype)
if ns_coeffs.ndim == 1:
x = jax.lax.fori_loop(
0, ns_steps, lambda _, x: newton_schulz_iterator(x, ns_coeffs), x
)
else:
x, _ = jax.lax.scan(
lambda x, abc: (newton_schulz_iterator(x, abc), None), x, ns_coeffs
)
if transposed:
x = x.T
return x
Expand All @@ -88,7 +91,6 @@ def scale_by_muon(
Tuple[Tuple[float, float, float], ...],
] = (3.4445, -4.7750, 2.0315),
ns_steps: int = 5,
ns_dtype: Optional[chex.ArrayDType] = None,
beta: float = 0.95,
eps: float = 1e-8,
mu_dtype: Optional[chex.ArrayDType] = None,
Expand All @@ -109,13 +111,9 @@ def scale_by_muon(
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
ns_coeffs: Coefficients for the Newton-schulz method.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a tuple of tuples.
ns_dtype: Data type to do the Newton-Schulz iteration in.
If None, the data type of x is used.
beta: Decay rate for the exponentially weighted average of grads.
eps: Term added to denominators to improve numerical stability.
mu_dtype: Data type of the momentum accumulator.
Expand All @@ -132,8 +130,6 @@ def scale_by_muon(
raise ValueError(
f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}'
)
if ns_coeffs_.ndim == 1:
ns_coeffs_ = jnp.tile(ns_coeffs_, (ns_steps, 1))

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
Expand All @@ -153,7 +149,7 @@ def update_fn(updates, state, params=None):
mu_hat = otu.tree_bias_correction(mu, beta, count_inc)
# Apply Newton-schulz orthogonalization.
updates = jax.tree.map(
lambda x: orthogonalize_via_newton_schulz(x, ns_coeffs_, ns_dtype, eps),
lambda x: orthogonalize_via_newton_schulz(x, ns_coeffs_, ns_steps, eps),
updates,
)
if adaptive:
Expand All @@ -174,7 +170,6 @@ def muon(
Tuple[Tuple[float, float, float], ...],
] = (3.4445, -4.7750, 2.0315),
ns_steps: int = 5,
ns_dtype: Optional[chex.ArrayDType] = None,
beta: float = 0.95,
eps: float = 1e-8,
mu_dtype: Optional[Any] = None,
Expand Down Expand Up @@ -208,8 +203,6 @@ def muon(
ns_coeffs: Coefficients for the Newton-schulz method.
ns_steps: Number of Newton-schulz iterations.
Ignored if `ns_coeffs` is a tuple of tuples.
ns_dtype: Data type to do the Newton-Schulz iteration in.
If None, the data type of x is used.
beta: Decay rate for the exponentially weighted average of grads.
eps: Term added to the denominator to improve numerical stability.
mu_dtype: Data type of the momentum accumulator.
Expand All @@ -230,7 +223,6 @@ def muon(
scale_by_muon(
ns_coeffs=ns_coeffs,
ns_steps=ns_steps,
ns_dtype=ns_dtype,
beta=beta,
eps=eps,
mu_dtype=mu_dtype,
Expand Down

0 comments on commit 0b42fba

Please sign in to comment.