diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 6b845b516..c92b61d6c 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -111,13 +111,19 @@ def compute_loss(model, example: LmExample, key=None): logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") # data loaders - eval_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore + eval_datasets = config.data.validation_sets(Pos.size) train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) train_loader = trainer.sharded_loader(train_dataset, Batch) # boilerplate hooks and such - trainer.add_eval_hook(eval_dataset) + if len(eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + + for name, eval_dataset in eval_datasets.items(): + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) + trainer.add_eval_hook(eval_dataset, name=name) + trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.peft_save_path is not None: full_save_path = os.path.join(config.peft_save_path, trainer.run_id)