From f61300960795513d341c44d560c574cd6fa0b381 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 13 Feb 2024 20:24:24 -0800 Subject: [PATCH] Tweaks/Fixes to Lora_LM, add a basic config for lora_lm (#466) * require a new enough optax (didn't i do that?) * add a lora_llama2.yaml * forgot to set up the data stuff * support multiple evaluation sets in lora_lm * try not failing if the ray node is already up * don't raise on hf datasets with no validation set * fix trainable_param_count invocations * reduce batch size for lora_llama2 * fix serialization of lora, which i had broken --- config/lora_llama2.yaml | 18 ++++++++++++++++++ pyproject.toml | 2 +- src/levanter/checkpoint.py | 9 ++++++++- src/levanter/data/text.py | 10 +++++++++- src/levanter/main/lora_lm.py | 16 ++++++++++------ src/levanter/trainer.py | 16 +++++++++++++--- tests/test_lora.py | 31 +++++++++++++++++++++++++++++++ 7 files changed, 90 insertions(+), 12 deletions(-) create mode 100644 config/lora_llama2.yaml diff --git a/config/lora_llama2.yaml b/config/lora_llama2.yaml new file mode 100644 index 000000000..cf6592153 --- /dev/null +++ b/config/lora_llama2.yaml @@ -0,0 +1,18 @@ +data: + # you should set a data.id or train_urls and validation_urls + # id: math-ai/AutoMathText + tokenizer: "meta-llama/Llama-2-70b-hf" +initialize_from_hf: "meta-llama/Llama-2-7b-hf" +trainer: + mp: p=f32,c=bfloat16 + wandb: + project: "levanter-lora" + tags: ["lora", "llama2"] + num_train_steps: 5000 # tune to suit your needs + train_batch_size: 64 + + # if using model parallelism, this is useful: + tensor_parallel_axes: ["mlp", "heads"] +optimizer: + learning_rate: 3e-4 + weight_decay: 0.0 diff --git a/pyproject.toml b/pyproject.toml index c4a109b21..7444c26b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "equinox>=0.10.7", "jaxtyping>=0.2.20", "transformers>=4.22.0", - "optax", + "optax>=0.1.9", "wandb", "draccus>=0.7.1", "pyarrow>=11.0.0", diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 5c243946e..e63da9c46 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -216,7 +216,7 @@ def _rm_checkpoint(self, checkpoint): def save_checkpoint(self, info, destination: str): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") - state = equinox.filter(info.state, info.state.is_trainable) + state = saveable_state(info.state) save_checkpoint( state, step=info.step, @@ -227,6 +227,13 @@ def save_checkpoint(self, info, destination: str): logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}") +def saveable_state(state): + to_keep = jax.tree_util.tree_map(lambda _: True, state) + to_keep = dataclasses.replace(to_keep, model=state.is_trainable) + state = equinox.filter(state, to_keep) + return state + + def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike): """ Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 00e17eb58..d347e9793 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -551,7 +551,15 @@ class LMDatasetSourceConfig: def get_shard_source(self, split) -> Optional[ShardedDataset[str]]: if self.id is not None: - ds = WrappedHFDataset(self.id, split=split, name=self.name, streaming=self.stream) + try: + ds = WrappedHFDataset(self.id, split=split, name=self.name, streaming=self.stream) + except ValueError as e: + # if the message starts with Bad split, then just return None + if str(e).startswith("Bad split"): + logger.warning(f"Splits {split} not found for {self.id} {self.name}") + return None + else: + raise if len(ds.shard_names) == 0: return None diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 6b845b516..b111d098f 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -94,9 +94,7 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) # TODO: remove this once we put this in trainer itself - just_lora_params = parameter_count( - levanter.trainer._partition_trainable_params(state.model, lora_param_filter) - ) + just_lora_params = parameter_count(state.trainable_model) levanter.tracker.log_summary( { @@ -108,16 +106,22 @@ def compute_loss(model, example: LmExample, key=None): logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}") # data loaders - eval_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore + eval_datasets = config.data.validation_sets(Pos.size) train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) train_loader = trainer.sharded_loader(train_dataset, Batch) # boilerplate hooks and such - trainer.add_eval_hook(eval_dataset) + if len(eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + + for name, eval_dataset in eval_datasets.items(): + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) + trainer.add_eval_hook(eval_dataset, name=name) + trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.peft_save_path is not None: full_save_path = os.path.join(config.peft_save_path, trainer.run_id) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index f85336df4..d682430a8 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -28,6 +28,7 @@ import equinox as eqx import jax +import jax.numpy as jnp import jmp import numpy as np from draccus import field @@ -54,7 +55,7 @@ from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils -from levanter.utils.jax_utils import as_arrayish, is_inexact_arrayish +from levanter.utils.jax_utils import is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -68,6 +69,15 @@ "jax_softmax_custom_jvp": True, } + +def _ensure_int_is_array(x): + # who tf decided that bools are ints + if isinstance(x, int) and not isinstance(x, bool): + return jnp.array(x) + else: + return x + + # A note on the semantics of "step" vs "next_step": # The "step" of a TrainerState is the state after `step` steps have been taken. # A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @@ -76,13 +86,13 @@ class TrainerState(eqx.Module, Generic[M]): """ This is the state of the trainer. It contains the model, optimizer state, and random key. - It is an equinox Module becaues it is a PyTree that gets passed to the core `train_step` method + It is an equinox Module because it is a PyTree that gets passed to the core `train_step` method of the Trainer. This unfortunately means that `step` is an Array and not an int, hence the IntScalar. It's designed to be extended by subclasses. """ - _step: IntScalar = eqx.field(converter=lambda x: as_arrayish(x)) + _step: IntScalar = eqx.field(converter=_ensure_int_is_array) model: M opt_state: OptState training_key: PRNGKeyArray diff --git a/tests/test_lora.py b/tests/test_lora.py index 5ba011bce..cd74a363a 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -3,22 +3,26 @@ import equinox as eqx import jax import numpy as np +import optax from transformers import AutoModelForCausalLM import haliax as hax import haliax.nn as hnn +from levanter.checkpoint import Checkpointer from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.lora import ( LoraConfig, LoraLinear, lora_state_dict, + lora_trainable_params_filter, loraize, merge_lora_modules, save_merged_hf_model, save_peft_pretrained, ) from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel +from levanter.trainer import StepInfo, TrainerState from levanter.utils.tree_utils import inference_mode from test_utils import skip_if_no_torch @@ -240,3 +244,30 @@ def test_lora_merged_load_in_hf(): assert np.allclose(lev_lora_out, hf_lora_out, atol=1e-4) assert not np.allclose(lev_lora_out, hf_out, atol=1e-4) + + +def test_lora_works_with_checkpointer(): + with tempfile.TemporaryDirectory() as tempdir: + k0 = jax.random.PRNGKey(0) + k1 = jax.random.PRNGKey(1) + + class Module(eqx.Module): + first: hnn.Linear + second: hnn.Linear + + def __call__(self, x): + return self.second(self.first(x)) + + module = Module(first=hnn.Linear.init(In, Mid, key=k0), second=hnn.Linear.init(Mid, Out, key=k1)) + + loraized = loraize(module, LoraConfig(r=8, target_modules=["first"]), key=k0) + lora_filter = lora_trainable_params_filter(loraized) + + optimizer = optax.adam(1e-3) + opt_state = optimizer.init(eqx.filter(loraized, lora_filter)) + + trainer_state = TrainerState(0, loraized, opt_state, jax.random.PRNGKey(0), lora_filter) + info = StepInfo(trainer_state, 0.0, 0.0) + + checkpointer = Checkpointer(tempdir, None, []) + checkpointer.save_checkpoint(info, "loraized")