diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 982f72358..ab6fbd63c 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -95,7 +95,10 @@ def main(config: TrainLmConfig): parameter_axis_mapping = config.trainer.parameter_axis_mapping def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + if key is None: + return model.compute_loss(example, key=None).scalar() + x, y = model.compute_loss(example, key=key) + return x.scalar(), y.scalar() optimizer = config.optimizer.build(config.trainer.num_train_steps) @@ -192,4 +195,4 @@ def compute_log_probs(model, example: LmExample): if __name__ == "__main__": - levanter.config.main(main)() + levanter.config.main(main)() \ No newline at end of file