Skip to content

Commit

Permalink
missed a few spots
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 9, 2024
1 parent 474206e commit 5d5c30f
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 49 deletions.
2 changes: 1 addition & 1 deletion examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_prompts(prompt_path) -> dict:


def train(config: TrainArgs):
config.trainer.initialize(config)
levanter.initialize(config)

# Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint.
# This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary.
Expand Down
91 changes: 47 additions & 44 deletions examples/gsm8k-lora/gsm8k_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax.random as jrandom
import numpy as np
import transformers
import wandb

import haliax as hax

Expand Down Expand Up @@ -127,7 +126,7 @@ def format_output(ex):


def train(config: TrainArgs):
config.trainer.initialize(config)
levanter.initialize(config)

# Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint.
# This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary.
Expand Down Expand Up @@ -169,53 +168,57 @@ def loraize_hf_model(model):
def compute_loss(model: LmHeadModel, example: LmExample, key=None):
return model.compute_loss(example, key=key).scalar()

trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter)

# end major difference from Alpaca

trainer.add_default_hooks()
state = trainer.initial_state(training_key, model=model)

# log some info about the model
all_param_count = parameter_count(state.model)
just_lora_params = parameter_count(trainer.trainable_params_only(state.model))

wandb.summary["parameter_count"] = all_param_count
wandb.summary["trainable_parameter_count"] = just_lora_params
logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")

# Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for
# single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large
# datasets. We use replicated here since the dataset is small.
loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch)
loader = non_caching_cycle(loader)

if state.step != 0:
logger.info(f"Resuming training from step {state.step}")
for i in range(state.step):
next(loader) # type: ignore

# Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights
if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
trainer.add_hook(
save_peft_checkpoint_callback(
full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload
),
every=config.hf_save_steps,
)
with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer:
state = trainer.initial_state(training_key, model=model)

# log some info about the model
all_param_count = parameter_count(state.model)
just_lora_params = parameter_count(trainer.trainable_params_only(state.model))

# Save merged HF checkpoints if requested
if config.merged_hf_save_path is not None:
full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id)
trainer.add_hook(
save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload),
every=config.hf_save_steps,
levanter.tracker.log_summary(
{
"parameter_count": all_param_count,
"trainable_parameter_count": just_lora_params,
"fraction_trainable": just_lora_params * 1.0 / all_param_count,
}
)

trainer.train(state, loader)
logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")

# Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for
# single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large
# datasets. We use replicated here since the dataset is small.
loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch)
loader = non_caching_cycle(loader)

if state.step != 0:
logger.info(f"Resuming training from step {state.step}")
for i in range(state.step):
next(loader) # type: ignore

# Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights
if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
trainer.add_hook(
save_peft_checkpoint_callback(
full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload
),
every=config.hf_save_steps,
)

# Save merged HF checkpoints if requested
if config.merged_hf_save_path is not None:
full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id)
trainer.add_hook(
save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload),
every=config.hf_save_steps,
)

trainer.train(state, loader)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions src/levanter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import levanter.logging as logging
import levanter.models as models
import levanter.optim as optim
import levanter.tracker as tracker
import levanter.trainer as trainer
import levanter.visualization as visualization
from levanter.trainer import initialize
13 changes: 9 additions & 4 deletions src/levanter/main/lora_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Optional

import jax.random as jrandom
import wandb

import haliax.random

Expand Down Expand Up @@ -47,6 +46,7 @@ class LoraLmConfig:


def main(config: LoraLmConfig):
levanter.initialize(config)
tokenizer = config.data.the_tokenizer

converter = HFCheckpointConverter.from_hf(config.initialize_from_hf, trust_remote_code=config.trust_remote_code)
Expand All @@ -55,7 +55,6 @@ def main(config: LoraLmConfig):

converter = converter.replaced(tokenizer=tokenizer)

config.trainer.initialize(config)
model_config = converter.default_config

# randomness in jax is tightly controlled by "keys" which are the states of the random number generators
Expand Down Expand Up @@ -96,8 +95,14 @@ def compute_loss(model, example: LmExample, key=None):
all_param_count = parameter_count(state.model)
just_lora_params = parameter_count(trainer.trainable_params_only(state.model))

wandb.summary["parameter_count"] = all_param_count
wandb.summary["trainable_parameter_count"] = just_lora_params
levanter.tracker.log_summary(
{
"parameter_count": all_param_count,
"trainable_parameter_count": just_lora_params,
"fraction_trainable": just_lora_params * 1.0 / all_param_count,
}
)

logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")
Expand Down

0 comments on commit 5d5c30f

Please sign in to comment.