Skip to content

Commit

Permalink
minimize use of optax internals
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 17, 2024
1 parent 1edeeef commit 6148381
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
14 changes: 9 additions & 5 deletions src/levanter/optim/second_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import jax
import optax
from jax import numpy as jnp
from optax._src import numerics
from optax._src.schedule import InjectHyperparamsState, _convert_floats
from optax import InjectHyperparamsState


class HessianUpdateFn(typing.Protocol):
Expand Down Expand Up @@ -189,10 +188,8 @@ def update_fn(updates, state, params=None):
hparams = {k: _convert_floats(v, dtype) for k, v in state.hyperparams.items()}
hparams.update(schedule_fn(state.count, dtype))
updates, inner_state = inner_factory(**other_hps, **hparams).update(updates, state.inner_state, params)
count_inc = numerics.safe_int32_increment(state.count)

# pylint:disable=too-many-function-args
return updates, InjectHyperparamsState(count_inc, hparams, inner_state)
return updates, InjectHyperparamsState(state.count + 1, hparams, inner_state)
# pylint:enable=too-many-function-args

def _find_first_floating_dtype(updates):
Expand Down Expand Up @@ -226,3 +223,10 @@ def update_hessian(state, fn, model, *batch, **batch_kwargs):
return SecondOrderTransformation(init_fn, update_fn, update_hessian)

return wrapped_transform


def _convert_floats(x, dtype):
"""Convert float-like inputs to dtype, rest pass through."""
if jax.dtypes.scalar_type_of(x) == float:
return jnp.asarray(x, dtype=dtype)
return x
33 changes: 24 additions & 9 deletions src/levanter/optim/sophia.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import functools
import typing
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional, TypeVar, runtime_checkable
Expand All @@ -11,10 +12,6 @@
from jax.random import PRNGKey
from jaxtyping import PRNGKeyArray

# TODO: remove dependency on _src internals
from optax._src import numerics
from optax._src.transform import bias_correction, update_moment

import levanter.tracker
from levanter.optim.config import HessianOptConfig, OptimizerConfig
from levanter.optim.second_order import SecondOrderTransformation, chain_second_order, inject_hyperparams
Expand Down Expand Up @@ -294,8 +291,7 @@ def init_fn(params):
def update_fn(updates, state, params=None):
mu = update_moment(updates, state.mu, b1, 1)
# nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = numerics.safe_int32_increment(state.count)
mu_hat = bias_correction(mu, b1, count_inc)
mu_hat = bias_correction(mu, b1, state.count + 1)
h_hat = state.h
# track how often hessian is used
mu_leaves = jax.tree_util.tree_leaves(mu_hat)
Expand Down Expand Up @@ -328,7 +324,7 @@ def update_fn(updates, state, params=None):
mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu)

return updates, ScaleBySophiaState(
count=count_inc, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key
count=state.count + 1, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key
)

def update_hessian(state, fn, model, *batch, **batch_kwargs):
Expand All @@ -338,10 +334,9 @@ def _do_update():
# new_hess = jax.tree_util.tree_map(lambda h: jnp.clip(h, -1, 1), new_hess)

# EMAs of hessian
hessian_count_inc = numerics.safe_int32_increment(state.hessian_count)
nu = update_moment(new_hess, state.h, b2, 1)
return ScaleBySophiaState(
count=state.count, hessian_count=hessian_count_inc, mu=state.mu, h=nu, hess_key=next_key
count=state.count, hessian_count=state.hessian_count + 1, mu=state.mu, h=nu, hess_key=next_key
)

def _dont_update():
Expand Down Expand Up @@ -410,3 +405,23 @@ def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs):
hessian = jax.tree_util.tree_map(lambda grad, gaussian: grad * gaussian, product, g)

return hessian


# Cribbed from optax._src.transform
def update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order`-th moment."""
return jax.tree_util.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


@functools.partial(jax.jit, inline=True)
def bias_correction(moment, decay, count):
"""Performs bias correction. It becomes a no-op as count goes to infinity."""
# The conversion to the data type of the moment ensures that bfloat16 remains
# bfloat16 in the optimizer state. This conversion has to be done after
# `bias_correction_` is calculated as calculating `decay**count` in low
# precision can result in it being rounded to 1 and subsequently a
# "division by zero" error.
bias_correction_ = 1 - decay**count

# Perform division in the original precision.
return jax.tree_util.tree_map(lambda t: t / bias_correction_.astype(t.dtype), moment)
10 changes: 7 additions & 3 deletions src/levanter/tracker/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Optional

from git import InvalidGitRepositoryError, NoSuchPathError, Repo
from optax._src.wrappers import MultiStepsState

import levanter.tracker
from levanter.utils.jax_utils import jnp_to_python
Expand All @@ -14,8 +13,13 @@


def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None):
if isinstance(opt_state, MultiStepsState):
opt_state = opt_state.inner_opt_state
try:
from optax._src.wrappers import MultiStepsState

if isinstance(opt_state, MultiStepsState):
opt_state = opt_state.inner_opt_state
except ImportError:
pass

def wrap_key(key):
if prefix:
Expand Down

0 comments on commit 6148381

Please sign in to comment.