Skip to content

Commit

Permalink
add tests for 2d tensor & mixed dim tensor params; optimize dual norm…
Browse files Browse the repository at this point in the history
… calculation
  • Loading branch information
leloykun committed Jan 2, 2025
1 parent 0b42fba commit 049c954
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
41 changes: 39 additions & 2 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
{'opt_name': 'dowg', 'opt_kwargs': {'learning_rate': 1.0}},
{'opt_name': 'momo', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'momo_adam', 'opt_kwargs': {'learning_rate': 1e-1}},
{'opt_name': 'muon', 'opt_kwargs': {'learning_rate': 1e-3}},
{'opt_name': 'muon', 'opt_kwargs': {'learning_rate': 1e-2}},
{'opt_name': 'prodigy', 'opt_kwargs': {'learning_rate': 1e-1}},
{
'opt_name': 'schedule_free_sgd',
Expand Down Expand Up @@ -178,11 +178,48 @@ def obj_fn(params):
return initial_params, final_params, obj_fn


def _setup_matrix_parabola(dtype):
"""Quadratic function as an optimization target with 2D tensor parameters."""
initial_params = jnp.zeros((2, 2), dtype=dtype)
final_params = jnp.array([[3.0, -2.0], [1.0, 4.0]], dtype=dtype)

def obj_fn(params):
return jnp.sum(numerics.abs_sq(params - final_params))

return initial_params, final_params, obj_fn


def _setup_mixed_tensor_target(dtype):
"""Optimization target combining 1D and 2D tensor parameters."""
initial_1d_params = jnp.zeros((3,), dtype=dtype)
final_1d_params = jnp.array([1.0, -1.0, 2.0], dtype=dtype)

initial_2d_params = jnp.zeros((2, 2), dtype=dtype)
final_2d_params = jnp.array([[1.0, 0.0], [-1.0, 1.0]], dtype=dtype)

def obj_fn(params):
"""Objective function combining 1D and 2D parameters."""
params_1d, params_2d = params
loss_1d = jnp.sum(numerics.abs_sq(params_1d - final_1d_params))
loss_2d = jnp.sum(numerics.abs_sq(params_2d - final_2d_params))
return loss_1d + loss_2d

initial_params = (initial_1d_params, initial_2d_params)
final_params = (final_1d_params, final_2d_params)

return initial_params, final_params, obj_fn


class ContribTest(chex.TestCase):

@parameterized.product(
_ALL_OPTIMIZERS_UNDER_TEST,
target=(_setup_parabola, _setup_rosenbrock),
target=(
_setup_parabola,
_setup_rosenbrock,
_setup_matrix_parabola,
_setup_mixed_tensor_target,
),
dtype=('float32',),
)
def test_optimizers(
Expand Down
24 changes: 15 additions & 9 deletions optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def scale_by_muon(
p=infty, it is equivalent to Shampoo without accumulation, or steepest
descent under the Spectral norm.
References:
Jordan, `modded-nanogpt: Speedrunning the NanoGPT baseline
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Args:
ns_coeffs: Coefficients for the Newton-schulz method.
ns_steps: Number of Newton-schulz iterations.
Expand All @@ -123,6 +119,13 @@ def scale_by_muon(
Returns:
A `GradientTransformation` object.
References:
Jordan, `modded-nanogpt: Speedrunning the NanoGPT baseline
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Bernstein et al., `Old Optimizer, New Norm: An Anthology
https://arxiv.org/abs/2409.20325`_, 2024
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)
ns_coeffs_ = jnp.asarray(ns_coeffs)
Expand Down Expand Up @@ -156,7 +159,7 @@ def update_fn(updates, state, params=None):
# Scale the orthogonalized updates by the dual norm of the original
# updates. See https://arxiv.org/abs/2409.20325 for the derivation.
updates = jax.tree.map(
lambda x, y: jnp.linalg.trace(x.T @ y) * y, mu_hat, updates
lambda x, y: jnp.einsum('ij,ij,ab->ab', x, y, y), mu_hat, updates
)
mu = otu.tree_cast(mu, mu_dtype)
return updates, MuonState(count=count_inc, mu=mu)
Expand Down Expand Up @@ -193,10 +196,6 @@ def muon(
This is because the Newton-Schulz iterator expects a matrix as input.
The non-2D parameters are instead passed through an Adam optimizer.
References:
Jordan, `modded-nanogpt: Speedrunning the NanoGPT baseline
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
Expand All @@ -216,6 +215,13 @@ def muon(
Returns:
The corresponding `GradientTransformation`.
References:
Jordan, `modded-nanogpt: Speedrunning the NanoGPT baseline
https://github.com/KellerJordan/modded-nanogpt`_, 2024
Bernstein et al., `Old Optimizer, New Norm: An Anthology
https://arxiv.org/abs/2409.20325`_, 2024
"""
return combine.multi_transform(
transforms={
Expand Down

0 comments on commit 049c954

Please sign in to comment.