Skip to content

Commit

Permalink
Update kron.py
Browse files Browse the repository at this point in the history
  • Loading branch information
evanatyourservice committed Dec 15, 2024
1 parent 9ef0869 commit ed50cce
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/levanter/optim/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def broadcast_qs(_, ps, q, s):
def update_fn(updates: base.Updates, state: dict, params: base.Params = None):
del params
count_inc = safe_int32_increment(state["count"])
key, subkey = jax.random.split(state["key"])

# unbox if haliax style partitioned
scanned_layers_ = scanned_layers
Expand Down Expand Up @@ -801,7 +802,7 @@ def add_dims_to_spec(_, qss, sds):
)

# maybe update preconditioner
def update_preconditioner_fn(key, Qs, grads_in, bal_counter):
def update_preconditioner_fn(rngkey, Qs, grads_in, bal_counter):
with jax.default_matmul_precision(precond_update_precision):
# balance preconditioners about every 100 updates
def balance_Qs(Qs_to_bal):
Expand All @@ -828,8 +829,7 @@ def _balance_Q(Q):
Qs = _safe_sharding_constraint(Qs, Qs_sharding)

# create random vectors
key, subkey = jax.random.split(key)
Vs = _tree_random_like(subkey, grads_in)
Vs = _tree_random_like(rngkey, grads_in)
# apply params sharding to random vectors
if have_params_sharding:
Vs = _safe_sharding_constraint(Vs, partitioned_sharding)
Expand Down Expand Up @@ -882,22 +882,22 @@ def _balance_Q(Q):
new_Qs = _safe_sharding_constraint(new_Qs, Qs_sharding)

new_Qs = otu.tree_cast(new_Qs, precond_dtype)
return key, new_Qs, balance_counter_inc
return new_Qs, balance_counter_inc

def pass_through_fn(key, qs, grads_in, bal_counter):
def pass_through_fn(rngkey, qs, grads_in, bal_counter):
if have_qs_sharding:
qs = _safe_sharding_constraint(qs, Qs_sharding)
return key, qs, bal_counter
return qs, bal_counter

# update preconditioner deterministically
update_counter_inc = safe_int32_increment(state["update_counter"])
do_update = update_counter_inc >= 1 / update_prob_in
update_counter_inc = jnp.where(do_update, 0, update_counter_inc)
key, Qs, balance_counter_inc = jax.lax.cond(
Qs, balance_counter_inc = jax.lax.cond(
do_update,
update_preconditioner_fn,
pass_through_fn,
state["key"],
subkey,
Qs,
momentum_updates,
state["balance_counter"],
Expand Down

0 comments on commit ed50cce

Please sign in to comment.