Skip to content

Commit

Permalink
add back option for hf models on sft
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 29, 2024
1 parent 5f36eb8 commit f5533d6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 24 deletions.
8 changes: 4 additions & 4 deletions examples/sft/alpaca-llama-sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ model:
use_bias: false
use_layer_norm_weight: false

# Training configuration
# Training configuration
trainer:
mp: p=f32,c=bfloat16
tracker:
Expand All @@ -21,7 +21,7 @@ trainer:
tags: ["llama", "sft"]
num_train_steps: 1218
train_batch_size: 64
tensor_parallel_axes: ["mlp", "heads"]
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
steps_per_eval: 1000
Expand All @@ -33,7 +33,7 @@ optimizer:
min_lr_ratio: 0.1
warmup: 100

# Supervised data configuration
# Supervised data configuration
supervised_data:
cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo"
input_field: "instruction"
Expand All @@ -49,4 +49,4 @@ tokenizer: "allenai/OLMo-1B"
max_tune_length: 2048
epoch: 3

initialize_from_hf: false
initialize_from_hf: false
56 changes: 36 additions & 20 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
import jax.random as jrandom
import transformers

import haliax as hax
from haliax import Axis
from haliax.partitioning import round_axis_for_partitioning

import levanter
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import EpochDataset, mk_supervised_dataset
from levanter.main.train_lm import TrainLmConfig
from levanter.models.lm_model import compute_next_token_loss
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.trainer import Trainer
from levanter.utils.py_utils import non_caching_cycle

Expand All @@ -33,23 +34,41 @@ class SFTConfig(TrainLmConfig):
# inherit most of the config from TrainLmConfig
max_tune_length: int = 2048 # maximum length of the input to the model during tuning
model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
tokenizer: str = "gpt2" # Tokenizer to use
tokenizer: str = "meta-llama/Llama-2-7b-hf" # Tokenizer to use


def train(config: SFTConfig):
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer,
model_max_length=config.max_tune_length,
padding_side="right",
trust_remote_code=True,
)
logger.info(f"Loaded tokenizer {tokenizer}")

if config.initialize_from_hf:
if config.trainer.initialize_from is not None:
raise ValueError("Cannot use both --initialize_from_hf and --initialize_from")

converter = HFCheckpointConverter.from_hf(
config.model_name_or_path, trust_remote_code=config.trust_remote_code
)
assert isinstance(config.model, HFCompatConfig)

converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True)
if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab:
logger.warning("The tokenizers appear to be different. You may want to check this.")
if isinstance(config.initialize_from_hf, str):
converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer)
else:
converter = converter.replaced(tokenizer=tokenizer)

model_config = converter.default_config

else:
converter = None

levanter.initialize(config)

num_new_tokens = add_special_tokens(tokenizer)
logger.info(f"Added {num_new_tokens} new tokens")
# randomness in jax is tightly controlled by "keys" which are the states of the random number generators
# this makes deterministic training pretty easy
seed = config.trainer.seed
Expand All @@ -59,16 +78,6 @@ def train(config: SFTConfig):
logger.info(f"Overriding data seed with {config.data_seed}")
data_key = jrandom.PRNGKey(config.data_seed)

tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer,
model_max_length=config.max_tune_length,
padding_side="right",
trust_remote_code=True,
)
logger.info(f"Loaded tokenizer {tokenizer}")
num_new_tokens = add_special_tokens(tokenizer)
logger.info(f"Added {num_new_tokens} new tokens")

# Configure supervised dataset
supervised_config = config.supervised_data

Expand Down Expand Up @@ -105,10 +114,17 @@ def train(config: SFTConfig):
# tokens: gpt-2 has 50257, for example. So we round up.
vocab_size = len(tokenizer)
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
if vocab_size != Vocab.size:
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")

state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))
if config.initialize_from_hf:
logger.info(f"Loading pretrained model from {converter.reference_checkpoint}")
model: LmHeadModel = converter.load_pretrained(
model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype
)
model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model)
state = trainer.initial_state(training_key, model=model)
else:
if vocab_size != Vocab.size:
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")
state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))

flops_per_token = config.model.flops_per_token(vocab_size)
flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
Expand Down

0 comments on commit f5533d6

Please sign in to comment.