From 5d5c30ff8f428dab20169b3d47689d910de2115b Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 9 Feb 2024 14:46:40 -0800 Subject: [PATCH] missed a few spots --- examples/alpaca/alpaca.py | 2 +- examples/gsm8k-lora/gsm8k_lora.py | 91 ++++++++++++++++--------------- src/levanter/__init__.py | 1 + src/levanter/main/lora_lm.py | 13 +++-- 4 files changed, 58 insertions(+), 49 deletions(-) diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 36a6dd943..a20f357fe 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -194,7 +194,7 @@ def get_prompts(prompt_path) -> dict: def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 5e4927d2f..febfd2013 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -9,7 +9,6 @@ import jax.random as jrandom import numpy as np import transformers -import wandb import haliax as hax @@ -127,7 +126,7 @@ def format_output(ex): def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -169,53 +168,57 @@ def loraize_hf_model(model): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - # end major difference from Alpaca - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + state = trainer.initial_state(training_key, model=model) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } ) - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index a7def0acb..548a113a0 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -5,6 +5,7 @@ import levanter.logging as logging import levanter.models as models import levanter.optim as optim +import levanter.tracker as tracker import levanter.trainer as trainer import levanter.visualization as visualization from levanter.trainer import initialize diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 93d60588a..babe7d2fa 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -4,7 +4,6 @@ from typing import Optional import jax.random as jrandom -import wandb import haliax.random @@ -47,6 +46,7 @@ class LoraLmConfig: def main(config: LoraLmConfig): + levanter.initialize(config) tokenizer = config.data.the_tokenizer converter = HFCheckpointConverter.from_hf(config.initialize_from_hf, trust_remote_code=config.trust_remote_code) @@ -55,7 +55,6 @@ def main(config: LoraLmConfig): converter = converter.replaced(tokenizer=tokenizer) - config.trainer.initialize(config) model_config = converter.default_config # randomness in jax is tightly controlled by "keys" which are the states of the random number generators @@ -96,8 +95,14 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")