Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Feb 16, 2024
1 parent 09a1f9f commit 347cc1f
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def main(config: TrainLmConfig):
parameter_axis_mapping = config.trainer.parameter_axis_mapping

def compute_loss(model: LmHeadModel, example: LmExample, key=None):
return model.compute_loss(example, key=key).scalar()
if key is None:
return model.compute_loss(example, key=None).scalar()
x, y = model.compute_loss(example, key=key)
return x.scalar(), y.scalar()

optimizer = config.optimizer.build(config.trainer.num_train_steps)

Expand Down Expand Up @@ -192,4 +195,4 @@ def compute_log_probs(model, example: LmExample):


if __name__ == "__main__":
levanter.config.main(main)()
levanter.config.main(main)()

0 comments on commit 347cc1f

Please sign in to comment.