From 347cc1fac186d8144edcbfe9c026983010955138 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Fri, 16 Feb 2024 10:32:40 -0800 Subject: [PATCH] fix --- src/levanter/main/train_lm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 982f72358..ab6fbd63c 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -95,7 +95,10 @@ def main(config: TrainLmConfig): parameter_axis_mapping = config.trainer.parameter_axis_mapping def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() + if key is None: + return model.compute_loss(example, key=None).scalar() + x, y = model.compute_loss(example, key=key) + return x.scalar(), y.scalar() optimizer = config.optimizer.build(config.trainer.num_train_steps) @@ -192,4 +195,4 @@ def compute_log_probs(model, example: LmExample): if __name__ == "__main__": - levanter.config.main(main)() + levanter.config.main(main)() \ No newline at end of file