From f5533d678dd021b344aaf00d2767f2e080ccf060 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 29 Oct 2024 16:31:32 -0700 Subject: [PATCH] add back option for hf models on sft --- examples/sft/alpaca-llama-sft.yaml | 8 ++--- examples/sft/sft.py | 56 +++++++++++++++++++----------- 2 files changed, 40 insertions(+), 24 deletions(-) diff --git a/examples/sft/alpaca-llama-sft.yaml b/examples/sft/alpaca-llama-sft.yaml index 72c9aad78..8f1c408b4 100644 --- a/examples/sft/alpaca-llama-sft.yaml +++ b/examples/sft/alpaca-llama-sft.yaml @@ -12,7 +12,7 @@ model: use_bias: false use_layer_norm_weight: false -# Training configuration +# Training configuration trainer: mp: p=f32,c=bfloat16 tracker: @@ -21,7 +21,7 @@ trainer: tags: ["llama", "sft"] num_train_steps: 1218 train_batch_size: 64 - tensor_parallel_axes: ["mlp", "heads"] + tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" batch_axis: "batch" steps_per_eval: 1000 @@ -33,7 +33,7 @@ optimizer: min_lr_ratio: 0.1 warmup: 100 -# Supervised data configuration +# Supervised data configuration supervised_data: cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo" input_field: "instruction" @@ -49,4 +49,4 @@ tokenizer: "allenai/OLMo-1B" max_tune_length: 2048 epoch: 3 -initialize_from_hf: false \ No newline at end of file +initialize_from_hf: false diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 74d4f6dc9..53638db12 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -5,16 +5,17 @@ import jax.random as jrandom import transformers +import haliax as hax from haliax import Axis from haliax.partitioning import round_axis_for_partitioning import levanter from levanter import callbacks -from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset from levanter.data.text import EpochDataset, mk_supervised_dataset from levanter.main.train_lm import TrainLmConfig -from levanter.models.lm_model import compute_next_token_loss +from levanter.models.lm_model import LmHeadModel, compute_next_token_loss from levanter.trainer import Trainer from levanter.utils.py_utils import non_caching_cycle @@ -33,23 +34,41 @@ class SFTConfig(TrainLmConfig): # inherit most of the config from TrainLmConfig max_tune_length: int = 2048 # maximum length of the input to the model during tuning model_name_or_path: str = "meta-llama/Llama-2-7b-hf" - tokenizer: str = "gpt2" # Tokenizer to use + tokenizer: str = "meta-llama/Llama-2-7b-hf" # Tokenizer to use def train(config: SFTConfig): + tokenizer = transformers.AutoTokenizer.from_pretrained( + config.tokenizer, + model_max_length=config.max_tune_length, + padding_side="right", + trust_remote_code=True, + ) + logger.info(f"Loaded tokenizer {tokenizer}") if config.initialize_from_hf: if config.trainer.initialize_from is not None: raise ValueError("Cannot use both --initialize_from_hf and --initialize_from") - converter = HFCheckpointConverter.from_hf( - config.model_name_or_path, trust_remote_code=config.trust_remote_code - ) + assert isinstance(config.model, HFCompatConfig) + + converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True) + if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab: + logger.warning("The tokenizers appear to be different. You may want to check this.") + if isinstance(config.initialize_from_hf, str): + converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer) + else: + converter = converter.replaced(tokenizer=tokenizer) + + model_config = converter.default_config + else: converter = None levanter.initialize(config) + num_new_tokens = add_special_tokens(tokenizer) + logger.info(f"Added {num_new_tokens} new tokens") # randomness in jax is tightly controlled by "keys" which are the states of the random number generators # this makes deterministic training pretty easy seed = config.trainer.seed @@ -59,16 +78,6 @@ def train(config: SFTConfig): logger.info(f"Overriding data seed with {config.data_seed}") data_key = jrandom.PRNGKey(config.data_seed) - tokenizer = transformers.AutoTokenizer.from_pretrained( - config.tokenizer, - model_max_length=config.max_tune_length, - padding_side="right", - trust_remote_code=True, - ) - logger.info(f"Loaded tokenizer {tokenizer}") - num_new_tokens = add_special_tokens(tokenizer) - logger.info(f"Added {num_new_tokens} new tokens") - # Configure supervised dataset supervised_config = config.supervised_data @@ -105,10 +114,17 @@ def train(config: SFTConfig): # tokens: gpt-2 has 50257, for example. So we round up. vocab_size = len(tokenizer) Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) - if vocab_size != Vocab.size: - logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - - state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + if config.initialize_from_hf: + logger.info(f"Loading pretrained model from {converter.reference_checkpoint}") + model: LmHeadModel = converter.load_pretrained( + model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype + ) + model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model) + state = trainer.initial_state(training_key, model=model) + else: + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) flops_per_token = config.model.flops_per_token(vocab_size) flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None