From 9b2c11fe469cfd152718ef190e33c7b485bb772d Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 19 Jan 2024 21:46:07 -0800 Subject: [PATCH 01/19] bugfix for model parallelism (#423) --- src/levanter/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index c615dc1a3..2a8c5e93c 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -678,7 +678,7 @@ def _validate_and_set_defaults(self): raise ValueError("either model_axis_size or local_device_count must be divisible by the other") if self.per_device_parallelism == -1: - self.per_device_parallelism = self.train_batch_size // jax.device_count() + self.per_device_parallelism = self.train_batch_size // self.data_axis_size # validate size of per_device_parallelism if self.train_batch_size % (self.per_device_parallelism * self.data_axis_size) != 0: From e27d39ca4519cd54872672277c2759da4aefa81a Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 22 Jan 2024 15:41:15 -0800 Subject: [PATCH 02/19] Alpaca tutorial and qol improvements (#425) * make the prompts in alpaca.yaml explicit * accept compressed json * update docs for alpaca --- docs/Fine-Tuning.md | 385 ++++++++++++++++------------- examples/alpaca/alpaca-llama2.yaml | 19 ++ examples/alpaca/alpaca.py | 11 +- examples/alpaca/alpaca.yaml | 19 ++ 4 files changed, 260 insertions(+), 174 deletions(-) diff --git a/docs/Fine-Tuning.md b/docs/Fine-Tuning.md index 4903e3c1e..58e8f455d 100644 --- a/docs/Fine-Tuning.md +++ b/docs/Fine-Tuning.md @@ -2,31 +2,35 @@ While Levanter's main focus is pretraining, we can also use it for fine-tuning. As an example, we'll show how to reproduce [Stanford Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html), -using [Levanter](https://github.com/stanford-crfm/levanter) and either Llama 1 or [Llama 2](https://ai.meta.com/llama/). +using [Levanter](https://github.com/stanford-crfm/levanter) and either [Llama 1](https://arxiv.org/abs/2302.13971) or [Llama 2](https://ai.meta.com/llama/) 7B. The script we develop will be designed for Alpaca, defaulting to using its dataset and prompts, but it should work for any single-turn instruction-following task. -This tutorial is meant to cover "full finetuning", where you start with a pretrained model and modify -all of its parameters to fit some final task, rather than something like LoRA (though see our [LoRA tutorial](./LoRA.md) for that). -It also documents how to work with datasets that aren't just single `"text"`s, as we use in pretraining. +This tutorial is meant to cover "full finetuning," where you start with a pretrained model and modify +all of its parameters to fit some final task, rather than something like LoRA that adds a (small) number +of additional parameters. (See our [LoRA tutorial](./LoRA.md) for that.) +It also documents how to work with datasets that aren't just single `"text"`s, which is what we use in pretraining. ## Overview of Alpaca -[Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) is a lightweight fine tune of Llama 1 on a +[Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) is a fine tune of Llama 1 on a [dataset of 52000 input/output pairs](https://huggingface.co/datasets/tatsu-lab/alpaca), which were generated by taking [a seed set from self-instruct](https://github.com/yizhongw/self-instruct) and asking `text-davinci-003` to generate more examples. +The original Alpaca script is [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py). + ![Schematic diagram of how the Alpaca model was created](https://crfm.stanford.edu/static/img/posts/2023-03-13-alpaca/alpaca_main.jpg) ### The Foundation Model -Llama 1 is a 7B parameter causal language model trained on 1T tokens from various mostly English sources. It's described -in [the Llama paper](https://arxiv.org/abs/2302.13971). +Llama 1 7B is a ≈7 billion parameter causal language model trained on 1 trillion tokens from various mostly English sources. +It's described in [the Llama paper](https://arxiv.org/abs/2302.13971). [Llama 2](https://ai.meta.com/llama/) is a similar model, +just trained on more data (and with some slight tweaks to the architecture for larger models). ### The Data -More precisely, the dataset is composed of triples of (instruction, input, output), where the instruction is a prompt +The Alpaca dataset is composed of triples of (instruction, input, output), where the instruction is a prompt describing the task. A bit less than 40% of the examples have inputs, and the rest are just the instruction and output. Here are some example inputs, instructions, and outputs: @@ -44,7 +48,8 @@ generated by an LLM after all.) But it's a good example of the kind of data you ### Preprocessing Because Llama is a causal language model, we need to do some preprocessing to turn the pairs/triples into -a single sequence. The usual thing is to interpolate the strings into a prompt. We'll have two prompts, +a single sequence. The usual thing is to interpolate the strings into a prompt that provides +some context/guidance to the LM. We'll have two prompts, depending on whether or not there's an input or just an instruction and output. For example, the first example above would be turned into: @@ -74,10 +79,11 @@ Compute the area of a rectangle with length 10cm and width 5cm. The area of the rectangle is 50 cm2. ``` -From there [original Alpaca script](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py) *masks out the loss* for all tokens before the start of the output. This gets -the model to learn to mimic outputs conditioned on inputs, rather than spending time learning to generate inputs and outputs. +From there, [the original Alpaca script](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py) *masks out the loss* for all tokens before the start of the output. This gets +the model to learn to mimic outputs conditioned on inputs, rather than getting the model to learn the prompts and +inputs along with the outputs. -## Setup +## Running the script Rather than going through the code first, we'll jump straight to running the script. We'll cover the code in the [Code Walkthrough](#code-walkthrough) section below. @@ -86,61 +92,170 @@ Rather than going through the code first, we'll jump straight to running the scr Make sure you go through either the [GPU](./Getting-Started-GPU.md) or [TPU](./Getting-Started-TPU-VM.md) setup, depending on what you want to use. -### Environment Setup +### NVIDIA GPU -#### \[GPU\] Environment Setup +#### Environment Setup Follow the instructions in the [Getting Started with GPUs](./Getting-Started-GPU.md) guide to create a conda environment or virtualenv -and to install JAX with CUDA. +and to install JAX with CUDA. Then, if you haven't already, clone the Levanter repo and install it in editable mode: + +```bash +git clone https://github.com/stanford-crfm/levanter.git +cd levanter +pip install -e . +``` + +You'll also want to log into [WANDB](https://wandb.ai/). + +```bash +wandb login +``` + +To use Llama 2, you'll need to request access to the model from [Llama 2's Hugging Face page](https://huggingface.co/meta-llama/Llama-2-7b-hf). +Then, you'll need to log into the Hugging Face CLI: + +```bash +huggingface-cli login +``` + +#### Running the Script + +The example commands below demonstrate how to launch a training job on a node with 8 A100 GPUs, but should work for +other single node GPU configurations. For example, we've also tested Alpaca replication with +a node of 8 RTX 6000 Ada Generation 49.1GB GPUs. Levanter works best with Ada or later generation NVIDIA GPUs. +To replicate Alpaca, you can run the following command: + +```bash +python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca.yaml +``` + +To use Llama 2: + +```bash +python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca-llama2.yaml +``` + +Alternatively: + +```bash +python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca-llama2.yaml --model_name_or_path meta-llama/Llama-2-7b-hf +``` + +!!! warning + + Fine-tuning a 7B parameter model needs **a lot** of accelerator memory, you will need more than 80GB of GPU memory in + aggregate to run this job. Because Levanter makes heavy use of FSDP, you can use several smaller cards. + If you don't have enough memory, you can try reducing the `train_batch_size` or the `per_device_parallelism` in + the config. + + +At some point the run will spit out a WandB link. You can click on that to see the training progress. There's +not a ton to see there (yet), but you can see the training loss go down over time. + +On an 8xA100 box, training should take about ~3.5 hours, similar to the original Alpaca script. +It should take ~8.5 hours on 8 RTX 6000 Ada Generation GPUs. + + +### TPUs -#### \[TPU\] Environment Setup +#### Environment Setup -For TPUs, please follow the instructions in the [Getting Started with TPUs](./Getting-Started-TPU-VM.md) guide to -get started with a TPU VM. Once you have, you can just run the -following command from a source checkout of Levanter: +For TPUs, please follow the instructions in the [Getting Started with TPUs](./Getting-Started-TPU-VM.md). +Once you have, you can run something like this to get a v3-32 TPU VM: ```bash bash infra/spin-up-vm.sh llama-32 -z us-east1-d -t v3-32 --preemptible ``` -### Install Levanter +You might need to change the zone and/or the TPU type depending on what's available. You can also use preemptible +TPUs if you want to save money (or that's what your quota is). I think training Alpaca should work on a v3-8, +but we don't have any of those. + +#### Running the Script + +Launching the run on TPU is a bit more complex because you need to specify a lot of paths to GCS buckets. You will +also likely need to run the command on multiple machines, because a v3-32 VM is actually 4 distinct machines, each +controlling 8 TPUs. + +This is what the command looks like: ```bash +export GCS_BASE="gs://" +gcloud compute tpus tpu-vm ssh llama-32 --zone us-east1-d --worker=all \ +--command="WANDB_API_KEY=${YOUR TOKEN HERE} \ +HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} \ +bash levanter/infra/run.sh python \ +levanter/examples/alpaca/alpaca.py \ +--config_path levanter/examples/alpaca/alpaca-llama2.yaml \ +--data_cache_dir ${GCS_BASE}/data \ +--trainer.checkpointer.base_path ${GCS_BASE}/ckpts \ +--hf_save_path ${GCS_BASE}/hf_ckpts +``` + +If you're using preemptible or TRC TPUs, you'll want to add `--trainer.id ` to the command line. +Alternatively, you can use the [babysitting script](./Getting-Started-TPU-VM.md#babysitting-script) to automatically restart the +VM and job if it gets preempted. (It will also set a run id automatically.) That would look like this: ```bash -git clone https://github.com/stanford-crfm/levanter.git -cd levanter -pip install -e . +infra/babysit-tpu-vm.sh llama-32 -z us-east1-d -t v3-32 --preemptible -- \ +WANDB_API_KEY=${YOUR TOKEN HERE} \ +HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} \ +bash levanter/infra/run.sh python \ +levanter/examples/alpaca/alpaca.py \ +--config_path levanter/examples/alpaca/alpaca-llama2.yaml \ +--trainer.checkpointer.base_path gs:// \ +--hf_save_path gs:// \ ``` +You should see a link to the WandB run in the output. You can click on that to see the training progress. +Similar to an 8xA100 box, training should take about ~3.5 hours on a v3-32. + + ## Configuration -We have two configs for Alpaca: one for Llama 1 and one for Llama 2. The only difference is the model id. +That should be all you need to run the script and replicate Alpaca. +However, if you want to customize the script, you can do so by modifying the config. +We have two configs for Alpaca: one for Llama 1 and one for Llama 2. The only difference is the `model_name_or_path` field. -### Config to Replicate Alpaca +### Base Config ```yaml # cf https://github.com/tatsu-lab/stanford_alpaca#fine-tuning data: tatsu-lab/alpaca model_name_or_path: huggyllama/llama-7b trainer: - mp: p=f32,c=bfloat16 + mp: p=f32,c=bfloat16 # Mixed precision training with fp32 parameters/optimizer state and bf16 activations wandb: project: "levanter-alpaca" num_train_steps: 1218 # 128 * 1218 = 155904, which is almost but not quite 3 epochs, which is what alpaca did train_batch_size: 128 - # if using model parallelism, this is useful: - tensor_parallel_axes: ["mlp", "heads"] optimizer: learning_rate: 2e-5 weight_decay: 0.0 +prompts: + # |- means multiline string, keeping all but the final newline + prompt_input: |- + Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Input: + {input} + + ### Response: + prompt_no_input: |- + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Response: ``` This config uses mixed fp32/bf16 precision and sets the number of training steps to be roughly 3 epochs. It sets up the optimizer to use a learning rate of 2e-5 and no weight decay. `trainer.per_device_parallelism` is roughly equivalent to HF's -`per_device_train_batch_size`. If you want to use model parallelism, you can set `trainer.model_axis_size` to something -like 2. (This will split the model across two devices. This might be useful if you're using a v3-64 or something similar and -want to maintain the same batch size.) +`per_device_train_batch_size`. ### Llama 2 Config @@ -150,156 +265,79 @@ If you haven't already, go to [Llama 2's Hugging Face page](https://huggingface. Once you have access, go to [Hugging Face's Tokens page](https://huggingface.co/settings/tokens) to get an API token. You'll need to provide this to the TPU VM as an environment variable. (We'll show you how to do this later.) -### Customizing the Config - -If you have your own dataset, you'll want to change the `data` field in the config to point to your dataset. -You'll also want to change the `model_name_or_path` field to point to the model you want to use. -Currently, Levanter supports GPT-2, Llama, MPT, and Backpack checkpoints. - -```yaml -data: # Path to the training data, or huggingface dataset name. -data_cache_dir: -model_name_or_path: "meta-llama/Llama-2-7b-chat-hf" -trainer: - ... -``` - -#### Custom Prompts - -If you want to use your own prompts, you can add them to the config, like this: - -```yaml -prompts: - prompt_input: |- - ### Instruction: {instruction} - ### Input: {input} - ### Output: - prompt_no_input: |- - ### Instruction: {instruction} - ### Output: -``` - -#### Custom Datasets +### Custom Datasets -The script we develop in this tutorial is designed for Alpaca, but it should work for any single-turn -instruction-following task. For instance, to train a code alpaca model, you could use the following config: +The script in this tutorial is designed for Alpaca, but it should work for any single-turn +instruction-following task. For instance, to train a Code Alpaca model, you could modify the config like this: ```yaml data: lucasmccabe-lmi/CodeAlpaca-20k # a dataset on the Hugging Face hub data_cache_dir: code_alpaca_cache # some path to store the cache ``` -The dataset can also be a JSON path. +The dataset can also be a path to a JSON or JSONL file, or compressed versions of those. -#### \[TPU\] Using a Modified Config +### Custom Models -If you make changes to the config, you'll need to get the config file to all the workers. For TPU, the best way to do this -is to copy it to Google Cloud Storage so that it persists when the machine is preempted. You can do this with: +You can also change the `model_name_or_path` field to point to the model you want to use. This +can be any Hugging Face model, or a path to a local checkpoint. Currently, Levanter supports GPT-2, Llama, MPT, and +Backpack checkpoints. -```bash -gsutil cp examples/alpaca/alpaca.yaml gs:///train-alpaca.yaml +```yaml +model_name_or_path: "meta-llama/Llama-2-7b-chat-hf" ``` -If using Llama 2: +Or on the command line: ```bash -gsutil cp examples/alpaca/alpaca-llama2.yaml gs:///train-alpaca.yaml +python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca.yaml --model_name_or_path "meta-llama/Llama-2-7b-chat-hf" ``` -And then using `--config_path gs:///alpaca.yaml` instead of `--config_path levanter/examples/alpaca/train-alpaca.yaml` -in the command line below. Levanter knows how to read from Google Cloud Storage, so you don't need to do anything else. +### Custom Prompts -## Launching the Job +If you want to use your own prompts, you can modify the `prompts` field. By default, the prompts are set to be the same +as the original Alpaca, but you can change them to whatever you want. They are formatted using Python's +[format strings](https://docs.python.org/3/library/string.html#format-string-syntax), meaning you can use `{instruction}` and `{input}`. +You should have two prompts: one for when there's an input and one for when there isn't. For example, here is +a more minimal prompt: -### \[GPU\] Launching the Job - -Right now, Levanter is only known to work with single node GPU training. The example commands below demonstrate how to launch a training job -on a node with 8 A100 GPUs, but should work for other single node GPU configurations. For example, we've also tested Alpaca replication with -a node of RTX 6000 Ada Generation 49.1GB GPUs. - -Before running your training bash command, ensure you are in your `levanter` conda environment, you've created a directory for saving checkpoints -during training, you are logged into your wandb account with the following two commands: - -!!! warning - - Fine-tuning a 7B parameter model needs **a lot** of accelerator memory, you will need more than 80GB of GPU memory in - aggregate to run this job. Because Levanter makes heavy use of FSDP, you can use several smaller cards. - If you don't have enough memory, you can try reducing the `train_batch_size` or the `per_device_parallelism` in - the config. - - - -```bash -conda activate levanter -wandb login ${YOUR TOKEN HERE} +```yaml +prompts: + prompt_input: |- + ### Instruction: {instruction} + ### Input: {input} + ### Output: + prompt_no_input: |- + ### Instruction: {instruction} + ### Output: ``` -Now you can run the training command: -```bash -python examples/alpaca/alpaca/alpaca.py \ ---config_path examples/alpaca/alpaca.yaml \ ---trainer.checkpointer.base_path levanter/checkpoints \ ---hf_save_path levanter/checkpoints -``` +We use YAML's [multiline string syntax](https://yaml-multiline.info/) to make the prompts easier to read. +You can also specify a path to a json file containing the prompts if you'd prefer. + -You can change `--trainer.checkpointer.base_path` and `--hf_save_path` to your desired model checkpoint directories. +### \[TPU\] Using a Modified Config -If you're using Llama 2, you'll need to first request access to the model, and then export your Hugging Face API token: +On a single machine, you can just modify the config and run the script. On TPU, however, you'll need to upload the config +to a Google Cloud Storage bucket so that all the workers can access it. You can do this with: ```bash -export HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} +gsutil cp my-config.yaml gs:///my-config.yaml ``` -or log into the HF CLI with `huggingface-cli login`. +And then using `--config_path gs:///my-config` instead of `--config_path levanter/examples/alpaca/train-alpaca.yaml` +in the command line. Levanter knows how to read from Google Cloud Storage, so you don't need to do anything else. -### \[GPU\] NLP-Group Slurm Cluster Launch Example +### Aside: Running on Slurm -Say you save the above Alpaca training command as a bash script called `train_alpaca.sh`, then +Say you save the above Alpaca training command as a bash script called `train_alpaca.sh`. Then you could launch a training job on a slurm cluster with `srun` as follows: ```bash srun --account=nlp --cpus-per-task=32 --gpus-per-node=8 --mem=400G --open-mode=append --partition=sphinx --nodes=1 --pty bash train_alpaca.sh ``` -### \[TPU\] Launching the Job - -For TPU, we need just a little bit of ceremony to get the Hugging Face and WANDB API tokens in the environment: -(If you're using Llama 1, you don't need the `HUGGING_FACE_HUB_TOKEN` line unless you're using a private model -or uploading to Hugging Face.) - -```bash -gcloud compute tpus tpu-vm ssh llama-32 --zone us-east1-d --worker=all \ ---command="WANDB_API_KEY=${YOUR TOKEN HERE} \ -HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} \ -bash levanter/infra/run.sh python \ -levanter/examples/alpaca/alpaca.py \ ---config_path levanter/examples/alpaca/alpaca.yaml \ ---trainer.checkpointer.base_path gs:// \ ---hf_save_path gs:// -``` - -If you're using preemptible or TRC TPUs, you'll want to add `--trainer.id ` to the command line. -Alternatively, you can use the [babysitting script](./Getting-Started-TPU-VM.md#babysitting-script) to automatically restart the -VM and job if it gets preempted. That would look like this: - -```bash -infra/babysit-tpu-vm.sh llama-32 -z us-east1-d -t v3-32 --preemptible -- \ -WANDB_API_KEY=${YOUR TOKEN HERE} \ -HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} \ -bash levanter/infra/run.sh python \ -levanter/examples/alpaca/alpaca.py \ ---config_path levanter/examples/alpaca/alpaca-llama2.yaml \ ---trainer.checkpointer.base_path gs:// \ ---hf_save_path gs:// \ -``` - -## Waiting - -At some point the run will spit out a WandB link. You can click on that to see the training progress. There's -not a ton to see there (yet), but you can see the training loss go down over time. - -On a v3-32 or an 8xA100 box, training should take about ~3.5 hours, similarly on a v3-32 TPU VM. -It should take ~8.5 hours on 8 RTX 6000 Ada Generation GPUs. +(This is for the Stanford NLP Cluster. Adjust as necessary for your cluster.) ## Using the Model @@ -309,7 +347,7 @@ When you're done, you can copy out the Hugging Face model with: gsutil cp -r gs:////step- ./my-alpaca ``` -The model should work out-of-the-box as a Hugging Face model. You can use it like this: +The model should work out-of-the-box as a Hugging Face model. For a quick test, you can use it like this: ```python from transformers import AutoModelForCausalLM, AutoTokenizer @@ -344,7 +382,7 @@ If you want to just run the script, you can skip to the [Setup](#setup) section. ### Approach Levanter's existing main entry points are designed for "pure" causal language modeling, where you have a single sequence -and don't want to mask out any tokens. So we'll instead write a custom script that does the following: +and don't have any prompts or custom formatting. So we'll instead write a custom script that does the following: * Preprocesses the dataset into a single sequence, interpolating prompts as we go. We'll also construct a `loss_mask` and do any padding. * Loads the model and resizes the vocabulary to match the tokenizer. @@ -359,7 +397,7 @@ if you want more information. The first step is to get the dataset. We'll use the [Hugging Face Dataset version](https://huggingface.co/datasets/tatsu-lab/alpaca) to do this. (You can also download it directly from the [dataset page](https://huggingface.co/datasets/tatsu-lab/alpaca), but -Levanter's integration with Hugging Face datasets makes it a bit easier to use.) +Levanter's integration with Hugging Face datasets means we don't need to do that.) ```python def _get_data_source(path_or_id): @@ -371,8 +409,8 @@ def _get_data_source(path_or_id): return levanter.data.dataset_from_hf(path_or_id, split="train") ``` -Preprocessing in Levanter typically comes in two phases: -* creating the on-disk cache, +Preprocessing in Levanter typically happens in two phases: +* creating an on-disk cache of the "heavy" preprocessing, like tokenization; and * transforming examples from the cache into the examples that the model expects. Here's the first phase, where we create the cache. We basically want to interpolate the prompt with the input @@ -380,19 +418,19 @@ and instructions, and then tokenize the result. We also want to keep track of th can mask out the loss appropriately. ```python -def mk_dataset(data_path_or_id: str, cache_dir: str, tokenizer): - # wrap an HF dataset with Levanter's native dataset class for fancier preprocessing. - # Levanter's native dataset class supports streaming, deteriministic, distributed preprocessing out of the box, - # which is a bit overkill for this dataset, but it's a good example of how to use it. - dataset = _get_data_source(data_path_or_id) +def mk_dataset(config: TrainArgs, tokenizer: transformers.PreTrainedTokenizerBase): + dataset = _get_data_source(config.data) - prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] + prompts = get_prompts(config.prompts) def preprocess(batch): - sources = [ - prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) - for example in batch - ] + def format_example(ex): + if ex.get("input", "") == "": + return prompts["prompt_no_input"].format_map(ex) + else: + return prompts["prompt_input"].format_map(ex) + + sources = [format_example(example) for example in batch] targets = [f"{example['output']}{tokenizer.eos_token}" for example in batch] # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how the original code does it. examples = [s + t for s, t in zip(sources, targets)] @@ -407,16 +445,15 @@ def mk_dataset(data_path_or_id: str, cache_dir: str, tokenizer): } dataset = dataset.map_batches(preprocess, batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer)) - dataset = dataset.build_cache(cache_dir, await_finished=True) + dataset = dataset.build_cache(config.data_cache_dir, await_finished=True) - # SupervisedDataset does last minute padding and masking - dataset = SupervisedDataset(dataset, tokenizer) + dataset = SupervisedDataset(dataset, tokenizer, mask_inputs=config.mask_inputs) return dataset ``` -In the second, we create [levanter.models.lm.LmExample][] objects from the cache. These are the inputs to the model. -`LmExample`s look like this: +`SupervisedDataset` is a class that we'll define later that does the final transformation from the cache to the +`LmExample` objects that the model expects. `LmExample`s look like this: ```python class LmExample(eqx.Module): @@ -425,8 +462,9 @@ class LmExample(eqx.Module): attn_mask: AttentionMask = AttentionMask.causal() ``` -So we need to populate the first two fields. We'll do that with a dataset whose job is to take the cache and turn it into -`LmExample`s. +So we need to populate the first two fields. `tokens` is the input sequence, and `loss_mask` is a boolean mask +that tells the model which tokens to compute the loss for. (We mask out the loss for everything before the start +of the output.) ```python class SupervisedDataset(Dataset[LmExample]): @@ -458,3 +496,12 @@ class SupervisedDataset(Dataset[LmExample]): The rest is boilerplate: setting up the model, optimizer, and trainer, and then running the training loop. We'll skip over that in this tutorial, but you can see the full script [here](https://github.com/stanford-crfm/levanter/blob/main/examples/alpaca/alpaca.py) if you want to see how it works. + + +## Conclusion + +That's it for this tutorial: you should now be able to fine-tune Llama 1 or Llama 2 on Alpaca or any other single-turn +instruction-following task. If you want to learn more about Levanter, check out the [Levanter docs](https://levanter.readthedocs.io/en/latest/) +or the [Levanter repo](https://github.com/stanford-crfm/levanter). For discussion, you can find us on [Discord](https://discord.gg/CKazXcbbBm). + +Let us know what you'd like to see next! diff --git a/examples/alpaca/alpaca-llama2.yaml b/examples/alpaca/alpaca-llama2.yaml index 88ea6c917..2527de03f 100644 --- a/examples/alpaca/alpaca-llama2.yaml +++ b/examples/alpaca/alpaca-llama2.yaml @@ -13,3 +13,22 @@ trainer: optimizer: learning_rate: 2e-5 weight_decay: 0.0 +prompts: + # |- means multiline string, keeping all but the final newline + prompt_input: |- + Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Input: + {input} + + ### Response: + prompt_no_input: |- + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Response: diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 73d147ac6..e02b0738a 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -35,6 +35,7 @@ # Ways this script could be improved: # * Could tune hparams more for throughput +# Original # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -93,7 +94,7 @@ class TrainArgs: model_cache_dir: Optional[str] = None # Path to cache the model. must be local. - hf_save_path: Optional[str] = None # Path to save the HuggingFace checkpoint. + hf_save_path: Optional[str] = "alpaca_hf_ckpts" # Path to save the HuggingFace checkpoint, can be gcs hf_upload: Union[bool, str] = False # Name of the HuggingFace repo to upload to (if any). hf_save_steps: int = 1000 # How often to save the HuggingFace checkpoint. @@ -134,14 +135,14 @@ def _get_data_source(path_or_id): """The original alpaca.py used a json file, but it's since been moved to the HF dataset hub. You can use any dataset that's compatible with the structure of the alpaca dataset.""" if fsspec_utils.exists(path_or_id): - # get file format: jsonl or json - if path_or_id.endswith(".jsonl"): + # we're a bit generous here b/c we support compression + if ".jsonl" in path_or_id: return JsonlDataset([path_or_id]) - elif path_or_id.endswith(".json"): + elif ".json" in path_or_id: return JsonDataset([path_or_id]) else: raise ValueError( - f"We only support HF Dataset or a data file with .json or .jsonl extensions, not {path_or_id}!" + f"We only support HF Datasets or a data file with .json or .jsonl extensions, not {path_or_id}!" ) else: return WrappedHFDataset(path_or_id, split="train") diff --git a/examples/alpaca/alpaca.yaml b/examples/alpaca/alpaca.yaml index 3eae596fc..fd0e2fb7c 100644 --- a/examples/alpaca/alpaca.yaml +++ b/examples/alpaca/alpaca.yaml @@ -13,3 +13,22 @@ trainer: optimizer: learning_rate: 2e-5 weight_decay: 0.0 +prompts: + # |- means multiline string, keeping all but the final newline + prompt_input: |- + Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Input: + {input} + + ### Response: + prompt_no_input: |- + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Response: From 8250b776c28f2d929c4c095dc571b12f128ec11d Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 23 Jan 2024 11:26:54 -0800 Subject: [PATCH 03/19] Alpaca qol (#430) * make the prompts in alpaca.yaml explicit * accept compressed json * update docs for alpaca * misc typos --- docs/Fine-Tuning.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/Fine-Tuning.md b/docs/Fine-Tuning.md index 58e8f455d..0bd3545b0 100644 --- a/docs/Fine-Tuning.md +++ b/docs/Fine-Tuning.md @@ -13,8 +13,7 @@ It also documents how to work with datasets that aren't just single `"text"`s, w ## Overview of Alpaca -[Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) is a fine tune of Llama 1 on a -[dataset of 52000 input/output pairs](https://huggingface.co/datasets/tatsu-lab/alpaca), which +[Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) is a fine tune of Llama 1 on a [dataset of 52000 input/output pairs](https://huggingface.co/datasets/tatsu-lab/alpaca), which were generated by taking [a seed set from self-instruct](https://github.com/yizhongw/self-instruct) and asking `text-davinci-003` to generate more examples. @@ -25,7 +24,7 @@ The original Alpaca script is [here](https://github.com/tatsu-lab/stanford_alpac ### The Foundation Model Llama 1 7B is a ≈7 billion parameter causal language model trained on 1 trillion tokens from various mostly English sources. -It's described in [the Llama paper](https://arxiv.org/abs/2302.13971). [Llama 2](https://ai.meta.com/llama/) is a similar model, +It's described in [the Llama 1 paper](https://arxiv.org/abs/2302.13971). [Llama 2](https://ai.meta.com/llama/) is a similar model, just trained on more data (and with some slight tweaks to the architecture for larger models). ### The Data @@ -97,7 +96,7 @@ Rather than going through the code first, we'll jump straight to running the scr #### Environment Setup Follow the instructions in the [Getting Started with GPUs](./Getting-Started-GPU.md) guide to create a conda environment or virtualenv -and to install JAX with CUDA. Then, if you haven't already, clone the Levanter repo and install it in editable mode: +and to install JAX with CUDA. Then, if you haven't already done so, clone the Levanter repository and install it in editable mode: ```bash git clone https://github.com/stanford-crfm/levanter.git @@ -122,7 +121,8 @@ huggingface-cli login The example commands below demonstrate how to launch a training job on a node with 8 A100 GPUs, but should work for other single node GPU configurations. For example, we've also tested Alpaca replication with -a node of 8 RTX 6000 Ada Generation 49.1GB GPUs. Levanter works best with Ada or later generation NVIDIA GPUs. +a node of 8 RTX 6000 Ada Generation 49.1GB GPUs. (Levanter works best with Ada or later generation NVIDIA GPUs.) + To replicate Alpaca, you can run the following command: ```bash @@ -143,7 +143,7 @@ python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca-l !!! warning - Fine-tuning a 7B parameter model needs **a lot** of accelerator memory, you will need more than 80GB of GPU memory in + Fine-tuning a 7B parameter model needs **a lot** of accelerator memory: you will need more than 80GB of GPU memory in aggregate to run this job. Because Levanter makes heavy use of FSDP, you can use several smaller cards. If you don't have enough memory, you can try reducing the `train_batch_size` or the `per_device_parallelism` in the config. @@ -168,7 +168,7 @@ bash infra/spin-up-vm.sh llama-32 -z us-east1-d -t v3-32 --preemptible ``` You might need to change the zone and/or the TPU type depending on what's available. You can also use preemptible -TPUs if you want to save money (or that's what your quota is). I think training Alpaca should work on a v3-8, +TPUs if you want to save money (or that's what your quota is). Training Alpaca should work on a v3-8, but we don't have any of those. #### Running the Script @@ -341,7 +341,7 @@ srun --account=nlp --cpus-per-task=32 --gpus-per-node=8 --mem=400G --open-mode=a ## Using the Model -When you're done, you can copy out the Hugging Face model with: +When you're done, you can download the Hugging Face model with: ```bash gsutil cp -r gs:////step- ./my-alpaca From efef064e5529f465c26aac1b0e79de56bc3dad98 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 24 Jan 2024 22:27:18 -0800 Subject: [PATCH 04/19] Tweaks to improve multiprocess gpu outside of slurm (#431) --- pyproject.toml | 2 +- src/levanter/checkpoint.py | 4 +- src/levanter/distributed.py | 92 +++++++++++++++++++++++++--------- src/levanter/utils/py_utils.py | 2 +- tests/test_checkpoint.py | 3 +- 5 files changed, 74 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3d8be709..96272c628 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "ray[default]", "pydantic<2", # temporary pin until Ray supports pydantic 2.0 "rich>=13", -# "chex>=0.1.85" + "filelock", ] [tool.hatch.build] diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 70087af75..15f16a203 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -231,7 +231,7 @@ def save_checkpoint(self, info, destination: str): logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}") -def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike, *, exist_ok: bool = False): +def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike): """ Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint will be saved even if a checkpoint already exists at the given path. @@ -247,7 +247,7 @@ def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike, fs: AbstractFileSystem fs, plain_path = _get_fs_and_plain_path(checkpoint_path) - fs.makedirs(plain_path, exist_ok=exist_ok) + fs.makedirs(plain_path, exist_ok=True) tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model) if training_state is not None: diff --git a/src/levanter/distributed.py b/src/levanter/distributed.py index f4a86f8b9..c0442b45e 100644 --- a/src/levanter/distributed.py +++ b/src/levanter/distributed.py @@ -224,29 +224,30 @@ def _munge_address_port(address: str): # this is no longer the case, so instead we need to check if we are the coordinator # and if so, start the head - if _is_this_machine(host): - logger.info(f"Starting ray head on port {ray_port}. We are process the coordinator {host}.") - logger.info(f"Starting ray with num_cpus set to {num_cpus}.") - ret = os.system( - f"ray start --head --port {ray_port} --num-cpus {num_cpus} --dashboard-host=0.0.0.0" - ) - if ret != 0: - raise RuntimeError(f"Failed to start ray head with exit code {ret}") - else: - logger.info(f"Successfully started ray head on port {ray_port}.") - - # install an atexit handler to kill the head when we exit - atexit.register(lambda: os.system("ray stop -g 10 --force")) - elif start_workers: - logger.info( - f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}." - ) - logger.info(f"Starting ray with num_cpus set to {num_cpus}.") - ret = os.system(f"ray start --address {address} --num-cpus {num_cpus}") - if ret != 0: - raise RuntimeError(f"Failed to start ray head with exit code {ret}") - else: - logger.info(f"Successfully started ray worker and connected to {address}.") + if _is_local_leader(): + if _is_this_machine(host): + logger.info(f"Starting ray head on port {ray_port}. We are process the coordinator {host}.") + logger.info(f"Starting ray head with num_cpus set to {num_cpus}.") + ret = os.system( + f"ray start --head --port {ray_port} --num-cpus {num_cpus} --dashboard-host=0.0.0.0" + ) + if ret != 0: + raise RuntimeError(f"Failed to start ray head with exit code {ret}") + else: + logger.info(f"Successfully started ray head on port {ray_port}.") + + # install an atexit handler to kill the head when we exit + atexit.register(lambda: os.system("ray stop -g 10 --force")) + elif start_workers: + logger.info( + f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}." + ) + logger.info(f"Starting ray worker with num_cpus set to {num_cpus}.") + ret = os.system(f"ray start --address {address} --num-cpus {num_cpus}") + if ret != 0: + raise RuntimeError(f"Failed to start ray head with exit code {ret}") + else: + logger.info(f"Successfully started ray worker and connected to {address}.") logger.info(f"ray.init(address={repr(address)}, namespace={repr(namespace)}, **{repr(kwargs)})") # Ray has retry logic, so we don't need to retry here :fingers-crossed: @@ -318,6 +319,9 @@ def _is_this_machine(host): """ Checks if the given host identifies this machine. """ + if host == "localhost" or host == "0.0.0.0": + return True + try: # Get IP addresses of all interfaces machine_ips = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] @@ -330,3 +334,45 @@ def _is_this_machine(host): # Check if the host IP matches any of the machine IPs return any(host_ip == machine_ip for machine_ip in machine_ips) + + +def _remove_if_possible(path): + try: + os.remove(path) + except OSError: + pass + + +def _touch(file_path): + with open(file_path, "a"): + os.utime(file_path, None) + + +def _is_local_leader(): + import atexit + + import filelock + from jax.experimental.multihost_utils import broadcast_one_to_all + + if jax.process_count() == 1: + return True + + import random + + random_id = random.randint(0, 1000000) + random_id = broadcast_one_to_all(random_id) + + lock = filelock.FileLock(f"/tmp/levanter_local_process_zero_lock.{random_id}") + action_performed_file = f"/tmp/levanter_local_process_zero_action_performed.{random_id}" + + try: + with lock.acquire(timeout=0.1): + if not os.path.exists(action_performed_file): + _touch(action_performed_file) + return True # Action needs to be performed + else: + return False # Action already performed + atexit.register(_remove_if_possible, lock.lock_file) + atexit.register(_remove_if_possible, action_performed_file) + except filelock.Timeout: + return False diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index a172b4498..38ecfc49c 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -5,7 +5,7 @@ def logical_cpu_core_count(): """Returns the number of logical CPU cores available to the process.""" - num_cpus = os.getenv("SLURM_CPUS_PER_TASK", None) + num_cpus = os.getenv("SLURM_CPUS_ON_NODE", None) if num_cpus is not None: return int(num_cpus) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index b8f588df4..c22525fd6 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -161,7 +161,6 @@ def make_state(key): (initial_opt_state, initial_key), step=10, checkpoint_path=tmpdir, - exist_ok=True, ) restored_model, (restored_optstate, rkey), step = load_checkpoint( rep_model, @@ -212,7 +211,7 @@ def loss_fn(model, data): assert_trees_not_close(state, rep_state) with tempfile.TemporaryDirectory() as tmpdir: - save_checkpoint(model, state, step=3, checkpoint_path=tmpdir, exist_ok=True) + save_checkpoint(model, state, step=3, checkpoint_path=tmpdir) restored_model, restored_optstate, step = load_checkpoint( rep_model, rep_state, checkpoint_path=tmpdir, discover_latest=False ) From 0e1f2dcb39a33d251992438360357b65c53de941 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 29 Jan 2024 23:17:09 -0800 Subject: [PATCH 05/19] Create dependabot.yml (#434) --- .github/dependabot.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..9d866e392 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" From 7c9257f08713be8e893a7be1907d373fa2c4f549 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Tue, 30 Jan 2024 13:28:05 -0800 Subject: [PATCH 06/19] Add Model: Mistral-7B (no MOE) (#435) * draft for mistral.py * add mistral7b config & tests * reuse modules from llama --- config/mistral_7b.yaml | 28 +++++ src/levanter/models/mistral.py | 218 +++++++++++++++++++++++++++++++++ tests/test_mistral.py | 144 ++++++++++++++++++++++ 3 files changed, 390 insertions(+) create mode 100644 config/mistral_7b.yaml create mode 100644 src/levanter/models/mistral.py create mode 100644 tests/test_mistral.py diff --git a/config/mistral_7b.yaml b/config/mistral_7b.yaml new file mode 100644 index 000000000..f4b3eab79 --- /dev/null +++ b/config/mistral_7b.yaml @@ -0,0 +1,28 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" + tokenizer: "mistralai/Mistral-7B-v0.1" +model: + type: mistral +# TODO: uncomment this once we resolve the resource exhaustion issue +# initialize_from_hf: "mistralai/Mistral-7B-v0.1" +# use_hf_model_config: true +trainer: + wandb: + project: "levanter" + tags: ["openwebtext", "mistral"] + + mp: p=f32,c=bfloat16 + train_batch_size: 256 # set for v4-64 TPU + num_train_steps: 1000 + steps_per_eval: 50 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 1.2E-5 # set low for fine-tuning + weight_decay: 0.1 + min_lr_ratio: 0.1 diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py new file mode 100644 index 000000000..404c956a8 --- /dev/null +++ b/src/levanter/models/mistral.py @@ -0,0 +1,218 @@ +import dataclasses +from dataclasses import dataclass +from typing import Dict, Optional, Type, Union + +import equinox as eqx +import jax.random as jrandom + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, NamedArray +from haliax.jax_utils import maybe_rng_split + +from levanter.compat.hf_checkpoints import HFCheckpointConverter +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + unflatten_linear_layers, +) +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionMask +from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaTransformer +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.utils.py_utils import cached_classproperty + + +silence_transformer_nag() +from transformers import MistralConfig as HfMistralConfig # noqa: E402 +from transformers import PretrainedConfig as HfConfig # noqa: E402 + + +@LmConfig.register_subclass("mistral") +@dataclass(frozen=True) +class MistralConfig(LlamaConfig): + """Config for MistralModel + + Args: + seq_len (int, optional): maximum length of the input sequence. Defaults to 8192. + hidden_dim (int, optional): dimension of the hidden state. Defaults to 4096. + intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 14336. + num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. + num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. + num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer. + Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA. + Note that num_heads must be divisible by this number. Defaults to 8. + activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". + sliding_window (int, optional): window size of sliding window attention. Defaults to 4096. + """ + + seq_len: int = 8192 + hidden_dim: int = 4096 + intermediate_dim: int = 14336 + num_layers: int = 32 + num_heads: int = 32 + num_kv_heads: int = 8 + activation_function: str = "silu" + initializer_range: float = 0.02 + layer_norm_epsilon: float = 1e-6 + sliding_window: int = 4096 + + # Attention-related config + upcast_attn: bool = False + use_flash_attention: bool = False + flash_attention_block_size: Optional[int] = None + + gradient_checkpointing: bool = True + gradient_checkpointing_block_size: int = 5 + + use_bias: bool = False + rope_scaling: Optional[dict] = None + + # Axis + Pos = property(lambda self: Axis(name="position", size=self.seq_len)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) + Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + + @cached_classproperty + def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["MistralConfig"]: # type: ignore + return HFCheckpointConverter( + cls, # type: ignore + "mistralai/Mistral-7B-v0.1", + trust_remote_code=True, + tokenizer="mistralai/Mistral-7B-v0.1", + HfConfigClass=HfMistralConfig, + ) + + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + return MistralConfig( + seq_len=hf_config.max_position_embeddings, # this might be too big... + hidden_dim=hf_config.hidden_size, + intermediate_dim=hf_config.intermediate_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + num_kv_heads=hf_config.num_key_value_heads, + activation_function=hf_config.hidden_act, + initializer_range=hf_config.initializer_range, + layer_norm_epsilon=hf_config.rms_norm_eps, + sliding_window=hf_config.sliding_window, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfMistralConfig: + """Convert to HuggingFace's MistralConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfMistralConfig: HuggingFace's MistralConfig + """ + if config_overrides is None: + config_overrides = {} + + return HfMistralConfig( + max_position_embeddings=self.seq_len, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + hidden_act=self.activation_function, + initializer_range=self.initializer_range, + rms_norm_eps=self.layer_norm_epsilon, + sliding_window=self.sliding_window, + vocab_size=vocab_size, + **config_overrides, + ) + + @property + def model_type(cls) -> Type["MistralLMHeadModel"]: + return MistralLMHeadModel + + +class MistralLMHeadModel(eqx.Module, LmHeadModel[MistralConfig], StateDictSerializationMixin): + transformer: LlamaTransformer + embeddings: LlamaEmbedding + lm_head: hnn.Linear + + @property + def config(self): + return self.transformer.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + @classmethod + def init(cls, Vocab: Axis, config: MistralConfig, *, key) -> "MistralLMHeadModel": + k_t, k_emb = jrandom.split(key, 2) + transformer = LlamaTransformer.init(config, key=k_t) + embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + return MistralLMHeadModel(transformer, embeddings, lm_head) + + def __call__( + self, + input_ids: NamedArray, + attn_mask: Optional[Union[NamedArray, AttentionMask]] = None, + *, + key=None, + ) -> NamedArray: + """ + Args: + input_ids (NamedArray): [batch, position] + Indices of input sequence tokens in the vocabulary. + attn_mask (Union[NamedArray, AttentionMask], optional): [batch, position] + Mask to avoid performing attention on the padding token indices of the encoder input. + The attn_mask from training pipeline may be an AttentionMask object instead of NamedArray + """ + k_t, k_head = maybe_rng_split(key, 2) + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, key=k_t) + lm_logits = self.lm_head(x, key=k_head) + return lm_logits + + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[MistralConfig]": + new_Vocab = self.Vocab.resize(new_size) + k1, k2 = maybe_rng_split(key, 2) + new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"transformer": "model", "embeddings": None} + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + # unflatten the linear layers of HF state_dict to match the shape of MistralMlp + d = state_dict.copy() + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + ) + ) + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix=prefix) + + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) + ) + + state_dict.update(my_dict) + return state_dict diff --git a/tests/test_mistral.py b/tests/test_mistral.py new file mode 100644 index 000000000..c758e3555 --- /dev/null +++ b/tests/test_mistral.py @@ -0,0 +1,144 @@ +import tempfile + +import jax +import numpy as np +import pytest +import transformers +from jax import random + +import haliax as hax + +from levanter.models.mistral import MistralConfig, MistralLMHeadModel +from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs, skip_if_no_torch + + +@skip_if_no_torch +def test_mistral_config(): + # load HF config and convert to levanter config + hf_config = transformers.MistralConfig.from_pretrained("mistralai/Mistral-7B-v0.1") + mistral_config = MistralConfig.from_hf_config(hf_config) + + # convert back to HF config + config_overrides = { + "_name_or_path": hf_config._name_or_path, + "architectures": hf_config.architectures, + "torch_dtype": hf_config.torch_dtype, + } + new_hf_config = mistral_config.to_hf_config( + vocab_size=hf_config.vocab_size, + config_overrides=config_overrides, + ) + + # assert the content in new_hf_config is the same as hf_config + for k in new_hf_config.__dict__.keys(): + if k in ["_commit_hash", "transformers_version"]: + continue + assert getattr(new_hf_config, k) == getattr( + hf_config, k + ), f"{k} {getattr(new_hf_config, k)} != {getattr(hf_config, k)}" + + +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_mistral_lm_head_model(num_kv_heads): + mistral_config = _get_mistral_config(num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = mistral_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = hax.nn.attention.causal_mask(Pos, mistral_config.KeyPos) + + mistral_model = MistralLMHeadModel.init(Vocab=Vocab, config=mistral_config, key=random.PRNGKey(0)) + out = mistral_model(input_ids, mask) + assert out.array.shape == (Batch.size, Pos.size, Vocab.size) + + +@skip_if_no_torch +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_mistral_roundtrip(num_kv_heads): + import torch + from transformers import AutoModelForCausalLM, MistralForCausalLM + + converter = MistralConfig.default_hf_checkpoint_converter + + config = MistralConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, + ) + Vocab = hax.Axis("vocab", 1000) + hf_config = config.to_hf_config(Vocab.size) + + # Make input and attn_mask + input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + attn_mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) + + torch.random.manual_seed(0) + + torch_model = MistralForCausalLM(hf_config) + torch_model.eval() + + torch_out = torch_model(input_torch) + torch_out = torch_out.logits[0].detach().cpu().numpy() + torch_out = jax.nn.softmax(torch_out, axis=-1) + + with tempfile.TemporaryDirectory() as tmpdir: + torch_model.save_pretrained(f"{tmpdir}/torch_model") + + model = converter.load_pretrained( + MistralLMHeadModel, f"{tmpdir}/torch_model", resize_vocab_to_match_tokenizer=False + ) + + def compute(input): + model_output = model(input, attn_mask=attn_mask) + return hax.nn.softmax(model_output, axis=model.Vocab) + + compute = jax.jit(compute) + jax_out = compute(input).array + + assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + + converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) + torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") + torch_model2.eval() + + torch_out2 = torch_model2(input_torch) + torch_out2 = torch_out2.logits[0].detach().cpu().numpy() + torch_out2 = jax.nn.softmax(torch_out2, axis=-1) + assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" + assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + + +def _get_mistral_config(use_flash=False, num_kv_heads=4) -> MistralConfig: + return MistralConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, # disable for tests so debugging is easier + use_flash_attention=use_flash, + ) + + +@parameterize_with_configs("mistral*.yaml") +def test_mistral_configs(config_file): + from levanter.main.train_lm import TrainLmConfig + + config_class = TrainLmConfig + + check_load_config(config_class, config_file) + + +@pytest.mark.parametrize("num_kv_heads", [1, 2]) +def test_pass_different_length_seq(num_kv_heads): + config = MistralConfig( + seq_len=32, + hidden_dim=16, + intermediate_dim=32, + num_heads=2, + num_kv_heads=num_kv_heads, + ) + check_model_works_with_seqlen(MistralLMHeadModel, config, 16) From 4cdf5afcf2ad3b5b2afc02cd157bae613836ec59 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 30 Jan 2024 13:29:12 -0800 Subject: [PATCH 07/19] Dependabot bumps (#441) * Bump datasets from 2.11.0 to 2.16.1 (#440) Bumps [datasets](https://github.com/huggingface/datasets) from 2.11.0 to 2.16.1. - [Release notes](https://github.com/huggingface/datasets/releases) - [Commits](https://github.com/huggingface/datasets/compare/2.11.0...2.16.1) --- updated-dependencies: - dependency-name: datasets dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Update pydantic requirement from <2 to <3 (#439) Updates the requirements on [pydantic](https://github.com/pydantic/pydantic) to permit the latest version. - [Release notes](https://github.com/pydantic/pydantic/releases) - [Changelog](https://github.com/pydantic/pydantic/blob/main/HISTORY.md) - [Commits](https://github.com/pydantic/pydantic/compare/v0.0.2...v2.6.0) --- updated-dependencies: - dependency-name: pydantic dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Update tblib requirement from <2.0.0,>=1.7.0 to >=1.7.0,<4.0.0 (#438) Updates the requirements on [tblib](https://github.com/ionelmc/python-tblib) to permit the latest version. - [Release notes](https://github.com/ionelmc/python-tblib/releases) - [Changelog](https://github.com/ionelmc/python-tblib/blob/master/CHANGELOG.rst) - [Commits](https://github.com/ionelmc/python-tblib/compare/v1.7.0...v3.0.0) --- updated-dependencies: - dependency-name: tblib dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Update gcsfs requirement from <2023.10.0 to <2023.13.0 (#437) Updates the requirements on [gcsfs](https://github.com/fsspec/gcsfs) to permit the latest version. - [Commits](https://github.com/fsspec/gcsfs/compare/0.0.1...2023.12.2post1) --- updated-dependencies: - dependency-name: gcsfs dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: David Hall * Update fsspec requirement from <2023.10.0 to <2023.13.0 (#436) Updates the requirements on [fsspec](https://github.com/fsspec/filesystem_spec) to permit the latest version. - [Commits](https://github.com/fsspec/filesystem_spec/compare/0.0.1...2023.12.2) --- updated-dependencies: - dependency-name: fsspec dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 96272c628..cf351f24c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,11 +34,11 @@ dependencies = [ "draccus>=0.6", "pyarrow>=11.0.0", "zstandard>=0.20.0", - "datasets==2.11.0", - "gcsfs<2023.10.0", + "datasets==2.16.1", + "gcsfs<2023.13.0", "braceexpand>=0.1.7", "jmp>=0.0.3", - "fsspec<2023.10.0", + "fsspec<2023.13.0", # TODO: minimize and report an issue to tensorstore # causes hangs when serializing to GCS "tensorstore==0.1.45", @@ -46,10 +46,10 @@ dependencies = [ "humanfriendly==10.0", "safetensors[numpy]", "matplotlib>=3.7.0", - "tblib>=1.7.0,<2.0.0", + "tblib>=1.7.0,<4.0.0", "dataclasses-json", "ray[default]", - "pydantic<2", # temporary pin until Ray supports pydantic 2.0 + "pydantic<3", # temporary pin until Ray supports pydantic 2.0 "rich>=13", "filelock", ] From 5fb676718bde29d790c85c4dc593bd5ac2ac1064 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 1 Feb 2024 16:51:13 -0800 Subject: [PATCH 08/19] bring over the improvement to ray preproc from dev (#445) --- src/levanter/data/dataset.py | 11 + src/levanter/data/mixture.py | 3 +- src/levanter/data/shard_cache.py | 444 ++++++++++++++++++++------- src/levanter/data/sharded_dataset.py | 7 +- src/levanter/data/text.py | 18 +- src/levanter/main/cache_dataset.py | 4 +- 6 files changed, 363 insertions(+), 124 deletions(-) diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index 3c49910a6..14c8979b3 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -24,6 +24,17 @@ def __iter__(self) -> Iterator[T]: raise NotImplementedError +class InMemoryDataset(ShardableDataset[T]): + def __init__(self, items: List[T]): + self.items = items + + def __iter__(self) -> Iterator[T]: + return iter(self.items) + + def shard(self, shard_id: int, num_shards: int) -> "InMemoryDataset[T]": + return InMemoryDataset(self.items[shard_id::num_shards]) + + class ShuffleDataset(ShardableDataset[T]): def __init__(self, dataset: Dataset[T], key: PRNGKey, buffer_size: int): self.dataset = dataset diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index dbe255748..71556833a 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -2,7 +2,6 @@ import jax.random import numpy as np -from jax.random import PRNGKey from jaxtyping import PRNGKeyArray from haliax.util import StringHolderEnum @@ -48,7 +47,7 @@ def __init__( self.stop_strategy = stop_strategy if not isinstance(key, int): - key = jax.random.randint(PRNGKey(key)[0], (), 0, 2**31).item() + key = jax.random.randint(key, (), 0, 2**20).item() self.key = key diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 6fa1898c2..62615b8a8 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -2,7 +2,7 @@ import asyncio import dataclasses import heapq -import logging +import logging as pylogging import os import threading import time @@ -31,10 +31,10 @@ TimeRemainingColumn, ) -from levanter.utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info - -from . import ShardableDataset +from .. import logging +from ..utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch +from .dataset import ShardableDataset from .sharded_dataset import ShardedDataset @@ -42,9 +42,10 @@ T_co = TypeVar("T_co", covariant=True) -logger = logging.getLogger(__name__) +logger = pylogging.getLogger(__name__) -DEFAULT_ROWS_PER_CHUNK = 1024 * 32 +DEFAULT_ROWS_PER_CHUNK = 8192 +DEFAULT_MAX_BYTES_PER_BATCH = 256 * 1024 * 1024 # 256 MB, this is pre-preprocessing python object size LEDGER_FILE_NAME = "cache_ledger.json" @@ -75,7 +76,7 @@ def build_cache( from shard names to iterators over the data in that shard. processor: A BatchProcessor that will be used to process batches of data. This is the main place where you can customize the preprocessing pipeline. - batch_size: The number of input examples to process at once. + batch_size: When reading from the cache, how many examples to read at a time. rows_per_chunk: The number of rows to write to each chunk. May be smaller at the end of a shard. await_finished: If True, this function will block until the cache is finished. If False, it will return immediately. @@ -316,100 +317,273 @@ def _shard_reader_generator(shard_source: ShardedDataset[T], shard_idx: int, sta yield batch -# This class is responsible for reading batches from a set of shards, prioritizing earlier -# chunks and earlier shards. (So that we approximately generate following the global order.) +class PriorityWorkTaskGroupSpec(Protocol): + name: str + + def build(self) -> "PriorityWorkTaskGroup": + raise NotImplementedError() + + +class PriorityWorkTaskGroup(Protocol): + name: str + spec: PriorityWorkTaskGroupSpec + + def items(self) -> Sequence["PriorityWorkItem"]: + raise NotImplementedError() + + +class PriorityWorkItem(Protocol): + name: str + priority: float + spec: PriorityWorkTaskGroupSpec + + def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: + """ + Returns true if the item is finished, false if it should be rescheduled. + The object ref is used (1) to block shutting down the actor too early + and (2) for backpressure. + """ + raise NotImplementedError() + + # needs to be sortable by priority + def __lt__(self, other: "PriorityWorkItem"): + if self.priority == other.priority: + return self.name < other.name + else: + return self.priority < other.priority + + def __le__(self, other: "PriorityWorkItem"): + if self.priority == other.priority: + return self.name <= other.name + else: + return self.priority <= other.priority + + @ray.remote(num_cpus=1, scheduling_strategy="SPREAD") -def _alternating_shard_reader( - builder_ref: ActorHandle, # _ChunkCacheBuilder - shard_writers: ActorHandle, # _GroupedShardWriter - shard_source: ShardedDataset[T], - shard_names: Sequence[str], - priority_fn: Callable[[int, int], float], - processor_actor: ActorHandle, # BatchProcessorQueue - batch_size, - num_rows_per_chunk, -): - shard_pqueue: list[tuple[int, int]] = [] # heapq of (num_chunks, shard_idx) - shard_readers: dict[int, Iterator[list[T]]] = {} - try: - shard_metadatas = _initial_shard_metadatas(shard_source, shard_names, shard_writers) - except Exception as e: - builder_ref.other_failed.remote(ser_exc_info()) - raise e +class PriorityProcessorActor: + def __init__(self, max_in_flight: Optional[int] = 200): + pylogging.basicConfig(level=pylogging.INFO) + self._queue: list[PriorityWorkItem] = [] # heapq + self._queue_lock = threading.Lock() + self._shutdown_event = threading.Event() + self._current_item: Optional[PriorityWorkItem] = None + self._max_in_flight = max_in_flight + + self._processing_thread = threading.Thread(target=self._loop, daemon=True) + self._processing_thread.start() + + def add_work_group(self, group: PriorityWorkTaskGroupSpec): + items = group.build().items() + with self._queue_lock: + for item in items: + heapq.heappush(self._queue, item) + + def is_group_finished(self, group: PriorityWorkTaskGroupSpec): + with self._queue_lock: + if any(item.spec == group for item in self._queue): + return False + + if self._current_item is not None and self._current_item.spec == group: + return False + + logger.info(f"Group {group.name} is finished.") - batch_size = min(batch_size, num_rows_per_chunk) + return True + + def cancel_work_group(self, group: PriorityWorkTaskGroupSpec): + # kill all the items in the group + with self._queue_lock: + self._queue = [item for item in self._queue if item.spec != group] + heapq.heapify(self._queue) + + def shutdown(self): + if not self._shutdown_event.is_set(): + self._shutdown_event.set() + + if self._processing_thread.is_alive(): + self._processing_thread.join() + + def _loop(self: "PriorityProcessorActor"): + should_sleep = False + backpressure_queue: list[ray.ObjectRef] = [] + + def drain_backpressure_to(count): + nonlocal backpressure_queue + while len(backpressure_queue) > count: + finished, remaining = ray.wait(backpressure_queue, num_returns=1, fetch_local=False) + backpressure_queue = remaining + + while not self._shutdown_event.is_set(): + if should_sleep: + time.sleep(0.1) + + drain_backpressure_to(self._max_in_flight) + + with self._queue_lock: + if len(self._queue) == 0: + should_sleep = True + continue + else: + should_sleep = False + + item = heapq.heappop(self._queue) + self._current_item = item + + try: + item_is_finished, ref = item.execute() + if ref is not None: + backpressure_queue.append(ref) + except Exception: + logger.exception(f"Error while processing {item.name}. Killing all associated work.") + self.cancel_work_group(item.spec) + continue + + with self._queue_lock: + self._current_item = None + if not item_is_finished: + heapq.heappush(self._queue, item) + + logger.info("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") + drain_backpressure_to(0) + logger.info("Backpressure drained. Shutting down PriorityProcessorActor.") + + +@dataclass +class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): + name: str + builder_ref: ray.actor.ActorHandle # _ChunkCacheBuilder + writer: ray.actor.ActorHandle # _GroupedShardWriter + shard_source: ShardedDataset + shard_names: Sequence[str] + priority_fn: Callable[[int, int], float] + processor_actor: ray.actor.ActorHandle # BatchProcessorQueue + batch_size: int + num_rows_per_chunk: int + + def build(self) -> "PriorityWorkTaskGroup": + return ShardGroupTaskGroup(self) + + +class ShardGroupTaskGroup(PriorityWorkTaskGroup): + def __init__(self, spec: ShardGroupToBeProcessed): + self.spec = spec + self.logger = pylogging.getLogger(f"shard_reader.{self.spec.name}") - for shard_name in shard_names: - shard_idx = shard_source.shard_names.index(shard_name) try: - shard_metadata = shard_metadatas[shard_name] - heapq.heappush(shard_pqueue, (len(shard_metadata.chunks), shard_idx)) - shard_readers[shard_idx] = _shard_reader_generator( - shard_source, shard_idx, shard_metadata.total_rows, batch_size + metadata: dict[str, ShardMetadata] = _initial_shard_metadatas( + self.spec.shard_source, self.spec.shard_names, self.spec.writer ) - except Exception as e: # noqa - logger.exception(f"Error while initializing shard {shard_name}") - ray.get(shard_writers[shard_name].shard_failed.remote(ser_exc_info())) + except Exception as e: + self.spec.builder_ref.other_failed.remote(ser_exc_info()) raise e - MAX_INFLIGHT = 30 - back_pressure_queue: list[ray.ObjectRef] = [] + batch_size = min(self.spec.batch_size, self.spec.num_rows_per_chunk) + + self._items: list[PriorityWorkItem] = [] + + for shard_name in self.spec.shard_names: + shard_idx = self.spec.shard_source.shard_names.index(shard_name) + try: + shard_metadata = metadata[shard_name] + reader = _shard_reader_generator( + self.spec.shard_source, shard_idx, shard_metadata.total_rows, batch_size + ) + + if shard_metadata.is_finished: + self.logger.info(f"Shard {shard_name} already finished. Skipping.") + + task_name = f"shard_reader.{self.spec.name}.{shard_name}" + + chunk_idx = len(shard_metadata.chunks) + item = ShardReaderItem(self, task_name, shard_name, shard_idx, chunk_idx, reader) + + heapq.heappush(self._items, item) + except Exception as e: + self.logger.exception(f"Error while initializing shard {shard_name}") + self.spec.writer[shard_name].shard_failed.remote(ser_exc_info()) + raise e + + @property + def name(self): + return self.spec.name + + def items(self) -> Sequence["PriorityWorkItem"]: + return self._items - while len(shard_pqueue) > 0: - chunk_id, shard_idx = heapq.heappop(shard_pqueue) - shard_name = shard_source.shard_names[shard_idx] - try: - shard_iter = shard_readers[shard_idx] - exhausted_shard = False +# NB This class is stateful +@dataclass +class ShardReaderItem(PriorityWorkItem): + """ + Each time execute is called, this class reads one chunk's worth of batches from the shard + and dispatches them to the processor. + """ + + group: ShardGroupTaskGroup + name: str + shard_name: str + shard_idx: int + chunk_idx: int + reader: Iterator[list] + + @property + def priority(self): + return self.group.spec.priority_fn(self.shard_idx, self.chunk_idx) + + @property + def spec(self): + return self.group.spec + + def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: + exhausted_shard = False + writer = self.spec.writer - chunk_batch_idx = 0 - chunk_filled = False - total_chunk_rows = 0 + chunk_batch_idx = 0 # the index of the batch within the chunk + chunk_filled = False # whether or not we've filled the chunk to max size + total_chunk_rows = 0 # the total number of rows in the chunk + batch_result_ref = None + try: while not chunk_filled: - batch = next(shard_iter, None) + batch = next(self.reader, None) if batch is None: exhausted_shard = True break - exhausted_shard = len(batch) < batch_size + exhausted_shard = len(batch) < self.spec.batch_size total_chunk_rows += len(batch) if batch: - # we want to limit the number of pending tasks, so we wait until we're below the limit - # before we start reading the next batch - while len(back_pressure_queue) >= MAX_INFLIGHT: - finished_ref, back_pressure_queue = ray.wait(back_pressure_queue, num_returns=1) - - priority = priority_fn(shard_idx, chunk_id) - batch = ray.put(batch) - batch_result_ref = ray.get(processor_actor.submit.remote(priority=priority, batch=RefBox(batch))) - shard_writers.chunk_batch_finished.remote( - shard_name, chunk_id, chunk_batch_idx, RefBox(batch_result_ref) + priority = self.spec.priority_fn(self.shard_idx, self.chunk_idx) + batch_result_ref = ray.get( + self.spec.processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) + ) + writer.chunk_batch_finished.remote( + self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref) ) - back_pressure_queue.append(batch_result_ref) - chunk_batch_idx += 1 + # enqueue_to_backpressure(batch, batch_result_ref) + del batch - if total_chunk_rows >= num_rows_per_chunk or exhausted_shard: + if total_chunk_rows >= self.spec.num_rows_per_chunk or exhausted_shard: chunk_filled = True if chunk_batch_idx > 0: - shard_writers.chunk_finished_reading.remote(shard_name, chunk_id, chunk_batch_idx) - chunk_id += 1 + writer.chunk_finished_reading.remote(self.shard_name, self.chunk_idx, chunk_batch_idx) + old_prio = self.priority + self.chunk_idx += 1 + assert self.priority > old_prio if exhausted_shard: - shard_writers.shard_finished_reading.remote(shard_name, chunk_id) - del shard_readers[shard_idx] - del shard_metadatas[shard_name] - else: - # we're not done with this shard, so put it back in the queue - heapq.heappush(shard_pqueue, (chunk_id, shard_idx)) + writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) + + logger.debug(f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}") + return exhausted_shard, batch_result_ref except Exception as e: # noqa - logger.exception(f"Error while processing shard {shard_name}") - ray.get(shard_writers.shard_failed.remote(shard_name, ser_exc_info())) + self.group.logger.exception(f"Error while processing shard {self.shard_name}") + # fire and forget + writer.shard_failed.remote(self.shard_name, ser_exc_info()) raise e @@ -510,7 +684,7 @@ def _init_progress(self, metrics): self.progress.start() -class WandbMetricsMonitor(MetricsMonitor): +class LoggingMetricsMonitor(MetricsMonitor): last_metrics: Optional[InProgressCacheMetrics] last_time: Optional[float] @@ -558,10 +732,10 @@ def __call__(self, metrics: InProgressCacheMetrics): class LoggerMetricsMonitor(MetricsMonitor): # TODO: I'd like to get the trainer pbar migrated to rich and just use rich everywhere, but until then, # we have separate logging - def __init__(self, logger: Optional[Union[logging.Logger, str]] = None, level=logging.INFO): + def __init__(self, logger: Optional[Union[pylogging.Logger, str]] = None, level=pylogging.INFO): if isinstance(logger, str): - logger = logging.getLogger(logger) - self.logger = logger or logging.getLogger(__name__) + logger = pylogging.getLogger(logger) + self.logger = logger or pylogging.getLogger(__name__) self.level = level def __call__(self, metrics: InProgressCacheMetrics): @@ -597,7 +771,7 @@ def is_finished_and_buffer_empty(self): def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) def process_task(batch: List[T]) -> pa.RecordBatch: - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) queue.task_running.remote() result = processor(batch) del batch @@ -667,10 +841,10 @@ def task_running(self): # Ray does poorly with large numbers of actors (grumble grumble), so we can't have one actor per shard. # This class wraps a map of shard names to _ShardWriterWorkers, and manages the lifecycle of the workers. -@ray.remote(num_cpus=1, scheduling_strategy="SPREAD") # type: ignore +@ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") # type: ignore class _GroupShardWriterWorker: def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) self.cache_dir = cache_dir self.shard_names = shard_names self.shard_writers: dict[str, _ShardWriterWorker] = { @@ -714,7 +888,7 @@ def __init__( cache_dir: str, shard_name: str, ): - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) self.parent_ref = parent_ref self.cache_dir = cache_dir self.shard_name = shard_name @@ -727,9 +901,11 @@ def __init__( self.parent_ref.new_chunk.remote(shard_name, *self.metadata_writer.chunks) if self.metadata_writer.is_finished: - logger.info(f"Shard {shard_name} already finished. Skipping.") self._expected_num_chunks = self.metadata_writer.num_chunks self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) + self.finished = True + else: + self.finished = False self.collator = _ChunkCollator(cache_dir, shard_name) @@ -788,12 +964,15 @@ def _attempt_to_commit_chunks(self): chunks_committed.append(chunk) if len(chunks_committed) > 0: + if self.finished: + raise RuntimeError("Tried to commit chunks after shard finished") # TODO: this is called inside an async call so we need to not block, but we do need to sequence # this to come before the shard_finished self.parent_ref.new_chunk.remote(self.shard_name, *chunks_committed) - if self._expected_num_chunks is not None and self.metadata_writer.num_chunks == self._expected_num_chunks: + if not self.finished and self.metadata_writer.num_chunks == self._expected_num_chunks: self.metadata_writer.finish() + self.finished = True self.parent_ref.shard_finished.remote(self.shard_name, self._expected_num_chunks) @@ -874,7 +1053,7 @@ def _attempt_to_write_chunk_fragments(self, chunk_id) -> Optional[ChunkMetadata] return None -@ray.remote +@ray.remote(num_cpus=0.0) # keep this small b/c it doesn't do a lot class ChunkCacheBuilder: """ Actor that manages the in-progress global ordering on chunks. ChunkCacheWriter's job is to hold the list of all @@ -889,11 +1068,13 @@ def __init__( self, broker_ref, cache_dir: str, + name: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int, ): - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) + self.logger = pylogging.getLogger(f"{__name__}.{name}") self.broker_ref = broker_ref self.shard_status: Dict[str, _ShardStatus] = dict() self._current_round_robin = [] @@ -903,10 +1084,10 @@ def __init__( self_ref = current_actor_handle() if len(source.shard_names) == 0: - logger.warning("No shards to index?!?") + self.logger.warning("No shards to index?!?") self._finish() else: - logger.info(f"Starting cache build for {len(source.shard_names)} shards") + self.logger.info(f"Starting cache build for {len(source.shard_names)} shards") self._shard_writers = [] self._shard_readers = [] @@ -927,24 +1108,47 @@ def priority_fn(shard_idx, chunk_idx): self._current_round_robin.append(shard_name) shard_groups[i % num_shard_groups].append(shard_name) - for shard_group in shard_groups: + for group_id, shard_group in enumerate(shard_groups): writer = _GroupShardWriterWorker.remote(self_ref, cache_dir, shard_group) # type: ignore self._shard_writers.append(writer) + # TODO: would probably be better if we didn't create one of these per shard group processor_actor = _BatchProcessorQueue.remote(processor) # type: ignore self._processor_actors.append(processor_actor) - reader = _alternating_shard_reader.remote( - self_ref, - writer, - source, - shard_group, - priority_fn, - processor_actor, - processor.batch_size, - rows_per_chunk, + work_item = ShardGroupToBeProcessed( + name=name, + builder_ref=self_ref, + writer=writer, + shard_source=source, + shard_names=shard_group, + priority_fn=priority_fn, + processor_actor=processor_actor, + batch_size=processor.batch_size, + num_rows_per_chunk=rows_per_chunk, ) - self._shard_readers.append(reader) + + # we want global names so that different tasks can coordinate priorities + priority_actor_name = f"priority_processor.{group_id}" + + reader_actor = PriorityProcessorActor.options( # type: ignore + name=priority_actor_name, get_if_exists=True + ).remote() + + ray.get(reader_actor.add_work_group.remote(work_item)) + + # reader = _alternating_shard_reader.remote( + # name, + # self_ref, + # writer, + # source, + # shard_group, + # priority_fn, + # processor_actor, + # processor.batch_size, + # rows_per_chunk, + # ) + self._shard_readers.append(reader_actor) def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): """Callback method for when a shard worker has produced a new chunk.""" @@ -966,6 +1170,9 @@ def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): def shard_finished(self, shard_name: str, expected_num_chunks: int): """Callback method for when a shard worker has finished.""" shard_status = self.shard_status[shard_name] + assert ( + shard_status.expected_num_chunks is None + ), f"Shard {shard_name} already finished: {shard_status.expected_num_chunks} {expected_num_chunks}" shard_status.expected_num_chunks = expected_num_chunks # we might still have buffered chunks, so we need to check if we can append them @@ -991,17 +1198,17 @@ def other_failed(self, error: ExceptionInfo): def _attempt_to_flush_buffers(self): # this is the most complex logic in this class. - # The global order on chunks is defined as "roundrobin" over shards, until one shard is done. - # after that, that shard is removed from the roundrobin and the process continues. - # roundrobin order is determined by self.source.shard_names - - # we are happy to release chunks that form a prefix of the global order so that they can be read - # to do that, we maintain the roundrobin order in self._current_round_robin - # and we maintain the current buffer for each shard in self.shard_status - # when we get a new chunk, we append it to the buffer for that shard - # when we get a finished message, we mark that shard as finished - # in either case, we check if we can send any chunks from the front of the roundrobin - # if we can, we send them to the broker + # The global order on chunks is defined as a roundrobin over shards, until one shard is done. + # After that, that shard is removed from the roundrobin and the process continues. + # Roundrobin order is determined by self.source.shard_names + + # We are happy to release chunks that form a prefix of the global order so that they can be read. + # To do that, we maintain the roundrobin order in self._current_round_robin + # and we maintain the current buffer for each shard in self.shard_status. + # When we get a new chunk, we append it to the buffer for that shard. + # When we get a finished message, we mark that shard as finished. + # In either case, we check if we can send any chunks from the front of the roundrobin. + # If we can, we send them to the broker # here "finished" means that the shard has sent all of its chunks and has told us that it's done. @@ -1049,7 +1256,7 @@ class ChunkCacheBroker: _finished_promise: asyncio.Future[None] def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int): - logging.basicConfig(level=logging.INFO) + pylogging.basicConfig(level=pylogging.INFO) self.chunks = [] self._reader_promises = {} self._is_finished = False @@ -1071,7 +1278,10 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr self._finished_promise.set_result(None) except FileNotFoundError: self_ref = ray.runtime_context.get_runtime_context().current_actor - self._builder_actor = ChunkCacheBuilder.remote(self_ref, self._cache_dir, self._source, self._processor, self._rows_per_chunk) # type: ignore + # only use the last two components of the name since it gets kind of long + path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) + name = f"builder::{path_for_name}" + self._builder_actor = ChunkCacheBuilder.remote(self_ref, self._cache_dir, name, self._source, self._processor, self._rows_per_chunk) # type: ignore def is_finished(self): return self._is_finished @@ -1259,6 +1469,8 @@ def __init__( self._num_readers = num_readers self._reader_offset = reader_offset + name = os.path.join(*cache_dir.split("/")[-2:]) + self.logger = pylogging.getLogger(f"ShardCache.{name}") @staticmethod def load(cache_dir: str, batch_size: int) -> "ShardCache": @@ -1317,13 +1529,15 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N time_in = time.time() # we want to also log if we're waiting for a long time, so we do this in a loop while timeout is None or time.time() - time_in < timeout: - current_timeout = 20.0 # be generous + current_timeout = 20.0 if timeout is not None: current_timeout = min(current_timeout, timeout - (time.time() - time_in)) try: chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout) except GetTimeoutError: - logger.warning(f"Waiting for chunk {mapped_index} after {int(time.time() - time_in)} seconds") + self.logger.warning(f"Waiting for chunk {mapped_index} for {int(time.time() - time_in)} seconds") + current_timeout *= 2 + current_timeout = min(current_timeout, 80) continue if chunk is None: @@ -1378,7 +1592,7 @@ def iter_batches_from_chunks(self, loop: bool = False): i = shard_offset while True: try: - logger.debug(f"Reading chunk {i}") + self.logger.debug(f"Reading chunk {i}") chunk = self._get_chunk_unmapped(i) i += self._num_readers yield from self._read_chunk(chunk) @@ -1391,7 +1605,7 @@ def iter_batches_from_chunks(self, loop: bool = False): else: break except Exception as e: - logger.exception("Error while reading from shard cache.") + self.logger.exception("Error while reading from shard cache.") raise e def __iter__(self): @@ -1452,7 +1666,7 @@ def _monitor_metrics(self): if metrics.is_finished: break except Exception as e: - logger.exception("Error while reading metrics from shard cache.") + self.logger.exception("Error while reading metrics from shard cache.") raise e diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index 3f3f8c036..1ceae6366 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -93,7 +93,12 @@ def build_cache( source, processor = _construct_composite_batch_processor(self) cache = build_cache( - path, source, processor, rows_per_chunk=rows_per_chunk, await_finished=await_finished, monitors=monitors + path, + source, + processor, + rows_per_chunk=rows_per_chunk, + await_finished=await_finished, + monitors=monitors, ) return DictCacheDataset(cache) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 8ac061eb3..212318433 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -42,9 +42,10 @@ from levanter.data.shard_cache import ( # noqa ChunkMetadata, LoggerMetricsMonitor, + LoggingMetricsMonitor, MetricsMonitor, ShardCache, - WandbMetricsMonitor, + _serialize_json_and_commit, build_cache, ) from levanter.data.sharded_dataset import ShardedDataset, TextUrlDataset, WrappedHFDataset # noqa @@ -567,9 +568,11 @@ def token_seq_dataset( return TokenSeqDataset(cache, seq_len) def build_or_load_cache( - self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True + self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None ) -> Optional[TokenizedDocumentCache]: split_cache_dir = os.path.join(self.cache_dir, split) + name = logger_name or os.path.basename(self.cache_dir) + try: return TokenizedDocumentCache.load(split_cache_dir, flatten_docs=True) except FileNotFoundError: @@ -584,8 +587,8 @@ def build_or_load_cache( if monitors is True: monitors = [ - WandbMetricsMonitor(prefix=f"preprocessing/{split}", commit=False), - LoggerMetricsMonitor(f"preprocessing.{split}"), + LoggingMetricsMonitor(prefix=f"preprocessing/{name}/{split}", commit=False), + LoggerMetricsMonitor(f"preprocessing.{name}.{split}"), ] elif monitors is False: monitors = [] @@ -657,6 +660,13 @@ def train_set( token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()} return MixtureDataset(datasets=token_datasets, weights=self.train_weights) + def training_sets( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Mapping[str, ShardableDataset[np.ndarray]]: + doc_caches = self.build_caches("train", monitors=monitors) + token_datasets = {name: TokenSeqDataset(cache, seq_len, stride=None) for name, cache in doc_caches.items()} + return token_datasets + def validation_sets( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True ) -> Mapping[str, ShardableDataset[np.ndarray]]: diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index bcf755b92..0b0636f4b 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -5,7 +5,7 @@ import wandb import levanter -from levanter.data.shard_cache import RichMetricsMonitor, WandbMetricsMonitor, build_cache +from levanter.data.shard_cache import LoggingMetricsMonitor, RichMetricsMonitor, build_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig from levanter.logging import init_logger @@ -40,7 +40,7 @@ def main(args: RayCachedLMDatasetConfig): logger.warning(f"Skipping {split} because it is empty.") continue - monitors = [RichMetricsMonitor(source.num_shards), WandbMetricsMonitor("preprocess/" + split, commit=True)] + monitors = [RichMetricsMonitor(source.num_shards), LoggingMetricsMonitor("preprocess/" + split, commit=True)] cache = build_cache( cache_dir=split_cache_dir, From 14b514d7653f6e83af4eb01394a823e5cfb09702 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 2 Feb 2024 10:01:53 -0800 Subject: [PATCH 09/19] pull in the microbatched implementation (#448) --- src/levanter/grad_accum.py | 151 ++++++++++++++++++++++++------------- src/levanter/trainer.py | 22 +++++- tests/test_grad_accum.py | 8 +- 3 files changed, 121 insertions(+), 60 deletions(-) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 9280e5234..8ac6e9395 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -1,4 +1,6 @@ +import enum import functools +from typing import Callable, Optional, ParamSpec, TypeVar import equinox as eqx import jax @@ -8,93 +10,122 @@ import haliax as hax from haliax import Axis -from haliax.jax_utils import named_call from haliax.partitioning import ResourceAxis -from haliax.util import is_named_array +from haliax.util import is_jax_array_like, is_named_array -from levanter.types import M, ValAndGradFn, ValFn, X +Args = ParamSpec("Args") +R = TypeVar("R") + + +class ReductionType(enum.Enum): + SUM = enum.auto() + MEAN = enum.auto() + # TODO: add MAX? + + +# TODO: should we use a custom_jvp on microbatched? # cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 -@named_call -def accumulate_gradients_sharded( - f: ValFn[M, X], +def microbatched( + fn: Callable[Args, R], Batch: Axis, - per_device_parallelism: int, - parameter_axis_mapping, -) -> ValAndGradFn[M, X]: + microbatch_size: int, + accum_axis_mapping, + compute_axis_mapping, + patch_in_rng_key: Optional[str] = "key", + reduce: ReductionType = ReductionType.MEAN, + accum_dtype: Optional[jnp.dtype] = None, +) -> Callable[Args, R]: """ - Accumulate gradients across a sharded batch, keeping a local copy of the gradient on each row of the data - parallel axis. (If the model is not sharded, then a copy of the gradient is on each individual device.) + Wraps a function that takes a batch and changes it to instead take microbatches and accumulate the results + This function has to reduce the batch axis, so it can't be used for functions that need to keep the batch axis. - Parameters: - f: a function whose gradients are to be accumulated - per_device_parallelism: how many examples to process at once on each device - inputs: inputs with the batch axis. non-named arrays assume that the 0th axis is the batch axis. - parameter_axis_mapping: the axis mapping for the model parameters - key: an optional PRNG key for the random number generator. - If provided, this key will be split, 1 for each accum step - kwargs: passed to f + Can be used as a decorator with functools.partial, e.g.: + >>> @functools.partial(microbatched, Batch=Batch, per_device_parallelism=4) + >>> def my_fn(x): + >>> return hax.mean(x + 1) + + + Args: + fn: a function to wrap + Batch: the batch axis + per_device_parallelism: how many examples to process at once on each device + accum_axis_mapping: the axis mapping for the accumulator (typically this is the same as the params) + compute_axis_mapping: the axis mapping for the computation (typically this is the same as the inputs) + patch_in_rng_key: if provided, this kwarg will be split, 1 for each accum step. It won't work if the + PRNGKey is passed in as a positional argument. + reduce: whether to sum or average the results + accum_dtype: the dtype of floating point values in the accumulator. If None, this will be inferred from the return type of `fn`. + + Returns: + a function that splits the batch into microbatches, calls the function on each microbatch, and + accumulates the results. """ batch_size = Batch.size - data_axis_size = hax.partitioning.physical_axis_size(Batch, parameter_axis_mapping) + data_axis_size = hax.partitioning.physical_axis_size(Batch, compute_axis_mapping) if data_axis_size is None: raise ValueError(f"{Batch} axis must be sharded") - physical_axis_name = hax.partitioning.physical_axis_name(Batch, parameter_axis_mapping) + physical_axis_name = hax.partitioning.physical_axis_name(Batch, compute_axis_mapping) assert physical_axis_name is not None - microbatch_size = data_axis_size * per_device_parallelism + if microbatch_size <= 0: + raise ValueError(f"Bad value for {microbatch_size=}") + num_micro_steps = batch_size // microbatch_size - assert batch_size % data_axis_size == 0, f"batch_size % data_axis_size != 0: {batch_size} % {data_axis_size} != 0" - assert ( - batch_size % microbatch_size == 0 - ), f"batch_size % microbatch_size != 0: {batch_size} % {microbatch_size} != 0" + if num_micro_steps == 1: + return fn - Microbatch = Axis(Batch.name, microbatch_size) + Microbatch = Batch.resize(microbatch_size) AccumStep = Axis("accum_step", num_micro_steps) assert num_micro_steps * microbatch_size == batch_size - grad_fn = eqx.filter_value_and_grad(f, has_aux=False) + if reduce not in ReductionType: + raise ValueError(f"accum_type must be one of {ReductionType}") - @functools.wraps(grad_fn) - def fn(model, *inputs, key=None, **batch_kwargs): + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + + # first, determine the shape and make accumulator arrays + r_shape = eqx.filter_eval_shape(fn, *args, **kwargs) + acc = _zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype) + + # then, reshape the inputs from (Batch, ...) to (AccumStep, Microbatch, ...) + + # Special handling for PRNGKey: it comes in as a single key, but we need to split it for each microbatch + key = kwargs.get(patch_in_rng_key, None) if key is not None: key = jax.random.split(key, num_micro_steps) + kwargs = kwargs.copy() + kwargs.pop(patch_in_rng_key) - # first things first, we want a copy of our gradient sharded like our model, along with a loss value - loss = jnp.zeros(()) - with jax.named_scope("zeros"): - grad = jax.tree_util.tree_map(jnp.zeros_like, eqx.filter(model, eqx.is_inexact_array_like)) - grad = hax.shard_with_axis_mapping(grad, parameter_axis_mapping) + args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping) - # second, we want to reshape our data to (num_micro_steps, micro_batch_size, ...), sharded along the data axis - inputs = _reshape_for_microbatch(Batch, Microbatch, AccumStep, inputs, parameter_axis_mapping) - - # third, we want to do compute. def loop(acc, microbatch_and_key): - loss, grad = acc microbatch, microbatch_kwargs, key = microbatch_and_key - with jax.named_scope("grad"): + with jax.named_scope("compute"): microbatch_kwargs = microbatch_kwargs.copy() if key is not None: - microbatch_kwargs["key"] = key - this_loss, this_grad = grad_fn(model, *microbatch, **microbatch_kwargs) - this_grad = hax.shard_with_axis_mapping(this_grad, parameter_axis_mapping) + microbatch_kwargs[patch_in_rng_key] = key + this_r = fn(*microbatch, **microbatch_kwargs) with jax.named_scope("accum"): - loss += this_loss - grad = eqx.apply_updates(grad, this_grad) - grad = hax.shard_with_axis_mapping(grad, parameter_axis_mapping) + acc = eqx.apply_updates(acc, this_r) + acc = hax.shard(acc, accum_axis_mapping) + + return acc - return loss, grad + with jax.named_scope("microbatched"): + acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key)) - loss, grad = hax.fold(loop, AccumStep)((loss, grad), (inputs, batch_kwargs, key)) + if reduce == ReductionType.MEAN: + acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc) - return loss / num_micro_steps, jax.tree_map(lambda x: x / num_micro_steps, grad) + return acc - return fn + return wrapped_fn def _reshape_for_microbatch(Batch: Axis, Microbatch: Axis, AccumStep: Axis, inputs, axis_mapping): @@ -103,7 +134,7 @@ def _reshape(x): if not x.has_axis(Batch.name): return x x = x.unflatten_axis(Batch, (AccumStep, Microbatch)) - return hax.shard_with_axis_mapping(x, axis_mapping) + return hax.shard(x, axis_mapping) elif isinstance(x, jnp.ndarray): x = x.reshape((AccumStep.size, Microbatch.size) + x.shape[1:]) return with_sharding_constraint(x, PartitionSpec(None, ResourceAxis.DATA, *(None,) * (len(x.shape) - 2))) @@ -112,3 +143,19 @@ def _reshape(x): return x return jax.tree_util.tree_map(_reshape, inputs, is_leaf=is_named_array) + + +def _zeros_like_tree(r_shape, axis_mapping, accum_dtype): + _zeros = functools.partial(_zeros_like, axis_mapping, accum_dtype) + acc = jax.tree_util.tree_map(_zeros, r_shape, is_leaf=is_named_array) + return acc + + +def _zeros_like(mapping, dtype, n): + if isinstance(n, hax.NamedArray): + return hax.shard(hax.zeros_like(n, dtype=dtype), mapping) + elif is_jax_array_like(n): + return jnp.zeros_like(n, dtype) + else: + assert jnp.isscalar(n) + return 0.0 diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 2a8c5e93c..5fe6e0302 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -29,13 +29,14 @@ import haliax as hax from haliax import Axis from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit +from haliax.types import Scalar import levanter.logging from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig -from levanter.grad_accum import accumulate_gradients_sharded +from levanter.grad_accum import microbatched from levanter.logging import WandbConfig, capture_time from levanter.types import FilterSpec from levanter.utils import cloud_utils @@ -406,9 +407,7 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): model = eqx.combine(trainable_model, rest_model) return self.loss_fn(model, *batch, **batch_kwargs) - loss, grads = accumulate_gradients_sharded( - split_loss_fn, self.TrainBatch, self.config.per_device_parallelism, self.parameter_axis_mapping - )(trainable_model, *batch, **batch_kwargs) + loss, grads = self._compute_gradients_microbatched(split_loss_fn, trainable_model, batch, **batch_kwargs) updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) model = eqx.apply_updates(model, updates) @@ -417,6 +416,17 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): return train_step + def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: + grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) + grad_fn = microbatched( + grad_fn, + self.TrainBatch, + self.config.microbatch_size, + self.parameter_axis_mapping, + self.compute_axis_mapping, + ) + return grad_fn(model, *batch, **batch_kwargs) + def _init_model_and_opt_state(self, model_init): model = model_init() # only force trainable params to param precision. Other params are cast to compute precision @@ -562,6 +572,10 @@ def TrainBatch(self): def EvalBatch(self): return Axis("batch", self.eval_batch_size) + @property + def microbatch_size(self): + return self.per_device_parallelism * self.data_axis_size + def initialize(self, all_config): """Initializes jax, wandb, logging, setting the run name/id in the process""" # Can't do full logging setup until we've initialized jax b/c we use jax for rank id diff --git a/tests/test_grad_accum.py b/tests/test_grad_accum.py index aec568a21..131ca3b89 100644 --- a/tests/test_grad_accum.py +++ b/tests/test_grad_accum.py @@ -7,7 +7,7 @@ import haliax as hax import haliax.nn as hnn -from levanter.grad_accum import accumulate_gradients_sharded +from levanter.grad_accum import microbatched class Mlp(eqx.Module): @@ -56,9 +56,9 @@ def loss_fn(mlp, x): @hax.partitioning.named_jit(axis_resources=axis_mapping) def jit_grad_accum(mlp, x): - acc_v, acc_g = accumulate_gradients_sharded( - loss_fn, Batch, per_device_parallelism=parallelism, parameter_axis_mapping=axis_mapping - )( + grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) + grad_fn = microbatched(grad_fn, Batch, parallelism, axis_mapping, axis_mapping) + acc_v, acc_g = grad_fn( mlp, x, ) From 768a7e623db44f655214c4dfeaf242d47578c210 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 2 Feb 2024 11:38:39 -0800 Subject: [PATCH 10/19] use hnn.Embeddings in gpt2 (#447) * use hnn.Embeddings in gpt2 * delete commented out stuff --- src/levanter/models/gpt2.py | 29 ++++++++++++++++------------- tests/test_weight_decay_mask.py | 8 ++++---- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 5822fab30..a6f27c7a5 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -312,40 +312,43 @@ class Gpt2Embeddings(StateDictSerializationMixin, eqx.Module): Vocab: Axis = eqx.static_field() config: Gpt2Config = eqx.static_field() - token_embeddings: NamedArray - position_embeddings: NamedArray + token_embeddings: hnn.Embedding + position_embeddings: hnn.Embedding dropout: hnn.Dropout @staticmethod def init(Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2Embeddings": k_wte, k_wpe, k_out = jrandom.split(key, 3) - token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) * config.initializer_range - position_embeddings = hax.random.normal(k_wpe, (config.Pos, config.Embed)) * (config.initializer_range / 2) + token_embeddings = hnn.Embedding.init( + Vocab, config.Embed, key=k_wte, initializer_range=config.initializer_range + ) + position_embeddings = hnn.Embedding.init( + config.Pos, config.Embed, key=k_wpe, initializer_range=config.initializer_range / 2 + ) dropout = hnn.Dropout(pdrop=config.embed_pdrop) return Gpt2Embeddings(Vocab, config, token_embeddings, position_embeddings, dropout) @named_call def embed(self, input_ids, *, key): - input_embeds = self.token_embeddings.take("vocab", input_ids) - position_embeds = self.position_embeddings - - input_len = input_ids.resolve_axis("position").size - x = input_embeds + position_embeds["position", hax.dslice(0, input_len)] + input_embeds = self.token_embeddings(input_ids) + input_Pos = input_ids.resolve_axis("position") + position_embeds = self.position_embeddings.embed(hax.arange(input_Pos)) + x = input_embeds + position_embeds x = self.dropout(x, key=key) return x def unembed(self, x: NamedArray): - return hax.dot("embed", x, self.token_embeddings) + return hax.dot("embed", x, self.token_embeddings.weight) def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return {"token_embeddings": "wte.weight", "position_embeddings": "wpe.weight"} + return {"token_embeddings": "wte", "position_embeddings": "wpe"} def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): - new_weights = hax.tree_util.resize_axis(self.token_embeddings, self.Vocab, new_size, key=key) - return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) + new_token_embeddings = self.token_embeddings.resize_embeddings(new_size, key=key) + return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_token_embeddings) class Gpt2LMHeadModel(eqx.Module, LmWithHfSerializationMixin[Gpt2Config]): diff --git a/tests/test_weight_decay_mask.py b/tests/test_weight_decay_mask.py index 0c0f00e5c..52834e679 100644 --- a/tests/test_weight_decay_mask.py +++ b/tests/test_weight_decay_mask.py @@ -18,8 +18,8 @@ def apply_weight_decay(tree): nodes = [] # apply on embedding - nodes.append(tree.embeddings.token_embeddings.array) - nodes.append(tree.embeddings.position_embeddings.array) + nodes.append(tree.embeddings.token_embeddings.weight.array) + nodes.append(tree.embeddings.position_embeddings.weight.array) # apply on attention nodes.append(tree.transformer.blocks.stacked.attn.c_attn.weight.array) @@ -49,8 +49,8 @@ def apply_weight_decay(tree): "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight", - "token_embeddings", - "position_embeddings", + "token_embeddings.weight", + "position_embeddings.weight", ] ) regex_config = OptimizerConfig( From c1e7b24aa47e073af0aa48606581f817b59e1b45 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Sat, 3 Feb 2024 11:05:53 -0800 Subject: [PATCH 11/19] Bugfix: GQA with FA (#449) * fix GQA FA * add fa block size to mistral test * remove attention backward test (because LMHeadModel backward is tested) * fix bug in fa summation * rename variable --- src/levanter/models/flash_attention.py | 12 +++++++-- tests/test_flash_attention.py | 35 ++++++++++++++++++++++++++ tests/test_llama.py | 23 ++++++++++++++++- tests/test_mistral.py | 21 ++++++++++++++++ 4 files changed, 88 insertions(+), 3 deletions(-) diff --git a/src/levanter/models/flash_attention.py b/src/levanter/models/flash_attention.py index b1adea0b3..c1be0091e 100644 --- a/src/levanter/models/flash_attention.py +++ b/src/levanter/models/flash_attention.py @@ -304,8 +304,16 @@ def do_inner_block(state): dAttn_ij = p_ij * (dP_ij - D_i) dAttn_ij = dAttn_ij.astype(dQ_i.dtype) - dV_j = dV_j + hax.dot(QPos.name, p_ij, dO_i).astype(dV_j.dtype) - dK_j = dK_j + hax.dot(QPos.name, dAttn_ij, q_i).astype(dK_j.dtype) + dV_ji = hax.dot(QPos.name, p_ij, dO_i).astype(dV_j.dtype) + dK_ji = hax.dot(QPos.name, dAttn_ij, q_i).astype(dK_j.dtype) + + # GQA-specific: eliminate unnecessary axes (e.g. 'q_heads_per_group') + unnecessary_axes = hax.eliminate_axes(dV_ji.axes, v.axes) + dV_ji = hax.sum(dV_ji, unnecessary_axes) + dK_ji = hax.sum(dK_ji, unnecessary_axes) + + dV_j = dV_j + dV_ji + dK_j = dK_j + dK_ji dQ_i = dQ_i + hax.dot(KPos.name, dAttn_ij, k_j).astype(dQ.dtype) # dQ[i*block_size:(i+1)*block_size] = dQi diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 8d1f5aab0..467ac6cba 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -3,6 +3,7 @@ import equinox import jax.numpy as jnp import jax.random as jrandom +import pytest import haliax as hax import haliax.nn as hnn @@ -75,6 +76,40 @@ def d_attn(qkv, fn): assert jnp.allclose(hax_dv.array, fa_dv.array, atol=1e-4, rtol=1e-4) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_grad_group_query_attention(num_kv_heads): + Batch = hax.Axis("batch", 2) + KVHeads = hax.Axis("kv_heads", num_kv_heads) + QHeadsPerGroup = hax.Axis("q_heads_per_group", 4 // num_kv_heads) + Key = hax.Axis("Key", 8) + QPos = hax.Axis("QPos", BLOCK_SIZE * 2) + KPos = hax.Axis("KPos", BLOCK_SIZE * 2) + + mask = hax.nn.attention.causal_mask(QPos, KPos) + + q = hax.random.normal(jrandom.PRNGKey(0), (Batch, KVHeads, QHeadsPerGroup, QPos, Key)) + k = hax.random.normal(jrandom.PRNGKey(1), (Batch, KVHeads, KPos, Key)) + v = hax.random.normal(jrandom.PRNGKey(2), (Batch, KVHeads, KPos, Key)) + + @equinox.filter_value_and_grad + def d_attn(qkv, fn): + q, k, v = qkv + x_out = fn(KPos, Key, q, k, v, mask=mask) + return (x_out * x_out).sum().scalar() + + hax_val, (hax_dq, hax_dk, hax_dv) = d_attn((q, k, v), hnn.attention.dot_product_attention) + fa_val, (fa_dq, fa_dk, fa_dv) = d_attn((q, k, v), functools.partial(flash_attention, QPos, inference=True)) + + assert jnp.allclose(hax_val, fa_val, atol=1e-4, rtol=1e-4) + assert hax_dq.axes == fa_dq.axes + assert hax_dk.axes == fa_dk.axes + assert hax_dv.axes == fa_dv.axes + + assert jnp.allclose(hax_dq.array, fa_dq.array, atol=1e-4, rtol=1e-4) + assert jnp.allclose(hax_dk.array, fa_dk.array, atol=1e-4, rtol=1e-4) + assert jnp.allclose(hax_dv.array, fa_dv.array, atol=1e-4, rtol=1e-4) + + def test_fa_dropout_does_something(): Key = hax.Axis("Key", 8) QPos = hax.Axis("QPos", BLOCK_SIZE * 2) diff --git a/tests/test_llama.py b/tests/test_llama.py index 7224a3ac1..1ace5c63c 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,5 +1,6 @@ import tempfile +import equinox as eqx import jax import numpy as np import pytest @@ -193,7 +194,7 @@ def test_llama_decoder_layer(num_kv_heads): state = llama_decoder_layer.to_state_dict() state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} - hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000)) + hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000), layer_idx=0) hf_decoder_layer.load_state_dict(state, strict=True) x, mask = _get_random_inputs(llama_config) @@ -224,6 +225,25 @@ def test_llama_lm_head_model(num_kv_heads): assert out.array.shape == (Batch.size, Pos.size, Vocab.size) +@pytest.mark.parametrize("use_flash", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_llama_lm_head_model_bwd(use_flash, num_kv_heads): + llama_config = _get_llama_config(use_flash=use_flash, num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = llama_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = hax.nn.attention.causal_mask(Pos, llama_config.KeyPos) + + llama_model = LlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + + def f(llama_model, input_ids, mask): + out = llama_model(input_ids, mask) + return hax.sum(out).scalar() + + _, grads = eqx.filter_value_and_grad(f)(llama_model, input_ids, mask) + + @skip_if_no_torch @pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) def test_llama_roundtrip(num_kv_heads): @@ -297,6 +317,7 @@ def _get_llama_config(use_flash=False, num_kv_heads=4) -> LlamaConfig: rope_scaling=rope_scaling, gradient_checkpointing=False, # disable for tests so debugging is easier use_flash_attention=use_flash, + flash_attention_block_size=8 if use_flash else None, ) diff --git a/tests/test_mistral.py b/tests/test_mistral.py index c758e3555..76630849f 100644 --- a/tests/test_mistral.py +++ b/tests/test_mistral.py @@ -1,5 +1,6 @@ import tempfile +import equinox as eqx import jax import numpy as np import pytest @@ -52,6 +53,25 @@ def test_mistral_lm_head_model(num_kv_heads): assert out.array.shape == (Batch.size, Pos.size, Vocab.size) +@pytest.mark.parametrize("use_flash", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_mistral_lm_head_model_bwd(use_flash, num_kv_heads): + llama_config = _get_mistral_config(use_flash=use_flash, num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = llama_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = hax.nn.attention.causal_mask(Pos, llama_config.KeyPos) + + llama_model = MistralLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + + def f(llama_model, input_ids, mask): + out = llama_model(input_ids, mask) + return hax.sum(out).scalar() + + _, grads = eqx.filter_value_and_grad(f)(llama_model, input_ids, mask) + + @skip_if_no_torch @pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) def test_mistral_roundtrip(num_kv_heads): @@ -120,6 +140,7 @@ def _get_mistral_config(use_flash=False, num_kv_heads=4) -> MistralConfig: num_kv_heads=num_kv_heads, gradient_checkpointing=False, # disable for tests so debugging is easier use_flash_attention=use_flash, + flash_attention_block_size=8 if use_flash else None, ) From f87828365bd75056ab050c5cfd4ca66a883b0582 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 11:22:47 -0800 Subject: [PATCH 12/19] update for latest datasets (#454) --- src/levanter/data/sharded_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index 1ceae6366..0ec178e08 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -171,7 +171,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources([int(shard_name)])) + shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) else: shard = dataset From 157db95dfc1dbe4513804ee98614750db7e5ca9c Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 11:23:20 -0800 Subject: [PATCH 13/19] add pile mixture configs (#450) * add pile mixture configs * mclmcmlca --- config/data/pile_mixture.yaml | 137 +++++++++++++++++++++++++++ config/gpt2_small_pile_mixture.yaml | 23 +++++ scripts/preproc/split-pile-shards.py | 75 +++++++++++++++ 3 files changed, 235 insertions(+) create mode 100644 config/data/pile_mixture.yaml create mode 100644 config/gpt2_small_pile_mixture.yaml create mode 100644 scripts/preproc/split-pile-shards.py diff --git a/config/data/pile_mixture.yaml b/config/data/pile_mixture.yaml new file mode 100644 index 000000000..ff75b8941 --- /dev/null +++ b/config/data/pile_mixture.yaml @@ -0,0 +1,137 @@ +cache_dir: "gs://levanter-data/tokenized/pile-domains/" +tokenizer: "EleutherAI/gpt-neox-20b" +configs: + arxiv: + train_urls: + - gs://levanter-data/pile-domains/arxiv/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/arxiv/val.jsonl.zst + books2: + train_urls: + - gs://levanter-data/pile-domains/books2/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/books2/val.jsonl.zst + books3: + train_urls: + - gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/books3/val.jsonl.zst + dm_math: + train_urls: + - gs://levanter-data/pile-domains/dm_math/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/dm_math/val.jsonl.zst + enron: + train_urls: + - gs://levanter-data/pile-domains/enron/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/enron/val.jsonl.zst + europarl: + train_urls: + - gs://levanter-data/pile-domains/europarl/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/europarl/val.jsonl.zst + free_law: + train_urls: + - gs://levanter-data/pile-domains/freelaw/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/freelaw/val.jsonl.zst + github: + train_urls: + - gs://levanter-data/pile-domains/github/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/github/val.jsonl.zst + hackernews: + train_urls: + - gs://levanter-data/pile-domains/hackernews/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/hackernews/val.jsonl.zst + nih: + train_urls: + - gs://levanter-data/pile-domains/nih/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/nih/val.jsonl.zst + opensubtitles: + train_urls: + - gs://levanter-data/pile-domains/opensubtitles/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/opensubtitles/val.jsonl.zst + owt2: + train_urls: + - gs://levanter-data/pile-domains/owt2/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/owt2/val.jsonl.zst + pg_19: + train_urls: + - gs://levanter-data/pile-domains/pg_19/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/pg_19/val.jsonl.zst + philpapers: + train_urls: + - gs://levanter-data/pile-domains/philpapers/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/philpapers/val.jsonl.zst + pile_cc: + train_urls: + - gs://levanter-data/pile-domains/pile_cc/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/pile_cc/val.jsonl.zst + pubmed_abs: + train_urls: + - gs://levanter-data/pile-domains/pubmed_abs/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/pubmed_abs/val.jsonl.zst + pubmed_central: + train_urls: + - gs://levanter-data/pile-domains/pubmed_central/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/pubmed_central/val.jsonl.zst + stack_exchange: + train_urls: + - gs://levanter-data/pile-domains/stack_exchange/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/stack_exchange/val.jsonl.zst + ubuntu_irc: + train_urls: + - gs://levanter-data/pile-domains/ubuntu_irc/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/ubuntu_irc/val.jsonl.zst + uspto: + train_urls: + - gs://levanter-data/pile-domains/uspto/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/uspto/val.jsonl.zst + wiki_en: + train_urls: + - gs://levanter-data/pile-domains/wiki_en/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/wiki_en/val.jsonl.zst + youtube_subtitles: + train_urls: + - gs://levanter-data/pile-domains/youtube_subtitles/{00..29}.jsonl.zst + validation_urls: + - gs://levanter-data/pile-domains/youtube_subtitles/val.jsonl.zst +train_weights: + # these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf + pile_cc: 0.1811 + pubmed_central: 0.1440 + books3: 0.1207 + owt2: 0.1001 + arxiv: 0.0896 + github: 0.0759 + free_law: 0.0612 + stack_exchange: 0.0513 + uspto: 0.0365 + pubmed_abs: 0.0307 + pg_19: 0.0217 + opensubtitles: 0.0155 + wiki_en: 0.0153 + dm_math: 0.0124 + ubuntu_irc: 0.0088 + books2: 0.0075 + europarl: 0.0073 + hackernews: 0.0062 + youtube_subtitles: 0.0060 + philpapers: 0.0038 + nih: 0.0030 + enron: 0.0014 diff --git a/config/gpt2_small_pile_mixture.yaml b/config/gpt2_small_pile_mixture.yaml new file mode 100644 index 000000000..e02e4bd1f --- /dev/null +++ b/config/gpt2_small_pile_mixture.yaml @@ -0,0 +1,23 @@ +data: !include data/pile_mixture.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 2048 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "pile", "gpt2"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 50000 +optimizer: + learning_rate: 6e-4 + weight_decay: 0.1 diff --git a/scripts/preproc/split-pile-shards.py b/scripts/preproc/split-pile-shards.py new file mode 100644 index 000000000..768b24874 --- /dev/null +++ b/scripts/preproc/split-pile-shards.py @@ -0,0 +1,75 @@ +import json +import os +import sys +from pathlib import Path + +import fsspec +import tqdm + + +OUT_PATH = "gs://levanter-data/pile-domains" + +categories_to_out_names = { + "ArXiv": "arxiv", + "BookCorpus2": "books2", + "Books3": "books3", + "DM Mathematics": "dm_math", + "Enron Emails": "enron", + "EuroParl": "europarl", + "FreeLaw": "freelaw", + "Github": "github", + "Gutenberg (PG-19)": "pg_19", + "HackerNews": "hackernews", + "NIH ExPorter": "nih", + "OpenSubtitles": "opensubtitles", + "OpenWebText2": "owt2", + "PhilPapers": "philpapers", + "Pile-CC": "pile_cc", + "PubMed Abstracts": "pubmed_abs", + "PubMed Central": "pubmed_central", + "StackExchange": "stack_exchange", + "USPTO Backgrounds": "uspto", + "Ubuntu IRC": "ubuntu_irc", + "Wikipedia (en)": "wiki_en", + "YoutubeSubtitles": "youtube_subtitles", +} + + +def format_category(category): + return categories_to_out_names[category] + + +def process_file(input_file_path): + base_file = Path(input_file_path).stem + compressors = {} + + with fsspec.open(input_file_path, "r", compression="infer") as text_stream: + for line in tqdm.tqdm(text_stream): + if not line.strip(): + continue # Skip empty lines + + # Decode line to string and load as JSON + data = json.loads(line) + category = data["meta"]["pile_set_name"] + category = format_category(category) + output_file_path = os.path.join(OUT_PATH, category, f"{base_file}.zst") + + # Check if compressor exists for this category, if not create it + if category not in compressors: + # output_file = open(output_file_path, 'wb') + output_file = fsspec.open(str(output_file_path), "wb", compression="infer").open() + print("opened", output_file_path) + compressors[category] = output_file + + # Write to the compressor + compressors[category].write(line.encode("utf-8")) + compressors[category].flush() + + # Close all open compressors + for compressor in compressors.values(): + compressor.close() + + +if __name__ == "__main__": + for path in sys.argv[1:]: + process_file(path) From f74fc5c589cc4a58dfc4d431a32c01a254ed0d33 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 6 Feb 2024 14:28:31 -0800 Subject: [PATCH 14/19] Tokenization perf tweaks (#455) * Bump tensorstore from 0.1.45 to 0.1.53 Bumps [tensorstore](https://github.com/google/tensorstore) from 0.1.45 to 0.1.53. - [Commits](https://github.com/google/tensorstore/compare/v0.1.45...v0.1.53) --- updated-dependencies: - dependency-name: tensorstore dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Update fsspec requirement from <2023.13.0 to <2024.3.0 Updates the requirements on [fsspec](https://github.com/fsspec/filesystem_spec) to permit the latest version. - [Commits](https://github.com/fsspec/filesystem_spec/compare/0.0.1...2024.2.0) --- updated-dependencies: - dependency-name: fsspec dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update gcsfs requirement from <2023.13.0 to <2024.3.0 Updates the requirements on [gcsfs](https://github.com/fsspec/gcsfs) to permit the latest version. - [Commits](https://github.com/fsspec/gcsfs/compare/0.0.1...2024.2.0) --- updated-dependencies: - dependency-name: gcsfs dependency-type: direct:production ... Signed-off-by: dependabot[bot] * improve logging/add some tweaks in shard_cache to improve stability * work around llama tokenizer being quadratic * fix short timeouts * add some janky retry logic to ray init * missed a log spam * reduce batch size * fix an assertion error when we have failures * that assert is wrong --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 6 +- src/levanter/data/shard_cache.py | 205 +++++++++++++++++++++++-------- src/levanter/data/text.py | 157 +++++++++++++++++++---- src/levanter/distributed.py | 13 +- tests/test_text.py | 31 ++++- 5 files changed, 332 insertions(+), 80 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cf351f24c..14f010c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,13 +35,13 @@ dependencies = [ "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets==2.16.1", - "gcsfs<2023.13.0", + "gcsfs<2024.3.0", "braceexpand>=0.1.7", "jmp>=0.0.3", - "fsspec<2023.13.0", + "fsspec<2024.3.0", # TODO: minimize and report an issue to tensorstore # causes hangs when serializing to GCS - "tensorstore==0.1.45", + "tensorstore==0.1.53", "pytimeparse>=1.1.8", "humanfriendly==10.0", "safetensors[numpy]", diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 62615b8a8..569bbe711 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -48,6 +48,8 @@ DEFAULT_MAX_BYTES_PER_BATCH = 256 * 1024 * 1024 # 256 MB, this is pre-preprocessing python object size LEDGER_FILE_NAME = "cache_ledger.json" +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + def build_cache( cache_dir: str, @@ -297,7 +299,7 @@ def _commit(self): # The difficulty is that we want parallelism and we want to control the order of chunks. # reading batches requires CPU and network. This means we should limit the number to roughly the number of nodes, maybe times 2. -# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from any shard. +# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from another shard. # We also want to prioritize reading earlier shards before later shards (within a chunk generation round). # Ray also seems to get upset about having too many processes, and we can't serialize the iterators over shards. @@ -362,7 +364,7 @@ def __le__(self, other: "PriorityWorkItem"): @ray.remote(num_cpus=1, scheduling_strategy="SPREAD") class PriorityProcessorActor: def __init__(self, max_in_flight: Optional[int] = 200): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self._queue: list[PriorityWorkItem] = [] # heapq self._queue_lock = threading.Lock() self._shutdown_event = threading.Event() @@ -386,7 +388,7 @@ def is_group_finished(self, group: PriorityWorkTaskGroupSpec): if self._current_item is not None and self._current_item.spec == group: return False - logger.info(f"Group {group.name} is finished.") + logger.debug(f"Group {group.name} is finished.") return True @@ -443,9 +445,9 @@ def drain_backpressure_to(count): if not item_is_finished: heapq.heappush(self._queue, item) - logger.info("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") + logger.debug("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") drain_backpressure_to(0) - logger.info("Backpressure drained. Shutting down PriorityProcessorActor.") + logger.debug("Backpressure drained. Shutting down PriorityProcessorActor.") @dataclass @@ -459,6 +461,7 @@ class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): processor_actor: ray.actor.ActorHandle # BatchProcessorQueue batch_size: int num_rows_per_chunk: int + group_id: int def build(self) -> "PriorityWorkTaskGroup": return ShardGroupTaskGroup(self) @@ -467,7 +470,7 @@ def build(self) -> "PriorityWorkTaskGroup": class ShardGroupTaskGroup(PriorityWorkTaskGroup): def __init__(self, spec: ShardGroupToBeProcessed): self.spec = spec - self.logger = pylogging.getLogger(f"shard_reader.{self.spec.name}") + self.logger = pylogging.getLogger(f"shard_reader.{spec.group_id}.{spec.name}") try: metadata: dict[str, ShardMetadata] = _initial_shard_metadatas( @@ -543,6 +546,8 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: total_chunk_rows = 0 # the total number of rows in the chunk batch_result_ref = None + self.group.logger.debug(f"Reading one chunk of shard {self.shard_name}: {self.chunk_idx}") + try: while not chunk_filled: batch = next(self.reader, None) @@ -555,14 +560,20 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if batch: priority = self.spec.priority_fn(self.shard_idx, self.chunk_idx) + # these times aren't exact because the times might be from different machines + # but they're just for logging + time_in = time.time() batch_result_ref = ray.get( - self.spec.processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) + self.spec.processor_actor.submit.remote( + priority=priority, + desc=f"{self.shard_name}.{self.chunk_idx}.{chunk_batch_idx}", + batch=RefBox(ray.put(batch)), + ) ) writer.chunk_batch_finished.remote( - self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref) + self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref), time_in ) chunk_batch_idx += 1 - # enqueue_to_backpressure(batch, batch_result_ref) del batch if total_chunk_rows >= self.spec.num_rows_per_chunk or exhausted_shard: @@ -577,7 +588,9 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if exhausted_shard: writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) - logger.debug(f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}") + self.group.logger.debug( + f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}" + ) return exhausted_shard, batch_result_ref except Exception as e: # noqa @@ -768,14 +781,49 @@ def is_finished_and_buffer_empty(self): return self.expected_num_chunks is not None and self.num_chunks_sent >= self.expected_num_chunks +class WaitTimeReportingThread(threading.Thread): + def __init__(self, report, interval=60): + super().__init__() + self.report = report + self.interval = interval + self.shutdown_event = threading.Event() + + def run(self): + total_waited = 0 + while True: + if self.shutdown_event.wait(self.interval): + break + if total_waited > 0: + self.report(total_waited) + total_waited += self.interval + + def shutdown(self): + self.shutdown_event.set() + + def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) - def process_task(batch: List[T]) -> pa.RecordBatch: - pylogging.basicConfig(level=pylogging.INFO) + def process_task(desc, batch: List[T]) -> pa.RecordBatch: + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) + logger.debug(f"Processing batch {desc}") queue.task_running.remote() - result = processor(batch) - del batch - return as_record_batch(result) + # timer_thread = WaitTimeReportingThread( + # lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 + # ) + # timer_thread.start() + try: + result = processor(batch) + del batch + result = as_record_batch(result) + logger.debug(f"Finished processing batch {desc}") + return result + except Exception as e: + logger.exception(f"Error while processing batch {desc}") + raise e + finally: + # timer_thread.shutdown() + # timer_thread.join() + pass return process_task @@ -783,6 +831,7 @@ def process_task(batch: List[T]) -> pa.RecordBatch: @dataclass(order=True, frozen=True) class _QueueItem: priority: float + desc: str batch: ray.ObjectRef = dataclasses.field(compare=False) task_id: int task_future: asyncio.Future = dataclasses.field(compare=False) @@ -817,13 +866,13 @@ def __init__(self, batch_processor: BatchProcessor[T]): # we don't need/want to dereference the batch, so we wrap it in a RefBox # one virtue of doing things this way is that we can let Ray try to schedule the compute near the data. - async def submit(self, priority: float, batch: RefBox): + async def submit(self, priority: float, desc: str, batch: RefBox): """Returns a future that is set to the *ObjectRef* of the processed batch. The future is "complete" when the task starts, not when it finishes. You then call ray.get on the future's result to get the actual batch.""" task_id = self._next_task_id self._next_task_id += 1 f: asyncio.Future = asyncio.Future() - self.pqueue.put(_QueueItem(priority, batch.ref, task_id, f)) + self.pqueue.put(_QueueItem(priority, desc, batch.ref, task_id, f)) self._maybe_start_task() return await f @@ -832,7 +881,7 @@ def _maybe_start_task(self): self.ready = False item = self.pqueue.get() batch = item.batch - item.task_future.set_result(self._task_processor.remote(batch)) + item.task_future.set_result(self._task_processor.remote(item.desc, batch)) def task_running(self): self.ready = True @@ -844,7 +893,7 @@ def task_running(self): @ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") # type: ignore class _GroupShardWriterWorker: def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.cache_dir = cache_dir self.shard_names = shard_names self.shard_writers: dict[str, _ShardWriterWorker] = { @@ -854,10 +903,40 @@ def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): def current_metadata(self, shard_name: str): return self.shard_writers[shard_name].current_metadata() - async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox): + async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox, time_in): # batch is a pa.RecordBatch ref box try: - batch = await batch.ref + time_mid = time.time() + logger.debug( + f"Received in progress batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in" + f" {time_mid - time_in}" + ) + # do a backoff loop until the batch is actually processed. log if it's been a while + timeout_interval = 20 + total_time_waited = 0 + + while True: + try: + # batch = await asyncio.wait_for(asyncio.shield(batch.ref), timeout_interval) + batch = await batch.ref + break + except asyncio.TimeoutError: + # to keep to round numbers, we log how much we asked for rather than how much we got + total_time_waited += timeout_interval + timeout_interval = min(2 * timeout_interval, 100) + logger.info( + f"Waiting for {shard_name}.{chunk_id}.{batch_idx} to be processed. " + f"Waited {total_time_waited} seconds." + ) + + if logger.isEnabledFor(pylogging.DEBUG): + logger.debug( + f"Received finished {shard_name}.{chunk_id}.{batch_idx} in {(time.time() - time_in):.2f} seconds." + ) + elif total_time_waited > 40: + logger.info( + f"Waited {total_time_waited} seconds for {shard_name}.{chunk_id}.{batch_idx} to be processed." + ) return self.shard_writers[shard_name].chunk_batch_finished(chunk_id, batch_idx, batch) except Exception as e: print(f"Error while processing batch {batch_idx} of chunk {chunk_id} of shard {shard_name}", flush=True) @@ -888,7 +967,7 @@ def __init__( cache_dir: str, shard_name: str, ): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.parent_ref = parent_ref self.cache_dir = cache_dir self.shard_name = shard_name @@ -933,13 +1012,9 @@ def chunk_failed(self, chunk_id: int, error: ExceptionInfo): self.parent_ref.shard_failed.remote(self.shard_name, error) def _finished_chunk(self, idx: int, chunk: ChunkMetadata): - if idx < self.metadata_writer.num_chunks: - logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") - error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") - self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) - raise error - - if self._expected_num_chunks is not None and idx >= self._expected_num_chunks: + if (idx < self.metadata_writer.num_chunks) or ( + self._expected_num_chunks is not None and idx >= self._expected_num_chunks + ): logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) @@ -960,6 +1035,8 @@ def _attempt_to_commit_chunks(self): chunks_committed = [] while len(self.uncommited_chunks) > 0 and self.uncommited_chunks[0][0] == self.metadata_writer.num_chunks: _, chunk = heapq.heappop(self.uncommited_chunks) + chunk_number = self.metadata_writer.num_chunks + logger.debug(f"Committing chunk {chunk.name} of shard {self.shard_name}. It is chunk {chunk_number}") self.metadata_writer.commit_chunk(chunk) chunks_committed.append(chunk) @@ -996,6 +1073,7 @@ def __init__(self, cache_dir: str, shard_name: str): self.chunk_writers: dict[int, _ChunkWriter] = {} # chunk index -> writer self.batch_counts: dict[int, int] = {} # chunk index -> number of batches written self.expected_totals: dict[int, int] = {} # chunk index -> expected num batches. + self.failed_chunks: dict[int, ExceptionInfo] = {} # chunk index -> error self.chunk_partial_batches: dict[ int, list[tuple[int, pa.RecordBatch]] ] = {} # chunk index -> heapq of (batch index, batch) @@ -1014,21 +1092,29 @@ def chunk_finished_reading(self, chunk_id, expected_num_batches) -> Optional[Chu return self._attempt_to_write_chunk_fragments(chunk_id) def chunk_failed(self, chunk_id, error: ExceptionInfo): + self.failed_chunks[chunk_id] = error if chunk_id in self.chunk_writers: self.chunk_writers[chunk_id].__exit__(*error.restore()) del self.chunk_writers[chunk_id] def _attempt_to_write_chunk_fragments(self, chunk_id) -> Optional[ChunkMetadata]: + if chunk_id in self.failed_chunks: + logger.error(f"Chunk {chunk_id} of shard {self.shard_name} already failed, not writing more") + raise self.failed_chunks[chunk_id].restore() if chunk_id in self.chunk_partial_batches: chunk_batches = self.chunk_partial_batches[chunk_id] - while len(chunk_batches) > 0 and chunk_batches[0][0] == self.batch_counts[chunk_id]: + while len(chunk_batches) > 0: + batch_id, batch = chunk_batches[0] + if batch_id != self.batch_counts[chunk_id]: + break + # we can write this batch batch_id, batch = heapq.heappop(chunk_batches) if chunk_id not in self.chunk_writers: - assert batch_id == 0 + assert batch_id == 0, f"Expected batch 0 but got {batch_id}" chunk_name = os.path.join(self.shard_name, f"chunk-{chunk_id}") writer = _ChunkWriter(self.cache_dir, chunk_name, batch.schema) writer.__enter__() @@ -1073,7 +1159,7 @@ def __init__( processor: BatchProcessor[T], rows_per_chunk: int, ): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.logger = pylogging.getLogger(f"{__name__}.{name}") self.broker_ref = broker_ref self.shard_status: Dict[str, _ShardStatus] = dict() @@ -1094,20 +1180,25 @@ def __init__( self._processor_actors = [] for shard_name in source.shard_names: + self._current_round_robin.append(shard_name) self.shard_status[shard_name] = _ShardStatus() num_shards = len(source.shard_names) + num_worker_groups = len(ray.nodes()) + num_shard_groups = max(min(num_worker_groups, num_shards), 1) - def priority_fn(shard_idx, chunk_idx): - return chunk_idx * num_shards + shard_idx - - num_shard_groups = max(min(len(ray.nodes()), num_shards), 1) + # if we have a bunch of caches to build with one shard, we don't want them all + # assigned to the same node, so we use an offset based on the hash of the name (for stability) + # in an attempt to spread them out + group_offset = int(hash(name) % num_worker_groups) shard_groups: list[list[str]] = [[] for _ in range(num_shard_groups)] for i, shard_name in enumerate(source.shard_names): - self._current_round_robin.append(shard_name) shard_groups[i % num_shard_groups].append(shard_name) + def priority_fn(shard_idx, chunk_idx): + return chunk_idx * num_shards + shard_idx + for group_id, shard_group in enumerate(shard_groups): writer = _GroupShardWriterWorker.remote(self_ref, cache_dir, shard_group) # type: ignore self._shard_writers.append(writer) @@ -1126,10 +1217,12 @@ def priority_fn(shard_idx, chunk_idx): processor_actor=processor_actor, batch_size=processor.batch_size, num_rows_per_chunk=rows_per_chunk, + group_id=group_id, ) # we want global names so that different tasks can coordinate priorities - priority_actor_name = f"priority_processor.{group_id}" + worker_to_assign = (group_id + group_offset) % num_worker_groups + priority_actor_name = f"priority_processor.{worker_to_assign}" reader_actor = PriorityProcessorActor.options( # type: ignore name=priority_actor_name, get_if_exists=True @@ -1227,14 +1320,23 @@ def _attempt_to_flush_buffers(self): next_chunk = status.pop_chunk_to_send() if next_chunk is not None: # we can send a chunk from this shard - logger.debug(f"Sending chunk from {name}") self._current_round_robin.pop(0) self._current_round_robin.append(name) chunks_to_send.append(next_chunk) continue else: - logger.debug(f"Shard {name} has no chunks to send and is not known to be finished") # we can't send a chunk from this shard, so we can't send any additional chunks + if self.logger.level <= pylogging.DEBUG: + chunks_waiting = [ + f"{n2} ({len(s2.current_buffer)})" + for n2, s2 in self.shard_status.items() + if len(s2.current_buffer) > 0 + ] + msg = ( + f"Shard {name} has no chunks to send and is not known to be finished. We have this many queued" + f" chunks: {chunks_waiting}" + ) + self.logger.debug(msg) break if len(chunks_to_send) > 0: @@ -1256,7 +1358,7 @@ class ChunkCacheBroker: _finished_promise: asyncio.Future[None] def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.chunks = [] self._reader_promises = {} self._is_finished = False @@ -1268,6 +1370,9 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr # used to subscribe to metrics updates self._latest_metrics = InProgressCacheMetrics() self._metrics_condition = asyncio.Condition() + path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) + name = f"broker::{path_for_name}" + self.logger = pylogging.getLogger(f"{name}") # initialize writer task # first see if we need to do anything: check the ledger for is_finished @@ -1279,7 +1384,6 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr except FileNotFoundError: self_ref = ray.runtime_context.get_runtime_context().current_actor # only use the last two components of the name since it gets kind of long - path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) name = f"builder::{path_for_name}" self._builder_actor = ChunkCacheBuilder.remote(self_ref, self._cache_dir, name, self._source, self._processor, self._rows_per_chunk) # type: ignore @@ -1306,7 +1410,6 @@ async def get_chunk(self, chunk_idx: int) -> Optional[ChunkMetadata]: elif self._is_finished: return None else: - # we don't have this chunk yet, so we need to wait if chunk_idx not in self._reader_promises: self._reader_promises[chunk_idx] = asyncio.Future() return await self._reader_promises[chunk_idx] @@ -1321,7 +1424,9 @@ def _append_chunks(self, *chunks: ChunkMetadata): for chunk in chunks: self.chunks.append(chunk) chunk_idx = len(self.chunks) - 1 + self.logger.debug(f"Received chunk {chunk_idx}") if chunk_idx in self._reader_promises: + self.logger.debug(f"Resolving promise for chunk {chunk_idx}") self._reader_promises[chunk_idx].set_result(chunk) del self._reader_promises[chunk_idx] @@ -1358,7 +1463,9 @@ def _finalize(self): _serialize_json_and_commit(os.path.join(self._cache_dir, LEDGER_FILE_NAME), CacheLedger(self.chunks)) self._reader_promises = {} - self._builder_actor = None + # TODO: For some reason this crashes other actors with weird reference counting assertion errors. + # pretty sure it's a ray bug + # self._builder_actor = None self._finished_promise.set_result(None) # notify metrics subscribers @@ -1527,17 +1634,19 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N else: assert self._broker is not None time_in = time.time() + next_time = time_in # we want to also log if we're waiting for a long time, so we do this in a loop - while timeout is None or time.time() - time_in < timeout: + while timeout is None or next_time - time_in < timeout: current_timeout = 20.0 if timeout is not None: - current_timeout = min(current_timeout, timeout - (time.time() - time_in)) + current_timeout = min(current_timeout, timeout - (next_time - time_in)) try: chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout) except GetTimeoutError: - self.logger.warning(f"Waiting for chunk {mapped_index} for {int(time.time() - time_in)} seconds") + self.logger.warning(f"Waiting for chunk {mapped_index} for {int(next_time - time_in)} seconds") + next_time = time.time() current_timeout *= 2 - current_timeout = min(current_timeout, 80) + current_timeout = min(current_timeout, 100) continue if chunk is None: diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 212318433..00e17eb58 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -11,12 +11,15 @@ import braceexpand import datasets +import equinox as eqx import fsspec import jax import numpy as np import pyarrow as pa +import regex from draccus import field from jaxtyping import PRNGKeyArray +from tokenizers import normalizers import haliax as hax from haliax import Axis @@ -91,29 +94,32 @@ def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": def __iter__(self) -> Iterator[LmExample]: key = self.key - for tokens in self.dataset: - with use_cpu_device(): - example = self._create_lm_example(tokens, key) - yield example + sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) + + with use_cpu_device(): - @functools.partial(jax.jit, static_argnums=(0)) - def _create_lm_example(self, tokens, key): - tokens = hax.named(tokens, self.QPos) + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _create_lm_example(tokens, key): + tokens = hax.named(tokens, self.QPos) - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 - assert self.key is not None - this_key, key = jax.random.split(key) - fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) - attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) - example = dataclasses.replace(example, attn_mask=attn_mask) + if self.fcm_prob > 0: + # masks for attention + # We support forgetful causal masking (FCM) which is a technique that improves training speed by + # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention + # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 + assert self.key is not None + this_key, key = jax.random.split(key) + fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) + attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) + example = dataclasses.replace(example, attn_mask=attn_mask) - return example + return example + + for tokens in self.dataset: + example = _create_lm_example(tokens, key) + yield example class TokenSeqDataset(ShardableDataset[np.ndarray]): @@ -306,13 +312,27 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): os.environ["TOKENIZERS_PARALLELISM"] = "true" +LONG_STRING_WORKAROUND = 100_000 + + +ws = regex.compile(r"\s") + + class BatchTokenizer(BatchProcessor[str]): """ A batch processor that tokenizes a batch of strings using a tokenizer. By default, this will append eos to the end of the string, even if the tokenizer doesn't. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, enforce_eos=True, override_resources=None): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + enforce_eos=True, + *, + batch_size=128, + override_resources=None, + _workaround_len=LONG_STRING_WORKAROUND, + ): _maybe_force_tokenizer_parallelism(tokenizer) self.tokenizer = tokenizer self.override_resources = override_resources @@ -326,17 +346,104 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, enforce_eos=True, overrid else: should_append_eos = False + self._batch_size = batch_size + self._need_to_add_eos = should_append_eos + self._workaround_len = _workaround_len def __call__(self, batch: Sequence[str]) -> BatchEncoding: + orig_lengths = [len(d) for d in batch] if self._need_to_add_eos: - encoding = self.tokenizer( - [d + " " + self.tokenizer.eos_token for d in batch], return_attention_mask=False, verbose=False - ) + batch = [d + " " + self.tokenizer.eos_token for d in batch] + + if self._needs_long_sequence_workaround: + # break any strings that are longer than 50K characters into smaller chunks + orig_batch = batch + batch = [] + needs_merge = [] + for i, d in enumerate(orig_batch): + needs_merge.append(False) + orig_len = orig_lengths[i] + while len(d) > self._workaround_len: + # we'd rather break strings at whitespace, so find the first whitespace + match = ws.search(d, self._workaround_len) + # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit + if match is None: + split = len(d) + else: + split = match.start() + + batch.append(d[:split]) + needs_merge.append(True) + + d = d[split:] + orig_len -= split + + batch.append(d) else: - encoding = self.tokenizer(batch, return_attention_mask=False, verbose=False) # type: ignore + needs_merge = [] + + encoding = self.tokenizer(batch, return_attention_mask=False, verbose=False) # type: ignore + + if needs_merge: + new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) + encoding = BatchEncoding(new_encoding) + return encoding + @staticmethod + def _merge_split_encodings(batch, encoding, needs_merge): + # merge the encodings back together + # we might need to merge multiple encodings together + # needs merge marks the first n-1 encodings that need to be merged for each document + new_encoding = {} + for k, v in encoding.items(): + if len(v) == 0: + continue + if isinstance(v[0], np.ndarray): + assert len(v) == len(batch) + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + v_out.append(np.concatenate(vs_to_merge)) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(np.concatenate(vs_to_merge)) + + new_encoding[k] = v_out + elif isinstance(v[0], list): + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + new_encoding[k] = v_out + else: + raise ValueError(f"Unknown type {type(v[0])}") + return new_encoding + + # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1449 + @cached_property + def _needs_long_sequence_workaround(self): + if isinstance(self.tokenizer, PreTrainedTokenizerFast): + normalizer = self.tokenizer.backend_tokenizer.normalizer + if normalizer is None: + return False + # if there's a "Replace" normalizer, then we need to do the workaround + # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it + return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) + else: + return False + @property def num_cpus(self) -> int: if self.override_resources is not None: @@ -353,7 +460,7 @@ def num_gpus(self) -> int: @property def batch_size(self) -> int: - return 1024 + return self._batch_size def concatenate_and_group_texts( diff --git a/src/levanter/distributed.py b/src/levanter/distributed.py index c0442b45e..eefb71fc4 100644 --- a/src/levanter/distributed.py +++ b/src/levanter/distributed.py @@ -250,8 +250,17 @@ def _munge_address_port(address: str): logger.info(f"Successfully started ray worker and connected to {address}.") logger.info(f"ray.init(address={repr(address)}, namespace={repr(namespace)}, **{repr(kwargs)})") - # Ray has retry logic, so we don't need to retry here :fingers-crossed: - ray.init(address=address, namespace=namespace, **kwargs) + # Ray has retry logic, but it doesn't seem to work super well, so we retry manually + for i in range(0, 5): + try: + ray.init(address=address, namespace=namespace, **kwargs) + break + except Exception as e: + if i == 4: + raise e + else: + logger.warning(f"Failed to initialize ray with address {address}. Retrying...") + continue atexit.register(lambda: ray.shutdown()) _already_initialized = True diff --git a/tests/test_text.py b/tests/test_text.py index 26fe98aa5..70b2d26a7 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,17 +1,18 @@ import tempfile import jax.numpy as jnp +from transformers import AutoTokenizer import haliax as hax -from levanter.data.text import LMDatasetConfig +from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.models.lm_model import LmExample from levanter.models.loss import next_token_loss +from test_utils import skip_if_hf_model_not_accessible def test_dont_blow_up_without_validation_set(): with tempfile.TemporaryDirectory() as tmpdir: - config = LMDatasetConfig( train_urls=["kaa"], validation_urls=[], @@ -40,3 +41,29 @@ def test_lm_example_handles_ignore_id(): no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size + + +def test_merge_split_encodings(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + # make this very short for testing + + lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.""" + + short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3) + # force this + short_batch_tokenizer._needs_long_sequence_workaround = True + + batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000) + batch = [lorem] + + short_out = short_batch_tokenizer(batch) + reg_out = batch_tokenizer(batch) + + assert short_out == reg_out + + +@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") +def test_llama_tokenizer_needs_long_sequence_workaround(): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + batch_tokenizer = BatchTokenizer(tokenizer) + assert batch_tokenizer._needs_long_sequence_workaround From 24c13097a4558da3c2bc2e8902e599eb2942a4e4 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 7 Feb 2024 10:58:28 -0800 Subject: [PATCH 15/19] Optim package (#457) --- docs/Configuration-Guide.md | 6 +- examples/alpaca/alpaca.py | 3 +- examples/gsm8k-lora/gsm8k_lora.py | 3 +- src/levanter/main/lora_lm.py | 5 +- src/levanter/main/train_lm.py | 5 +- src/levanter/optim/__init__.py | 1 + src/levanter/optim/config.py | 160 ++++++++++++++++++++++++++++++ src/levanter/trainer.py | 139 +------------------------- tests/test_hf_gpt2_serialize.py | 4 +- tests/test_weight_decay_mask.py | 6 +- 10 files changed, 183 insertions(+), 149 deletions(-) create mode 100644 src/levanter/optim/__init__.py create mode 100644 src/levanter/optim/config.py diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index 7927da154..bbc7d7cb5 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -13,7 +13,7 @@ class TrainLmConfig: data: LMDatasetConfig = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) ``` Your training run will typically be associated with a single config file. For instance, you might have a file @@ -238,7 +238,7 @@ If you're not using SLURM or TPUs, you can specify the cluster manually using th ## Optimizer -[levanter.trainer.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields: +[levanter.optim.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields: | Parameter | Description | Default | |-----------------|-------------------------------------------------------------------|----------| @@ -288,7 +288,7 @@ We won't go into detail here. You can see the auto-generated docs below. ### Optimizer -::: levanter.trainer.OptimizerConfig +::: levanter.optim.OptimizerConfig ### LM Model diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index e02b0738a..36a6dd943 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -18,7 +18,8 @@ from levanter.data import Dataset from levanter.data.sharded_dataset import JsonDataset, JsonlDataset, WrappedHFDataset from levanter.models.lm_model import LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils import fsspec_utils from levanter.utils.hf_utils import num_cpus_used_by_tokenizer from levanter.utils.py_utils import non_caching_cycle diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 6b369bf77..5e4927d2f 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -25,7 +25,8 @@ save_peft_checkpoint_callback, ) from levanter.models.lm_model import LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.hf_utils import num_cpus_used_by_tokenizer from levanter.utils.jax_utils import parameter_count from levanter.utils.py_utils import non_caching_cycle diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index d19c80943..93d60588a 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -19,7 +19,8 @@ save_merged_hf_checkpoint_callback, save_peft_checkpoint_callback, ) -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count from levanter.utils.py_utils import non_caching_cycle @@ -33,7 +34,7 @@ class LoraLmConfig: lora: LoraConfig = field(default_factory=LoraConfig) data: LMDatasetConfig = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) peft_save_path: Optional[str] = None # path to save peft-compatible checkpoints peft_hf_upload: Optional[str] = None diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 982f72358..f5b6e83b4 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -17,7 +17,8 @@ from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count @@ -29,7 +30,7 @@ class TrainLmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) # config related to continued pretraining initialize_from_hf: Union[bool, str] = False diff --git a/src/levanter/optim/__init__.py b/src/levanter/optim/__init__.py new file mode 100644 index 000000000..749c4e642 --- /dev/null +++ b/src/levanter/optim/__init__.py @@ -0,0 +1 @@ +from .config import AdamConfig, OptimizerConfig diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py new file mode 100644 index 000000000..ad8708c6a --- /dev/null +++ b/src/levanter/optim/config.py @@ -0,0 +1,160 @@ +import abc +import re +import warnings +from dataclasses import dataclass +from typing import Optional + +import draccus +import equinox as eqx +import jax +import optax +from jax import numpy as jnp + +from levanter.utils.jax_utils import leaf_key_paths + + +@dataclass +class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): + learning_rate: float = 6e-4 + weight_decay: float = 0.0 + + min_lr_ratio: float = 0.1 + warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup + warmup: float = 0.01 + """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" + cooldown: float = 0.0 + """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" + lr_schedule: str = "cosine" # constant, cosine, linear + weight_decay_modules: Optional[list[str] | str] = None + """A regex or a list of strings to identify where to mask weight. + For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`""" + + @classmethod + def default_choice_name(cls) -> Optional[str]: + return "adam" + + @abc.abstractmethod + def build(self, num_train_steps: int): + raise NotImplementedError + + def build_weight_decay_mask(self): + if self.weight_decay_modules is None: + return None + else: + # mask based on regex or module path + def _apply_on(x, key_path): + if isinstance(self.weight_decay_modules, str): + compiled_regex = re.compile(self.weight_decay_modules) + return compiled_regex.match(key_path) is not None + else: + return any(key_path.__contains__(target) for target in self.weight_decay_modules) + + def mask_fn(model): + return jax.tree_util.tree_map( + _apply_on, + model, + leaf_key_paths(model, is_leaf=eqx.is_array), + is_leaf=eqx.is_array, + ) + + return mask_fn + + def lr_scheduler(self, num_train_steps): + warmup_steps = self._convert_warmup(num_train_steps) + cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) + lr_decay_steps = num_train_steps - warmup_steps - cooldown_steps + min_lr = self.learning_rate * self.min_lr_ratio + + match self.lr_schedule: + case "constant": + schedule = optax.constant_schedule(self.learning_rate) + case "cosine": + schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio) + case "linear": + schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps - warmup_steps) + case "inv_sqrt": + schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000) + case _: + raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}") + + schedules = [] + boundaries = [] + + if warmup_steps != 0: + warmup = optax.linear_schedule(0.0, self.learning_rate, warmup_steps) + schedules.append(warmup) + boundaries.append(warmup_steps) + + schedules.append(schedule) + + if cooldown_steps != 0: + final_main_lr = schedule(lr_decay_steps) + cooldown = optax.linear_schedule(final_main_lr, min_lr, cooldown_steps) + schedules.append(cooldown) + boundaries.append(num_train_steps - cooldown_steps) + + if len(schedules) > 1: + schedule = optax.join_schedules(schedules, boundaries) + + return schedule + + def _convert_warmup(self, num_train_steps: int): + if self.warmup_ratio is not None: + warnings.warn("warmup_ratio is deprecated. Use warmup instead") + return int(self.warmup_ratio * num_train_steps) + else: + return _convert_ratio_or_steps(self.warmup, num_train_steps) + + +def _inv_sqrt_decay_schedule(lr: float, min_lr: float, warmup_steps: int, timescale: float = 10000): + def schedule(count): + decay = jnp.minimum(1.0, 1.0 / jnp.sqrt(jnp.maximum(count + warmup_steps, 1) / timescale)) + return jnp.maximum(lr * decay, min_lr) + + return schedule + + +def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): + if ratio_or_steps < 1.0: + return int(ratio_or_steps * num_train_steps) + else: + return int(ratio_or_steps) + + +@dataclass +class HessianOptConfig(OptimizerConfig, abc.ABC): + update_interval: int = 10 + """How often to update the hessian approximation.""" + + +@OptimizerConfig.register_subclass("adam") +@dataclass +class AdamConfig(OptimizerConfig): + weight_decay: float = 0.1 + beta1: float = 0.9 + beta2: float = 0.999 + epsilon: float = 1e-8 + max_grad_norm: Optional[float] = 1.0 + + def build(self, num_train_steps): + """Creates the optimizer""" + # indirection makes it work with optax.inject_hyperparams so we can log the learning rate + def _optimizer(learning_rate): + components = [] + + if self.max_grad_norm: + components.append(optax.clip_by_global_norm(self.max_grad_norm)) + + components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) + + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = optax.chain(*components) + + return optimizer + + return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 5fe6e0302..e2083a750 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -3,10 +3,8 @@ import functools import logging as pylogging import os -import re import sys import typing -import warnings from dataclasses import dataclass from functools import cached_property from pathlib import Path @@ -14,10 +12,8 @@ import equinox as eqx import jax -import jax.numpy as jnp import jmp import numpy as np -import optax import wandb from draccus import field from jax import ShapeDtypeStruct @@ -38,9 +34,12 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched from levanter.logging import WandbConfig, capture_time + +# import for backward compatibility +from levanter.optim import OptimizerConfig # noqa: F401 from levanter.types import FilterSpec from levanter.utils import cloud_utils -from levanter.utils.jax_utils import is_inexact_arrayish, leaf_key_paths +from levanter.utils.jax_utils import is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -705,135 +704,5 @@ def _validate_and_set_defaults(self): self.per_device_eval_parallelism = self.per_device_parallelism -@dataclass -class OptimizerConfig: - # Config related to optimizer (always adam for now) - learning_rate: float = 6e-4 - weight_decay: float = 0.0 - beta1: float = 0.9 - beta2: float = 0.999 - epsilon: float = 1e-8 - max_grad_norm: Optional[float] = 1.0 - - min_lr_ratio: float = 0.1 - warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup - warmup: float = 0.01 - """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" - cooldown: float = 0.0 - """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" - lr_schedule: str = "cosine" # constant, cosine, linear - """a regex or a list of strings to identify where to mask weight. """ - """For nano-GPT, this field can be set as - `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`""" - weight_decay_modules: Optional[Union[List[str], str]] = None - - def build(self, num_train_steps: int) -> GradientTransformation: - """Creates the optimizer""" - - # indirection makes it work with optax.inject_hyperparams so we can log the learning rate - def _optimizer(learning_rate): - components = [] - - if self.max_grad_norm: - components.append(optax.clip_by_global_norm(self.max_grad_norm)) - - components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) - - if self.weight_decay > 0: - components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) - - # - learning rate for descent - components.append(optax.scale(-learning_rate)) - - optimizer = optax.chain(*components) - - return optimizer - - return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) - - def build_weight_decay_mask(self): - if self.weight_decay_modules is None: - return None - else: - # mask based on regex or module path - def _apply_on(x, key_path): - if isinstance(self.weight_decay_modules, str): - compiled_regex = re.compile(self.weight_decay_modules) - return compiled_regex.match(key_path) is not None - else: - return any(key_path.__contains__(target) for target in self.weight_decay_modules) - - def mask_fn(model): - return jax.tree_util.tree_map( - _apply_on, - model, - leaf_key_paths(model, is_leaf=eqx.is_array), - is_leaf=eqx.is_array, - ) - - return mask_fn - - def lr_scheduler(self, num_train_steps): - warmup_steps = self._convert_warmup(num_train_steps) - cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) - lr_decay_steps = num_train_steps - warmup_steps - cooldown_steps - min_lr = self.learning_rate * self.min_lr_ratio - - match self.lr_schedule: - case "constant": - schedule = optax.constant_schedule(self.learning_rate) - case "cosine": - schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio) - case "linear": - schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps - warmup_steps) - case "inv_sqrt": - schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000) - case _: - raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}") - - schedules = [] - boundaries = [] - - if warmup_steps != 0: - warmup = optax.linear_schedule(0.0, self.learning_rate, warmup_steps) - schedules.append(warmup) - boundaries.append(warmup_steps) - - schedules.append(schedule) - - if cooldown_steps != 0: - final_main_lr = schedule(lr_decay_steps) - cooldown = optax.linear_schedule(final_main_lr, min_lr, cooldown_steps) - schedules.append(cooldown) - boundaries.append(num_train_steps - cooldown_steps) - - if len(schedules) > 1: - schedule = optax.join_schedules(schedules, boundaries) - - return schedule - - def _convert_warmup(self, num_train_steps: int): - if self.warmup_ratio is not None: - warnings.warn("warmup_ratio is deprecated. Use warmup instead") - return int(self.warmup_ratio * num_train_steps) - else: - return _convert_ratio_or_steps(self.warmup, num_train_steps) - - -def _inv_sqrt_decay_schedule(lr: float, min_lr: float, warmup_steps: int, timescale: float = 10000): - def schedule(count): - decay = jnp.minimum(1.0, 1.0 / jnp.sqrt(jnp.maximum(count + warmup_steps, 1) / timescale)) - return jnp.maximum(lr * decay, min_lr) - - return schedule - - def _params_only(t): return eqx.filter(t, is_inexact_arrayish) - - -def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): - if ratio_or_steps < 1.0: - return int(ratio_or_steps * num_train_steps) - else: - return int(ratio_or_steps) diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index c9c2bbc61..34c4bb941 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -17,7 +17,7 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel from levanter.models.loss import next_token_loss -from levanter.trainer import OptimizerConfig +from levanter.optim import AdamConfig from levanter.utils.tree_utils import inference_mode from test_utils import skip_if_no_torch @@ -142,7 +142,7 @@ def compute_loss(model, input_ids): assert onp.isclose(jax_g, torch_g.detach().cpu().numpy(), rtol=1e-2, atol=1e-2).all(), f"{jax_g} != {torch_g}" # now we also want to check that the optimizers do similar things - optimizer_config = OptimizerConfig(weight_decay=0.0, learning_rate=1e-3, warmup_ratio=0.0, lr_schedule="constant") + optimizer_config = AdamConfig(weight_decay=0.0, learning_rate=1e-3, warmup_ratio=0.0, lr_schedule="constant") if optimizer_config.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(torch_model.parameters(), optimizer_config.max_grad_norm) diff --git a/tests/test_weight_decay_mask.py b/tests/test_weight_decay_mask.py index 52834e679..cc94c5749 100644 --- a/tests/test_weight_decay_mask.py +++ b/tests/test_weight_decay_mask.py @@ -5,7 +5,7 @@ import haliax as hax from levanter.models.gpt2 import Gpt2Config -from levanter.trainer import OptimizerConfig +from levanter.optim import AdamConfig def test_weight_decay_masking(): @@ -43,7 +43,7 @@ def apply_weight_decay(tree): gpt_config = Gpt2Config() Vocab = hax.Axis("vocab", 100) model = gpt_config.build(Vocab, key=jrandom.PRNGKey(0)) - string_list_config = OptimizerConfig( + string_list_config = AdamConfig( weight_decay_modules=[ "attn.c_attn.weight", "attn.c_proj.weight", @@ -53,7 +53,7 @@ def apply_weight_decay(tree): "position_embeddings.weight", ] ) - regex_config = OptimizerConfig( + regex_config = AdamConfig( weight_decay_modules=r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings", ) # masking using `equinox.tree_at` From 7078f80880884d332379589a9ac80dabf2e2c49a Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 9 Feb 2024 12:21:18 -0800 Subject: [PATCH 16/19] Sophia (#458) --- .flake8 | 2 +- .github/workflows/run_entry_tests.yaml | 2 - .github/workflows/run_tests.yaml | 4 +- README.md | 7 +- config/gpt2_large_sophia_h.yaml | 21 ++ config/gpt2_small_fast_sophia_h.yaml | 24 ++ config/gpt2_small_fast_sophiah.yaml | 26 ++ config/gpt2_small_pile.yaml | 23 ++ config/gpt2_small_sophiah.yaml | 19 + config/optim/sophia-h_large.yaml | 7 + config/optim/sophia-h_medium.yaml | 7 + config/optim/sophia-h_small.yaml | 7 + config/optim/sophia-h_xl.yaml | 7 + docs/Configuration-Guide.md | 22 +- docs/Levanter-1.0-Release.md | 4 +- docs/{ => dev}/Port-Models.md | 4 +- mkdocs.yml | 3 +- src/levanter/__init__.py | 3 + src/levanter/logging.py | 2 + src/levanter/lora.py | 16 +- src/levanter/optim/__init__.py | 6 + src/levanter/optim/sophia.py | 463 +++++++++++++++++++++++++ src/levanter/optim/util.py | 23 ++ src/levanter/trainer.py | 32 +- src/levanter/types.py | 45 ++- src/levanter/utils/hf_utils.py | 3 +- src/levanter/utils/jax_utils.py | 42 ++- tests/data/hero_data.npy | Bin 0 -> 32128 bytes tests/test_logging.py | 1 - tests/test_mpt.py | 11 +- tests/test_sophia.py | 66 ++++ 31 files changed, 836 insertions(+), 66 deletions(-) create mode 100644 config/gpt2_large_sophia_h.yaml create mode 100644 config/gpt2_small_fast_sophia_h.yaml create mode 100644 config/gpt2_small_fast_sophiah.yaml create mode 100644 config/gpt2_small_pile.yaml create mode 100644 config/gpt2_small_sophiah.yaml create mode 100644 config/optim/sophia-h_large.yaml create mode 100644 config/optim/sophia-h_medium.yaml create mode 100644 config/optim/sophia-h_small.yaml create mode 100644 config/optim/sophia-h_xl.yaml rename docs/{ => dev}/Port-Models.md (98%) create mode 100644 src/levanter/optim/sophia.py create mode 100644 src/levanter/optim/util.py create mode 100644 tests/data/hero_data.npy create mode 100644 tests/test_sophia.py diff --git a/.flake8 b/.flake8 index d067c43ce..636dc598f 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ [flake8] exclude = .git max-line-length = 120 -ignore = E203, E501, W503, W605, F821, E266 +ignore = E203, E501, W503, W605, F821, E266, E731 per-file-ignores = */__init__.py: F401 examples/*.py: E402 diff --git a/.github/workflows/run_entry_tests.yaml b/.github/workflows/run_entry_tests.yaml index c958e9bf2..dbde2dbd1 100644 --- a/.github/workflows/run_entry_tests.yaml +++ b/.github/workflows/run_entry_tests.yaml @@ -21,8 +21,6 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: Run entry tests with pytest run: | diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 46828d5b8..3af69bacf 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -21,9 +21,7 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: Test with pytest run: | - XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry" + XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry and not slow" diff --git a/README.md b/README.md index b240f3560..5a6b89cf6 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,13 @@ Haliax's documentation is available at [haliax.readthedocs.io](https://haliax.re * **Distributed Training**: We support distributed training on TPUs (and soon, GPUs), including FSDP and tensor parallelism. * **Compatibility**: Levanter supports importing and exporting models to/from the Hugging Face ecosystem, including tokenizers, datasets, and models via [SafeTensors](https://github.com/huggingface/safetensors). * **Performance**: Levanter's performance rivals commercially-backed frameworks like MosaicML's Composer or Google's MaxText. -* **Reproducibility**: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption. * **Cached On-Demand Data Preprocessing**: We preprocess corpora online, but we cache the results of preprocessing so that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training. -* **Logging**: Logging is done with [WandB](https://wandb.ai/), complete with a fancy online visualization of the validation set during training. +* **Optimization**: Levanter supports the new [Sophia](https://arxiv.org/abs/2305.14342) optimizer, which can be 2x as fast as Adam. We also support ses [Optax](https://github.com/deepmind/optax) for optimization with AdamW, etc. +* **Logging**: Levanter supports a few different logging backends, including [WandB](https://wandb.ai/site) and [TensorBoard](https://www.tensorflow.org/tensorboard). (Adding a new logging backend is easy!) Levanter even exposes the ability +to log inside of JAX `jit`-ted functions. +* **Reproducibility**: On TPU, Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption. * **Distributed Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now. -* **Optimization**: Levanter uses [Optax](https://github.com/deepmind/optax) for optimization. Our new optimizer, [Sophia](https://arxiv.org/abs/2305.14342), is available in the [dev branch](https://github.com/stanford-crfm/levanter/tree/dev). diff --git a/config/gpt2_large_sophia_h.yaml b/config/gpt2_large_sophia_h.yaml new file mode 100644 index 000000000..314801728 --- /dev/null +++ b/config/gpt2_large_sophia_h.yaml @@ -0,0 +1,21 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 1280 + num_heads: 20 + num_layers: 36 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-h"] + + num_train_steps: 200000 + mp: p=f32,c=bfloat16 + +optimizer: + type: sophia-h + learning_rate: 1.7E-4 + weight_decay: 0.2 diff --git a/config/gpt2_small_fast_sophia_h.yaml b/config/gpt2_small_fast_sophia_h.yaml new file mode 100644 index 000000000..671acec8f --- /dev/null +++ b/config/gpt2_small_fast_sophia_h.yaml @@ -0,0 +1,24 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest", "sophia-h"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + type: sophia-h + learning_rate: .85E-3 + weight_decay: 0.2 diff --git a/config/gpt2_small_fast_sophiah.yaml b/config/gpt2_small_fast_sophiah.yaml new file mode 100644 index 000000000..71675312c --- /dev/null +++ b/config/gpt2_small_fast_sophiah.yaml @@ -0,0 +1,26 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + type: sophia-h + learning_rate: 0.8E-3 + weight_decay: 0.1 + warmup: 0.01 + gamma: 0.005 diff --git a/config/gpt2_small_pile.yaml b/config/gpt2_small_pile.yaml new file mode 100644 index 000000000..ab7503871 --- /dev/null +++ b/config/gpt2_small_pile.yaml @@ -0,0 +1,23 @@ +data: !include data/pile_source_old.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 2048 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "pile", "gpt2"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 50000 +optimizer: + learning_rate: 6e-4 + weight_decay: 0.1 diff --git a/config/gpt2_small_sophiah.yaml b/config/gpt2_small_sophiah.yaml new file mode 100644 index 000000000..1dd5824c3 --- /dev/null +++ b/config/gpt2_small_sophiah.yaml @@ -0,0 +1,19 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-h"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + + train_batch_size: 512 +optimizer: !include optim/sophia-h_small.yaml diff --git a/config/optim/sophia-h_large.yaml b/config/optim/sophia-h_large.yaml new file mode 100644 index 000000000..6644f20b8 --- /dev/null +++ b/config/optim/sophia-h_large.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 3E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_medium.yaml b/config/optim/sophia-h_medium.yaml new file mode 100644 index 000000000..5c411f109 --- /dev/null +++ b/config/optim/sophia-h_medium.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 4E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_small.yaml b/config/optim/sophia-h_small.yaml new file mode 100644 index 000000000..0bb8ea2a7 --- /dev/null +++ b/config/optim/sophia-h_small.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 6E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_xl.yaml b/config/optim/sophia-h_xl.yaml new file mode 100644 index 000000000..fe2c868b3 --- /dev/null +++ b/config/optim/sophia-h_xl.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 1.2E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index bbc7d7cb5..607129e1a 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -213,11 +213,11 @@ the machines being used for training. This is useful for distributed preprocessing. You can disable this behavior using `auto_start_cluster: false`. -| Parameter | Description | Default | -|---------------------|-----------------------------------------------------------------------------|---------| -| `address` | The address of the Ray cluster to connect to. | `None` | -| `start_workers` | Whether to start Ray workers. If `False`, you must start them yourself. | `True` | -| `auto_start_cluster`| Whether to start a Ray cluster automatically. | `True` | +| Parameter | Description | Default | +|----------------------|-------------------------------------------------------------------------|---------| +| `address` | The address of the Ray cluster to connect to. | `None` | +| `start_workers` | Whether to start Ray workers. If `False`, you must start them yourself. | `True` | +| `auto_start_cluster` | Whether to start a Ray cluster automatically. | `True` | ## Distributed Config @@ -227,12 +227,12 @@ If you're not using SLURM or TPUs, you can specify the cluster manually using th **Don't use this on TPU, and possibly not on SLURM either.** -| Parameter | Description | Default | -|---------------------|-----------------------------------------------------------------------------|-------------------------| -| `coordinator_address`| The address of the coordinator. If `None`, we'll use the default address. | `None` | -| `num_processes` | The number of processes in the cluster. | `None` | -| `process_id` | The process id of this process. | `None` | -| `local_device_ids` | The local device ids of this process. | ${CUDA_VISIBLE_DEVICES} | +| Parameter | Description | Default | +|-----------------------|---------------------------------------------------------------------------|-------------------------| +| `coordinator_address` | The address of the coordinator. If `None`, we'll use the default address. | `None` | +| `num_processes` | The number of processes in the cluster. | `None` | +| `process_id` | The process id of this process. | `None` | +| `local_device_ids` | The local device ids of this process. | ${CUDA_VISIBLE_DEVICES} | diff --git a/docs/Levanter-1.0-Release.md b/docs/Levanter-1.0-Release.md index 8fed293dd..05c66683a 100644 --- a/docs/Levanter-1.0-Release.md +++ b/docs/Levanter-1.0-Release.md @@ -539,7 +539,7 @@ learn differently from Transformers. ## A few other features * **Training**: Levanter uses [Optax](https://github.com/deepmind/optax) for optimization, - though our new optimizer, [Sofia](https://arxiv.org/abs/2305.14342), is coming to Levanter soon! + though our new optimizer, [Sophia](https://arxiv.org/abs/2305.14342), is coming to Levanter soon! * **Logging**: Logging is done with [WandB](https://wandb.ai/), complete with a fancy online visualization of the validation set during training. * **Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now. * **Export**: We also support exporting models to the Hugging Face Hub, with export compatible with Pytorch and Transformers via [SafeTensors](https://github.com/huggingface/safetensors). @@ -627,7 +627,7 @@ trained on the [Lakh MIDI](https://colinraffel.com/projects/lmd/) corpus. The la This is just the beginning for Levanter. In the future, look for: * more models on interesting problem domains, * scaled up versions of new architectures developed here at Stanford and elsewhere, -* new training techniques, including the newly released [Sofia](https://arxiv.org/abs/2305.14342) optimizer, +* new training techniques, including the newly released [Sophia](https://arxiv.org/abs/2305.14342) optimizer, * and larger models! Levanter is still a work in progress, but we are excited to share it with the community. We hope that Levanter will be diff --git a/docs/Port-Models.md b/docs/dev/Port-Models.md similarity index 98% rename from docs/Port-Models.md rename to docs/dev/Port-Models.md index f75fa7534..41228c2a3 100644 --- a/docs/Port-Models.md +++ b/docs/dev/Port-Models.md @@ -287,7 +287,7 @@ model: num_layers: 2 ``` -For more details on the training configuration, please refer to [Configuration Guide](./Configuration-Guide.md). +For more details on the training configuration, please refer to [Configuration Guide](../Configuration-Guide.md). ### Launch Training Job Once you have your training configuration ready and your training environment set up, you can launch a training job with the following command: @@ -299,7 +299,7 @@ HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \ python levanter/src/levanter/main/train_lm.py --config_path $CONFIG_PATH ``` -Check out [Training on Your Own Data](./Training-On-Your-Data.md) for more detailed guide on how to spin off a training cluster and launch a training job. +Check out [Training on Your Own Data](../Training-On-Your-Data.md) for more detailed guide on how to spin off a training cluster and launch a training job. ### Profile Your Model If you are interested in profiling the training throughput of your model, good news is that it comes for free with automatic job monitoring in Levanter, powered through Weights & Biases. diff --git a/mkdocs.yml b/mkdocs.yml index 35fdaf5c4..568716ac4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -97,7 +97,8 @@ nav: - "tutorials/Fine-Tuning-Semantic-Parsing.md" - "Hardware-Agnostic-Training.md" - 'Developer Guide': - - 'Port-Models.md' + - 'dev/Port-Models.md' +# - 'dev/Trackers.md' - 'FAQ' : 'faq.md' - Other: - 'Levanter-1.0-Release.md' diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 33bcd249d..30c32a712 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -3,4 +3,7 @@ import levanter.data as data import levanter.distributed as distributed import levanter.logging as logging +import levanter.models as models +import levanter.optim as optim +import levanter.trainer as trainer import levanter.visualization as visualization diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 7ffa90c91..4fbb4a618 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -50,6 +50,8 @@ def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: # use ISO 8601 format for timestamps, except no TZ, because who cares date_format = "%Y-%m-%dT%H:%M:%S" + os.makedirs(os.path.dirname(path), exist_ok=True) + handlers: List[pylogging.Handler] = [pylogging.FileHandler(path, mode="a"), pylogging.StreamHandler()] # Create Root Logger w/ Base Formatting diff --git a/src/levanter/lora.py b/src/levanter/lora.py index 0af676ef5..3e0dee750 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -367,14 +367,14 @@ def save_peft_checkpoint_callback( If hf_repo is provided, this will upload the checkpoint to the huggingface hub, passing any additional kwargs to the huggingface_hub.upload_folder function. - Args - base_path: the base path to save the checkpoint to. `/step-` will be appended to this. base_path - may be a GCS bucket path, in which case the checkpoint will be uploaded to GCS after being written to a tmp - config: the LoRA config to use - base_model_name_or_path: the name or path of the base model - tokenizer: If provided, will save the tokenizer to the checkpoint - upload_to_hf: the repo to upload to. If a string, will be interpreted as a repo name + branch - hf_upload_kwargs: kwargs to pass to the upload function + Args: + base_path: the base path to save the checkpoint to. `/step-` will be appended to this. base_path + may be a GCS bucket path, in which case the checkpoint will be uploaded to GCS after being written to a tmp + config: the LoRA config to use + base_model_name_or_path: the name or path of the base model + tokenizer: If provided, will save the tokenizer to the checkpoint + upload_to_hf: the repo to upload to. If a string, will be interpreted as a repo name + branch + hf_upload_kwargs: kwargs to pass to the upload function """ def cb(step: StepInfo): diff --git a/src/levanter/optim/__init__.py b/src/levanter/optim/__init__.py index 749c4e642..7dec2ebb4 100644 --- a/src/levanter/optim/__init__.py +++ b/src/levanter/optim/__init__.py @@ -1 +1,7 @@ from .config import AdamConfig, OptimizerConfig +from .sophia import ( # SophiaGConfig,; SophiaGObjective, + ScaleBySophiaState, + SophiaHConfig, + scale_by_sophia_g, + scale_by_sophia_h, +) diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py new file mode 100644 index 000000000..8895942a2 --- /dev/null +++ b/src/levanter/optim/sophia.py @@ -0,0 +1,463 @@ +import abc +import functools +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional, TypeVar + +import equinox as eqx +import jax +import jaxtyping +import optax +from jax import numpy as jnp +from jax.random import PRNGKey +from jaxtyping import PRNGKeyArray + +# import levanter.tracker +from levanter.optim.config import HessianOptConfig, OptimizerConfig +from levanter.optim.util import hvp, tree_gaussian_like +from levanter.utils.jax_utils import parameter_count, tree_filter_like + + +M = TypeVar("M") +Ex = TypeVar("Ex") + +GAMMA_SOPHIA_G = 0.05 +GAMMA_SOPHIA_H = 0.01 + + +class ScaleBySophiaState(NamedTuple): + """State for Sophia and similar.""" + + count: jaxtyping.Array # shape=(), dtype=jnp.int32. + hessian_count: jaxtyping.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates # momentum + h: optax.Updates # EMA of hessian diagonal + hess_key: PRNGKey + + +# @runtime_checkable +# class SophiaGObjective(typing.Protocol): +# """ +# Class for objective functions that can be used with Sophia-G +# +# Sophia-G is a second order optimizer that uses the Gauss-Newton-Bartlett approximation to the Hessian +# to compute the second order update. This requires the objective function be of the form loss(logits(x)) +# where logits(x) is the activation of the model for the given example x. This is the case for most models +# that are trained with "typical" losses. +# """ +# +# def logits(self, parameters: M, *args, **kwargs) -> Any: +# """ +# Returns the logits/activations of the model for the given example, +# or just sufficient statistics for the example for non-categorical models. +# """ +# ... +# +# def sample(self, logits, *example, key: PRNGKey, **kwargs) -> Ex: +# """ +# Samples a new example with the same shape as the original example, but with +# the "labels" replaced with some sampled values +# """ +# ... +# +# def loss(self, logits, *example: Ex, **kwargs) -> jnp.ndarray: +# """ +# Just computes the loss, e.g. cross entropy. +# +# Should return the mean loss over the batch, not the sum. +# +# TODO: should we reconsider this? +# """ +# ... +# +# def __call__(self, parameters: M, *args, **kwargs) -> jnp.ndarray: +# """ +# Just a convenience method for invoking the objective for "normal" training w/o sophia-g +# """ +# logits = self.logits(parameters, *args, **kwargs) +# return self.loss(logits, *args, **kwargs) +# +# def num_data_points(self, example: Ex) -> int: +# """ +# Returns the number of data points in the example. This should take into account the loss mask +# or any other masking that might be applied to the example. +# +# By default, we just return 1, and you can just pull the term into the hyperparams of Sophia if you want. +# +# Returns: +# The number of data points in the example +# """ +# return 1 +# +# +# def apply_partial(self, *args, **kwargs) -> "SophiaGObjective": +# """ +# Returns a new objective that is a partial application of the current objective, used for +# passing in the data points. +# """ +# +# +# +# class PartialSophiaG(SophiaGObjective): +# def __init__(self, objective: SophiaGObjective, *args, **kwargs): +# self.objective = objective +# self.args = args +# self.kwargs = kwargs +# +# def logits(self, parameters: M, *args, **kwargs) -> Any: +# return self.objective.logits(parameters, *self.args, *args, **self.kwargs, **kwargs) +# +# def sample(self, logits, *example, key: PRNGKey, **kwargs) -> Ex: +# return self.objective.sample(logits, *self.args, *example, key=key, **self.kwargs, **kwargs) +# +# def loss(self, logits, *example: Ex, **kwargs) -> jnp.ndarray: +# return self.objective.loss(logits, *self.args, *example, **self.kwargs, **kwargs) +# +# def __call__(self, parameters: M, *args, **kwargs) -> jnp.ndarray: +# return self.objective(parameters, *self.args, *args, **self.kwargs, **kwargs) +# +# def num_data_points(self, example: Ex) -> int: +# return self.objective.num_data_points(*self.args, example, **self.kwargs) +# +# def apply_partial(self, *args, **kwargs) -> SophiaGObjective: +# return PartialSophiaG(self.objective, *self.args, *args, **self.kwargs, **kwargs) + + +@dataclass +class BaseSophiaConfig(HessianOptConfig): + """Base class for sophia variants. Doesn't implement the state update""" + + weight_decay: float = 0.1 + beta1: float = 0.96 + beta2: float = 0.99 + + epsilon: float = 1e-12 + clip_threshold: Optional[float] = 1.0 + rng_seed: int = 0 + + @abc.abstractmethod + def compute_hessian( + self, + fn, + model, + *batch, + hess_key: PRNGKey, + **batch_kwargs, + ): + raise NotImplementedError + + def build(self, num_train_steps: int): + def _optimizer(learning_rate, gamma) -> optax.GradientTransformation: + components = [] + key = jax.random.PRNGKey(self.rng_seed) + + components.append( + _sophia_gradient_transform( + sophia_hess_fn=self.compute_hessian, + update_interval=self.update_interval, + b1=self.beta1, + b2=self.beta2, + eps=self.epsilon, + gamma=gamma, + initial_key=key, + clip_threshold=self.clip_threshold, + ) + ) + + # Algorithm 3, step 11 (Note, this comes after clipping b/c it's not supposed to be clipped) + # In the paper, it comes as a prior step, but doesn't get clipped + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = optax.chain(*components) + + return optimizer + + # Hong suggested using cosine decay for gamma + # gamma_decay_schedule = optax.cosine_decay_schedule(self.gamma, num_train_steps // 2, 0) # type: ignore + constant_gamma_schedule = optax.constant_schedule(self.gamma) # type: ignore + # gamma_schedule = optax.join_schedules([constant_gamma_schedule, gamma_decay_schedule], [num_train_steps // 2]) + + return optax.inject_hyperparams(_optimizer)( + learning_rate=self.lr_scheduler(num_train_steps), gamma=constant_gamma_schedule + ) + + +# @OptimizerConfig.register_subclass("sophia-g") +# @dataclass +# class SophiaGConfig(BaseSophiaConfig): +# gamma: float = GAMMA_SOPHIA_G +# +# def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs): +# return stochastic_diag_gauss_newton(fn, model, *batch, **batch_kwargs, hess_key=hess_key) +# + + +@OptimizerConfig.register_subclass("sophia-h") +@dataclass +class SophiaHConfig(BaseSophiaConfig): + gamma: float = GAMMA_SOPHIA_H + + def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs): + return stochastic_hessian_diagonal(fn, model, *batch, **batch_kwargs, hess_key=hess_key) + + +def sophia_h( + lr: float = 0.85e-3, + *, + b1: float = 0.965, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_H, + weight_decay: float = 0.0, + clip_threshold: Optional[float] = 1.0, + update_interval: int = 10, + key: PRNGKey, +) -> optax.GradientTransformation: + """Sophia-H: https://arxiv.org/pdf/2305.14342.pdf Algorithm 1&3""" + components = [] + + components.append(scale_by_sophia_h(b1, b2, eps, gamma, clip_threshold, update_interval, key=key)) + + if weight_decay > 0: + components.append(optax.add_decayed_weights(weight_decay)) + + components.append(optax.scale(-lr)) + + return optax.chain(*components) + + +def scale_by_sophia_h( + b1=0.965, + b2=0.99, + eps=1e-8, + gamma=GAMMA_SOPHIA_H, + clip_threshold: Optional[float] = 1.0, + update_interval=10, + *, + key: PRNGKey, +): + + return _sophia_gradient_transform( + sophia_hess_fn=stochastic_hessian_diagonal, + update_interval=update_interval, + b1=b1, + b2=b2, + eps=eps, + gamma=gamma, + clip_threshold=clip_threshold, + initial_key=key, + ) + + +def sophia_g( + lr: float = 1e-3, + *, + b1: float = 0.99, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_G, + weight_decay: float = 0.0, + clip_threshold: Optional[float] = 1.0, + update_interval: int = 10, + key: PRNGKey, +) -> optax.GradientTransformation: + """Sophia-G: https://arxiv.org/pdf/2305.14342.pdf Algorithm 2&3""" + components = [] + + components.append(scale_by_sophia_g(b1, b2, eps, gamma, clip_threshold, update_interval, key=key)) + + if weight_decay > 0: + components.append(optax.add_decayed_weights(weight_decay)) + + components.append(optax.scale(-lr)) + + return optax.chain(*components) + + +def scale_by_sophia_g( + b1: float = 0.99, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_G, + clip_threshold: Optional[float] = 1.0, + update_interval=10, + *, + key: PRNGKeyArray, +): + + return _sophia_gradient_transform( + sophia_hess_fn=stochastic_diag_gauss_newton, + update_interval=update_interval, + b1=b1, + b2=b2, + eps=eps, + gamma=gamma, + clip_threshold=clip_threshold, + initial_key=key, + ) + + +def _sophia_gradient_transform( + sophia_hess_fn, + update_interval: int, + b1: float, + b2: float, + eps: float, + gamma: float, + clip_threshold: Optional[float], + initial_key: PRNGKeyArray, + mu_dtype: Optional[Any] = None, +) -> optax.GradientTransformation: + mu_dtype = jax.canonicalize_dtype(mu_dtype) if mu_dtype is not None else None + + def init_fn(params): + mu = jax.tree_util.tree_map(lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) # First moment + h = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleBySophiaState( + count=jnp.zeros([], jnp.int32), hessian_count=jnp.zeros([], jnp.int32), mu=mu, h=h, hess_key=initial_key + ) + + def update_fn(updates, state, params=None, *, obj_fn, **kwargs): + mu = update_moment(updates, state.mu, b1, 1) + # nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + mu_hat = bias_correction(mu, b1, state.count + 1) + h_hat = state.h + # track how often hessian is used + mu_leaves = jax.tree_util.tree_leaves(mu_hat) + h_leaves = jax.tree_util.tree_leaves(h_hat) + + stats: dict[str, Any] = { + "optim/param_norm": jnp.sqrt(sum(jnp.sum(p**2) for p in jax.tree_util.tree_leaves(params))), + "optim/momentum_norm": jnp.sqrt(sum(jnp.sum(m**2) for m in mu_leaves)), + "optim/hessian_norm": jnp.sqrt(sum(jnp.sum(h**2) for h in h_leaves)), + } + + # with sophia-g the max(h, 0) is not needed but no harm + updates = jax.tree_util.tree_map( + # lambda m, v: m / jnp.maximum(jnp.maximum(jnp.abs(m), gamma * jnp.maximum(v, 0)), eps), mu_hat, h_hat + lambda m, h: m / jnp.maximum(gamma * h, eps), + mu_hat, + h_hat, + ) + + if clip_threshold is not None: + unclipped_count = sum(jnp.sum(jnp.abs(u) < clip_threshold) for u in jax.tree_util.tree_leaves(updates)) + updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates) + stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates) + + # this doesn't work well on CPU, so skip if cpu + # if jax.lib.xla_bridge.get_backend().platform != "cpu": + # levanter.tracker.jit_log_metrics(stats, step=state.count) + + if mu_dtype is not None: + mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) + + state = ScaleBySophiaState( + count=state.count + 1, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key + ) + state = update_hessian(state, params, obj_fn=obj_fn, **kwargs) + return updates, state + + def update_hessian(state, params, *, obj_fn, **kwargs): + def _do_update(): + key, next_key = jax.random.split(state.hess_key) + new_hess = sophia_hess_fn(obj_fn, params, hess_key=key, **kwargs) + + new_hess = tree_filter_like(state.h, new_hess) + + # EMAs of hessian + nu = update_moment(new_hess, state.h, b2, 1) + return ScaleBySophiaState( + count=state.count, hessian_count=state.hessian_count + 1, mu=state.mu, h=nu, hess_key=next_key + ) + + def _dont_update(): + return state + + return jax.lax.cond( + jnp.equal(state.count % update_interval, 0), + lambda _: _do_update(), + lambda _: _dont_update(), + state.count, + ) + + return optax.GradientTransformationExtraArgs(init_fn, update_fn) + + +# use this for Sophia-G +def stochastic_diag_gauss_newton(fn, model, *args, hess_key: PRNGKey, **kwargs): + """ + + Approximate the diagonal of the Hessian using an approximation to the Gauss Newton matrix. + This is Algorithm 2 of https://arxiv.org/pdf/2305.14342.pdf + + Args: + fn (SophiaGObjective): objective function + model: model whose Hessian to compute + hess_key: key for sampling + *args, **kwargs: passed to fn's logits + """ + raise NotImplementedError("This is not implemented yet") + # if not isinstance(fn, SophiaGObjective): + # raise ValueError("objective must be a SophiaGObjective") + + # Step 3 + logits, model_backward = eqx.filter_vjp(lambda model: fn.logits(model, *args, **kwargs), model) + + # Step 4 + y_hat = fn.sample(logits, key=hess_key) + + # Step 5 + grad_loss_logits = eqx.filter_grad(fn.loss)(logits, y_hat) + pseudo_g = model_backward(grad_loss_logits)[0] + + # Step 6 + bs = fn.num_data_points() + h = jax.tree_util.tree_map(lambda x: x**2 * bs, pseudo_g) + + return h + + +# Use this for Sophia-H +def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs): + """Compute the diagonal of the Hessian of a function using a normal distribution. + + https://arxiv.org/pdf/2305.14342.pdf Algorithm 1 + + Args: + fn: function to compute the Hessian of + model: model to compute the Hessian of + hess_key: key for the normal distribution + """ + # cf https://arxiv.org/pdf/2006.00719.pdf eqn 9 + # https://www-users.cse.umn.edu/~saad/PDF/umsi-2005-082.pdf + # https://arxiv.org/pdf/2208.03268.pdf + g = tree_gaussian_like(hess_key, model) + # TODO: consider allowing for n > 1 gaussians? + product = hvp(lambda m: fn(m, *args, **kwargs), model, g) + hessian = jax.tree_util.tree_map(lambda grad, gaussian: grad * gaussian, product, g) + + return hessian + + +# Cribbed from optax._src.transform +def update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order`-th moment.""" + return jax.tree_util.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +@functools.partial(jax.jit, inline=True) +def bias_correction(moment, decay, count): + """Performs bias correction. It becomes a no-op as count goes to infinity.""" + # The conversion to the data type of the moment ensures that bfloat16 remains + # bfloat16 in the optimizer state. This conversion has to be done after + # `bias_correction_` is calculated as calculating `decay**count` in low + # precision can result in it being rounded to 1 and subsequently a + # "division by zero" error. + bias_correction_ = 1 - decay**count + + # Perform division in the original precision. + return jax.tree_util.tree_map(lambda t: t / bias_correction_.astype(t.dtype), moment) diff --git a/src/levanter/optim/util.py b/src/levanter/optim/util.py new file mode 100644 index 000000000..7fd3a41df --- /dev/null +++ b/src/levanter/optim/util.py @@ -0,0 +1,23 @@ +import equinox as eqx +import jax + +from levanter.utils.jax_utils import is_inexact_arrayish + + +def hvp(f, x, v): + """Compute the Hessian-vector product of a function.""" + return eqx.filter_jvp(eqx.filter_grad(f), (x,), (v,))[1] + + +def tree_gaussian_like(key, tree): + """ + Samples a tree of gaussian noise with the same structure as `tree`, except for leaves which are not inexact arrays, + for which it returns None + """ + leaves, structure = jax.tree_util.tree_flatten(tree) + keys = jax.random.split(key, len(leaves)) + rand_n = lambda x, key: jax.random.normal(key, x.shape) if is_inexact_arrayish(x) else None + g = jax.tree_util.tree_map(rand_n, leaves, list(keys)) + g = jax.tree_util.tree_unflatten(structure, g) + + return g diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index e2083a750..d9db8dc91 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -34,9 +34,6 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched from levanter.logging import WandbConfig, capture_time - -# import for backward compatibility -from levanter.optim import OptimizerConfig # noqa: F401 from levanter.types import FilterSpec from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish @@ -169,6 +166,10 @@ def mp(self) -> jmp.Policy: """Returns the mixed precision policy""" return self.config.mp + @property + def num_train_steps(self) -> int: + return self.config.num_train_steps + @typing.overload def add_hook(self, fn: Callable[[StepInfo], Any], *, every: int = 1): ... @@ -408,7 +409,9 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): loss, grads = self._compute_gradients_microbatched(split_loss_fn, trainable_model, batch, **batch_kwargs) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) + partial_fn = lambda model: split_loss_fn(model, *batch, **batch_kwargs) + + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) model = eqx.apply_updates(model, updates) return loss, model, opt_state @@ -577,21 +580,24 @@ def microbatch_size(self): def initialize(self, all_config): """Initializes jax, wandb, logging, setting the run name/id in the process""" + self._initialize_jax_config() # Can't do full logging setup until we've initialized jax b/c we use jax for rank id pylogging.basicConfig(level=pylogging.INFO) self.distributed.initialize() - self._maybe_set_id() - self._initialize_logging() - self.ray.initialize() - self._initialize_jax_config() self._validate_and_set_defaults() - self.wandb.init(self.id, all_config) + + id = self._maybe_set_id() + levanter.logging.init_logger(f"{self.log_dir}/{id}.log") + self.wandb.init(id, all_config) + + self.ray.initialize() if self.require_accelerator is None: self.require_accelerator = not sys.platform.startswith("darwin") if self.require_accelerator: - assert jax.default_backend() != "cpu", "Accelerator required but not found" + if jax.default_backend() == "cpu": + raise RuntimeError("No accelerator found. Please run on a TPU or GPU.") if self.shutdown_at_exit is not False: if isinstance(self.shutdown_at_exit, bool): @@ -650,10 +656,6 @@ def _initialize_jax_config(self): for key, value in self.jax_config.items(): jax.config.update(key, value) - def _initialize_logging(self): - self.log_dir.mkdir(parents=True, exist_ok=True) - levanter.logging.init_logger(self.log_dir / f"{self.id}.log") - def _maybe_set_id(self): # always do this so we don't get weird hangs if the id isn't set right # for random ids, we want to ensure that all hosts have the same id @@ -677,6 +679,8 @@ def _maybe_set_id(self): logger.info(f"Setting run id to {self.id}") + return self.id + # we can't do this in post_init because we don't want to call jax.device_count before calling distributed.initialize def _validate_and_set_defaults(self): if jax.device_count() % self.model_axis_size != 0: diff --git a/src/levanter/types.py b/src/levanter/types.py index 954578d27..60d7b82a0 100644 --- a/src/levanter/types.py +++ b/src/levanter/types.py @@ -1,17 +1,21 @@ -from typing import Any, Callable, Protocol, Tuple, TypeVar, Union +from typing import Any, Callable, Optional, Protocol, Tuple, TypeVar, Union + +import haliax as hax +from haliax.types import Scalar M = TypeVar("M") # Model +M_con = TypeVar("M_con", contravariant=True) # Model X = TypeVar("X", contravariant=True) # Input class ValAndGradFn(Protocol[M, X]): - def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: + def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[Scalar, M]: ... -class ValFn(Protocol[M, X]): - def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: +class ValFn(Protocol[M_con, X]): + def __call__(self, model: M_con, *inputs: X, **input_kwargs) -> Scalar: ... @@ -21,3 +25,36 @@ def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: treated as-is, while callables are called on each element of the pytree. If the callable returns True, the element is kept, otherwise it is filtered out. """ + + +class ComputeLossFunction(Protocol[M_con, X]): + """ + Function signature for "compute_loss" functions in Levanter: these + couple the computation of the logits and the evaluation of the loss + """ + + def __call__( + self, + model: M_con, + *inputs: X, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + **kwargs, + ) -> Scalar | hax.NamedArray: + ... + + +class ModuleComputeLoss(ComputeLossFunction[M, X]): + """ + Loss that just delegates to the model's compute_loss method. + """ + + def __call__( + self, + model, + *inputs: X, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + **kwargs, + ) -> Scalar | hax.NamedArray: + return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs) diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index ff9fdb7af..408a8c8da 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -18,8 +18,7 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: else: # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. # we reserve a couple of cores just so Ray has somewhere to run the coordinator. - # Empirically I never see it get past 10 (usually more like 5-8), so we'll say 8 - return min(max(1, logical_cpu_core_count() - 2), 8) + return min(max(1, logical_cpu_core_count() - 2), 32) else: return 1 diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 258318497..a1253b500 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -1,5 +1,6 @@ import contextlib import json +import warnings from dataclasses import fields from typing import Any, Callable, Optional, TypeVar @@ -30,6 +31,11 @@ def use_cpu_device(): yield +def is_inside_jit(): + """Returns True if we're currently inside a jit""" + return isinstance(jnp.zeros(()), jax.core.Tracer) + + def flops_estimate(fn, *args, **kwargs): """Estimates the flop count of a function""" return jax.jit(fn).lower(*args).cost_analysis()["flops"] @@ -135,7 +141,12 @@ def leaf_key_paths( rec_value = rec(field, field_name) rec_values.append(rec_value) - return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values) + + _, tree_def = eqx.tree_flatten_one_level(pytree) + out = jax.tree_util.tree_unflatten(tree_def, rec_values) + return out + # this doesn't work reliably because tree_at doesn't like none values + # return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values, is_leaf=lambda x: x is None) else: leaves, treedef = jax.tree_util.tree_flatten(pytree, is_leaf=is_leaf) if len(leaves) == 1: @@ -150,7 +161,9 @@ def join_key(prefix, k): return f"{prefix}.{k}" if prefix else k -def key_iterator(key: PRNGKeyArray): +def key_iterator(key: PRNGKeyArray | int): + if isinstance(key, int): + key = jax.random.PRNGKey(key) while True: key, subkey = jax.random.split(key) yield subkey @@ -167,3 +180,28 @@ def is_inexact_arrayish(x): return jnp.issubdtype(x.dtype, jnp.inexact) else: return False + + +def tree_filter_like(template: X, tree: X) -> X: + """ + Filters a tree to only include the leaves that are not None in the template. + + This is useful for filtering out nontrainable parameters from a tree. + """ + + def match_like(templ_leaf, tree_leaf): + if templ_leaf is None: + return None + else: + if tree_leaf is None: + warnings.warn(f"Template has a non-None value where tree is None. Template value: {templ_leaf}") + return tree_leaf + + return jax.tree_util.tree_map(match_like, template, tree, is_leaf=lambda x: x is None) + + +def as_arrayish(x): + if hasattr(x, "shape") and hasattr(x, "dtype"): + return x + else: + return jnp.asarray(x) diff --git a/tests/data/hero_data.npy b/tests/data/hero_data.npy new file mode 100644 index 0000000000000000000000000000000000000000..f39678d793efddbd3b2df287f35ce5e38dccfe2a GIT binary patch literal 32128 zcmbSS_dnI|`?u0SL(!6zRYsCGlDa4=GNO`|Qg$e$fy&5gXc!G!B&&>UxsJUukG=Oe z*5Pn$zUT8Vd|$u4UO$}2x$o<`ulstAD?syx##McKy2o@b!lq`9_Z)=fd4(0s<%Gp} zh0Py0JaRCxyZyky)a-xXFPhjmn$f;HTADmGqrHoql$4Yd}NWMsXPfn(1X%#0Zf1^~jdEl1T`dVH1pSpMm(CE1~627ce#H(Zk(-YbaJ! zyq&dR0y{F_>Pm2xfRy^(E#}Ajz|Tu6y768%{0Io==6XB}0U;lH4E2XG=bIJDv~?73 z@zGz@c(aInUiNpmz72q&!OxYxgf*Oam3QH``#cP0N<5c#-aywK*&gM?gXs3#^`P_A zFivcvYe^E#gM>Fa2e`VrQE*vD?6K7V=AJzu$*M7cf%30jXR}nprN|^=d3y(5xE!LC z{%IME52(KAcJ08rJItTmU!(&QJ4>rU3mL4{1`Ja^H(-xFd#_np1|*8Q=l*i=Lp|?L z$G**P!uhbz{El3mC|e_&M&u^@aFO*zH-feDrobL<%J62pCL4 zP)oG2c}WX!oR+xmr?!>gG*JIF#eE+4>KLau@9M%={K3~6_mSb6eE;5`=Ud>e)3@zO zcP269^lw@E!V0)JAhI(kp%Y4T!{VeR8?o)A3VRZN7IYh==WD#8SFB*{sqolbMB;=2 zSG~z7E{5cm+|BBMpXv+}{+lz%_aQ$$Wup(2iv;_|zRZH@(U&9NOlHyImiw4k?jRm| zEo^qol#bxI!%U_5X&=6ht)m>AnZreTXC6PnVeC1`f3I9>5zg@>bybZ|3|4BzjFptc&VGnP_yUN3b z1_oWY-ka=V5-dl$jNQ$voMEgkqD4EVCUjA5N7V?@`pO;o%yTPM7G9Fk1x zG&_t&Q8wYogy7;Ba6OCW%B2iorP0&>PDf9}biva5Bl{Q?zn#7s7b`IfH(s*P-w$LW zytX*6dgkvqvIqi`Q*SlsXf=zK%;dplVNG^})d=Ld^H14ZPa#2!kT~vLgPa<6>ySHN~lG7%0foTwKi|X+?OK!rh&L^w3 zXO_Tj)Q2O}sRNA%FXn!FLqIm85M^Pm7Wk5~_~^1wC^j@SH|rZv&_ynwUO;IcIpjSY zOD3kV;iaE8-)IlqU<+4O(CCNsbN@xm_*3DyZ)T;*+7RxkIOsPMTZECP^dF_4>xV~4 z3G90!2eDA_`#-x!Wbn}1p?^?k3XF61%kO+LkMp6u%hLz*VWFf#Lyx@}_*W|nEwkFu z`?N5TDKZ0=4~G_hc|41DhZcpvdj*R>oHZb5rQrEY&5stxTA_I@r=e$a0zWfWRXRRR zfjgZGO<#|VqY(9G%BrKaaut$YzFTqIR8B0F^osHeo23$L5sJ5SxFal79%Bh z7^CP6Ofo8yx3TtPO|1QLbHE@n7>QikB3=*wO`U%&<~|DvdtGIhN*{uq!FKNg#|cnQ z724i>ijJ_PIn`E|MZ#~Q>Nibv$WU_LJa4>r0M^+SLz^Thu;YuXOgrx=2nzeA)|aJ$ z+VOP%^Zx!w$5A-yn3;jvJ8vEJ`;~>rGJ1aaPb5gscIkSa9mg}7{_iJNhvD?WXCp_& z$++)?s&Z0)DZXWsx|Yf@hpqk>KImrTqrcuypY3^M*!m>ASEVKc+{$j@(_ z?&bh=ZQAwk!U8a>WlSoU%Y}#{r~ZWeK71H|Q>AF33q|$jxa8I{h&Iyr%c9vI5$ehqd}VDWGh!UnV8K3T)-p`dqmNaH5I(vizYj zpbuhwFwFW7=Izg>%5EEji<-(hVJRyZf6zG1bdCsK5(ga@Bo|?>cp)gpeio{X_w0;J zoktbI-9NAId_)G7mxI{?xe)BG7oVF%MRp+Niur z!Izh+Ez;|zph#I-Z`QmE&K--)=8Y;u@leid=jmr4;1S)@FaBXnSnY4ia$kb!;O&3f zdTIMobE7G9+dR}?m-PE6G5~UKw@O^yItTI-skhV3)}ZxUwDULeEIc$8C{dSr3F~O8g(Ad~nvKA6#@;!S0V_f6k6!xX&p0frpXEv}o7*TWw34^?6Yg%Fk~v)EF; z2tL1jgo;!aaD^Ycx3 z`gqVJn{Ee!cGuB67J2wbkWakt_Y^7%`NuVV97Pd}-eT+MeoQ_4eRL14&VC%*{PslR z7yf<6G%M@4gwmB8l@df|f?($Hr%reKVfT%V%!6uV6l#j37oqh7)!O=5ZB>h?c3)Oe z#I_&)@<{D@wtWJ_T2FN7%=H5{y?s|9I|T=Xk`kR>BiNH76T?Hgv2&8j{gA5~cO7sE zoNFN9q9kD(F|ZTwwq}kmNKb%V!z=6BZT;9mA1(1fbqUk69CeM(Y$7u?CF(C{EA|jR zY**}V!Na7KMcwZth&|n<7qLbHKK4=KYg~r#!f(`~tqmY!N=m#&9mWR{2?jn_8qxP2 z(~f;5ZMgl3_Bky+3O@51P9EV}0pT>GBh9a(3@ZcPB$pc!Jsi z*Fp@kDy?{5M#hOJ6#c*&GCaIBox>Vfh*^ap4z({Qf#uQbu1u*FoDuoHHSkF}(m#+b z&;8p0GVf$(6B1@Is=>u8g|=P|3F*!6+$|X7e!oAutrdQ1FxUAXr&Cl)5Rlh6-j2ym z)F6E|Dukt#JuaQeLB34?z-#ww@!c_Nspb3vkYf0(n!mppx&APg>~5!E*UdjYz0WC7 z;d$j-cu*#;=7l~Ge*FiSW*v$>1Lzf_O(Mza7Z*Xe`%At|OAZ8v$7cQ9J&1dAO^qi- zSJ7gBTJH;EDu}R2hi7hS0}>TdodH2SbVdZ2$8s;(WO-)`HD-yB5MJ#*qw%Ooq_ zK!8H4qZ&Har#o9RD?M-xef*P~{~~l#tM?b;V?C{#5E_Vah-VPIt6FX153reYyz@ zD`ea|`22gCLO-51;PvZaUB;f3+7`ZpB=nkpbp8G^31jv!UK{z?0-puG=H=1yU}VCD zvcXXTf48Ns>@%7Ly5Q}e#YM{)=l7oEur&{lJzaimYB!31vkXlnYsXPCa&~ltR);np z2AfVinZ!JiQ@*k9DX{)yyh3M~NwHz2;p!Hf4j61(|NhRV5x?KRTI@Ygh!Y3ON%w7M z;e%7Xiqy?!toeE8mP79tid9TsdH*~M78ayy-~_GC8;>15d76q*$^!ZPt$EP>om1)C zg+9zR8`xi3Mgo^F+;Td$Nnr6u^3+QF5Sp)K1PB(cA-`$dPld=ncz(>u^sWXOUX&&s zDb(x)#@Wi?uFYk5_iy>2{N_56izBGR!UpJ+nLXgi-va_ily_w&%)n-nhV;tA3LIzF z%fH_?hvF|j?%YY90Ed$#E&@^GB0T*J7b`rrU zTl5R-wh?ToJhf{1bOWs2bB1u5h!sokJ~ge);;u{d1)XMA7X*;IsO|0O)?#dD8 z^P|vu*}+Dzmx_&!-8M4^n?b-=?v8-`B)%zpnyQ?G zNP0}%@0ZYr8oX~aT0#a;{ESbOz?Nw+V@!UF^|eM$wb%2L0Il*Irr~E2e#hiPv7!y4la2e%;~)O3&mDi*mOM>P>n<6!R@Ur zINNca^3s0_HrKB3C@rplM+C!DInNR7x|FXpr`?0$uChT|@A~2PvHf2J83!TLbBLd3 z+c56!4(~HOoPv){LJSteI$^PZUaR8E6a?||$u|hkLTJt*-4=slJg)L-Uh6V54KM#b zbZldecK*SH)odD-nx0-|WotwA_CKYUOLK9*buxWK)Ho>kBr*x6&7&lhZKO4Q2}#wI z8g2pw=N`))K6G#q>_6JixP|9HVa@}yA+9CZ`c~nBB*O~C2A$m2_NW%BggLEd^p~-F zy{oM6TI9Cj&r+XzwgUNY_BXVyH{i{6|D}A6 z2G}`%I_19u9a#E!hwG8LWwhGy#w6&(EUI;+ON<>X#0$ryT$d$hkT}|QVkmVCcTcUi z1slxZkK!POEY^9rd|cPaaOF8>NakJR{`Qn+^;g~S7Jv2k2#?e(f8U>r>gRm!r30ksuF;hd^*KHA>ICo>RqstB%fmO zuoE6+)erMnRb%v6`{!xC8O#o+tFnyj!>8x?ocSf&G2Vyz!Ru@zeu-Y>^X;C(%`a(h zJnwCSP5Z~SuA(f69qdX=%3Q)i8(x^I9)^6HHT&f5_5(HrSO(b%wzxd zCcLTUle^z?6HmJS{&%N?3|*%k+>ZUFQ)GVeP<3ix5}bX1j~_Cl;MB8(-y+FVFtqoH ztFvJbzApd3Ua?Gq_wzfS1c$caBV9+HW{v@fVzq5jDw@Uhi&Ij z&k?+Hw?w(}*%H38jO9B_UIpgSZO=a>QGjh5wdkht5I!=hi4V7+@mPWdMF-b=pij9m zb3vjP+WNRU+-AGLd8c^T?FC{ zr!QG|_ahEUl=dmDV(48_>HIHb5SQw{#2r2dX)StQY|kfggZXRMS*JQU=~l>cyX`Mp z3OiV}-W)}fZ}f$yFHJ-CfgY{946VrcIp#zn7X|)_ZDaYQxdPwltTijWyKqPM&wQsB zy>MiUWoq^}THaCqY>SDf!bLhybDpVi43F^oqrGJob48X!1nk?eo%!^G(8ELs$+WVV z-`Yyc%Rm#SU$eA&PqvKoU&FO1mSCeT!&uH5-$NXrfLz4I|JITgVW%Q@t)*cNeoY+S zwbOe9a*vRX?50~o#m*v~GK&^?`Hd~U$$S|9GnpA_weP~Cy(aCyGn?@ zncU2~CQzv4i~M@p0=7?9eLH^VC%&Q7Up`B226qjwTU91KD4tp6AaQ&WETmG|=*H_& z=oA|w=b4LM4nl7oqkG2o zR{WU9ddRbP87+3GMoyWx;Ad%XmTA#mv_1E}a-4?0FJ;VM5z85YcOFA)d7XhM%^F+8c{r4;IGZOes`W3b(=Ao9^(jIHl9A3X`O8D)(gnl`# zt|J|zP{Z$FD$p8)%xwy?FP`*5@o&vn=L*}wM*db}wL1~awg0VomC*LniBWQw+8Q{! zopbtSTMNYtDaUH)E76=d-Kq637hiv>Rxvz23mhRFXHGJX;I*tbd!`$H0{Nf7z!TXS z_^=>(;MRF6th`u=S9?V}r#xTci1izo`>y3N`N9Ga@~YX%^DW4FzVpFLrXT3&-WMph z^b`d`NcP8=I-obf>kdc7I3!jy+pS7;L7RL;Xu9+m@JAn4NKDxTf^CXmjYKhYK6r1a zTt!EC_#|+~`|=jWabN$>vOH9{`g>Jh>dGP{23}CV7KZ8 zljiY%koREZhg;+#TCs^m{zt>RRk6b?Jm-gj_vcX#Nsf6utgUL0cbtK6$TI!j1alLJ z*(RR6PoBmwdDrizom03glMQs5YH@qXOO-zD0U*5{r}tdxfXmF@$wOLyAoFj3-p9)Y zP-uUh)J%&DUyS#Hyfg`3XBKR`UYdYyuMP?YCiSC=F7xY;MHS$?`Qw{BXC~V4g!4qL zw1WrpS?&sHGCUrUJ!f`y01OFRgKl)xqnpO0cV|im1n8}Eb?S^EvAT!;HK7?M&z$Je zzc++S@p-YGH|D|lx=OzQHyNJAt^bLfZ^h@KHOGIS>;r>0tJ5b>P2eT|@$+n&6ksB) zSf{GwfsTvo7QWY1*qe1^bWCgtjXOvv=*?^y06fu7mz$coe5;JaVP_-#oKbUMTt<)55_U&jjFZDdpN zw>HZt_4xq4E;v^4qUiCJ$_CWbA{p;o1O7UyqX%XY< zaeQieugGp675-~zFm7IJ#a<1kN9t)5;0@OcOB5-F@5~onG|tg0dZ-m{z4Ub!zXvxf zexN5o5}j1?>eopK-c0-S`~w-^kSJFIOy~&V0YOZ);>&obQU>5G{#+qM7&iqSK?&1GeOKj1^{ncAH@mZqQj%`-NrzJN#yoe~dvk zTlTGLE&?!%q^B|PufT!HJ*UqA5!trstEgP+2m9)Aud0kiY|QPbmaA@rEf+RS)XWm` znXe{m%04RYQonMXnSBYb<$Bi{_cmj!%f4pc-0$Gt!`#UgIfPbxC)Cw;tl{$28#88U zBk*+J@14LtiyRwfQN?vcWL+?P^iyF9BPTwP9))k3@19|UO4nNcX4gWb8niz4nLO{!w zh+USWuvo`>Th(Y1TG$4(EX`)nrx%^#d#!uC*DjJ_@wJdZTs56y>NC8nOC1GE z`B(QSe+Qs}JBfE>Y7E{U8Y2H&qwxg7GxLXy24TKJ_D;gDRcwPDr=8i#fiiZ!Oy8sf zS7vu!=6E;{!+nDJ+O&PR;z|)xJogVl%XdUgo)(|?Ov-LqVesMKrG&KhdFV$T}l52p!u z5_|TINpdd)K2hmj9N2_VJMX*w;w{imPKw#ZH3v=ACtrxjPJ`U>AMm-b32Qt4uwJ=8 z0*x2;USZrZ1PUxfeRb^>;4TaQJbHKtKJRC8yYZtHdJfR~2)-fEvXwC~`P75=$7bHW zU#In|IIA!G1JX(!)zx*+W7rsM$f)$) zY0uxrR(!zdZyeN92_Z?$aWNC~=o#p<+<`?Hbgl3^-8Wi)!d@w6e}4cL*k9&uX`6=o zVGhV|nF86#h1Y|6sHkZr5%cv;J^qazsC4+vsQ3|H@6R|)L^{8oH{NU%*qlFb%(*E6 z^iEC3y4H5W$e$N#g)<2tNAXft(;)z|?itj7yoxUkDBxp0hR%;|CY2(INRTe*Vw{UY z?U;GXEcz-a7x&}O)L8&gro8>T+D5Q`-0*}2_XOIvk~rd1da?PVzq>JO3|?%^`E<-< z8i&;VTX!%HVRn1v@z1>dI5#JA>b?sZd;X4Pej^UUXCsXRcwzx+^2?w8S2hbS_SD0% zBPIA@uZrK<2s*;$#AijZ+lSziG+h)2OD~)`E>F1GJpw4mt`CId3s4dXQ2r(m?6D%u1>3ing=zRsgsUbBfcjn}VO zy0P>`>6|HX`mC03F zfXSTcH;Q_naobO;11`}$=vY!Jb@I#r?wS4SV{&T_{=U@K+WDy;Lb}%}BZu0+eIsJ& zL|{Ia$uG-XenLX-tEu-k4txX-vbeF^pHAedPpOL09HsGG<#Z96jga*7Usn020la&E z(|DzB1zj&Zbm?B3gus)3FX&55g8z>!tEy}w+)y_-I}%B!_*VDl$Z5ZSIQq9UVtIWQ za=%L%=>GTtz$g4NO|TR>uJ7i~_(6u+wYLF-5fj+p#rjCWG8=xIn|ZF>8v>WZN8F=I zmO%054~noj2}(13wVWf`r~NSC{o!$T*fic2H&CeKq4^xHy!Cm!M$@xrUkSRsrBl4| zOfovDwi_5dyy7lcRKwO-qj~wCL`b+i_Q}5_8lS3cpv251*iu*&Iv4su&JUgAgnvNG zM0*ksPaj@L6O(qG9fB9dZ3w^Kj0yOBe5uf#i>H5gyHT+qWph9T*5&E$PF zKg8xyoKx-s&E!t@jDrLOdA_oD^xvU%(fi_&6T{FM{e=CxDH)y_>m9sGVA#*2&ukip3qKbuu%sVW^O9V$zC_OFehIcrt^ZVt9E) zx%^N*W?3}Fw-^(&AFCeA?}57GN@-t5XusoI?;zLl|s=zl?_mDvdd%h>j%v23!i+J^AAclbVt3O@BBXV#5u0fBLT=8Kv2&yw z)|Cm$cEm;u)S1rtpxF!sM@GhF2PYwWv1@yrRWnjr0{WlnjlfPTX@!A>Hr!jfMd*2H z8g3A@pZ}PgMsmyTh}XD^UNz(znQrR;B)SG z;Hd?eP^~PK_)CJ@ExL!fFZZDo|GRPlbsB#&RU>^X4RKoBRK;;>6(5h)RrMv2AnoEW zf~{;n+HU8VsA^(Ve7Ea|{PFt~6h3jVzv*xzre0u+zRt9ar#CbXxk?UVe@}W_kt-G7 zucv-dIUS8hm~G#t@7$^=vA4EVO?V!TqFix=P%l37+a>zie-8HLN$ZdM)B}UPis#=G z9W?zo-9y1T2PZ$Dc@^+&3Ckz09u<`t0bi;j|5b$rkT406Ihwo zOs{DAhfeN{%7Q0n@ODYYrE|SKcr@y0wx3uZWUl9qB_x%B+y#E4ol0rgaFp0OaIhYw zCB%^3AL5!1`EnBJY1Vsoztkc|RFFXBTVbZi!(CUBBEaWP*sg!pO*#QoX2_RVghQlGg} z(EScL)KO$UWZwxq3a5pF-J3C5@aCB$-yy74XZkp^KUxYTe)VgtYXWRhlBF}dK||$p_l%Ao8f=cFnlNEBG`3czN(1b16q7W#<n`MHHaJQG36@FJE}Hetg3=|<(y4O|-eY~7dQk14jNmclb>IOO3= zo3q^g7?x!GWYdgq5I|8g6_!P8FsmJdjH^FvX3pm#Ya7I;UT`#y?gJQ@x?)IqrF zd0VI=j)o6!EslJC)(`odJ1!?LmO!~i%aG@zV#wP*U+}kM2?e>jnST{c0Nhs4l48rk zgz;_dAMUn8!()-QwzyKjEc43L`9@5zS5yxhNJrza^fR}%4ne6}mI34RFf1!w)cnZV z2n)@xUuc*XgY1NA6)Sr?T<>~!rlmX{Iu_r1*O<)VjP50iOBo~ZUAMe!@AM>yTGR!V z?5o7Bp9XpaPSGhcJY73MFdM|vf7@)RbzxvrEdJE8w;A$sidpSW57O=h()+Z7y_oD7 zb7iS+3QS_dR8fl;~J&Wp5lr?%8>+3wy6Oa~GS>JF1&8xO~u$jmm>`Pi+=M>S=pd+}GO~$ex4#2)U&yrudRbigI%4IP^ z53U944hDW(fHT)r{+`s2hn;e!0pfEc5Zm+#8oJwo`yc#H;Ys<8i(D(&$7s5w9M|mO z_NOzDZgE!NXK^dYUF}K!`KkqHdgt~l*v{h$&w%4Iw>;R&dPt0qhCfFL*(;L`tH7pM z55i;VaAwdZ$NO~w8l01`Z>pf+`;^U3Q>ha$(Y4)$!*2{T@4qjeq3eRZFVm$G_-0Y{ z_{$!*n@i}+GDF8m--pBHd$akr*Ff5*Q#~eJ{m|>w_PN1+0XPmtM|FQ5fbbSyX`!_t zxbm#}yFySUq~;7-3=J$}QU1P6r$hZP&#u>(Av6IM?EjrSvDOd#A*u1A%8U4D=)yhe z#wPsc+cz&l!4#M_ z1sC3(ySApjh_{dII<9h}2};tAq8q zY-XBY_FrxGz&rzCk!vDzkl2gmHlOK=+6KV);%?_g8jrm{*Um4Vw+k8)qrcZgQXte* zL31p&5vhWW$#EZtAY5^0%VI_^cy%BCmrUA3ZSpF^*-w4w&(P#y$eD;2N8julacM%8 z7KP<8Svp1A*CwAbzeVHWC;JlY)JE_n{jFE$8i$~2;pl3u%&{s9eK!TgskT>b34J+QaVPnyb9Bva?Mng zETC;hbCHwzD5N~?XZzYU0-tn_G};&}!q%jW`RdVf^!X%}qjtO*Uj`Q%oK9vSgeQTW z)hGpezFd#z&Rl^bq0n%Za|Lp$ukS3`9D!^jKOd_Q8h&@tPF_zRfK=!GS$TsLJRj-r zaU&`d{T2k`RD$G^hAJ-mU$xTCdG}=DZuMUyzgQ9M@ z{P45x?IBO_J9XrJ*M~`zI`h2rirXlD6DYP*EX)VPqbt&4w0t_RDaHbZQ_!BMQS&#Q zUeRm+cOo@!091}Jv@5w$VLfbpf35cd=50z8zH1mq-lAG-dm;snDz~l~_>oX3;ip~8 z){kg-W|iBGSPH}`)mZbSb?iFyi=Z2q1NFL}UW75zd`m~?8Y5(e(1%o}?t7^LMP8bp z*m-##o^gPH>I-^BX?d|6{rMSet+F4JEg!>*hWVXds}%UUr1?N0VisKz?LAveCorPu z;};*%HoW!bS$OdFWh8BljXzG)dzaU|wI(-u!A<1z{a+lju-9&|wAXDB#SfpzyEs6F zSK*?L(!G7aXV*7%-hK{;Uu!MAlU+uIhUY;m!`j zgJ`HG-B>L&j#7`R-k;K*!ipxXgt+D<*t~x8IZrc_qP92l*PENeaF@r5?eo3~-1c9^ za#I-+vrkIkn~0)97y8q`$ur;5LY?Lyzf47n<;{^}f3~Pe`D+Feg+@w7Uzl4kH9X#R{qe&o3PgTRemS+b z5SleVT)Do5=I0t=li<&31bd^uhut08;D~_t){)o4pqW>^KuxMey7J$V(-M6!v1WbO z$7UKicJ_K-{W^vLlRd#oN<`Qm+h0jGXom2`uQ68we!|;9GE&jrp6;mM~04ll#T(52FGklrZ5)N`X1H?MxZv2l15!U9?@uy{-% zcShF|g)oo2RR-lJY#M<6R&;+_KZD{H)IR@Qb{))vBe#tjPU52vmB%6_#z1F>$l78t z2`y5-c>VC8()@yni=paliq;!hJ5#u3;nyy+IriHt=-JLzd@-^Ry{~;5wr3y55Ejh? zmHYZ3XXlm2lQ({#*gmTWURobj;cI8yk<^8;XO|qWZMNaq`1h>MjU_zO7Z_6Au!!nX zpRaB|--qhsI^&jn>7Aj&vG9BITcnGDUp$vU1a^gju||5DspCzLNnB_n(coR zIfbU06`MxiJo~*aU7x)B(A_c@mswhF*3R8 z<45x;l%1d)RPF;2UAe8&Gi_jC#dbC;ZW!OzA9qeGnZ{{>8IHSmT9NDyGFla*7`I&i zXUgwk2Gil&xJRngDFn_FY+KWs;_Ya_|w>gpn1`z^`Cbz}su*BF|ws;|IY z*^X~>!wiJ%s4EGliwE$@#E0$u$~4~a)3u$%=i^w)7Ud_C+k<7iQEZ*(M^N_9lP48} zt*=Lj|1OW~8A<<2CH2t{h0Kpv`TF@_*%`#^Ha# zpP%jP?gs%yE@FZC5Ip02tI76}f|ORiY%AVTOuMaG=#&%*dKFc@Vt*SkrY&#kKvg>& z8Ztzy_3v15L+(*%-~dpod?z;@7zs~bH1iGEP{DO=tz23C7j`>Im%XkJq{U0{PjdDQ z>}P5cDLUVRtuL?4Wt<|RZ4S?|qNP>XT#)U%bF&$)sL?&Wpg4|!?_=c-MoeQyXqwML z@&Fht-%p36IjF5L1jDK|Shnrs{ZKFrPi^8_uVr?Rw3`k$zizj zpX8{eGHsq}lfF+ABado_hj>w6tWYFlZ39t}ptpJk;jFn@0DZO@?; zG&I|!Cs7(f?n$;<1OFTfY~A(n^28kcXr=hE^wJUXudJFr^;*JF8?sgVm0?^yka0aJ za2{^V>Z^^ui@+(Fw|#bhMqy^{qPf;)KlUG554=%D(@T~lf+7`5;TKOJy*A4Vgx*Bz zrFTWRE7CTWU+D*4rSd3WnjXX-Gy29i)sw*AyN$$d+78Fek6w5b--=Ltzbi3(9i05X zD9rAkhAMLF?cdX6P=dFw0zC$B{g3R;GJX<<&z6?z_VmNV$;HMxGcp+Dh2BxQG>ArL zKItDjQVU_azjRy3AMxy9h?ITc2$IAr`SdbIkuJL-q3-1hFmLL(=Hv~amV9*e_wowN zYr8zLlVu35(^+WJJ?w+(V4=KkagA`o{ovt#&QZLmL?q81p@Ot{s4vNA1nySH?T9!- zr?>&9;;Mpb!DK-qvN&b{@657rN*`Eu#IjFzuCw*y1(>R6C)19FZZ4ca$7 z1hEa)YRtC=(QbT-GC)}aKeI0pLZks?c}HJx?*;{4UrJd@jqAinp${fM9?qcu8;2^d zuLVeyKN{j|Oax`=iw6s0n=yAH#aZ`#F=VCr%#1XZ!n>;lUve`>@Rp$HIatJzJ1D_jR9N*Am3dGm^7&=fI)To?VMJFS5CjpZ)gL5{`gkeoXN?q~^+Q zHDIm;x=BrLv}I6~e&8#|o7n(5r(92KYji>1=lb5>l1V%$9-~&XYn3*|*>?Ed*D6}+ixQY-up-JD#19DcW4EkPbPm2x1=Dc>c6aW&3(Y2G&w)zLWa${ z`>NjFUC>qUZl51tk7sn=j?0!+fyZ0Ne+gJhtH&Pdk-R0)CYZ4qZTl)IAU=w zn~c@pj_{=@w?b!tJ>BSfJy>1ipHlwN4mt-G*j(HDVc)L!&nfq&a7I(+d-b+ba3kFL zSsu`Zyk|^Fcj!jIj%ed9!`KKNy}s2B0k!z&^Ygg%(J7dfovCc=ErV1+PT~5KRP=K$ zsr@Ur1T1?VFKzR0MWy62!|!Lt;6n4u)1K^2XvRAgIiXF)Umtzq_44N6_a$pK)AB~x zS8vv&Ur|Au59;rFr_>Al`Jla)@2uAO;qyH(<9lz{c`GuQHQxO6 zKClsjFQzm}M0UZo1E(z5!h2}`b2fS9&&yDvL*Mznj@FNjEbk4b&3!oNc0AF)k_kMI zW}7UW$*7?z%gkcEh?cK~b@4<49uTgv{q%{EAh;woeVE2$iC&RSj?P`hdw&_CLWe3b zDYspHpK%-NPltlvQWt#jimHFsPJ;d$etcj44M2XTrSU%=651yXjIZDtR45q{^CJdf zd)LhpzJzYLGF85+UB3j|cNqJsJI(^1oD>W{G#L^Duc8^YBS^ecNhoVWiBHHLc z5=#T8mD9W8C;IVG^O)q^o)Y+~?m#qtMS^#Gw=XO_tU}KJeAF)l4S;gM!OIy7S_v$g4wd zdA}Z_W%9m@&D3LVns{3y{{%F!e~Gg_*$I0eetO2SI)d(t!Y#UY8t}Z>$7e=+H=v*H z_fUIO82B&ql2sL2@X@8Xg9DD!aE|U)^QG<~5OjI}FScM5MB{czCsLYm?@g%?f?Fr1 z3&;ghvs!UP#6L-&qY13FrT!aY9RQb4s<~VEitxORfM?6Z03Io8W6s!;0{K^`D2!YT ziiS4jadu0U7&!&p?zFjuIm0Mhn^htnN!!x@<=!InWb5>`tgdW9oq)oea=)oT(PiMxkQsnX%@5IoDC(bj-tvr zDQ#w*0K6u4VSaj<1kY`%Z9}Umq%uPlDffgX;>$hOd+V3R@reuHG@oKW6kllmx}Tf) zd#tAvyZojxCCG%n_gpS~nm8epG+Y9f%jOwWe&aZ9WR#g6x{CY(Z8FVeJ-BVxf91ng zKarc-OjP0?!f*2GjnC8{qKG|H{eFuD>>BYhZaqbu&lVl|HF>sS-vzDfH9ccs{+%^^ zLv|9ia!zC)_H2f|S{*OX)Q-Y-b1VPct9dk?_;td_p$`kfH`woOSA%Jv#TfPGA{wWQ z$ommF2=R9LcjjCvVBYkCA}1sXmLFXqw?6VkN%tGKI8V)CNxc+Up6Y?jKgvM`j_tVH zRyP}NcS7;9;?Jhz1Hk$6Jm2o~WN>lrXu8*tgR+XxuJfiupkCzXM%8+v|IIw@)6De> zjVTlk=@ETUy*_z=Ey8!tZQWsA(?0_%P4_oHySL%#4(5BdMnr#1f0*0oR6CkPeJ-fM zIr#C^x8Keo9EE;Z9D8cD23M=Uz7@YfE%hcj?P%I?Cq7}KkGx$wgUs~5L!!>6;|N1M z|9^e;m?xC<_^M?k@XUPSKFK%+&Wr(K-w1zP&hLVu*{@z~yU*NrhS1q0Eu?=LH&>x@ zTK+ZlZWIDmUZU%lT+Dv`C*D({4YoHOKF_$$V!3=@{t2NLEZB8PKQD;r zD@&IqM7NHCW}$+bZEY`}d~9Z-;xGikUv97b_qGp@-S?n)mRSx~Y(edgQRB$`V_;j4 zy&7b0Y`s-g?8C2fYdjBIm%+@>Z0~Q{MLZ>bQUB;0B41;=@5%0fX1HwHDBR7Tfi8Dq zx`XVO(SM3U#U^nM0zVY1x2cSy?x9{g`}TYom>&i05i%-XJtXuyfQ%(`6b8m|Q^xwz1D!Hd8w}@r^?W!PKS|3@!tcUaaek-hDf2$QM zZFD!Dk`q{!{RcK(=Vl~@F7jkx1Hz}w%(@Ml(8(wQa*Kyf&)~fiwRxRHH2R%cJ1X?Yw>m1 zFlr*ds;Ys1lA`W53G?7mDt)vCVk5p0E=3F z@1i5?cYjS@sRNhSuS04i#VO{X($kmpT&==8MOY=`vPMPeY}m;`+>y zI;be@2>zP3fFFODI4ZErVVPgX#}^?ZpjCW~?`}{V-j!bT6ZpId!@ojJ&&JGwgSy?^ zp0RrLV6O_fajyrMd#+aI5q$vLc6EO9gPo`qT~2wLXA-YA?>^e!o{PUfP6gW)_Myo? zl^>obTHzOSeC~g|b5OrVw#nSHiBBT6&i!|dQfi;=qlbH@8-P|`q2UIRKYF)uA;f#2 z2I&WS&$>7cK&ym@Em?mE1`cz2-BYTA$#_<|2VyHlVA{!wWaXA`<+f^4otisuacM?Gtz`fU{NP~5mr_hbN-Lp!yKq>28W z8-s29lO_zT)X@3f^bg*${S{EDo<#pJYmCAtx`H_9QRD#e?K-wqgK1- z-v{SbHmthO_rZu_=)cwLi%=U-v2%`CPaxu!c;a3gXq=oB8?IOa#rJWmwBBuC_d9xr zHKD^B_3xjbBzl2Dd-Sfj#`j>?yAqMV%NwvB8YDEuR}DV`4Y}r&W+Cj-(Y_`9CQ$m| zTA&d-fF=QNKR)|JMJj3jbM-*(5I8nVy8V*Lfaks@KUbv}z~av`mojM`rryNHCDZ%? zI*Z^k!@??<9=tmytd|0r&Afl@mpgDin(qtezc%2}ne8)WUIR^o&Rh?P4aDakuyAD+ zHt2%6duV^cP$fswhnF+pV3$59R5K0dZ75aL;)l@CWPyX~aS7aHzvtjpwG8WPkxOL# zdHlW^V?KX#7TOdKhe|Uyz+|S2rq`Qg^tg17a;J7b@UhlJ?cQ4nYR7a__T_EiSA{=K z;vt=Ivix=K-h~ll9lrYXw*MU1|9BpG*K`S#8S(9kNDUm|5Mp22(}tMIE$^pC2JN%{ zw>tLqquNQ8*s#ZcadWqyFWGGtj#BSc4!>6k@|(H!3-_ilfZ|W>PBDT9!NuZ9{!F|- z9n;~TCg<_#E6M%!!`&D$#6CeaQh{ag zxFtCY(R?@I#O9x)!5iHW`Bf~oXg~Y~%iT-b&Wy`*FSANkF9xXv9OM&Z_}sh=H-c-1xL#h~60YCiQGJC@Q0-`-_ulM6;c`n~+H z+o>tsy49piDLx5uvF1T{X!_7|s3w#B_Xw;AtT3c~?!u$~toH>OIA|wX# zq%9{5p~a~>NyVUHaR1UD(Mz|0YJwY%end{dgE6FG>+&#MKe}dc@JS8o+eQSgL}#Fb z(uV8^S15|ww|_0XO-Hh>x62NtAatC{dnZp;5`CVU9Ni!L%FsvEqWfj@68u=~Ja9yG z1p@hI1QyRZ#_7@Mxi_)kh)=o)t$I z=ZUYt`P4~Y@y%ge+tcn}(A0%FW|wn%wb zzRB0?3W@pi*__ng&`-tq@<#2bY}^cZ>)hQP5kceugt?DZcF!Q4{=Ou`mu;Zl=J0`S zsswuv{)0)8CH%CnOS+E8-6Y#<(-h|<16lm>Rv7&TSY^;Dgm#WX@YydXJt-GZdrUek ztiA>2zp9yhJJf}JMxDJBqp8@)uVx=D`3BrKpo+F9dF99C$j8;TWc#&U;a@4>?Ij(sD?uhKhMYN zBP-vz-Z#NzHM!r(Z#Td>Nmt5o-z>o?&1jO}A)GGoG5>qn@Kq zqu$5fxu!GQx4NrXdcVB)8_gtR4DNOADPGX~v(=_Md5N3=D z5dI6vVYOTKiQ;F|9&G zb1t111V7TKMZAZQvA?YhCF{pL$q>qi!2wco1xhLEm zhbv9D!xGE-p_k)a!wg?9@-M9#9`}fV(Z-P6y;ylcm}lZ#J1z_UlJwx9BsG_(48)jB!>2cT$M$%x;<@j3 z&q99-(%l_$YrtrXdfl%M5+mmIiUvTcoxRqD(ytt4SCl@vpP7@ zKK<9={4Dm0-tT1;NyGV%Ldqmv!iW24lKahe9n^yICy%RYY|C z{NCa$`0|$w-8NAdJ`x<86wBCACRTdVVs5kOJm)+yPyX%P#X$5iBT7wNJ}+Q0Rr`AM zj#`x9oxY<|)&*z#-Wtyd?T|W7NlAWwZVJY_-x~B9k|Bp*b<#NB3wjuq?7x^SA?2sL z{{l%HxJ9e=vml@VY>sO>>)Xu0(>+$Tchl$L*PZh>7xXEmUU&a|XdAkN^xpfJ&h&A@ z*Hw$F3`y94c%jmZ_sygZ4j^R;USKQ19_iqO9_ z-dY?NW@tIyyNL2eM-{xhlhJdw;HE|xrBq<}_PfhnGf?cd+ivD6(W|`hgQBHx4t&?P z1y>|z&^7nzZn|i~&r3T0{@l}ONIlQ|FV<-QO(jLQ*9pD%(S&KY??f6XUdW7Iq4&Uz z)W!GacYE-9`!%cfj8%~6eYDAVqZb|xKZ_nRCG>mwYZ{#oOCjLDtKWJPmr$*Qi=tv| z5S>`+e1$&`LR1*LdGN{rD23=+FiCFWE7{jQoTPDBme~0nZBnSqAVWDd4U zLLc>aGhH05hcyAq-?lLXw=D1OS7+@7JX5S%Yg0Fl+1WkrmrfO9m8qL%-_aFxKPec| zKwgF^7imKmf-@4SGq;PT@OaDl4^0sn`R*r_eb4R(tJ27Z_ZPargHCufmEfeb zJj`Ir=v~EA>W6|!T11{RY%?%|(8&Xf<^sHYlYm;~49WQY5-c8!*%9}p8saim?p3bm z;m!a-E$AG?jlJ;;-oM+RZIjM=tiBJp*U$4k<)b9o?F^w-^J@i~6VkUl>_*>_=L~#td8lyt@|_!Ei+Eh->t>nJBIrEP%lxKY1`osk zy$O5W3h}W9I#Mf3cqC(xc}K82@|@XE|F*mc3yqs}S6SELRI*IMvAAj6W|qDFXcE7L`t89oRJ>79FsY0YCqDUevHud zY-;pnQ=5V9F`4xR@%sqKy^K>H>4ye|CRc$Sy>Nq;vb?mzBx)x~gjsmC zn0^kg-H7uywJU>XEj87OM4t77h>}l0*Z}VI)lFPHN8|&`KPfv$6S@qpDQui3<5-K( zS@Dui)P657dystu>u%)tdKnTvh!3rx=wcssI?(H1^T~%X-*uHQ0+l#gT+qtm*#yrO zHuq=MPvXZfN4j`sLhj#bw2Q z(KyXPusaqzRr|6R)>!$XnxD^s)9;)Fqx^u$Fl;;f8M6}y{ZM7rQ_D7)mbpct@K^P zvlA|=r~ir}bmO+`KU4TWPaxOFU96soL>}kU`jfiyQ8WzgT;?ojg&yzkUdE1Dc!k1d zpK3z~)X8t_Z*rOAg~Y{)22wtZJU+}MJRJve)JjhmJ-Sd<_u_$@2YNA;y2AfvZVx`A zi3mu0yNE|T{KS9g65q*0R zVapRjL2;>dn3fiJb%$vlQ0}rx-*tWl<5h&@B%F8HuQwM@DOKY}DhnlLcxPCL_Qj_Wbh*T^ zi;r;-n_CuN)Eg3ha>b|Ze#3rCaeb<&XXqqdILDr_QKNN5&I)&fqm z+>Ea`>BfodGlc?ETeyiuCFhwaq)z-ED_gwq500_x{-PL-N2i{>kIh-jVL+?F5SfY~ zkCsZwnV3KF-+6H>n0^Cx-9@tZ`U1w^n;!BY6X%3`i5+u~rg1~a;5a|_La{^CGRwJj zC}@BBN5ORrg^PyE=r1qBnR80NW(^uKQHNesZG_OT%};AP`}V`RQza=|Eo1nkPC!xm zbT^tEy4=u5Ka1DZ9iQn7CPA^1&gCB=Q^;7gnoN`4k9=SFc<%mO1ZvA|aZ@QWWPj4W zAorT!v^*_;vtN%|>V=_IvI2P&)OVUy88}l*ndR+Hm)}1R?UrmkIxjl$-WOgb+QW4i zL_-tUL__G`=L^odoz90HdG}Y0Z2FMd_Wfeck#|_NVZfajk%&jz|1*22yo?L;>hzx! zMxm9;)m)No8{HDG&I|DMphA~kobc}rJp9U2XMKJIIc>UIqEmffYP(3#WML3oV^zH$ zKNy0G8z+i9XJ_m2`p$vdr#Qu6gNEqObi_?nMR91itX> zYu51^gstVu2kBEZQi}0{ryQI}ARtJcc;NXMvKo4n#s6J|>z97$U0@zU;T9fY9<3Gh zU{;!WC$R-#)gc^d$xE1>JT}=ao(V4=SU5LUwW0Bz9?B8wVGO7i(Ul2Ghrk<6m$e25 z;2had#hCD~n`2^fMLi=yp&=_%;A}E##9aK$NBrJW=NUb_+t@ z*XUoxFhV=C0Cr|~Cfq!uQ7rPm>eR+oq@IfO70xMx?*&E{GELKP_Vd5&xcCiR(vj`E zBv1?WIj_Heq*=y1v$_uTm<+IfTM?v@v<2ouLvj9Rs)+o{%tB+z9CF)Pc#ZAfM9sX> zHYxQ@FwzjB4xAc?!lN6G`$LA294_kGYSoJy*P2eazUxBKkX6YSid`T%m|57AwuJLx z`_EeuJ@fSCJNMMIh9OfWOs&Ld7T+$ut{)0UJEm5 z^g$}sH;4NSqd?!Z8dty98SXh8sW5CM_$fcV7&NkeAy>dqk89s5up0bPN)aSsji6fK z!ij9G5!qkd@6nIJ@}lpJ9?YQ+zi|!Y#R@3X;xEZ|ngPmfpMJ*6Ur_MTf4CNmVG|27TNy;3ZG4e2Dk!8W@@}qkVP|nm^V=GTv&& z6qS0XM!qR<=B!l-W}C$L15H03D8*yPu~&+M2Zw-lXrkrs^$p~HXnvV($1Fa}tbbJT zhsec^Gm3GPu7KF?`)Z}v;&G5Y*}g}A8A|rc2FkI|KwHPjLvMP=L2EIWUSfI?uO}W+ z(1|2tFb#?CbHN5^)crPKp6Wn157Y0pXJ*iTIoSQE=q%bZsxWG+l|bHiowCxa3lQbK zbcu((7nPXH6e6EdN|nEEH@r65j!nrq2Ut9ZFx_bK1<#!=bglJyLeJF$2jhJFstMoj zDeqIOmbzZJcC{+v?NToqpCNshk?BAh=AzgmETdq>s&eJ9?K~#mW0?E?q!y3L-znke zngRW_{CZW|e<Wc0E=a1ohwdsET_xLq#u*$@cL%%zo1UxlN=9c6?z@__RSy`u-&J znu%gFG;TF&>h3Iu2ak>0mYVxOyriY&IAbT;W*xO+%IbsJQ&s7aT_upPFK~>{p2%^f zeiSSV-hdy@e7*n0q~p$xgZnIkhcHWosw-!q6t&pxo5d!FKx@@rmLhis&seC-ePS6z zg&E3uW8fTM_#Vpo(>8mdioUe;&V*BJ139i}6S7-TN z6~O3#P*+#UBDk(!XrVJ$$9y`5MTzSZ1Q*RItCG-b5*Jh^E%l`rBpLGdgJ1jnm`v?C4>uoZJmI? zu>Y=tNFCDsr4qDaZv_Z3U#Y)d348??t5V)3!O7r%@1I>>MaKgt1C{QUL)cXKc}t5< ze7sLycBZiz{Z9QjVr5u~7W<*G`g%9AvohJ8vmS+I1GBoZTZJH?Zrrq&u^0Ez@)xlY<-n;5Q_e*4bHq#iERw4d4JBQ;73~LYcfY#p=-?Ym9*;u+O%gp zcdl%Mr0Jx3DzO>pl6heiLrn(h{dxSO zYT(7`RI|fRUP1r;6M=1qMj*Y2QHs*zDfZ@q4VUN7 z;Hh~S-`tm?!AWqY#W{r{U-!cSk8dl2r6E8imatdqKqKDJybxGA&<0CuR6+5Zf3Vc_ zyqosmC=RwBQ=ceUhl_C^M0XMAF$Kr}>V~b`uzJ%z{YTs=PVKSze)%UEmL66t{$y`L zHlu!8d0A?bUsKc5BR)x>*C`sZnB9kN<1a@Kb}VCA;qMzM{iDF}DYN0C711l93(2OF zC3tu?DigaRXCV3~AKI>_T$_02iBT2M<81` zww*m-6qMVmvK$r1a|KG7RBIzt+XVyeYgX z=a4{KNK1NfB6UWIasnMXDc>v+=L8WWW#?_5d~~l1KEd2QjG68URx6sF=*1ALm@h$w zz4sZf`IwJD(Se*_+<#KB(o5*t6~0E?o;f^H5;l*PK8cMUM2@svOJiU7-6FVP?NxtV zIv-*rewSO^qL30eG<@x1cp|D8rE_HxIZNHoM=G2OYO%)hP{IV;EM$0Xe;aTna+s`k z@1K6`!-xC$5?!wmzUZrWk~)ShL~dqHSGj@EpCp#E?H*7_T^&r&u66Fm(CT>MG}T5> zJ)-S}#zg*C@_53w)&x%F`yN-m?T6I1pN_T;Y{H_#KfPJvyxG>ncK-o!{%7-)DBoeZ zg?C-9UgB}B0!Cw|YRbVmeDCema!_v-L$3N|EAo?xoZU{1@>AJ(vc5FM*L4wN)5dR( z94~;==4Fbjwgku5wIWTIBL*%9PKdwSLGZyh-!)JQ&O_XpYt8+)t5D$2%pb36Dw1;U zVIS-BsklZA_F__#NF&=JL{ry=11*yjl#Vr+L0vyx{Cbhtzi+yw6TXmn!gi;P^BTTQ ztR4+>q)UPg}1o1X6Lh_wCb&fPR9nQ8fQEjE*E{tupE_5DBGsepk{e zlz_AC6|CG!2lb|b36Y2@=wo)C-#Sk0%X{{#5tB3t6Mphj+~|Ohf16{^bu9qj*qM&i z-g!KK_+TW1J0+=~-L;n--UOc2w^^1c`rz(LC*PgU8FbOJ*O(E?gx(Z$R<0m4w;0X;=o;bK%oGN1z(GOKRSz!dj;jJ)0)sxbdzN? zpask>Uz|&R)Qf?|tBM=r2-Bp42M4_dAh;?2qJ!mcxO+wJZ1mSo^a$g!I5JA`Bh%z0 zqP}$_kG5<;S5zmYv#VMh`QD0-bXQMy>J#@yRC`&yEboVcLcCTYI)EX4GvuJlLts^` zHN(NaiisM}SsYfE(96}8la_xO8G>opF0^)G=4QM5A%cT>F!yvd5q4)s8+-4o)D20hidRAr}-z2ErWT^^W9p6zt`U^TYSHA5O<9lL>Mr2<4F1|Z|VGO zR97*eblaVPcUSP4UKf=VJt=G)QF zHYMovNH$78X{Y=b-VNkEWtU79n}G9-=>tDPe~l@4FKake4m~%@A9YCfLxS>i$(>q# zNXkkXKIugz6+y=p%1P{p;&%V)jNbhM*4H7+Rr_l|`SvF3qpmUdR;D5*TxOS{Z(bCq!o1LB-=c50B3{Am)UBLlN zSB&VY;2A>BF8*{*{w}y;cRhm3ZxmZMDmwgZ%YZlSkZ8AgJF+;REmC>dk38!4(|-S$ zBKT8-&U#RVeF3NZ9q0*;#%Y;fp}YU!z5}DJPwVP&RuwMpzB`Hhl~!C0=EQqFJEah( zI}axmpV1l&wPS>%2w#KuG}dZJe6b?>c*1%0*$g%tP!+${#{O&wp43>VsE-YyT zKi;1ct8gS?$UZ=#4`X-Kocp$C720$h$Ua2>V?=9xa$=6)*B&~iR5&{TVm6H<#`V-v ze;Y~70aP8RB06qf`l1S!+;5XsdOPt*2$gWb_YIU#whNfJUIG@ycQ|iY{{yoLc^1z% zOJKYiUhFF~13QJj1eK-sAV)@!U&;{*(vz#w?-K{+Ai;_u2=uyOA~;7*-r_G>g`d4E zM|__SUVPx&&Nq(s7dfh<2w#AeQR9GyOcSm=o`}q8?7_=tX`a45+5_GqJ|+_nr}1!p zS8a;I0*)NW{1>=u0c}eHy$>9X!4)&|6q=E6E&i6f^WQ48Z;`nwHAMJ3(|c9DzydGzjyC?VtLbiF-|||1&WobmEt_JbXSZ6>!7$^9v^;cR8+WKP__l8~pY*r#(>f z6%`EX4SmZ#qbVK>;8kb_kzkQ;)S5((Nh9-EYSA`mR&?_S3=>@8_Px7Tvzqak(M#nZ zL2tNuVfPNvibhCeGp>B;vkA%Tx*rl8W|7;LzEDt{@Z&f$>!>GcQO-N$O6vDkn7`U? zbx~m$>2{&-Z!PN)SNq*NZ!f{uA_2@^n}#G(l9rlHG&qD`pq{jxK&uv6 z6&kB4{AvE5CS6k{%pH;PR%99lD&`t;B5y6!Ezv0_6cYTC+?~`4B2*;rg`O{3B5lNZ zCbrr>EDIcNZOz?gAtTE|u6|(F3^=Z)#z|Ul;pBa4L5s8Fc$z2H?&)MRkgoV&=NDds z%xGhg;G18N`qFDQ9XmpwI~4lN!gCn;)kKpw?a1gDQ575BIEX2ij*>344&lBt>}zfl z8(@BF+&bpS0MZXf$E-b{!s^>bGi>!$ApGdy67|T?R9tOV3G@ z_ii4%bd^FEwc21_)_{geY!PTK6@)pJ5L}bLw;7&uZ{e!Vw4aOF8lFs}>(M;EPW129 z|CSS+Yq8|ptSP>E`227*9}mIL;C=I=L`0t8Df_!q&wKmAgS%SXr(hAnv~R@-&n)0^ z>eAa`b3|TYO#X-du5tV==5@z7d={fCKRvJ`@)tTkDei2wPv8R6=Z8o3O=E@}yQ&xc zJZ$y%HOH)zF}`oWCz;?VpFJqA{6QxQez|UD6*&&0oDLU{i~|`QW4j}|7W1%gx^HBF zeGS#5pJ~*Zt^=@;_d0&f#+Ar}?!1py;ECR4W9FhEd~x&Hu};}>T)(}K!snnJwi7klD86BFZjeuKypze)a2^Ya3DPq1uD6FI>0T}4WAV(s{k@{+7pRSzbx zF_dE1H1h6|vXLVC=LsIaUw$Y0`>nK3lDZ&!Byzn4WNf~6*w3tCUTNxp%?fcJhlgCz3GY-iT^4oo zY@J8Bnb5M>BW2K)Q>nkfP>x~!jaO=-SKv$;S?_FlEz+x>as5QRzq%j(B)xSV1eR-u zZtYVcgVQGcDef~wpOTMmjCtQEXuh4&pOS9IUbXv5k4gytx3f;H{W}?~lG)5Wrd!eQ zZ@){G`XG8W#`>RiAICEt>AFdx8<5$?M|FB&20AKst91OE2D6|i5+W6);NcPXS4+AK zUg%Rx#yBm5_GB`y?NOU#GM+OMa#yH*LFrJw4# zt7GuuS*L#6#SM6Jb6T)beg=(}vM1UbMlp9^LUY!mR{Y{JIVAOg=M-{ca}uL`1y*In2~m%k!gR$d$JjsmWNfTh`!*mMPS6_^cbEn za(TS_5G6_4?$_;_))YA7%6@%2d>yqrk`8>ACD!*&tFoc)I!GOJ^7yVn%yW*DJzS~- zz%0=0sr{h_2F!M;xNtXPhU}$3bIQwrL2Oxf>82n=$~|(0{Ewpm+(`%5k4H^{H)76 z32gh>WGubc(TJ!1TrROr!!yR)4w(}7g5~ni7(Jgsjz_DX7Ay$jt+BTqUpmF~;^pTn`pBvtz8FzYB<%C#K2qnZjlwt;^p^h?ou;-R(v@d%W< zkag}Oe$;>o|P>Y|zS+32eIh7^sgC_e%XNcdtA^ z#!m)ZQ92Jv`0?a#wv%VqVAla}U9plzVDN9I8&sMDUgIs7pF~dL!C}U`oq1GJr)ckT zT8I$3Y4-8XM@n1hTgO=Za#uO*DvH?kyk!<>JvMoq1m{t#zjn_383pO9@{^b+vJ+Tv zeGf^ML;{(Oh-4MwJl$`#uuXA>m^+0H8rLIRai<^yW!%LU%p5Itd$)^1s^W>#kAGb) zn1696n)3tQYN}|3~QiABVh30^&)QWx*>6i z&>1(SPO zG9z%M!caYZvK-wGa_wzB+K4m!&0pT^nZ)ydp0gEocVO}RQKwz5xma3SFQ-3F^a3(& zm|PN^1?3QLo#nPb7}5_T7Zr_Rv+uIUL)|PKV(62ee_0E$_fAY{Z%iV&y|0S5s~e5F zu8jJzEP-?OSswEfZBWAgOVl)hQfkrm(M~F=F$}NuB_)@3!mH+6FFq#p-y z=I2r9MgN^jVqVZ`o%>->wu&?Q!~Ew@6M3<>3&L6BgCI#MFDJA#hmt+_dcq7kh@A6b zscUIBAoQ(JCjz_pFufofgRuSI9E{xF_8ou0< zj^~XkKI~pAL()=jt)2b^D1^0@t|!&OZQi8GMR(%+bl~Q==@G9LXa$Eh!9$_9HepCqCM?Il8C<55{^N&nEDw&KwrlBwwgi@6 z756vL^OBUruzxzvmM4y5R2{q=r})`P%&%j2e->lh2AunIvRQGe1FDZ}+8P)n<5yX^ zzQvX@B9EA6B@s3N%Z7@UCkE=lsN(wB82)M^S7=nWdSVRo^p-!?QxU#jrHr2wsTZy3 z=|rX9^$>Zx^Wh6WOK^ocHRNA4(Gyz!DZTZ94D~xt9e6f31JLX3`dDiU?wG^P1tO1i zIP8I#I70*K&u{;i;xdFcM9jqfjC0V|{8tFG$TBke(ouF3-}$tm)CbYgi{q<9oewv2V@XBKF^iUhfO|JD1RiKUr>M8KYu^`XVk==tC1hzv)ka3r zjyEf!lzlM$*Gy^P+ziyKQ`E0YtpN4Gs+856OQ_0MU!irn5*LmY_Z}LXg(+ckzPUX^ zU{Lci&q?_!`X~K*Zj?_Um3oZocs|WIdbrT`=}OOHd%TR`C!+s$g7MOscR^E7aOb2r z#p@X;XP0f_G8^UvDqEqoCvH(0a~1 zyuNy6%<5Y&eq%|eDIs)V!`i{^b(w4!omx&kz&H(KrG1B`-p`@%`LDr0vk-)Rp2>D* zzUneC#ZKI5f4ExY=LFpPka5@f8{uoZOR;3MOo4gF-nq3$%|Pu# z7Ca?2fe%8<8r9Ropi?t${2bQ9dmT%C9>H`BzW1NsE~gp{J{IH8*xCutIX!%6tN?DlAk@Fz=pnaqbFEA+Q% zZ$>?4m$&Et2(2UTJF?u)^)AMmuLg{5FB`zU=I=#gqF>+Ivm9$D-vzuCh0soioKeeay4#*M$A9^keiIZE++S8>DJ)h zI=hs7o*xi@f6q#!)HJ;FQK=j%rI30$A!>FilY*4r)oT`}-~`9-T6a1n<-xW3Q^xt_ z6X0@YuV=bF;gc)y(Of@12dN>4_zvYKqWBGI>nYk1oKBW>zs7BeS7q}8FUweAE~BVe zk?<&z+C~*r2))LT;!vxh>oSaqYfk>!)dHM{Ue{XI&SCfOFS8QK6|g-XVH>M63%&iR z=N_M$2JxX4JSe{g^27huvw+}?aRlXOTJ&OJzej-HrAg=y{I%#TGzlajL90sXbsV=@ z^mSyeK*x%Gl)SuM_^Cd`DWk0h1I%~*Zuja(Zi)I~Bh7k<;{DjoYBdE+okDqX)Ps0p zflG14pOR#oMYW@rXK<4G<>SU3Z79en-dKX}5){n|*ZsnmX>o3NM11qS%?Xq|OjY)i!Q#=uGMMYYX_-fu!p8)SdgR+fU7g3!h zvgT9qIJB=;FS%8xLrK2>x`o*!>{dF&e7Jz%V5$2wNR>>0a=8aNE|SP=-fiW>y;vPkZFhU;GW_TN{mgA$TcolvEw{Pf G2mb?DtmEqd literal 0 HcmV?d00001 diff --git a/tests/test_logging.py b/tests/test_logging.py index dc74c78ed..cf99b8c35 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -20,5 +20,4 @@ def test_infer_experiment_git_root(): assert pathlib.Path(root).exists() repo = Repo(root) assert repo.working_dir == root - print(root, __file__) assert pathlib.Path(__file__).is_relative_to(root), f"{__file__} is not relative to {root}" diff --git a/tests/test_mpt.py b/tests/test_mpt.py index 8b1fd2bb9..8f3384c3d 100644 --- a/tests/test_mpt.py +++ b/tests/test_mpt.py @@ -10,7 +10,7 @@ from levanter.models.mpt import MptConfig, MptLmHeadModel from levanter.utils.tree_utils import inference_mode -from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs, skip_if_no_torch +from test_utils import check_model_works_with_seqlen, skip_if_no_torch @pytest.mark.skip(reason="MPT is broken in the latest version of transformers") @@ -104,15 +104,6 @@ def test_mpt_nano_compare(attn_impl): # lev_model = MptLmHeadModel.from_hf_pretrained("mosaicml/mpt-7b") -@parameterize_with_configs("mpt*.yaml") -def test_mpt_configs(config_file): - from levanter.main.train_lm import TrainLmConfig - - config_class = TrainLmConfig - - check_load_config(config_class, config_file) - - def test_pass_different_length_seq(): config = MptConfig( max_seq_len=32, diff --git a/tests/test_sophia.py b/tests/test_sophia.py new file mode 100644 index 000000000..1ca3a7265 --- /dev/null +++ b/tests/test_sophia.py @@ -0,0 +1,66 @@ +import functools +import os + +import equinox as eqx +import equinox.nn as nn +import jax +import jax.numpy as jnp +import numpy as np + +import levanter +import levanter.optim.sophia + + +def test_sophia_h(): + key = jax.random.PRNGKey(0) + model = nn.Linear(4, 4, use_bias=False, key=key) + data = np.load(f"{os.path.dirname(__file__)}/data/hero_data.npy").astype("float32") + optimizer = levanter.optim.sophia.sophia_h( + lr=1, + b1=0, + b2=0.99, + gamma=2, + weight_decay=0.0, + clip_threshold=1, + key=key, + update_interval=1, + ) + model = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), model) + zero_grad = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), model) + + opt_state = optimizer.init(model) + + def loss_fn(model, data): + out = eqx.filter_vmap(model)(data) + return jnp.mean(out**2) * 4 + + jit_update = eqx.filter_jit(optimizer.update) + + obj_fn = functools.partial(loss_fn, data=data) + for i in range(1000): + _, opt_state = jit_update(zero_grad, opt_state, params=model, obj_fn=obj_fn) + + # print('Test-estimated hessian: most coordinates should be approximately 2') + # print('Estimated hessian:', opt_state[0].h.weight) + assert jnp.allclose(opt_state[0].h.weight, 2, rtol=0.2, atol=0.3) # this is very approximate + + grad_loss_fn = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn)) + + loss, grad = grad_loss_fn(model, data) + model_updates, opt_state = optimizer.update(grad, opt_state, params=model, obj_fn=obj_fn) + model = eqx.apply_updates(model, model_updates) + + # loss should be 15.74834156036377 + assert jnp.allclose(loss, 15.74834156036377) + + # print("Test-model param after 1 step: most coordinates should be very loosely 0.5") + assert jnp.allclose(model.weight, 0.5, rtol=0.2, atol=0.1) # this is very approximate + + # print("Test-loss: loss should shrink by approximately 75% after each iteration") + for i in range(10): + loss, grad = grad_loss_fn(model, data) + model_updates, opt_state = optimizer.update(grad, opt_state, params=model, obj_fn=obj_fn) + model = eqx.apply_updates(model, model_updates) + + # print('Step:', i , "Loss:", loss.item()) + assert loss < 15.74834156036377 * 0.75 ** (i + 1) From 474206e7901bfded9a9e532611434a86e0a42b74 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 9 Feb 2024 13:35:21 -0800 Subject: [PATCH 17/19] Trackers let us abstract out TB vs wandb --- README.md | 3 +- config/backpack.yaml | 2 +- config/gpt2_1536.yaml | 2 +- config/gpt2_20b.yaml | 2 +- config/gpt2_7b.yaml | 2 +- config/gpt2_large.yaml | 4 +- config/gpt2_large_sophia_g.yaml | 21 +++ config/gpt2_medium.yaml | 2 +- config/gpt2_micro.yaml | 2 +- config/gpt2_nano.yaml | 3 +- config/gpt2_nano_tb.yaml | 25 +++ config/gpt2_small.yaml | 4 +- config/gpt2_small_fast.yaml | 7 +- config/gpt2_small_fast_mix.yaml | 2 +- config/gpt2_small_fast_pile.yaml | 2 +- config/gpt2_small_fast_wiki.yaml | 2 +- config/gpt2_small_sophiah.yaml | 2 +- config/gpt2_xl.yaml | 2 +- config/llama2_7b.yaml | 3 +- config/llama2_7b_continued.yaml | 3 +- config/llama2_nano.yaml | 2 +- config/lora/mpt_biomed.yaml | 3 +- config/mpt_7b_continued.yaml | 22 --- config/mpt_7b_continued_biomedlm.yaml | 27 --- docs/Configuration-Guide.md | 84 ++++++++- docs/Training-On-Your-Data.md | 3 +- docs/dev/Trackers.md | 104 +++++++++++ examples/alpaca-lora/alpaca_lora.py | 92 +++++----- mkdocs.yml | 2 +- pyproject.toml | 2 +- src/levanter/__init__.py | 1 + src/levanter/callbacks.py | 61 ++++--- src/levanter/logging.py | 237 ++------------------------ src/levanter/main/cache_dataset.py | 13 +- src/levanter/main/eval_lm.py | 2 +- src/levanter/main/train_lm.py | 56 +++--- src/levanter/main/viz_logprobs.py | 6 +- src/levanter/tracker/__init__.py | 29 ++++ src/levanter/tracker/helpers.py | 75 ++++++++ src/levanter/tracker/tensorboard.py | 81 +++++++++ src/levanter/tracker/tracker.py | 117 +++++++++++++ src/levanter/tracker/tracker_fns.py | 235 +++++++++++++++++++++++++ src/levanter/tracker/wandb.py | 199 +++++++++++++++++++++ src/levanter/trainer.py | 120 ++++++++++--- tests/test_eval_lm.py | 2 +- tests/test_export_to_hf.py | 3 +- tests/test_logging.py | 4 +- tests/test_tracker.py | 80 +++++++++ tests/test_train_lm.py | 2 +- tests/test_viz_lm.py | 8 +- 50 files changed, 1326 insertions(+), 441 deletions(-) create mode 100644 config/gpt2_large_sophia_g.yaml create mode 100644 config/gpt2_nano_tb.yaml delete mode 100644 config/mpt_7b_continued.yaml delete mode 100644 config/mpt_7b_continued_biomedlm.yaml create mode 100644 docs/dev/Trackers.md create mode 100644 src/levanter/tracker/__init__.py create mode 100644 src/levanter/tracker/helpers.py create mode 100644 src/levanter/tracker/tensorboard.py create mode 100644 src/levanter/tracker/tracker.py create mode 100644 src/levanter/tracker/tracker_fns.py create mode 100644 src/levanter/tracker/wandb.py create mode 100644 tests/test_tracker.py diff --git a/README.md b/README.md index 5a6b89cf6..13097d7dd 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/backpack.yaml b/config/backpack.yaml index 5b6cef3cb..493be77a3 100644 --- a/config/backpack.yaml +++ b/config/backpack.yaml @@ -10,7 +10,7 @@ model: num_senses: 16 sense_intermediate_scale: 4 trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "backpack" ] diff --git a/config/gpt2_1536.yaml b/config/gpt2_1536.yaml index 50ccbd882..a3633bf65 100644 --- a/config/gpt2_1536.yaml +++ b/config/gpt2_1536.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_20b.yaml b/config/gpt2_20b.yaml index 76bf6ba96..6f5f40e1b 100644 --- a/config/gpt2_20b.yaml +++ b/config/gpt2_20b.yaml @@ -12,7 +12,7 @@ model: use_bias: false fcm_prob: 0.15 trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_7b.yaml b/config/gpt2_7b.yaml index affb67aa5..36a3d4fd2 100644 --- a/config/gpt2_7b.yaml +++ b/config/gpt2_7b.yaml @@ -11,7 +11,7 @@ model: resid_pdrop: 0.0 fcm_prob: 0.15 trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_large.yaml b/config/gpt2_large.yaml index 525a92c99..8a8aea8d7 100644 --- a/config/gpt2_large.yaml +++ b/config/gpt2_large.yaml @@ -8,13 +8,13 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 model_axis_size: 1 - per_device_parallelism: 16 + per_device_parallelism: -1 optimizer: learning_rate: 2E-4 weight_decay: 0.1 diff --git a/config/gpt2_large_sophia_g.yaml b/config/gpt2_large_sophia_g.yaml new file mode 100644 index 000000000..53a1d0806 --- /dev/null +++ b/config/gpt2_large_sophia_g.yaml @@ -0,0 +1,21 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 1280 + num_heads: 20 + num_layers: 36 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-g"] + + num_train_steps: 200000 + mp: p=f32,c=bfloat16 + +optimizer: + type: sophia-g + learning_rate: 2E-4 + weight_decay: 0.15 diff --git a/config/gpt2_medium.yaml b/config/gpt2_medium.yaml index 9ea4408bc..47e21799c 100644 --- a/config/gpt2_medium.yaml +++ b/config/gpt2_medium.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_micro.yaml b/config/gpt2_micro.yaml index 274ecddaa..0a8283e78 100644 --- a/config/gpt2_micro.yaml +++ b/config/gpt2_micro.yaml @@ -6,7 +6,7 @@ model: num_heads: 8 num_layers: 4 trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 993302670..1ad0ceb3b 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -14,8 +14,7 @@ trainer: - every: 50 save_interval: 5m - per_device_eval_parallelism: 1 - per_device_parallelism: 1 + per_device_parallelism: -1 train_batch_size: 32 tensor_parallel_axes: ["mlp", "heads"] diff --git a/config/gpt2_nano_tb.yaml b/config/gpt2_nano_tb.yaml new file mode 100644 index 000000000..f6847d693 --- /dev/null +++ b/config/gpt2_nano_tb.yaml @@ -0,0 +1,25 @@ +data: + id: dlwh/wikitext_103_detokenized +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_parallelism: -1 + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + tracker: + type: tensorboard + logdir: tb_logs/ diff --git a/config/gpt2_small.yaml b/config/gpt2_small.yaml index 74d0e031a..b3e0295af 100644 --- a/config/gpt2_small.yaml +++ b/config/gpt2_small.yaml @@ -8,13 +8,13 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 model_axis_size: 1 - per_device_parallelism: 4 + per_device_parallelism: -1 train_batch_size: 512 optimizer: diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 4c8434f38..6242a37bc 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -8,9 +8,10 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: - project: "levanter" - tags: [ "openwebtext", "gpt2", "itest"] + tracker: + - type: wandb + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] mp: p=f32,c=bfloat16 model_axis_size: 1 diff --git a/config/gpt2_small_fast_mix.yaml b/config/gpt2_small_fast_mix.yaml index 0785e9103..ca9fa2ca6 100644 --- a/config/gpt2_small_fast_mix.yaml +++ b/config/gpt2_small_fast_mix.yaml @@ -21,7 +21,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext+wiki", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index f30743c1d..a0336da45 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_wiki.yaml b/config/gpt2_small_fast_wiki.yaml index 407d8705b..a25736434 100644 --- a/config/gpt2_small_fast_wiki.yaml +++ b/config/gpt2_small_fast_wiki.yaml @@ -9,7 +9,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2", "itest"] diff --git a/config/gpt2_small_sophiah.yaml b/config/gpt2_small_sophiah.yaml index 1dd5824c3..fd82ab226 100644 --- a/config/gpt2_small_sophiah.yaml +++ b/config/gpt2_small_sophiah.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2", "sophia-h"] diff --git a/config/gpt2_xl.yaml b/config/gpt2_xl.yaml index 8230b56a5..026fc077e 100644 --- a/config/gpt2_xl.yaml +++ b/config/gpt2_xl.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 diff --git a/config/llama2_7b.yaml b/config/llama2_7b.yaml index 68931f3fa..b4ebe705f 100644 --- a/config/llama2_7b.yaml +++ b/config/llama2_7b.yaml @@ -11,7 +11,8 @@ model: # initialize_from_hf: "meta-llama/Llama-2-7b-hf" # use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["openwebtext", "llama"] diff --git a/config/llama2_7b_continued.yaml b/config/llama2_7b_continued.yaml index e03be7168..edb72a7e4 100644 --- a/config/llama2_7b_continued.yaml +++ b/config/llama2_7b_continued.yaml @@ -6,7 +6,8 @@ model: initialize_from_hf: true use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pile", "llama2"] diff --git a/config/llama2_nano.yaml b/config/llama2_nano.yaml index c3ae4cdb8..58415022e 100644 --- a/config/llama2_nano.yaml +++ b/config/llama2_nano.yaml @@ -12,7 +12,7 @@ model: num_kv_heads: 4 num_layers: 2 trainer: - wandb: + tracker: project: "levanter" tags: ["openwebtext", "llama"] mp: p=f32 diff --git a/config/lora/mpt_biomed.yaml b/config/lora/mpt_biomed.yaml index f49267ca1..6b19d0ab5 100644 --- a/config/lora/mpt_biomed.yaml +++ b/config/lora/mpt_biomed.yaml @@ -11,7 +11,8 @@ lora: alpha: 32.0 target_modules: ["Wqkv"] trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["mpt", "lora", "pubmed"] diff --git a/config/mpt_7b_continued.yaml b/config/mpt_7b_continued.yaml deleted file mode 100644 index a7eaf800b..000000000 --- a/config/mpt_7b_continued.yaml +++ /dev/null @@ -1,22 +0,0 @@ -data: !include data/pile_source_old.yaml -model: - type: mpt -initialize_from_hf: true -use_hf_model_config: true -trainer: - wandb: - project: "levanter" - tags: ["pile", "mpt"] - - mp: p=f32,c=bfloat16 - - model_axis_size: 1 - per_device_parallelism: 4 - per_device_eval_parallelism: 4 - - train_batch_size: 1024 - num_train_steps: 10000 - steps_per_eval: 500 -optimizer: - learning_rate: 1.2e-4 - weight_decay: 0.1 diff --git a/config/mpt_7b_continued_biomedlm.yaml b/config/mpt_7b_continued_biomedlm.yaml deleted file mode 100644 index 44961df46..000000000 --- a/config/mpt_7b_continued_biomedlm.yaml +++ /dev/null @@ -1,27 +0,0 @@ -data: - train_urls: - - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_train.{1..128}-of-128.jsonl.gz" - validation_urls: - - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_val.{1..8}-of-8.jsonl.gz" - cache_dir: "gs://pubmed-mosaic/tokenized/pubmed-sharded-neox/" - tokenizer: "EleutherAI/gpt-neox-20b" -model: - type: mpt -initialize_from_hf: "mosaicml/mpt-7b@68e1a8e0ebb9b30f3c45c1ef6195980f29063ae2" -use_hf_model_config: true -trainer: - wandb: - project: "levanter" - tags: ["pubmed", "mpt", "continued"] - - mp: p=f32,c=bfloat16 - - model_axis_size: 1 - per_device_parallelism: 8 - - train_batch_size: 2048 - num_train_steps: 50000 - steps_per_eval: 1000 -optimizer: - learning_rate: 1.2e-5 - weight_decay: 0.1 diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index 607129e1a..bdb09e4f1 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -35,7 +35,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] @@ -179,12 +180,34 @@ The default step-based checkpoint policy is to save a checkpoint every 10,000 st -## WandB +## Trackers and Logging -We mostly use wandb for logging, including using wandb for allocating the run id. We may change this. -These all live in a nested object `wandb` inside `trainer`. Most of these are the same as the corresponding `wandb.init` -parameters. +We mostly use [W&B](https://wandb.ai/site) for tracking values and other metadata about a run. However, we also support +Tensorboard and a few other trackers. You can also use multiple trackers at once, or even write your own. +See [Trackers](dev/Trackers.md) for more information. + +### W&B + +Wandb is the default tracker and is installed by default. To use it, you can configure it in your config file: + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + +Because wandb is the default, you can also just do: + +```yaml +trainer: + tracker: + project: my-project + entity: my-entity +``` + | Parameter | Description | Default | @@ -206,6 +229,35 @@ of your main script. To use it, you must also set the right environment variables. Something like `XLA_FLAGS="--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*`. We will automatically parse out the env variable. +### Tensorboard + +Tensorboard is also supported. To use it, you can configure it in your config file: + +```yaml +trainer: + tracker: + type: tensorboard + logdir: logs +``` + +### Multiple Trackers + +In some cases, you may want to use multiple trackers at once. +For example, you may want to use both W&B and Tensorboard. + +To do this, you can use the [levanter.tracker.tracker.CompositeTracker][] class, or, if using a config file, you +can specify multiple trackers: + +```yaml +trainer: + tracker: + - type: wandb + project: my-project + entity: my-entity + - type: tensorboard + logdir: logs +``` + ## Ray Config Levanter will by default automatically start a Ray cluster with all @@ -277,8 +329,26 @@ We won't go into detail here. You can see the auto-generated docs below. ::: levanter.checkpoint.Checkpointer -### Wandb -::: levanter.logging.WandbConfig +### Trackers and Metrics + +See also [Trackers](dev/Trackers.md) for more information. Basic configuration is shown below. + +#### Single Tracker + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + + + +::: levanter.tracker.wandb.WandbConfig + +::: levanter.tracker.tensorboard.TensorboardConfig + ### Distributed and Ray diff --git a/docs/Training-On-Your-Data.md b/docs/Training-On-Your-Data.md index edf33e0af..4c543b04f 100644 --- a/docs/Training-On-Your-Data.md +++ b/docs/Training-On-Your-Data.md @@ -214,7 +214,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" # TODO tags: ["gpt2"] diff --git a/docs/dev/Trackers.md b/docs/dev/Trackers.md new file mode 100644 index 000000000..1f1677d52 --- /dev/null +++ b/docs/dev/Trackers.md @@ -0,0 +1,104 @@ +# Trackers and Metrics + +Logging values and other metadata about a run is a core requirement for any ML framework. +Until recently, Levanter had a hard dependency on [W&B](https://wandb.ai/site) for tracking such values. + +In the latest version, we introduce the [levanter.tracker.Tracker][] interface, which allows you to use any tracking backend you want. +The interface name is taken from the [HuggingFace Accelerate](https://github.com/huggingface/accelerate/blob/0f2686c8d3e6d949c4b7efa15d7f2dee44f7ce91/src/accelerate/tracking.py#L395) +framework. + +Given Levanter's historical dependency on W&B, the interface is designed to look similar to W&B's API. +The methods currently exposed are: + +* [levanter.tracker.current_tracker][]: returns the current tracker instance or sets it. +* [levanter.tracker.log_metrics][]: logs a dictionary of metrics for a given step. +* [levanter.tracker.log_summary][]: logs a dictionary of "summary" information, analogous to W&B's version. +* [levanter.tracker.get_tracker][]: returns a tracker with the given name. +* [levanter.tracker.jit_log_metrics][]: a version of [levanter.tracker.log_metrics][] that works inside JAX jit. + +A basic example of using the tracker interface is shown below: + +```python +import wandb +from levanter.tracker import current_tracker, log_metrics, log_summary +from levanter.tracker.wandb import WandbTracker + +with current_tracker(WandbTracker(wandb.init())): + for step in range(100): + log_metrics({"loss": 100 -0.01 * step}, step=step) + + log_summary({"best_loss": 0.0}) +``` + +A more typical example would be to use it in a config file, as we do with Trainer: + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + +### Multiple Trackers + +In some cases, you may want to use multiple trackers at once. +For example, you may want to use both W&B and Tensorboard. + +To do this, you can use the [levanter.tracker.tracker.CompositeTracker][] class, or, if using a config file, you +can specify multiple trackers: + +```yaml +trainer: + tracker: + - type: wandb + project: my-project + entity: my-entity + - type: tensorboard + logdir: logs +``` + +## Adding your own tracker + +To add your own tracker, you need to implement the [levanter.tracker.Tracker][] interface. +You will also want to register your config with TrackerConfig as a "choice" in the choice type. +Follow the pattern for Tensorboard and W&B. + +TODO: expand this section. + + +## API Reference + +### Core Functions + +::: levanter.tracker.current_tracker + +::: levanter.tracker.log_metrics + +::: levanter.tracker.log_summary + +::: levanter.tracker.get_tracker + +::: levanter.tracker.jit_log_metrics + +### Trackers + +::: levanter.tracker.Tracker + +::: levanter.tracker.tracker.CompositeTracker + +::: levanter.tracker.tracker.NoopTracker + +::: levanter.tracker.tensorboard.TensorboardTracker + +::: levanter.tracker.wandb.WandbTracker + +### Tracker Config + +::: levanter.tracker.TrackerConfig + +::: levanter.tracker.tracker.NoopConfig + +::: levanter.tracker.tensorboard.TensorboardConfig + +::: levanter.tracker.wandb.WandbConfig diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index a4380a92b..0e7c5790e 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -8,7 +8,6 @@ import jax.random as jrandom import transformers -import wandb import haliax as hax @@ -49,7 +48,7 @@ class TrainArgs(alpaca.TrainArgs): def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -101,53 +100,58 @@ def loraize_hf_model(model): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - # end major difference from Alpaca - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + trainer.add_default_hooks() + state = trainer.initial_state(training_key, model=model) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } ) - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": diff --git a/mkdocs.yml b/mkdocs.yml index 568716ac4..28fdb9849 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -98,7 +98,7 @@ nav: - "Hardware-Agnostic-Training.md" - 'Developer Guide': - 'dev/Port-Models.md' -# - 'dev/Trackers.md' + - 'dev/Trackers.md' - 'FAQ' : 'faq.md' - Other: - 'Levanter-1.0-Release.md' diff --git a/pyproject.toml b/pyproject.toml index 14f010c1b..a717d9d97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "transformers>=4.22.0", "optax", "wandb", - "draccus>=0.6", + "draccus>=0.7.1", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets==2.16.1", diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 30c32a712..a7def0acb 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -7,3 +7,4 @@ import levanter.optim as optim import levanter.trainer as trainer import levanter.visualization as visualization +from levanter.trainer import initialize diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 2292c714a..b0244e0e3 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -1,5 +1,5 @@ import copy -import logging +import logging as pylogging import os import re import subprocess @@ -11,20 +11,24 @@ import humanfriendly import jax -import wandb from tqdm import tqdm -from levanter.logging import WandbConfig, log_optimizer_hyperparams, save_xla_dumps_to_wandb +import levanter.tracker +from levanter.logging import save_xla_dumps_to_wandb +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils.jax_utils import jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs -logger = logging.getLogger(__name__) +logger = pylogging.getLogger(__name__) def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): total_loss = 0.0 + total_load_time = 0.0 + total_loss_time = 0.0 n = 0 if name is not None: @@ -33,10 +37,20 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n desc = "eval" pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches) - for batch in pbar: + iter_ = iter(pbar) + while True: + time_in = time.time() + batch = next(iter_, None) + if batch is None: + break + load_time = time.time() - time_in + total_load_time += load_time loss = loss_fn(model, batch) total_loss += loss.item() n += 1 + loss_time = time.time() - time_in - load_time + total_loss_time += loss_time + pbar.set_postfix(loss=total_loss / n) if max_batches is not None and n >= max_batches: @@ -45,6 +59,9 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n if n > 0: total_loss /= n + # logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") + # logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") + return total_loss @@ -57,11 +74,10 @@ def compute_validation_loss( def compute_loss(info: StepInfo): loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) - if wandb.run is not None: - prefix = "eval" - if name: - prefix += "/" + name - wandb.log({f"{prefix}/loss": loss}, step=info.step) + prefix = "eval" + if name: + prefix += "/" + name + levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -73,12 +89,14 @@ def compute_loss(info: StepInfo): return compute_loss -def log_to_wandb(step: StepInfo): - wandb.log({"train/loss": step.loss, "global_step": step.step}, step=step.step) +def log_step_info(step: StepInfo): + levanter.tracker.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") def wandb_xla_logger(config: WandbConfig): + import wandb + last_mtime = wandb.run and wandb.run.start_time or time.time() def log_xla_to_wandb(step: StepInfo): @@ -108,14 +126,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - wandb.log( + levanter.tracker.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -125,7 +143,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - wandb.log( + levanter.tracker.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -152,7 +170,7 @@ def update_pbar(step: StepInfo): def log_memory_usage(sample_interval: float = 1.0, log_individual_devices: bool = False): """ - Logs memory usage to wandb. This runs a loop that samples memory usage every `sample_interval` seconds. + Logs memory usage. This runs a loop that samples memory usage every `sample_interval` seconds. We only log when hooks are invoked, so there's not much point in running this much more frequently than you invoke the hook. @@ -218,7 +236,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - wandb.log({"memory/total": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -229,14 +247,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - wandb.log({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - wandb.log({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage @@ -262,6 +280,9 @@ def compute_and_viz_log_probs(step: StepInfo): path = os.path.join(html_dir, f"step_{step}.html") viz_probs(path, model, tokenizer, log_prob_fn, test_data, max_docs=max_docs) + # TODO: convert to generic logging + import wandb + wandb.log({"log_probs": wandb.Html(path)}, step=step.step) return compute_and_viz_log_probs diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 4fbb4a618..78588669f 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -1,57 +1,32 @@ import contextlib -import dataclasses -import logging import logging as pylogging import os -import tempfile import time -import warnings -from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Union +from typing import List, Union -import draccus import jax -import wandb -from draccus import field -from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from optax import MultiStepsState -from levanter.utils import jax_utils -from levanter.utils.jax_utils import jnp_to_python +pylogger = pylogging.getLogger(__name__) -logger = pylogging.getLogger(__name__) - -def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): - if isinstance(opt_state, MultiStepsState): - opt_state = opt_state.inner_opt_state - - def wrap_key(key): - if prefix: - return f"{prefix}/{key}" - return key - - if hasattr(opt_state, "hyperparams"): - params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - wandb.log(params, step=step) - - -def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: +def init_logging(log_dir: Union[str, Path], run_id: str, level: int = pylogging.INFO) -> None: """ Initialize logging.Logger with the appropriate name, console, and file handlers. :param path: Path for writing log file :param level: Default logging level """ + log_dir = Path(log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + path = log_dir / f"{run_id}.log" + process_index = jax.process_index() log_format = f"%(asctime)s - {process_index} - %(name)s - %(filename)s:%(lineno)d - %(levelname)s :: %(message)s" # use ISO 8601 format for timestamps, except no TZ, because who cares date_format = "%Y-%m-%dT%H:%M:%S" - os.makedirs(os.path.dirname(path), exist_ok=True) - handlers: List[pylogging.Handler] = [pylogging.FileHandler(path, mode="a"), pylogging.StreamHandler()] # Create Root Logger w/ Base Formatting @@ -64,13 +39,21 @@ def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os + from levanter.tracker.wandb import is_wandb_available + + if not is_wandb_available(): + pylogger.warning("Wandb is not available, so we can't save XLA dumps") + return + + import wandb + # attempt to parse xla_flags to see if we're dumping assembly files flags = os.getenv("XLA_FLAGS", None) if flags is not None and "xla_dump_to" in flags: # parse the path # this isn't robust to quotes path = flags.split("xla_dump_to=")[1].split(" ")[0] - logger.info(f"Found xla_dump_to={path}, logging to wandb") + pylogger.info(f"Found xla_dump_to={path}, logging to wandb") if wandb.run: # only want to save the files that were generated during this run # XLA_FLAGS has to be set before the first jax call, so we can't just set it in the middle of the run @@ -82,7 +65,7 @@ def include_file(path: str): wandb.run.log_code(root=path, name="xla_dumps", include_fn=include_file) else: - logger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") + pylogger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") @contextlib.contextmanager @@ -100,23 +83,6 @@ def fn(): end = time.time() -@contextlib.contextmanager -def log_time_to_wandb(name: str, *, step=None): - with capture_time() as fn: - yield fn - wandb.log({name: fn()}, step=step) - - -def jittable_wandb_log(data, *, step=None): - """uses jax effect callback to log to wandb from the host""" - if is_wandb_available(): - jax.debug.callback(wandb.log, data, step=step) - - -def is_wandb_available(): - return wandb is not None and wandb.run is not None - - def silence_transformer_nag(): # this is a hack to silence the transformers' "None of PyTorch, TensorFlow 2.0 or Flax have been found..." thing # which is annoying and not useful @@ -125,172 +91,3 @@ def silence_transformer_nag(): os.environ["TRANSFORMERS_VERBOSITY"] = "error" import transformers # noqa: F401 - - -@dataclass -class WandbConfig: - """ - Configuration for wandb. - """ - - entity: Optional[str] = None # An entity is a username or team name where you send runs - project: Optional[str] = None # The name of the project where you are sending the enw run. - name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. - tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. - id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project - group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. - mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. - resume: Optional[Union[bool, str]] = None # - """ - Set the resume behavior. Options: "allow", "must", "never", "auto" or None. - By default, if the new run has the same ID as a previous run, this run overwrites that data. - Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) - document for more details. - """ - - save_code: Union[bool, str] = True - """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we - typically don't run from the root of the repo).""" - - save_xla_dumps: bool = False - """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - - def init(self, run_id: Optional[str], hparams=None, **extra_hparams): - import wandb - - if run_id is not None and self.id is not None and run_id != self.id: - warnings.warn( - f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" - " config." - ) - - id = self.id - if id is None: - id = run_id - - if hparams is None: - hparams_to_save = {} - elif dataclasses.is_dataclass(hparams): - hparams_to_save = dataclasses.asdict(hparams) - else: - hparams_to_save = dict(hparams) - - if extra_hparams: - hparams_to_save.update(extra_hparams) - - # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled - # however, we do share information about the run id, so that we can link to it from the other workers - mode = self.mode - if jax.process_index() != 0: - mode = "disabled" - - if isinstance(self.save_code, str): - code_dir = self.save_code - elif self.save_code: - code_dir = WandbConfig._infer_experiment_git_root() or "." # type: ignore - else: - code_dir = None - - other_settings = dict() - if code_dir is not None: - logger.info(f"Setting wandb code_dir to {code_dir}") - other_settings["code_dir"] = code_dir - other_settings["git_root"] = code_dir - # for some reason, wandb isn't populating the git commit, so we do it here - try: - repo = Repo(code_dir) - other_settings["git_commit"] = repo.head.commit.hexsha - hparams_to_save["git_commit"] = repo.head.commit.hexsha - except (NoSuchPathError, InvalidGitRepositoryError): - logger.warning(f"Could not find git repo at {code_dir}") - pass - - r = wandb.init( - entity=self.entity, - project=self.project, - name=self.name, - tags=self.tags, - id=id, - group=self.group, - resume=self.resume, - mode=mode, - config=hparams_to_save, - settings=other_settings, - allow_val_change=True, - ) - - assert r is not None - - if jax.process_count() > 1: - # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things - metadata_to_share = dict( - entity=r.entity, - project=r.project, - name=r.name, - tags=r.tags, - id=r.id, - group=r.group, - ) - metadata_to_share = jax_utils.multihost_broadcast_sync( - metadata_to_share, is_source=jax.process_index() == 0 - ) - - if jax.process_index() != 0: - assert r.mode == "disabled" - for k, v in metadata_to_share.items(): - setattr(r, k, v) - - logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") - - if dataclasses.is_dataclass(hparams): - with tempfile.TemporaryDirectory() as tmpdir: - config_path = os.path.join(tmpdir, "config.yaml") - with open(config_path, "w") as f: - draccus.dump(hparams, f, encoding="utf-8") - if wandb.run is not None: - wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") - - # generate a pip freeze - with tempfile.TemporaryDirectory() as tmpdir: - requirements_path = os.path.join(tmpdir, "requirements.txt") - requirements = _generate_pip_freeze() - with open(requirements_path, "w") as f: - f.write(requirements) - if wandb.run is not None: - wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") - - wandb.summary["num_devices"] = jax.device_count() - wandb.summary["num_hosts"] = jax.process_count() - wandb.summary["backend"] = jax.default_backend() - - @staticmethod - def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: - # sniff out the main directory (since we typically don't run from the root of the repo) - # we'll walk the stack and directories for the files in the stack the until we're at a git root - import os - import traceback - - stack = traceback.extract_stack() - # start from the top of the stack and work our way down since we want to hit the main file first - top_git_root = None - for frame in stack: - dirname = os.path.dirname(frame.filename) - # bit hacky but we want to skip anything that's in the python env - if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): - continue - # see if it's under a git root - try: - repo = Repo(dirname, search_parent_directories=True) - top_git_root = repo.working_dir - break - except (NoSuchPathError, InvalidGitRepositoryError): - logger.debug(f"Skipping {dirname} since it's not a git root") - pass - return top_git_root - - -def _generate_pip_freeze(): - from importlib.metadata import distributions - - dists = distributions() - return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 0b0636f4b..9ee6614ca 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -1,14 +1,13 @@ import logging import os -from dataclasses import dataclass - -import wandb +from dataclasses import dataclass, field import levanter from levanter.data.shard_cache import LoggingMetricsMonitor, RichMetricsMonitor, build_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig -from levanter.logging import init_logger +from levanter.logging import init_logging +from levanter.tracker import NoopConfig, TrackerConfig logger = logging.getLogger(__name__) @@ -16,19 +15,17 @@ @dataclass class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig): - pass + tracker: TrackerConfig = field(default_factory=NoopConfig) @levanter.config.main() def main(args: RayCachedLMDatasetConfig): """Caches two different kinds of datasets. It can cache a dataset from a list of urls, or a dataset from a hf dataset""" - init_logger("cache_dataset.log") + init_logging(".", "cache_dataset.log") args.initialize() tokenizer = args.the_tokenizer - wandb.init(mode="offline") - for split in ["train", "validation"]: print(f"Caching {split} to {args.cache_dir}.") # connect or start the actor diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 6262eb428..ab6d9d6b9 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -41,7 +41,7 @@ class EvalLmConfig: def main(config: EvalLmConfig): - config.trainer.initialize(config) + levanter.initialize(config) tokenizer = config.data.the_tokenizer Batch = Axis("batch", config.trainer.eval_batch_size) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index f5b6e83b4..42c415b75 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,7 +5,6 @@ from typing import Optional, Union import jax.random as jrandom -import wandb import haliax as hax from haliax import Axis @@ -76,39 +75,40 @@ def main(config: TrainLmConfig): else: converter = None - # initialize training config *after* we've done the hf stuff b/c we might have changed the model config - config.trainer.initialize(config) + levanter.initialize(config) - # randomness in jax is tightly controlled by "keys" which are the states of the random number generators - # this makes deterministic training pretty easy - seed = config.trainer.seed - data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) - - # some axes we need - Batch = config.trainer.TrainBatch - EvalBatch = config.trainer.EvalBatch - Pos = config.model.Pos - KeyPos = config.model.KeyPos - - # We have two axis_mappings: one for storing the model and optimizer states, and one for compute - # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh - compute_axis_mapping = config.trainer.compute_axis_mapping - parameter_axis_mapping = config.trainer.parameter_axis_mapping + optimizer = config.optimizer.build(config.trainer.num_train_steps) def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - optimizer = config.optimizer.build(config.trainer.num_train_steps) - # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss) - - eval_datasets = config.data.validation_sets(Pos.size) - train_dataset = CausalLmDataset( - config.data.train_set(Pos.size), Pos, KeyPos, ignore_index=config.data.ignore_token_id - ) + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics tracker + with Trainer(config.trainer, optimizer, compute_loss) as trainer: + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) + + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + compute_axis_mapping = trainer.compute_axis_mapping + parameter_axis_mapping = trainer.parameter_axis_mapping + + # some axes we need + Batch = config.trainer.TrainBatch + EvalBatch = config.trainer.EvalBatch + Pos = config.model.Pos + KeyPos = config.model.KeyPos + + eval_datasets = config.data.validation_sets(Pos.size) + train_dataset = CausalLmDataset( + config.data.train_set(Pos.size), Pos, KeyPos, ignore_index=config.data.ignore_token_id + ) - with trainer.device_mesh: # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. @@ -135,7 +135,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): else: logger.info("No checkpoint found. Starting from scratch.") - wandb.summary["parameter_count"] = parameter_count(state.model) + levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) # boilerplate hooks and such trainer.add_default_hooks() diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 370b20d59..b992cd3f5 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -36,12 +36,11 @@ class VizGpt2Config: def main(config: VizGpt2Config): - config.trainer.initialize(config) + levanter.initialize(config) tokenizer = config.data.the_tokenizer - EvalBatch = Axis("batch", config.trainer.eval_batch_size) - # some axes we use outside the model proper + EvalBatch = config.trainer.EvalBatch Pos = config.model.Pos KeyPos = config.model.KeyPos @@ -53,7 +52,6 @@ def main(config: VizGpt2Config): # some axes we use outside the model proper Pos = config.model.Pos - KeyPos = config.model.KeyPos compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py new file mode 100644 index 000000000..69156c6a6 --- /dev/null +++ b/src/levanter/tracker/__init__.py @@ -0,0 +1,29 @@ +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.tracker import CompositeTracker, NoopConfig, NoopTracker, Tracker, TrackerConfig +from levanter.tracker.tracker_fns import ( + current_tracker, + get_tracker, + jit_log_metrics, + log_configuration, + log_hyperparameters, + log_metrics, + log_summary, + set_global_tracker, +) + + +__all__ = [ + "Tracker", + "TrackerConfig", + "CompositeTracker", + "log_optimizer_hyperparams", + "NoopTracker", + "current_tracker", + "get_tracker", + "jit_log_metrics", + "log_configuration", + "log_metrics", + "log_summary", + "log_hyperparameters", + "set_global_tracker", +] diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py new file mode 100644 index 000000000..1091840c5 --- /dev/null +++ b/src/levanter/tracker/helpers.py @@ -0,0 +1,75 @@ +import dataclasses +import logging +import os +from typing import Optional + +from git import InvalidGitRepositoryError, NoSuchPathError, Repo + +import levanter.tracker +from levanter.utils.jax_utils import jnp_to_python + + +logger = logging.getLogger(__name__) + + +def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): + try: + from optax._src.wrappers import MultiStepsState + + if isinstance(opt_state, MultiStepsState): + opt_state = opt_state.inner_opt_state + except ImportError: + pass + + def wrap_key(key): + if prefix: + return f"{prefix}/{key}" + return key + + if hasattr(opt_state, "hyperparams"): + params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} + levanter.tracker.log_metrics(params, step=step) + + +def hparams_to_dict(hparams, **extra_hparams): + if hparams is None: + hparams_to_save = {} + elif dataclasses.is_dataclass(hparams): + hparams_to_save = dataclasses.asdict(hparams) + else: + hparams_to_save = dict(hparams) + if extra_hparams: + hparams_to_save.update(extra_hparams) + return hparams_to_save + + +def infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: + # sniff out the main directory (since we typically don't run from the root of the repo) + # we'll walk the stack and directories for the files in the stack the until we're at a git root + import os + import traceback + + stack = traceback.extract_stack() + # start from the top of the stack and work our way down since we want to hit the main file first + top_git_root = None + for frame in stack: + dirname = os.path.dirname(frame.filename) + # bit hacky but we want to skip anything that's in the python env + if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): + continue + # see if it's under a git root + try: + repo = Repo(dirname, search_parent_directories=True) + top_git_root = repo.working_dir + break + except (NoSuchPathError, InvalidGitRepositoryError): + logger.debug(f"Skipping {dirname} since it's not a git root") + pass + return top_git_root + + +def generate_pip_freeze(): + from importlib.metadata import distributions + + dists = distributions() + return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py new file mode 100644 index 000000000..bd3ee70ba --- /dev/null +++ b/src/levanter/tracker/tensorboard.py @@ -0,0 +1,81 @@ +import logging +import os +import typing +from dataclasses import dataclass +from typing import Any, Optional + +from levanter.tracker import Tracker, TrackerConfig + + +pylogger = logging.getLogger(__name__) + +if typing.TYPE_CHECKING: + from tensorboardX import SummaryWriter # noqa: F401 + + +class TensorboardTracker(Tracker): + name: str = "tensorboard" + + def __init__(self, writer: "SummaryWriter"): + self.writer = writer + + def log_hyperparameters(self, hparams: dict[str, Any]): + self.writer.add_hparams(hparams, {"dummy": 0}) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + del commit + for k, v in metrics.items(): + self.writer.add_scalar(k, v, step) + + def log_summary(self, metrics: dict[str, Any]): + for k, v in metrics.items(): + self.writer.add_scalar(k, v, global_step=None) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pylogger.error("TensorboardLogger does not support logging artifacts yet") + pass + + +@TrackerConfig.register_subclass("tensorboard") +@dataclass +class TensorboardConfig(TrackerConfig): + logdir: str = "tblogs" + comment: Optional[str] = "" + purge_step: Optional[int] = None + max_queue: Optional[int] = 10 + flush_secs: Optional[int] = 120 + filename_suffix: Optional[str] = "" + write_to_disk: Optional[bool] = True + + def init(self, run_id: Optional[str]) -> TensorboardTracker: + dir_to_write = self.logdir + if run_id is not None: + dir_to_write = os.path.join(dir_to_write, run_id) + + pylogger.info(f"Writing Tensorboard logs to {dir_to_write}") + + from tensorboardX import SummaryWriter # noqa: F811 + + writer = SummaryWriter( + dir_to_write, + comment=self.comment, + purge_step=self.purge_step, + max_queue=self.max_queue, + flush_secs=self.flush_secs, + filename_suffix=self.filename_suffix, + write_to_disk=self.write_to_disk, + ) + + return TensorboardTracker(writer) + + +def _flatten_nested_dict(d): + def items(): + for key, value in d.items(): + if isinstance(value, dict): + for subkey, subvalue in _flatten_nested_dict(value).items(): + yield key + "/" + subkey, subvalue + else: + yield key, value + + return dict(items()) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py new file mode 100644 index 000000000..8b6816f17 --- /dev/null +++ b/src/levanter/tracker/tracker.py @@ -0,0 +1,117 @@ +import abc +import dataclasses +import typing +from typing import Any, List, Optional + +import draccus + + +class Tracker(abc.ABC): + """ + A tracker is responsible for logging metrics, hyperparameters, and artifacts. + Meant to be used with the [levanter.tracker.current_tracker][] context manager, but can also be used directly. + + The name is borrowed from HF Accelerate. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + """ + + name: str + + @abc.abstractmethod + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + @abc.abstractmethod + def log(self, metrics: dict[str, typing.Any], *, step: Optional[int], commit: Optional[bool] = None): + """ + Log metrics to the tracker. Step is always required. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + pass + + @abc.abstractmethod + def log_summary(self, metrics: dict[str, Any]): + pass + + @abc.abstractmethod + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + def __enter__(self): + import levanter.tracker.tracker_fns as tracker_fns + + if hasattr(self, "_tracker_cm"): + raise RuntimeError("This tracker is already set as the global tracker") + setattr(self, "_tracker_cm", tracker_fns.current_tracker(self)) + self._tracker_cm.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if not hasattr(self, "_tracker_cm"): + raise RuntimeError("This tracker is not set as the global tracker") + self._tracker_cm.__exit__(exc_type, exc_val, exc_tb) + delattr(self, "_tracker_cm") + + +class CompositeTracker(Tracker): + def __init__(self, loggers: List[Tracker]): + self.loggers = loggers + + def log_hyperparameters(self, hparams: dict[str, Any]): + for tracker in self.loggers: + tracker.log_hyperparameters(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + for tracker in self.loggers: + tracker.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + for tracker in self.loggers: + tracker.log_summary(metrics) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + for tracker in self.loggers: + tracker.log_artifact(artifact_path, name=name, type=type) + + +class TrackerConfig(draccus.PluginRegistry, abc.ABC): + discover_packages_path = "levanter.tracker" + + @abc.abstractmethod + def init(self, run_id: Optional[str]) -> Tracker: + raise NotImplementedError + + @classmethod + def default_choice_name(cls) -> Optional[str]: + return "wandb" + + +class NoopTracker(Tracker): + name: str = "noop" + + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + def log(self, metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + pass + + def log_summary(self, metrics: dict[str, Any]): + pass + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + +@TrackerConfig.register_subclass("noop") +@dataclasses.dataclass +class NoopConfig(TrackerConfig): + def init(self, run_id: Optional[str]) -> Tracker: + return NoopTracker() diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py new file mode 100644 index 000000000..e3b6a1f71 --- /dev/null +++ b/src/levanter/tracker/tracker_fns.py @@ -0,0 +1,235 @@ +import dataclasses +import logging +import os +import tempfile +import typing +import warnings +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional + +import draccus +import jax + +from levanter.tracker import CompositeTracker, Tracker +from levanter.tracker.helpers import hparams_to_dict +from levanter.tracker.tensorboard import TensorboardTracker +from levanter.tracker.wandb import WandbTracker +from levanter.utils.jax_utils import is_inside_jit + + +logger = logging.getLogger(__name__) + + +_global_tracker: Optional["Tracker"] = None + + +def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): + """ + Log metrics to the global tracker. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + if is_inside_jit(): + # we're inside a jit, so we need to log from the host + if commit: + raise ValueError("Cannot commit from inside jit") + jit_log_metrics(metrics, step=step) + else: + # TODO: do we need to coerce to np here? + _global_tracker.log(metrics, step=step) + + +def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): + try: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log(metrics, step=step, commit=False) + except Exception: + logger.exception("Error logging metrics") + + +def jit_log_metrics(metrics, *, step=None): + """uses jax effect callback to log to wandb from the host""" + jax.debug.callback(_no_throw_log_metrics, metrics, step=step) + + +def log_summary(metrics: dict[str, Any]): + """ + Log summary metrics to the global tracker. + + Args: + metrics: Metrics to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log_summary(metrics) + + +def log_hyperparameters(hparams: dict[str, Any]): + """ + Log hyperparameters to the global tracker. + + Args: + hparams: Hyperparameters to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + _global_tracker.log_hyperparameters(hparams) + + +def log_configuration(hparams: Any, config_name: Optional[str] = None): + """ + Logs a configuration object to the global tracker. If the configuration object is a dataclass, + it is dumped to a yaml file and logged as an artifact. + + Args: + hparams: Hyperparameters to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + hparams_dict = hparams_to_dict(hparams) + _global_tracker.log_hyperparameters(hparams_dict) + + if dataclasses.is_dataclass(hparams): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.yaml") + with open(config_path, "w") as f: + draccus.dump(hparams, f, encoding="utf-8") + name = config_name or "config.yaml" + _global_tracker.log_artifact(config_path, name=name, type="config") + + +def set_global_tracker(tracker: Tracker): + """ + Set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. + + In general, it's preferred to use the context manager returned by `current_tracker` instead of this function + except for once at the beginning of the program. + + Args: + tracker: The tracker to set as the global tracker + force: Whether to force setting the global tracker even if it is already set + + Examples: + >>> from levanter.tracker import set_global_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> set_global_tracker(WandbTracker()) + >>> log_metrics({"foo": 1}, step=0) + """ + global _global_tracker + if _global_tracker is not None: + warnings.warn("Global tracker is already set. Overwriting it.") + _global_tracker = tracker + + +@typing.overload +def current_tracker() -> "Tracker": + ... + + +@typing.overload +def current_tracker(tracker: "Tracker") -> typing.ContextManager: + """Returns a context manager for setting the global tracker""" + ... + + +def current_tracker( + tracker: Optional[Tracker] = None, +) -> Tracker | typing.ContextManager: + """ + Get or set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. + + Args: + tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Returns: + If no tracker is provided, returns the current global tracker. + If a tracker is provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... current_tracker().log({"foo": 2}, step=1) + """ + global _global_tracker + if tracker is None: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + return _global_tracker + else: + return _GlobalLoggerContextManager(tracker) + + +@typing.overload +def get_tracker(name: Literal["wandb"]) -> WandbTracker: + ... + + +@typing.overload +def get_tracker(name: Literal["tensorboard"]) -> TensorboardTracker: + ... + + +@typing.overload +def get_tracker(name: str) -> Tracker: + ... + + +def get_tracker(name: str) -> Tracker: + """ + Lookup a tracker in the current global tracker with the provided name. + + Args: + name: Name of the tracker to lookup + + Returns: + The tracker with the provided name + + Examples: + >>> from levanter.tracker import get_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... get_tracker("wandb").log_metrics({"foo": 2}, step=1) + """ + tracker = current_tracker() + if isinstance(tracker, CompositeTracker): + for t in tracker.loggers: + if t.name == name: + return t + elif tracker.name == name: + return tracker + + raise KeyError(f"Tracker with name {name} not found") + + +class _GlobalLoggerContextManager(AbstractContextManager): + def __init__(self, tracker: "Tracker"): + self.tracker = tracker + + def __enter__(self): + global _global_tracker + self.old_tracker = _global_tracker + _global_tracker = self.tracker + + return self.tracker + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_tracker + _global_tracker = self.old_tracker diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py new file mode 100644 index 000000000..d217ab000 --- /dev/null +++ b/src/levanter/tracker/wandb.py @@ -0,0 +1,199 @@ +import logging +import os +import tempfile +import typing +import warnings +from dataclasses import dataclass +from typing import Any, List, Optional, Union + +import jax +from draccus import field +from git import InvalidGitRepositoryError, NoSuchPathError, Repo + +from levanter.tracker import Tracker +from levanter.tracker.helpers import generate_pip_freeze, infer_experiment_git_root +from levanter.tracker.tracker import TrackerConfig +from levanter.utils import jax_utils + + +if typing.TYPE_CHECKING: + import wandb + import wandb.sdk.lib.disabled + + +logger = logging.getLogger(__name__) + +WandbRun = Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.disabled.RunDisabled"] + + +class WandbTracker(Tracker): + name: str = "wandb" + run: WandbRun + + def __init__(self, run: Optional[WandbRun]): + import wandb + + if run is None: + if wandb.run is None: + logger.warning("Wandb run is not initialized. Initializing a new run.") + runx = wandb.init() + if runx is None: + raise RuntimeError("Wandb run is not initialized.") + self.run = runx + else: + self.run = wandb.run + else: + self.run = run + + def log_hyperparameters(self, hparams: dict[str, Any]): + self.run.config.update(hparams, allow_val_change=True) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + if step is None and not commit: + step = self.run.step + + self.run.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + self.run.summary.update(metrics) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + self.run.log_artifact(artifact_path, name=name, type=type) + + +def is_wandb_available(): + try: + import wandb + except ImportError: + return False + return wandb is not None and wandb.run is not None + + +@TrackerConfig.register_subclass("wandb") +@dataclass +class WandbConfig(TrackerConfig): + """ + Configuration for wandb. + """ + + entity: Optional[str] = None # An entity is a username or team name where you send runs + project: Optional[str] = None # The name of the project where you are sending the enw run. + name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. + tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. + id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project + group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. + mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be whatever W&B decides. + resume: Optional[Union[bool, str]] = None + """ + Set the resume behavior. Options: "allow", "must", "never", "auto" or None. + By default, if the new run has the same ID as a previous run, this run overwrites that data. + Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) + document for more details. + """ + + save_code: Union[bool, str] = True + """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we + typically don't run from the root of the repo).""" + + save_xla_dumps: bool = False + """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" + + def init(self, run_id: Optional[str]) -> WandbTracker: + import wandb + + if run_id is not None and self.id is not None and run_id != self.id: + warnings.warn( + f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" + " config." + ) + + id = self.id + if id is None: + id = run_id + + hparams_to_save = {} + + # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled + # however, we do share information about the run id, so that we can link to it from the other workers + if jax.process_index() == 0: + mode = self.mode + else: + mode = "disabled" + + git_settings = self._git_settings() + + if "git_commit" in git_settings: + hparams_to_save["git_commit"] = git_settings["git_commit"] + + r = wandb.init( + entity=self.entity, + project=self.project, + name=self.name, + tags=self.tags, + id=id, + group=self.group, + resume=self.resume, + mode=mode, + config=hparams_to_save, + settings=git_settings, + allow_val_change=True, + ) + + assert r is not None + + if jax.process_count() > 1: + # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things + metadata_to_share = dict( + entity=r.entity, + project=r.project, + name=r.name, + tags=r.tags, + id=r.id, + group=r.group, + ) + metadata_to_share = jax_utils.multihost_broadcast_sync( + metadata_to_share, is_source=jax.process_index() == 0 + ) + + if jax.process_index() != 0: + assert r.mode == "disabled" + for k, v in metadata_to_share.items(): + setattr(r, k, v) + + logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") + + # generate a pip freeze + with tempfile.TemporaryDirectory() as tmpdir: + requirements_path = os.path.join(tmpdir, "requirements.txt") + requirements = generate_pip_freeze() + with open(requirements_path, "w") as f: + f.write(requirements) + if wandb.run is not None: + wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") + + wandb.summary["num_devices"] = jax.device_count() + wandb.summary["num_hosts"] = jax.process_count() + wandb.summary["backend"] = jax.default_backend() + + return WandbTracker(r) + + def _git_settings(self): + other_settings = dict() + if isinstance(self.save_code, str): + code_dir = self.save_code + elif self.save_code: + code_dir = infer_experiment_git_root() or "." # type: ignore + else: + code_dir = None + if code_dir is not None: + logger.info(f"Setting wandb code_dir to {code_dir}") + other_settings["code_dir"] = code_dir + other_settings["git_root"] = code_dir + # for some reason, wandb isn't populating the git commit, so we do it here + try: + repo = Repo(code_dir) + other_settings["git_commit"] = repo.head.commit.hexsha + except (NoSuchPathError, InvalidGitRepositoryError): + logger.warning(f"Could not find git repo at {code_dir}") + pass + return other_settings diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index d9db8dc91..5577c6406 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -5,16 +5,30 @@ import os import sys import typing +import warnings from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, +) import equinox as eqx import jax import jmp import numpy as np -import wandb from draccus import field from jax import ShapeDtypeStruct from jax.experimental import multihost_utils @@ -28,12 +42,16 @@ from haliax.types import Scalar import levanter.logging +import levanter.tracker +import levanter.tracker.wandb +from levanter import tracker from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched -from levanter.logging import WandbConfig, capture_time +from levanter.logging import capture_time +from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish @@ -112,8 +130,10 @@ class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks + tracker: levanter.tracker.Tracker is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable + _cmanagers: List[typing.ContextManager] = [] def __init__( self, @@ -140,6 +160,8 @@ def __init__( self.optimizer = optimizer self.is_trainable_param = is_trainable + self._cmanagers = [] + @cached_property def loss_fn(self): """ @@ -204,6 +226,34 @@ def TrainBatch(self): def EvalBatch(self): return self.config.EvalBatch + def __enter__(self): + if len(self._cmanagers) > 0: + raise RuntimeError("Trainer is already entered") + + self._cmanagers = [ + # levanter.current_tracker(self.tracker), + self.device_mesh, + hax.axis_mapping(self.parameter_axis_mapping), + ] + + for cmanager in self._cmanagers: + cmanager.__enter__() + + return self + + def __exit__(self, *args): + problems = [] + for cmanager in reversed(self._cmanagers): + try: + cmanager.__exit__(*args) + except Exception as e: + problems.append(e) + + self._cmanagers = [] + + if len(problems) > 0: + raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] + def initial_state( self, training_key: PRNGKeyArray, model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None ) -> TrainerState: @@ -213,7 +263,6 @@ def initial_state( Returns: model, opt_state, key, resume_step """ - if model is not None and model_init is not None: raise ValueError("only one of model and model_init should be specified") elif model is None and model_init is None: @@ -306,8 +355,7 @@ def training_steps( with capture_time() as loading_time: example = next(iter_data) - # TODO: refactor logging - wandb.log({"throughput/loading_time": loading_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) info = self.train_step(state, example) state = info.state @@ -316,7 +364,7 @@ def training_steps( with capture_time() as hook_time: self.run_hooks(info) - wandb.log({"throughput/hook_time": hook_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) yield info @@ -337,10 +385,9 @@ def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) - self.add_hook(callbacks.log_to_wandb, every=1) + self.add_hook(callbacks.log_step_info, every=1) if eval_dataset is not None: self.add_eval_hook(eval_dataset) - self.add_hook(callbacks.wandb_xla_logger(self.config.wandb), every=self.config.steps_per_eval) # engine.add_hook(callbacks.log_memory_usage(), every=1) checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency @@ -409,7 +456,9 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): loss, grads = self._compute_gradients_microbatched(split_loss_fn, trainable_model, batch, **batch_kwargs) - partial_fn = lambda model: split_loss_fn(model, *batch, **batch_kwargs) + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) + + partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) model = eqx.apply_updates(model, updates) @@ -500,16 +549,27 @@ def maybe_load_checkpoint( return None +def _initialize_global_tracker(config, run_id): + if isinstance(config, Sequence): + tracker = levanter.tracker.CompositeTracker([c.init(run_id) for c in config]) + else: + tracker = config.init(run_id) + + levanter.tracker.set_global_tracker(tracker) + + @dataclass class TrainerConfig: seed: int = 0 # random seed mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy - wandb: WandbConfig = field(default_factory=WandbConfig) + wandb: Optional[tracker.wandb.WandbConfig] = None log_dir: Path = Path("logs/") run_base_dir: Path = Path("runs/") id: Optional[str] = None # run id. if None, will be set to a random string + tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.wandb.WandbConfig) + # config related to partitioning batch_axis: Optional[str] = "batch" # Batch axis for data parallel. @@ -557,15 +617,6 @@ class TrainerConfig: # whether or not to shutdown the tpu at exit. If a float, shutdown after that many seconds. True = 5 minutes shutdown_at_exit: Union[bool, float] = False - @property - def run_name(self) -> str: - try: - import wandb - - return wandb.run and (wandb.run.name or wandb.run.id) or "unnamed" - except ImportError: - return "unnamed" - @property def TrainBatch(self): return Axis("batch", self.train_batch_size) @@ -578,7 +629,12 @@ def EvalBatch(self): def microbatch_size(self): return self.per_device_parallelism * self.data_axis_size - def initialize(self, all_config): + def __post_init__(self): + if self.wandb is not None: + warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) + self.tracker = self.wandb + + def initialize(self): """Initializes jax, wandb, logging, setting the run name/id in the process""" self._initialize_jax_config() # Can't do full logging setup until we've initialized jax b/c we use jax for rank id @@ -587,8 +643,8 @@ def initialize(self, all_config): self._validate_and_set_defaults() id = self._maybe_set_id() - levanter.logging.init_logger(f"{self.log_dir}/{id}.log") - self.wandb.init(id, all_config) + levanter.logging.init_logging(self.log_dir, f"{id}.log") + _initialize_global_tracker(self.tracker, id) self.ray.initialize() @@ -668,7 +724,7 @@ def _maybe_set_id(self): # TODO: this doesn't work with wandb sweeps. need to reconcile when we merge if "RUN_ID" in os.environ: self.id = os.environ["RUN_ID"] - elif self.wandb.id is not None: + elif self.wandb is not None and self.wandb.id is not None: self.id = self.wandb.id else: # wandb run ids are 8 characters [a-z0-9], which we'll emulate here @@ -708,5 +764,21 @@ def _validate_and_set_defaults(self): self.per_device_eval_parallelism = self.per_device_parallelism +class AllConfig(Protocol): + trainer: TrainerConfig + + +def initialize(config: TrainerConfig | AllConfig): + """Initializes jax, logging, setting the run name/id in the process. Also initializes tracking and saves config + as hyperparameters and an artifact""" + if isinstance(config, TrainerConfig): + trainer_config = config + else: + trainer_config = config.trainer + + trainer_config.initialize() + levanter.tracker.log_configuration(config) + + def _params_only(t): return eqx.filter(t, is_inexact_arrayish) diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index f1193f4f4..178069f26 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -11,8 +11,8 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig from levanter.models.gpt2 import Gpt2LMHeadModel +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_export_to_hf.py b/tests/test_export_to_hf.py index b50bde9cb..3ce092789 100644 --- a/tests/test_export_to_hf.py +++ b/tests/test_export_to_hf.py @@ -50,8 +50,7 @@ def test_export_lm_to_hf(): export_lm_to_hf.main(config) if has_torch(): - m = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") - print(m) + AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") finally: try: diff --git a/tests/test_logging.py b/tests/test_logging.py index cf99b8c35..ab7cc35f2 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,7 +3,7 @@ import pytest from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from levanter.logging import WandbConfig +from levanter.tracker.helpers import infer_experiment_git_root def test_infer_experiment_git_root(): @@ -13,7 +13,7 @@ def test_infer_experiment_git_root(): except (InvalidGitRepositoryError, NoSuchPathError): pytest.skip("test not running in a git repo") - root = WandbConfig._infer_experiment_git_root() + root = infer_experiment_git_root() # ensure that 1) this is a git root and 2) this source file is underneath assert root is not None diff --git a/tests/test_tracker.py b/tests/test_tracker.py new file mode 100644 index 000000000..15485b83e --- /dev/null +++ b/tests/test_tracker.py @@ -0,0 +1,80 @@ +# NOTE: Do not explicitly import wandb/other trackers here, as this will cause the tests to trivially pass. +import dataclasses +from typing import Tuple + +import pytest +import yaml + +import levanter.tracker +from levanter.tracker import CompositeTracker, TrackerConfig + + +def test_tracker_plugin_stuff_works(): + assert TrackerConfig.get_choice_class("wandb") is not None + with pytest.raises(KeyError): + TrackerConfig.get_choice_class("foo") + + +def test_tracker_plugin_default_works(): + config = """ + tracker: + entity: foo + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig + + import draccus + + tconfig = draccus.decode(ConfigHolder, parsed).tracker + + assert isinstance(tconfig, TrackerConfig.get_choice_class("wandb")) + + assert tconfig.entity == "foo" # type: ignore + + +def test_tracker_plugin_multi_parsing_work(): + config = """ + tracker: + type: noop + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig | Tuple[TrackerConfig, ...] + + import draccus + + from levanter.tracker.tracker import NoopConfig + + assert isinstance(draccus.decode(ConfigHolder, parsed).tracker, NoopConfig) + + config = """ + tracker: + - type: noop + - type: wandb + """ + parsed = yaml.safe_load(config) + decoded = draccus.decode(ConfigHolder, parsed).tracker + assert decoded == (NoopConfig(), TrackerConfig.get_choice_class("wandb")()) + + +def test_get_tracker_by_name(): + wandb_config = TrackerConfig.get_choice_class("wandb") + if wandb_config is None: + pytest.skip("wandb not installed") + + from levanter.tracker import NoopTracker + + wandb1 = wandb_config(mode="disabled").init(None) + tracker = CompositeTracker([wandb1, NoopTracker()]) + + with tracker: + assert levanter.tracker.get_tracker("wandb") is wandb1 + assert levanter.tracker.get_tracker("noop") is not None + + with pytest.raises(KeyError): + levanter.tracker.get_tracker("foo") diff --git a/tests/test_train_lm.py b/tests/test_train_lm.py index 3cd762d8b..f95b27efb 100644 --- a/tests/test_train_lm.py +++ b/tests/test_train_lm.py @@ -8,7 +8,7 @@ import levanter.main.train_lm as train_lm import tiny_test_corpus from levanter.distributed import RayConfig -from levanter.logging import WandbConfig +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 665c98772..29d8f943c 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -11,14 +11,18 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count def setup_module(module): ray_designated_cores = max(1, logical_cpu_core_count()) - ray.init("local", num_cpus=ray_designated_cores) + try: + ray.init("local", num_cpus=ray_designated_cores) + except AssertionError: + # don't get upset if ray is already running + pass def teardown_module(module): From 5d5c30ff8f428dab20169b3d47689d910de2115b Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 9 Feb 2024 14:46:40 -0800 Subject: [PATCH 18/19] missed a few spots --- examples/alpaca/alpaca.py | 2 +- examples/gsm8k-lora/gsm8k_lora.py | 91 ++++++++++++++++--------------- src/levanter/__init__.py | 1 + src/levanter/main/lora_lm.py | 13 +++-- 4 files changed, 58 insertions(+), 49 deletions(-) diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 36a6dd943..a20f357fe 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -194,7 +194,7 @@ def get_prompts(prompt_path) -> dict: def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 5e4927d2f..febfd2013 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -9,7 +9,6 @@ import jax.random as jrandom import numpy as np import transformers -import wandb import haliax as hax @@ -127,7 +126,7 @@ def format_output(ex): def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -169,53 +168,57 @@ def loraize_hf_model(model): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - # end major difference from Alpaca - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + state = trainer.initial_state(training_key, model=model) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } ) - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index a7def0acb..548a113a0 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -5,6 +5,7 @@ import levanter.logging as logging import levanter.models as models import levanter.optim as optim +import levanter.tracker as tracker import levanter.trainer as trainer import levanter.visualization as visualization from levanter.trainer import initialize diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 93d60588a..babe7d2fa 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -4,7 +4,6 @@ from typing import Optional import jax.random as jrandom -import wandb import haliax.random @@ -47,6 +46,7 @@ class LoraLmConfig: def main(config: LoraLmConfig): + levanter.initialize(config) tokenizer = config.data.the_tokenizer converter = HFCheckpointConverter.from_hf(config.initialize_from_hf, trust_remote_code=config.trust_remote_code) @@ -55,7 +55,6 @@ def main(config: LoraLmConfig): converter = converter.replaced(tokenizer=tokenizer) - config.trainer.initialize(config) model_config = converter.default_config # randomness in jax is tightly controlled by "keys" which are the states of the random number generators @@ -96,8 +95,14 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") From 7ba2b39a519b16039f965e2b08f0dbfee3dc9dfb Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 9 Feb 2024 15:11:16 -0800 Subject: [PATCH 19/19] remove old config --- config/gpt2_large_sophia_g.yaml | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 config/gpt2_large_sophia_g.yaml diff --git a/config/gpt2_large_sophia_g.yaml b/config/gpt2_large_sophia_g.yaml deleted file mode 100644 index 53a1d0806..000000000 --- a/config/gpt2_large_sophia_g.yaml +++ /dev/null @@ -1,21 +0,0 @@ -data: !include data/openwebtext_source.yaml -model: - type: gpt2 - hidden_dim: 1280 - num_heads: 20 - num_layers: 36 - seq_len: 1024 - gradient_checkpointing: true - scale_attn_by_inverse_layer_idx: true -trainer: - wandb: - project: "levanter" - tags: [ "openwebtext", "gpt2", "sophia-g"] - - num_train_steps: 200000 - mp: p=f32,c=bfloat16 - -optimizer: - type: sophia-g - learning_rate: 2E-4 - weight_decay: 0.15