Skip to content

Commit

Permalink
support multiple evaluation sets in lora_lm
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 13, 2024
1 parent 25eb6d2 commit bbf7462
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/levanter/main/lora_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,19 @@ def compute_loss(model, example: LmExample, key=None):
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")

# data loaders
eval_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore
eval_datasets = config.data.validation_sets(Pos.size)

train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos)
train_loader = trainer.sharded_loader(train_dataset, Batch)

# boilerplate hooks and such
trainer.add_eval_hook(eval_dataset)
if len(eval_datasets) == 0:
logger.warning("No evaluation datasets provided.")

for name, eval_dataset in eval_datasets.items():
eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id)
trainer.add_eval_hook(eval_dataset, name=name)

trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1)
if config.peft_save_path is not None:
full_save_path = os.path.join(config.peft_save_path, trainer.run_id)
Expand Down

0 comments on commit bbf7462

Please sign in to comment.