diff --git a/src/levanter/optim/kron.py b/src/levanter/optim/kron.py index c7de998ff..bf5662d97 100644 --- a/src/levanter/optim/kron.py +++ b/src/levanter/optim/kron.py @@ -896,7 +896,7 @@ def pass_through_fn(key, qs, grads_in, bal_counter): do_update, update_preconditioner_fn, pass_through_fn, - key, + state["key"], Qs, momentum_updates, state["balance_counter"],