diff --git a/src/levanter/optim/kron.py b/src/levanter/optim/kron.py index 6cbf583bf..2de182052 100644 --- a/src/levanter/optim/kron.py +++ b/src/levanter/optim/kron.py @@ -527,6 +527,7 @@ def broadcast_qs(_, ps, q, s): def update_fn(updates: base.Updates, state: dict, params: base.Params = None): del params count_inc = safe_int32_increment(state["count"]) + key, subkey = jax.random.split(state["key"]) # unbox if haliax style partitioned scanned_layers_ = scanned_layers @@ -801,7 +802,7 @@ def add_dims_to_spec(_, qss, sds): ) # maybe update preconditioner - def update_preconditioner_fn(key, Qs, grads_in, bal_counter): + def update_preconditioner_fn(rngkey, Qs, grads_in, bal_counter): with jax.default_matmul_precision(precond_update_precision): # balance preconditioners about every 100 updates def balance_Qs(Qs_to_bal): @@ -828,8 +829,7 @@ def _balance_Q(Q): Qs = _safe_sharding_constraint(Qs, Qs_sharding) # create random vectors - key, subkey = jax.random.split(key) - Vs = _tree_random_like(subkey, grads_in) + Vs = _tree_random_like(rngkey, grads_in) # apply params sharding to random vectors if have_params_sharding: Vs = _safe_sharding_constraint(Vs, partitioned_sharding) @@ -882,22 +882,22 @@ def _balance_Q(Q): new_Qs = _safe_sharding_constraint(new_Qs, Qs_sharding) new_Qs = otu.tree_cast(new_Qs, precond_dtype) - return key, new_Qs, balance_counter_inc + return new_Qs, balance_counter_inc - def pass_through_fn(key, qs, grads_in, bal_counter): + def pass_through_fn(rngkey, qs, grads_in, bal_counter): if have_qs_sharding: qs = _safe_sharding_constraint(qs, Qs_sharding) - return key, qs, bal_counter + return qs, bal_counter # update preconditioner deterministically update_counter_inc = safe_int32_increment(state["update_counter"]) do_update = update_counter_inc >= 1 / update_prob_in update_counter_inc = jnp.where(do_update, 0, update_counter_inc) - key, Qs, balance_counter_inc = jax.lax.cond( + Qs, balance_counter_inc = jax.lax.cond( do_update, update_preconditioner_fn, pass_through_fn, - state["key"], + subkey, Qs, momentum_updates, state["balance_counter"],