Skip to content

Commit

Permalink
Tweaks/Fixes to Lora_LM, add a basic config for lora_lm (#466)
Browse files Browse the repository at this point in the history
* require a new enough optax

(didn't i do that?)

* add a lora_llama2.yaml

* forgot to set up the data stuff

* support multiple evaluation sets in lora_lm

* try not failing if the ray node is already up

* don't raise on hf datasets with no validation set

* fix trainable_param_count invocations

* reduce batch size for lora_llama2

* fix serialization of lora, which i had broken
  • Loading branch information
dlwh authored Feb 14, 2024
1 parent 3cd37f6 commit f613009
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 12 deletions.
18 changes: 18 additions & 0 deletions config/lora_llama2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
data:
# you should set a data.id or train_urls and validation_urls
# id: math-ai/AutoMathText
tokenizer: "meta-llama/Llama-2-70b-hf"
initialize_from_hf: "meta-llama/Llama-2-7b-hf"
trainer:
mp: p=f32,c=bfloat16
wandb:
project: "levanter-lora"
tags: ["lora", "llama2"]
num_train_steps: 5000 # tune to suit your needs
train_batch_size: 64

# if using model parallelism, this is useful:
tensor_parallel_axes: ["mlp", "heads"]
optimizer:
learning_rate: 3e-4
weight_decay: 0.0
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"equinox>=0.10.7",
"jaxtyping>=0.2.20",
"transformers>=4.22.0",
"optax",
"optax>=0.1.9",
"wandb",
"draccus>=0.7.1",
"pyarrow>=11.0.0",
Expand Down
9 changes: 8 additions & 1 deletion src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _rm_checkpoint(self, checkpoint):
def save_checkpoint(self, info, destination: str):
path = os.path.join(self.base_path, destination)
logger.info(f"Saving checkpoint at step {info.step} to {path}")
state = equinox.filter(info.state, info.state.is_trainable)
state = saveable_state(info.state)
save_checkpoint(
state,
step=info.step,
Expand All @@ -227,6 +227,13 @@ def save_checkpoint(self, info, destination: str):
logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}")


def saveable_state(state):
to_keep = jax.tree_util.tree_map(lambda _: True, state)
to_keep = dataclasses.replace(to_keep, model=state.is_trainable)
state = equinox.filter(state, to_keep)
return state


def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike):
"""
Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint
Expand Down
10 changes: 9 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,15 @@ class LMDatasetSourceConfig:

def get_shard_source(self, split) -> Optional[ShardedDataset[str]]:
if self.id is not None:
ds = WrappedHFDataset(self.id, split=split, name=self.name, streaming=self.stream)
try:
ds = WrappedHFDataset(self.id, split=split, name=self.name, streaming=self.stream)
except ValueError as e:
# if the message starts with Bad split, then just return None
if str(e).startswith("Bad split"):
logger.warning(f"Splits {split} not found for {self.id} {self.name}")
return None
else:
raise

if len(ds.shard_names) == 0:
return None
Expand Down
16 changes: 10 additions & 6 deletions src/levanter/main/lora_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def compute_loss(model, example: LmExample, key=None):

all_param_count = parameter_count(state.model)
# TODO: remove this once we put this in trainer itself
just_lora_params = parameter_count(
levanter.trainer._partition_trainable_params(state.model, lora_param_filter)
)
just_lora_params = parameter_count(state.trainable_model)

levanter.tracker.log_summary(
{
Expand All @@ -108,16 +106,22 @@ def compute_loss(model, example: LmExample, key=None):

logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}")

# data loaders
eval_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore
eval_datasets = config.data.validation_sets(Pos.size)

train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos)
train_loader = trainer.sharded_loader(train_dataset, Batch)

# boilerplate hooks and such
trainer.add_eval_hook(eval_dataset)
if len(eval_datasets) == 0:
logger.warning("No evaluation datasets provided.")

for name, eval_dataset in eval_datasets.items():
eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id)
trainer.add_eval_hook(eval_dataset, name=name)

trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1)
if config.peft_save_path is not None:
full_save_path = os.path.join(config.peft_save_path, trainer.run_id)
Expand Down
16 changes: 13 additions & 3 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import equinox as eqx
import jax
import jax.numpy as jnp
import jmp
import numpy as np
from draccus import field
Expand All @@ -54,7 +55,7 @@
from levanter.tracker import TrackerConfig
from levanter.types import FilterSpec
from levanter.utils import cloud_utils
from levanter.utils.jax_utils import as_arrayish, is_inexact_arrayish
from levanter.utils.jax_utils import is_inexact_arrayish
from levanter.utils.tree_utils import inference_mode


Expand All @@ -68,6 +69,15 @@
"jax_softmax_custom_jvp": True,
}


def _ensure_int_is_array(x):
# who tf decided that bools are ints
if isinstance(x, int) and not isinstance(x, bool):
return jnp.array(x)
else:
return x


# A note on the semantics of "step" vs "next_step":
# The "step" of a TrainerState is the state after `step` steps have been taken.
# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`.
Expand All @@ -76,13 +86,13 @@
class TrainerState(eqx.Module, Generic[M]):
"""
This is the state of the trainer. It contains the model, optimizer state, and random key.
It is an equinox Module becaues it is a PyTree that gets passed to the core `train_step` method
It is an equinox Module because it is a PyTree that gets passed to the core `train_step` method
of the Trainer. This unfortunately means that `step` is an Array and not an int, hence the IntScalar.
It's designed to be extended by subclasses.
"""

_step: IntScalar = eqx.field(converter=lambda x: as_arrayish(x))
_step: IntScalar = eqx.field(converter=_ensure_int_is_array)
model: M
opt_state: OptState
training_key: PRNGKeyArray
Expand Down
31 changes: 31 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,26 @@
import equinox as eqx
import jax
import numpy as np
import optax
from transformers import AutoModelForCausalLM

import haliax as hax
import haliax.nn as hnn

from levanter.checkpoint import Checkpointer
from levanter.compat.hf_checkpoints import HFCheckpointConverter
from levanter.lora import (
LoraConfig,
LoraLinear,
lora_state_dict,
lora_trainable_params_filter,
loraize,
merge_lora_modules,
save_merged_hf_model,
save_peft_pretrained,
)
from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel
from levanter.trainer import StepInfo, TrainerState
from levanter.utils.tree_utils import inference_mode
from test_utils import skip_if_no_torch

Expand Down Expand Up @@ -240,3 +244,30 @@ def test_lora_merged_load_in_hf():

assert np.allclose(lev_lora_out, hf_lora_out, atol=1e-4)
assert not np.allclose(lev_lora_out, hf_out, atol=1e-4)


def test_lora_works_with_checkpointer():
with tempfile.TemporaryDirectory() as tempdir:
k0 = jax.random.PRNGKey(0)
k1 = jax.random.PRNGKey(1)

class Module(eqx.Module):
first: hnn.Linear
second: hnn.Linear

def __call__(self, x):
return self.second(self.first(x))

module = Module(first=hnn.Linear.init(In, Mid, key=k0), second=hnn.Linear.init(Mid, Out, key=k1))

loraized = loraize(module, LoraConfig(r=8, target_modules=["first"]), key=k0)
lora_filter = lora_trainable_params_filter(loraized)

optimizer = optax.adam(1e-3)
opt_state = optimizer.init(eqx.filter(loraized, lora_filter))

trainer_state = TrainerState(0, loraized, opt_state, jax.random.PRNGKey(0), lora_filter)
info = StepInfo(trainer_state, 0.0, 0.0)

checkpointer = Checkpointer(tempdir, None, [])
checkpointer.save_checkpoint(info, "loraized")

0 comments on commit f613009

Please sign in to comment.