diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 8a118f549..9c8f2d2f2 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -358,7 +358,6 @@ def init_state(partial_state, model_init, training_key, is_trainable): return eqx.combine(partial_state, fresh_state) state = init_state(state, model_init, training_key, self.is_trainable_param) - print(state.step.sharding) return state @@ -385,8 +384,6 @@ def training_steps( with capture_time() as loading_time: example = next(iter_data) - levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) - info = self.train_step(state, example) state = info.state @@ -394,7 +391,9 @@ def training_steps( with capture_time() as hook_time: self.run_hooks(info) - levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=info.step) + + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=info.step) yield info