Skip to content

Commit

Permalink
sigh
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 13, 2024
1 parent c241da4 commit 2cddfd5
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def init_state(partial_state, model_init, training_key, is_trainable):
return eqx.combine(partial_state, fresh_state)

state = init_state(state, model_init, training_key, self.is_trainable_param)
print(state.step.sharding)

return state

Expand All @@ -385,16 +384,16 @@ def training_steps(
with capture_time() as loading_time:
example = next(iter_data)

levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step)

info = self.train_step(state, example)
state = info.state

if run_hooks:
with capture_time() as hook_time:
self.run_hooks(info)

levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step)
levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=info.step)

levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=info.step)

yield info

Expand Down

0 comments on commit 2cddfd5

Please sign in to comment.