diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..9d866e392 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 267d83b51..1ad0ceb3b 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -1,5 +1,5 @@ -#data: -# id: dlwh/wikitext_103_detokenized +data: + id: dlwh/wikitext_103_detokenized model: type: gpt2 hidden_dim: 32 @@ -14,7 +14,7 @@ trainer: - every: 50 save_interval: 5m - per_device_parallelism: 16 + per_device_parallelism: -1 train_batch_size: 32 tensor_parallel_axes: ["mlp", "heads"] diff --git a/config/gpt2_nano_tb.yaml b/config/gpt2_nano_tb.yaml index 9ada16aa3..f6847d693 100644 --- a/config/gpt2_nano_tb.yaml +++ b/config/gpt2_nano_tb.yaml @@ -14,8 +14,7 @@ trainer: - every: 50 save_interval: 5m - per_device_eval_parallelism: 1 - per_device_parallelism: 1 + per_device_parallelism: -1 train_batch_size: 32 tensor_parallel_axes: ["mlp", "heads"] diff --git a/config/gpt2_small.yaml b/config/gpt2_small.yaml index c657fe787..b3e0295af 100644 --- a/config/gpt2_small.yaml +++ b/config/gpt2_small.yaml @@ -14,7 +14,7 @@ trainer: mp: p=f32,c=bfloat16 model_axis_size: 1 - per_device_parallelism: 4 + per_device_parallelism: -1 train_batch_size: 512 optimizer: diff --git a/config/mistral_7b.yaml b/config/mistral_7b.yaml new file mode 100644 index 000000000..f4b3eab79 --- /dev/null +++ b/config/mistral_7b.yaml @@ -0,0 +1,28 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext_llama/" + tokenizer: "mistralai/Mistral-7B-v0.1" +model: + type: mistral +# TODO: uncomment this once we resolve the resource exhaustion issue +# initialize_from_hf: "mistralai/Mistral-7B-v0.1" +# use_hf_model_config: true +trainer: + wandb: + project: "levanter" + tags: ["openwebtext", "mistral"] + + mp: p=f32,c=bfloat16 + train_batch_size: 256 # set for v4-64 TPU + num_train_steps: 1000 + steps_per_eval: 50 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 1.2E-5 # set low for fine-tuning + weight_decay: 0.1 + min_lr_ratio: 0.1 diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index 8336b1eb8..bdb09e4f1 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -13,7 +13,7 @@ class TrainLmConfig: data: LMDatasetConfig = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) ``` Your training run will typically be associated with a single config file. For instance, you might have a file diff --git a/docs/Fine-Tuning.md b/docs/Fine-Tuning.md index 4903e3c1e..0bd3545b0 100644 --- a/docs/Fine-Tuning.md +++ b/docs/Fine-Tuning.md @@ -2,31 +2,34 @@ While Levanter's main focus is pretraining, we can also use it for fine-tuning. As an example, we'll show how to reproduce [Stanford Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html), -using [Levanter](https://github.com/stanford-crfm/levanter) and either Llama 1 or [Llama 2](https://ai.meta.com/llama/). +using [Levanter](https://github.com/stanford-crfm/levanter) and either [Llama 1](https://arxiv.org/abs/2302.13971) or [Llama 2](https://ai.meta.com/llama/) 7B. The script we develop will be designed for Alpaca, defaulting to using its dataset and prompts, but it should work for any single-turn instruction-following task. -This tutorial is meant to cover "full finetuning", where you start with a pretrained model and modify -all of its parameters to fit some final task, rather than something like LoRA (though see our [LoRA tutorial](./LoRA.md) for that). -It also documents how to work with datasets that aren't just single `"text"`s, as we use in pretraining. +This tutorial is meant to cover "full finetuning," where you start with a pretrained model and modify +all of its parameters to fit some final task, rather than something like LoRA that adds a (small) number +of additional parameters. (See our [LoRA tutorial](./LoRA.md) for that.) +It also documents how to work with datasets that aren't just single `"text"`s, which is what we use in pretraining. ## Overview of Alpaca -[Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) is a lightweight fine tune of Llama 1 on a -[dataset of 52000 input/output pairs](https://huggingface.co/datasets/tatsu-lab/alpaca), which +[Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) is a fine tune of Llama 1 on a [dataset of 52000 input/output pairs](https://huggingface.co/datasets/tatsu-lab/alpaca), which were generated by taking [a seed set from self-instruct](https://github.com/yizhongw/self-instruct) and asking `text-davinci-003` to generate more examples. +The original Alpaca script is [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py). + ![Schematic diagram of how the Alpaca model was created](https://crfm.stanford.edu/static/img/posts/2023-03-13-alpaca/alpaca_main.jpg) ### The Foundation Model -Llama 1 is a 7B parameter causal language model trained on 1T tokens from various mostly English sources. It's described -in [the Llama paper](https://arxiv.org/abs/2302.13971). +Llama 1 7B is a ≈7 billion parameter causal language model trained on 1 trillion tokens from various mostly English sources. +It's described in [the Llama 1 paper](https://arxiv.org/abs/2302.13971). [Llama 2](https://ai.meta.com/llama/) is a similar model, +just trained on more data (and with some slight tweaks to the architecture for larger models). ### The Data -More precisely, the dataset is composed of triples of (instruction, input, output), where the instruction is a prompt +The Alpaca dataset is composed of triples of (instruction, input, output), where the instruction is a prompt describing the task. A bit less than 40% of the examples have inputs, and the rest are just the instruction and output. Here are some example inputs, instructions, and outputs: @@ -44,7 +47,8 @@ generated by an LLM after all.) But it's a good example of the kind of data you ### Preprocessing Because Llama is a causal language model, we need to do some preprocessing to turn the pairs/triples into -a single sequence. The usual thing is to interpolate the strings into a prompt. We'll have two prompts, +a single sequence. The usual thing is to interpolate the strings into a prompt that provides +some context/guidance to the LM. We'll have two prompts, depending on whether or not there's an input or just an instruction and output. For example, the first example above would be turned into: @@ -74,10 +78,11 @@ Compute the area of a rectangle with length 10cm and width 5cm. The area of the rectangle is 50 cm2. ``` -From there [original Alpaca script](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py) *masks out the loss* for all tokens before the start of the output. This gets -the model to learn to mimic outputs conditioned on inputs, rather than spending time learning to generate inputs and outputs. +From there, [the original Alpaca script](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py) *masks out the loss* for all tokens before the start of the output. This gets +the model to learn to mimic outputs conditioned on inputs, rather than getting the model to learn the prompts and +inputs along with the outputs. -## Setup +## Running the script Rather than going through the code first, we'll jump straight to running the script. We'll cover the code in the [Code Walkthrough](#code-walkthrough) section below. @@ -86,61 +91,171 @@ Rather than going through the code first, we'll jump straight to running the scr Make sure you go through either the [GPU](./Getting-Started-GPU.md) or [TPU](./Getting-Started-TPU-VM.md) setup, depending on what you want to use. -### Environment Setup +### NVIDIA GPU -#### \[GPU\] Environment Setup +#### Environment Setup Follow the instructions in the [Getting Started with GPUs](./Getting-Started-GPU.md) guide to create a conda environment or virtualenv -and to install JAX with CUDA. +and to install JAX with CUDA. Then, if you haven't already done so, clone the Levanter repository and install it in editable mode: + +```bash +git clone https://github.com/stanford-crfm/levanter.git +cd levanter +pip install -e . +``` + +You'll also want to log into [WANDB](https://wandb.ai/). + +```bash +wandb login +``` + +To use Llama 2, you'll need to request access to the model from [Llama 2's Hugging Face page](https://huggingface.co/meta-llama/Llama-2-7b-hf). +Then, you'll need to log into the Hugging Face CLI: + +```bash +huggingface-cli login +``` + +#### Running the Script + +The example commands below demonstrate how to launch a training job on a node with 8 A100 GPUs, but should work for +other single node GPU configurations. For example, we've also tested Alpaca replication with +a node of 8 RTX 6000 Ada Generation 49.1GB GPUs. (Levanter works best with Ada or later generation NVIDIA GPUs.) + +To replicate Alpaca, you can run the following command: + +```bash +python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca.yaml +``` + +To use Llama 2: + +```bash +python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca-llama2.yaml +``` + +Alternatively: + +```bash +python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca-llama2.yaml --model_name_or_path meta-llama/Llama-2-7b-hf +``` + +!!! warning + + Fine-tuning a 7B parameter model needs **a lot** of accelerator memory: you will need more than 80GB of GPU memory in + aggregate to run this job. Because Levanter makes heavy use of FSDP, you can use several smaller cards. + If you don't have enough memory, you can try reducing the `train_batch_size` or the `per_device_parallelism` in + the config. + + +At some point the run will spit out a WandB link. You can click on that to see the training progress. There's +not a ton to see there (yet), but you can see the training loss go down over time. + +On an 8xA100 box, training should take about ~3.5 hours, similar to the original Alpaca script. +It should take ~8.5 hours on 8 RTX 6000 Ada Generation GPUs. + + +### TPUs -#### \[TPU\] Environment Setup +#### Environment Setup -For TPUs, please follow the instructions in the [Getting Started with TPUs](./Getting-Started-TPU-VM.md) guide to -get started with a TPU VM. Once you have, you can just run the -following command from a source checkout of Levanter: +For TPUs, please follow the instructions in the [Getting Started with TPUs](./Getting-Started-TPU-VM.md). +Once you have, you can run something like this to get a v3-32 TPU VM: ```bash bash infra/spin-up-vm.sh llama-32 -z us-east1-d -t v3-32 --preemptible ``` -### Install Levanter +You might need to change the zone and/or the TPU type depending on what's available. You can also use preemptible +TPUs if you want to save money (or that's what your quota is). Training Alpaca should work on a v3-8, +but we don't have any of those. + +#### Running the Script + +Launching the run on TPU is a bit more complex because you need to specify a lot of paths to GCS buckets. You will +also likely need to run the command on multiple machines, because a v3-32 VM is actually 4 distinct machines, each +controlling 8 TPUs. + +This is what the command looks like: ```bash +export GCS_BASE="gs://" +gcloud compute tpus tpu-vm ssh llama-32 --zone us-east1-d --worker=all \ +--command="WANDB_API_KEY=${YOUR TOKEN HERE} \ +HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} \ +bash levanter/infra/run.sh python \ +levanter/examples/alpaca/alpaca.py \ +--config_path levanter/examples/alpaca/alpaca-llama2.yaml \ +--data_cache_dir ${GCS_BASE}/data \ +--trainer.checkpointer.base_path ${GCS_BASE}/ckpts \ +--hf_save_path ${GCS_BASE}/hf_ckpts +``` + +If you're using preemptible or TRC TPUs, you'll want to add `--trainer.id ` to the command line. +Alternatively, you can use the [babysitting script](./Getting-Started-TPU-VM.md#babysitting-script) to automatically restart the +VM and job if it gets preempted. (It will also set a run id automatically.) That would look like this: ```bash -git clone https://github.com/stanford-crfm/levanter.git -cd levanter -pip install -e . +infra/babysit-tpu-vm.sh llama-32 -z us-east1-d -t v3-32 --preemptible -- \ +WANDB_API_KEY=${YOUR TOKEN HERE} \ +HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} \ +bash levanter/infra/run.sh python \ +levanter/examples/alpaca/alpaca.py \ +--config_path levanter/examples/alpaca/alpaca-llama2.yaml \ +--trainer.checkpointer.base_path gs:// \ +--hf_save_path gs:// \ ``` +You should see a link to the WandB run in the output. You can click on that to see the training progress. +Similar to an 8xA100 box, training should take about ~3.5 hours on a v3-32. + + ## Configuration -We have two configs for Alpaca: one for Llama 1 and one for Llama 2. The only difference is the model id. +That should be all you need to run the script and replicate Alpaca. +However, if you want to customize the script, you can do so by modifying the config. +We have two configs for Alpaca: one for Llama 1 and one for Llama 2. The only difference is the `model_name_or_path` field. -### Config to Replicate Alpaca +### Base Config ```yaml # cf https://github.com/tatsu-lab/stanford_alpaca#fine-tuning data: tatsu-lab/alpaca model_name_or_path: huggyllama/llama-7b trainer: - mp: p=f32,c=bfloat16 + mp: p=f32,c=bfloat16 # Mixed precision training with fp32 parameters/optimizer state and bf16 activations wandb: project: "levanter-alpaca" num_train_steps: 1218 # 128 * 1218 = 155904, which is almost but not quite 3 epochs, which is what alpaca did train_batch_size: 128 - # if using model parallelism, this is useful: - tensor_parallel_axes: ["mlp", "heads"] optimizer: learning_rate: 2e-5 weight_decay: 0.0 +prompts: + # |- means multiline string, keeping all but the final newline + prompt_input: |- + Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Input: + {input} + + ### Response: + prompt_no_input: |- + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Response: ``` This config uses mixed fp32/bf16 precision and sets the number of training steps to be roughly 3 epochs. It sets up the optimizer to use a learning rate of 2e-5 and no weight decay. `trainer.per_device_parallelism` is roughly equivalent to HF's -`per_device_train_batch_size`. If you want to use model parallelism, you can set `trainer.model_axis_size` to something -like 2. (This will split the model across two devices. This might be useful if you're using a v3-64 or something similar and -want to maintain the same batch size.) +`per_device_train_batch_size`. ### Llama 2 Config @@ -150,166 +265,89 @@ If you haven't already, go to [Llama 2's Hugging Face page](https://huggingface. Once you have access, go to [Hugging Face's Tokens page](https://huggingface.co/settings/tokens) to get an API token. You'll need to provide this to the TPU VM as an environment variable. (We'll show you how to do this later.) -### Customizing the Config +### Custom Datasets -If you have your own dataset, you'll want to change the `data` field in the config to point to your dataset. -You'll also want to change the `model_name_or_path` field to point to the model you want to use. -Currently, Levanter supports GPT-2, Llama, MPT, and Backpack checkpoints. - -```yaml -data: # Path to the training data, or huggingface dataset name. -data_cache_dir: -model_name_or_path: "meta-llama/Llama-2-7b-chat-hf" -trainer: - ... -``` - -#### Custom Prompts - -If you want to use your own prompts, you can add them to the config, like this: - -```yaml -prompts: - prompt_input: |- - ### Instruction: {instruction} - ### Input: {input} - ### Output: - prompt_no_input: |- - ### Instruction: {instruction} - ### Output: -``` - -#### Custom Datasets - -The script we develop in this tutorial is designed for Alpaca, but it should work for any single-turn -instruction-following task. For instance, to train a code alpaca model, you could use the following config: +The script in this tutorial is designed for Alpaca, but it should work for any single-turn +instruction-following task. For instance, to train a Code Alpaca model, you could modify the config like this: ```yaml data: lucasmccabe-lmi/CodeAlpaca-20k # a dataset on the Hugging Face hub data_cache_dir: code_alpaca_cache # some path to store the cache ``` -The dataset can also be a JSON path. +The dataset can also be a path to a JSON or JSONL file, or compressed versions of those. -#### \[TPU\] Using a Modified Config +### Custom Models -If you make changes to the config, you'll need to get the config file to all the workers. For TPU, the best way to do this -is to copy it to Google Cloud Storage so that it persists when the machine is preempted. You can do this with: +You can also change the `model_name_or_path` field to point to the model you want to use. This +can be any Hugging Face model, or a path to a local checkpoint. Currently, Levanter supports GPT-2, Llama, MPT, and +Backpack checkpoints. -```bash -gsutil cp examples/alpaca/alpaca.yaml gs:///train-alpaca.yaml +```yaml +model_name_or_path: "meta-llama/Llama-2-7b-chat-hf" ``` -If using Llama 2: +Or on the command line: ```bash -gsutil cp examples/alpaca/alpaca-llama2.yaml gs:///train-alpaca.yaml +python examples/alpaca/alpaca.py --config_path levanter/examples/alpaca/alpaca.yaml --model_name_or_path "meta-llama/Llama-2-7b-chat-hf" ``` -And then using `--config_path gs:///alpaca.yaml` instead of `--config_path levanter/examples/alpaca/train-alpaca.yaml` -in the command line below. Levanter knows how to read from Google Cloud Storage, so you don't need to do anything else. +### Custom Prompts -## Launching the Job +If you want to use your own prompts, you can modify the `prompts` field. By default, the prompts are set to be the same +as the original Alpaca, but you can change them to whatever you want. They are formatted using Python's +[format strings](https://docs.python.org/3/library/string.html#format-string-syntax), meaning you can use `{instruction}` and `{input}`. +You should have two prompts: one for when there's an input and one for when there isn't. For example, here is +a more minimal prompt: -### \[GPU\] Launching the Job - -Right now, Levanter is only known to work with single node GPU training. The example commands below demonstrate how to launch a training job -on a node with 8 A100 GPUs, but should work for other single node GPU configurations. For example, we've also tested Alpaca replication with -a node of RTX 6000 Ada Generation 49.1GB GPUs. - -Before running your training bash command, ensure you are in your `levanter` conda environment, you've created a directory for saving checkpoints -during training, you are logged into your wandb account with the following two commands: - -!!! warning - - Fine-tuning a 7B parameter model needs **a lot** of accelerator memory, you will need more than 80GB of GPU memory in - aggregate to run this job. Because Levanter makes heavy use of FSDP, you can use several smaller cards. - If you don't have enough memory, you can try reducing the `train_batch_size` or the `per_device_parallelism` in - the config. - - - -```bash -conda activate levanter -wandb login ${YOUR TOKEN HERE} +```yaml +prompts: + prompt_input: |- + ### Instruction: {instruction} + ### Input: {input} + ### Output: + prompt_no_input: |- + ### Instruction: {instruction} + ### Output: ``` -Now you can run the training command: -```bash -python examples/alpaca/alpaca/alpaca.py \ ---config_path examples/alpaca/alpaca.yaml \ ---trainer.checkpointer.base_path levanter/checkpoints \ ---hf_save_path levanter/checkpoints -``` +We use YAML's [multiline string syntax](https://yaml-multiline.info/) to make the prompts easier to read. +You can also specify a path to a json file containing the prompts if you'd prefer. -You can change `--trainer.checkpointer.base_path` and `--hf_save_path` to your desired model checkpoint directories. -If you're using Llama 2, you'll need to first request access to the model, and then export your Hugging Face API token: +### \[TPU\] Using a Modified Config + +On a single machine, you can just modify the config and run the script. On TPU, however, you'll need to upload the config +to a Google Cloud Storage bucket so that all the workers can access it. You can do this with: ```bash -export HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} +gsutil cp my-config.yaml gs:///my-config.yaml ``` -or log into the HF CLI with `huggingface-cli login`. +And then using `--config_path gs:///my-config` instead of `--config_path levanter/examples/alpaca/train-alpaca.yaml` +in the command line. Levanter knows how to read from Google Cloud Storage, so you don't need to do anything else. -### \[GPU\] NLP-Group Slurm Cluster Launch Example +### Aside: Running on Slurm -Say you save the above Alpaca training command as a bash script called `train_alpaca.sh`, then +Say you save the above Alpaca training command as a bash script called `train_alpaca.sh`. Then you could launch a training job on a slurm cluster with `srun` as follows: ```bash srun --account=nlp --cpus-per-task=32 --gpus-per-node=8 --mem=400G --open-mode=append --partition=sphinx --nodes=1 --pty bash train_alpaca.sh ``` -### \[TPU\] Launching the Job - -For TPU, we need just a little bit of ceremony to get the Hugging Face and WANDB API tokens in the environment: -(If you're using Llama 1, you don't need the `HUGGING_FACE_HUB_TOKEN` line unless you're using a private model -or uploading to Hugging Face.) - -```bash -gcloud compute tpus tpu-vm ssh llama-32 --zone us-east1-d --worker=all \ ---command="WANDB_API_KEY=${YOUR TOKEN HERE} \ -HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} \ -bash levanter/infra/run.sh python \ -levanter/examples/alpaca/alpaca.py \ ---config_path levanter/examples/alpaca/alpaca.yaml \ ---trainer.checkpointer.base_path gs:// \ ---hf_save_path gs:// -``` - -If you're using preemptible or TRC TPUs, you'll want to add `--trainer.id ` to the command line. -Alternatively, you can use the [babysitting script](./Getting-Started-TPU-VM.md#babysitting-script) to automatically restart the -VM and job if it gets preempted. That would look like this: - -```bash -infra/babysit-tpu-vm.sh llama-32 -z us-east1-d -t v3-32 --preemptible -- \ -WANDB_API_KEY=${YOUR TOKEN HERE} \ -HUGGING_FACE_HUB_TOKEN=${YOUR TOKEN HERE} \ -bash levanter/infra/run.sh python \ -levanter/examples/alpaca/alpaca.py \ ---config_path levanter/examples/alpaca/alpaca-llama2.yaml \ ---trainer.checkpointer.base_path gs:// \ ---hf_save_path gs:// \ -``` - -## Waiting - -At some point the run will spit out a WandB link. You can click on that to see the training progress. There's -not a ton to see there (yet), but you can see the training loss go down over time. - -On a v3-32 or an 8xA100 box, training should take about ~3.5 hours, similarly on a v3-32 TPU VM. -It should take ~8.5 hours on 8 RTX 6000 Ada Generation GPUs. +(This is for the Stanford NLP Cluster. Adjust as necessary for your cluster.) ## Using the Model -When you're done, you can copy out the Hugging Face model with: +When you're done, you can download the Hugging Face model with: ```bash gsutil cp -r gs:////step- ./my-alpaca ``` -The model should work out-of-the-box as a Hugging Face model. You can use it like this: +The model should work out-of-the-box as a Hugging Face model. For a quick test, you can use it like this: ```python from transformers import AutoModelForCausalLM, AutoTokenizer @@ -344,7 +382,7 @@ If you want to just run the script, you can skip to the [Setup](#setup) section. ### Approach Levanter's existing main entry points are designed for "pure" causal language modeling, where you have a single sequence -and don't want to mask out any tokens. So we'll instead write a custom script that does the following: +and don't have any prompts or custom formatting. So we'll instead write a custom script that does the following: * Preprocesses the dataset into a single sequence, interpolating prompts as we go. We'll also construct a `loss_mask` and do any padding. * Loads the model and resizes the vocabulary to match the tokenizer. @@ -359,7 +397,7 @@ if you want more information. The first step is to get the dataset. We'll use the [Hugging Face Dataset version](https://huggingface.co/datasets/tatsu-lab/alpaca) to do this. (You can also download it directly from the [dataset page](https://huggingface.co/datasets/tatsu-lab/alpaca), but -Levanter's integration with Hugging Face datasets makes it a bit easier to use.) +Levanter's integration with Hugging Face datasets means we don't need to do that.) ```python def _get_data_source(path_or_id): @@ -371,8 +409,8 @@ def _get_data_source(path_or_id): return levanter.data.dataset_from_hf(path_or_id, split="train") ``` -Preprocessing in Levanter typically comes in two phases: -* creating the on-disk cache, +Preprocessing in Levanter typically happens in two phases: +* creating an on-disk cache of the "heavy" preprocessing, like tokenization; and * transforming examples from the cache into the examples that the model expects. Here's the first phase, where we create the cache. We basically want to interpolate the prompt with the input @@ -380,19 +418,19 @@ and instructions, and then tokenize the result. We also want to keep track of th can mask out the loss appropriately. ```python -def mk_dataset(data_path_or_id: str, cache_dir: str, tokenizer): - # wrap an HF dataset with Levanter's native dataset class for fancier preprocessing. - # Levanter's native dataset class supports streaming, deteriministic, distributed preprocessing out of the box, - # which is a bit overkill for this dataset, but it's a good example of how to use it. - dataset = _get_data_source(data_path_or_id) +def mk_dataset(config: TrainArgs, tokenizer: transformers.PreTrainedTokenizerBase): + dataset = _get_data_source(config.data) - prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] + prompts = get_prompts(config.prompts) def preprocess(batch): - sources = [ - prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) - for example in batch - ] + def format_example(ex): + if ex.get("input", "") == "": + return prompts["prompt_no_input"].format_map(ex) + else: + return prompts["prompt_input"].format_map(ex) + + sources = [format_example(example) for example in batch] targets = [f"{example['output']}{tokenizer.eos_token}" for example in batch] # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how the original code does it. examples = [s + t for s, t in zip(sources, targets)] @@ -407,16 +445,15 @@ def mk_dataset(data_path_or_id: str, cache_dir: str, tokenizer): } dataset = dataset.map_batches(preprocess, batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer)) - dataset = dataset.build_cache(cache_dir, await_finished=True) + dataset = dataset.build_cache(config.data_cache_dir, await_finished=True) - # SupervisedDataset does last minute padding and masking - dataset = SupervisedDataset(dataset, tokenizer) + dataset = SupervisedDataset(dataset, tokenizer, mask_inputs=config.mask_inputs) return dataset ``` -In the second, we create [levanter.models.lm.LmExample][] objects from the cache. These are the inputs to the model. -`LmExample`s look like this: +`SupervisedDataset` is a class that we'll define later that does the final transformation from the cache to the +`LmExample` objects that the model expects. `LmExample`s look like this: ```python class LmExample(eqx.Module): @@ -425,8 +462,9 @@ class LmExample(eqx.Module): attn_mask: AttentionMask = AttentionMask.causal() ``` -So we need to populate the first two fields. We'll do that with a dataset whose job is to take the cache and turn it into -`LmExample`s. +So we need to populate the first two fields. `tokens` is the input sequence, and `loss_mask` is a boolean mask +that tells the model which tokens to compute the loss for. (We mask out the loss for everything before the start +of the output.) ```python class SupervisedDataset(Dataset[LmExample]): @@ -458,3 +496,12 @@ class SupervisedDataset(Dataset[LmExample]): The rest is boilerplate: setting up the model, optimizer, and trainer, and then running the training loop. We'll skip over that in this tutorial, but you can see the full script [here](https://github.com/stanford-crfm/levanter/blob/main/examples/alpaca/alpaca.py) if you want to see how it works. + + +## Conclusion + +That's it for this tutorial: you should now be able to fine-tune Llama 1 or Llama 2 on Alpaca or any other single-turn +instruction-following task. If you want to learn more about Levanter, check out the [Levanter docs](https://levanter.readthedocs.io/en/latest/) +or the [Levanter repo](https://github.com/stanford-crfm/levanter). For discussion, you can find us on [Discord](https://discord.gg/CKazXcbbBm). + +Let us know what you'd like to see next! diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index c83beddaa..eef14e026 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -114,40 +114,40 @@ def loraize_hf_model(model): } ) - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) - - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, - ) - - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": diff --git a/examples/alpaca/alpaca-llama2.yaml b/examples/alpaca/alpaca-llama2.yaml index 88ea6c917..2527de03f 100644 --- a/examples/alpaca/alpaca-llama2.yaml +++ b/examples/alpaca/alpaca-llama2.yaml @@ -13,3 +13,22 @@ trainer: optimizer: learning_rate: 2e-5 weight_decay: 0.0 +prompts: + # |- means multiline string, keeping all but the final newline + prompt_input: |- + Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Input: + {input} + + ### Response: + prompt_no_input: |- + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Response: diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 3c5a7097e..58d5bc17c 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -36,6 +36,7 @@ # Ways this script could be improved: # * Could tune hparams more for throughput +# Original # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -94,7 +95,7 @@ class TrainArgs: model_cache_dir: Optional[str] = None # Path to cache the model. must be local. - hf_save_path: Optional[str] = None # Path to save the HuggingFace checkpoint. + hf_save_path: Optional[str] = "alpaca_hf_ckpts" # Path to save the HuggingFace checkpoint, can be gcs hf_upload: Union[bool, str] = False # Name of the HuggingFace repo to upload to (if any). hf_save_steps: int = 1000 # How often to save the HuggingFace checkpoint. @@ -135,14 +136,14 @@ def _get_data_source(path_or_id): """The original alpaca.py used a json file, but it's since been moved to the HF dataset hub. You can use any dataset that's compatible with the structure of the alpaca dataset.""" if fsspec_utils.exists(path_or_id): - # get file format: jsonl or json - if path_or_id.endswith(".jsonl"): + # we're a bit generous here b/c we support compression + if ".jsonl" in path_or_id: return JsonlDataset([path_or_id]) - elif path_or_id.endswith(".json"): + elif ".json" in path_or_id: return JsonDataset([path_or_id]) else: raise ValueError( - f"We only support HF Dataset or a data file with .json or .jsonl extensions, not {path_or_id}!" + f"We only support HF Datasets or a data file with .json or .jsonl extensions, not {path_or_id}!" ) else: return WrappedHFDataset(path_or_id, split="train") diff --git a/examples/alpaca/alpaca.yaml b/examples/alpaca/alpaca.yaml index 3eae596fc..fd0e2fb7c 100644 --- a/examples/alpaca/alpaca.yaml +++ b/examples/alpaca/alpaca.yaml @@ -13,3 +13,22 @@ trainer: optimizer: learning_rate: 2e-5 weight_decay: 0.0 +prompts: + # |- means multiline string, keeping all but the final newline + prompt_input: |- + Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Input: + {input} + + ### Response: + prompt_no_input: |- + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction: + {instruction} + + ### Response: diff --git a/pyproject.toml b/pyproject.toml index 342e9bd45..789f35fad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,24 +34,24 @@ dependencies = [ "draccus>=0.7.1", "pyarrow>=11.0.0", "zstandard>=0.20.0", - "datasets==2.11.0", - "gcsfs<2023.10.0", + "datasets==2.16.1", + "gcsfs<2024.3.0", "braceexpand>=0.1.7", "jmp>=0.0.3", - "fsspec<2023.10.0", + "fsspec<2024.3.0", # TODO: minimize and report an issue to tensorstore # causes hangs when serializing to GCS - "tensorstore==0.1.45", + "tensorstore==0.1.53", "pytimeparse>=1.1.8", "humanfriendly==10.0", "safetensors[numpy]", "matplotlib>=3.7.0", - "tblib>=1.7.0,<2.0.0", + "tblib>=1.7.0,<4.0.0", "dataclasses-json", "ray[default]", - "pydantic<2", # temporary pin until Ray supports pydantic 2.0 + "pydantic<3", # temporary pin until Ray supports pydantic 2.0 "rich>=13", -# "chex>=0.1.85" + "filelock", ] [tool.hatch.build] diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 0c462a997..b0244e0e3 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -59,8 +59,8 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n if n > 0: total_loss /= n - logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") - logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") + # logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") + # logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") return total_loss diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index d1076f325..dcbb77270 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -226,7 +226,7 @@ def save_checkpoint(self, info, destination: str): logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}") -def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike, *, exist_ok: bool = False): +def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike): """ Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint will be saved even if a checkpoint already exists at the given path. @@ -242,7 +242,7 @@ def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike, *, exist_ok: fs: AbstractFileSystem fs, plain_path = _get_fs_and_plain_path(checkpoint_path) - fs.makedirs(plain_path, exist_ok=exist_ok) + fs.makedirs(plain_path, exist_ok=True) tree_serialize_leaves_tensorstore(checkpoint_path, tree) save_metadata(checkpoint_path, fs, step) diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f3565665b..c1d24c1a0 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -34,8 +34,8 @@ from .. import logging from ..utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info -from . import ShardableDataset from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch +from .dataset import ShardableDataset from .sharded_dataset import ShardedDataset @@ -49,6 +49,8 @@ DEFAULT_MAX_BYTES_PER_BATCH = 256 * 1024 * 1024 # 256 MB, this is pre-preprocessing python object size LEDGER_FILE_NAME = "cache_ledger.json" +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + def build_cache( cache_dir: str, @@ -298,7 +300,7 @@ def _commit(self): # The difficulty is that we want parallelism and we want to control the order of chunks. # reading batches requires CPU and network. This means we should limit the number to roughly the number of nodes, maybe times 2. -# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from any shard. +# We want to prioritize so that we read 1 chunks worth of batches from each shard before reading more from another shard. # We also want to prioritize reading earlier shards before later shards (within a chunk generation round). # Ray also seems to get upset about having too many processes, and we can't serialize the iterators over shards. @@ -363,7 +365,7 @@ def __le__(self, other: "PriorityWorkItem"): @ray.remote(num_cpus=1, scheduling_strategy="SPREAD") class PriorityProcessorActor: def __init__(self, max_in_flight: Optional[int] = 200): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self._queue: list[PriorityWorkItem] = [] # heapq self._queue_lock = threading.Lock() self._shutdown_event = threading.Event() @@ -387,7 +389,7 @@ def is_group_finished(self, group: PriorityWorkTaskGroupSpec): if self._current_item is not None and self._current_item.spec == group: return False - logger.info(f"Group {group.name} is finished.") + logger.debug(f"Group {group.name} is finished.") return True @@ -404,7 +406,7 @@ def shutdown(self): if self._processing_thread.is_alive(): self._processing_thread.join() - def _loop(self): + def _loop(self: "PriorityProcessorActor"): should_sleep = False backpressure_queue: list[ray.ObjectRef] = [] @@ -444,9 +446,9 @@ def drain_backpressure_to(count): if not item_is_finished: heapq.heappush(self._queue, item) - logger.info("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") + logger.debug("Shutting down PriorityProcessorActor. Waiting for backpressure to drain.") drain_backpressure_to(0) - logger.info("Backpressure drained. Shutting down PriorityProcessorActor.") + logger.debug("Backpressure drained. Shutting down PriorityProcessorActor.") @dataclass @@ -460,6 +462,7 @@ class ShardGroupToBeProcessed(PriorityWorkTaskGroupSpec): processor_actor: ray.actor.ActorHandle # BatchProcessorQueue batch_size: int num_rows_per_chunk: int + group_id: int def build(self) -> "PriorityWorkTaskGroup": return ShardGroupTaskGroup(self) @@ -468,7 +471,7 @@ def build(self) -> "PriorityWorkTaskGroup": class ShardGroupTaskGroup(PriorityWorkTaskGroup): def __init__(self, spec: ShardGroupToBeProcessed): self.spec = spec - self.logger = pylogging.getLogger(f"shard_reader.{self.spec.name}") + self.logger = pylogging.getLogger(f"shard_reader.{spec.group_id}.{spec.name}") try: metadata: dict[str, ShardMetadata] = _initial_shard_metadatas( @@ -544,6 +547,8 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: total_chunk_rows = 0 # the total number of rows in the chunk batch_result_ref = None + self.group.logger.debug(f"Reading one chunk of shard {self.shard_name}: {self.chunk_idx}") + try: while not chunk_filled: batch = next(self.reader, None) @@ -556,14 +561,20 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if batch: priority = self.spec.priority_fn(self.shard_idx, self.chunk_idx) + # these times aren't exact because the times might be from different machines + # but they're just for logging + time_in = time.time() batch_result_ref = ray.get( - self.spec.processor_actor.submit.remote(priority=priority, batch=RefBox(ray.put(batch))) + self.spec.processor_actor.submit.remote( + priority=priority, + desc=f"{self.shard_name}.{self.chunk_idx}.{chunk_batch_idx}", + batch=RefBox(ray.put(batch)), + ) ) writer.chunk_batch_finished.remote( - self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref) + self.shard_name, self.chunk_idx, chunk_batch_idx, RefBox(batch_result_ref), time_in ) chunk_batch_idx += 1 - # enqueue_to_backpressure(batch, batch_result_ref) del batch if total_chunk_rows >= self.spec.num_rows_per_chunk or exhausted_shard: @@ -578,7 +589,9 @@ def execute(self) -> tuple[bool, Optional[ray.ObjectRef]]: if exhausted_shard: writer.shard_finished_reading.remote(self.shard_name, self.chunk_idx) - logger.debug(f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}") + self.group.logger.debug( + f"Finished reading one chunk of shard {self.shard_name}: {self.chunk_idx} {exhausted_shard}" + ) return exhausted_shard, batch_result_ref except Exception as e: # noqa @@ -769,14 +782,49 @@ def is_finished_and_buffer_empty(self): return self.expected_num_chunks is not None and self.num_chunks_sent >= self.expected_num_chunks +class WaitTimeReportingThread(threading.Thread): + def __init__(self, report, interval=60): + super().__init__() + self.report = report + self.interval = interval + self.shutdown_event = threading.Event() + + def run(self): + total_waited = 0 + while True: + if self.shutdown_event.wait(self.interval): + break + if total_waited > 0: + self.report(total_waited) + total_waited += self.interval + + def shutdown(self): + self.shutdown_event.set() + + def _mk_queue_aware_process_task(processor: BatchProcessor[T], queue: ActorHandle): @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) - def process_task(batch: List[T]) -> pa.RecordBatch: - pylogging.basicConfig(level=pylogging.INFO) + def process_task(desc, batch: List[T]) -> pa.RecordBatch: + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) + logger.debug(f"Processing batch {desc}") queue.task_running.remote() - result = processor(batch) - del batch - return as_record_batch(result) + # timer_thread = WaitTimeReportingThread( + # lambda t: logger.info(f"Waiting for {desc} to be processed for {t} seconds"), interval=30 + # ) + # timer_thread.start() + try: + result = processor(batch) + del batch + result = as_record_batch(result) + logger.debug(f"Finished processing batch {desc}") + return result + except Exception as e: + logger.exception(f"Error while processing batch {desc}") + raise e + finally: + # timer_thread.shutdown() + # timer_thread.join() + pass return process_task @@ -784,6 +832,7 @@ def process_task(batch: List[T]) -> pa.RecordBatch: @dataclass(order=True, frozen=True) class _QueueItem: priority: float + desc: str batch: ray.ObjectRef = dataclasses.field(compare=False) task_id: int task_future: asyncio.Future = dataclasses.field(compare=False) @@ -818,13 +867,13 @@ def __init__(self, batch_processor: BatchProcessor[T]): # we don't need/want to dereference the batch, so we wrap it in a RefBox # one virtue of doing things this way is that we can let Ray try to schedule the compute near the data. - async def submit(self, priority: float, batch: RefBox): + async def submit(self, priority: float, desc: str, batch: RefBox): """Returns a future that is set to the *ObjectRef* of the processed batch. The future is "complete" when the task starts, not when it finishes. You then call ray.get on the future's result to get the actual batch.""" task_id = self._next_task_id self._next_task_id += 1 f: asyncio.Future = asyncio.Future() - self.pqueue.put(_QueueItem(priority, batch.ref, task_id, f)) + self.pqueue.put(_QueueItem(priority, desc, batch.ref, task_id, f)) self._maybe_start_task() return await f @@ -833,7 +882,7 @@ def _maybe_start_task(self): self.ready = False item = self.pqueue.get() batch = item.batch - item.task_future.set_result(self._task_processor.remote(batch)) + item.task_future.set_result(self._task_processor.remote(item.desc, batch)) def task_running(self): self.ready = True @@ -845,7 +894,7 @@ def task_running(self): @ray.remote(num_cpus=0.0, scheduling_strategy="SPREAD") # type: ignore class _GroupShardWriterWorker: def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.cache_dir = cache_dir self.shard_names = shard_names self.shard_writers: dict[str, _ShardWriterWorker] = { @@ -855,10 +904,40 @@ def __init__(self, parent_ref, cache_dir: str, shard_names: Sequence[str]): def current_metadata(self, shard_name: str): return self.shard_writers[shard_name].current_metadata() - async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox): + async def chunk_batch_finished(self, shard_name: str, chunk_id: int, batch_idx: int, batch: RefBox, time_in): # batch is a pa.RecordBatch ref box try: - batch = await batch.ref + time_mid = time.time() + logger.debug( + f"Received in progress batch {batch_idx} of chunk {chunk_id} of shard {shard_name} in" + f" {time_mid - time_in}" + ) + # do a backoff loop until the batch is actually processed. log if it's been a while + timeout_interval = 20 + total_time_waited = 0 + + while True: + try: + # batch = await asyncio.wait_for(asyncio.shield(batch.ref), timeout_interval) + batch = await batch.ref + break + except asyncio.TimeoutError: + # to keep to round numbers, we log how much we asked for rather than how much we got + total_time_waited += timeout_interval + timeout_interval = min(2 * timeout_interval, 100) + logger.info( + f"Waiting for {shard_name}.{chunk_id}.{batch_idx} to be processed. " + f"Waited {total_time_waited} seconds." + ) + + if logger.isEnabledFor(pylogging.DEBUG): + logger.debug( + f"Received finished {shard_name}.{chunk_id}.{batch_idx} in {(time.time() - time_in):.2f} seconds." + ) + elif total_time_waited > 40: + logger.info( + f"Waited {total_time_waited} seconds for {shard_name}.{chunk_id}.{batch_idx} to be processed." + ) return self.shard_writers[shard_name].chunk_batch_finished(chunk_id, batch_idx, batch) except Exception as e: print(f"Error while processing batch {batch_idx} of chunk {chunk_id} of shard {shard_name}", flush=True) @@ -889,7 +968,7 @@ def __init__( cache_dir: str, shard_name: str, ): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.parent_ref = parent_ref self.cache_dir = cache_dir self.shard_name = shard_name @@ -934,13 +1013,9 @@ def chunk_failed(self, chunk_id: int, error: ExceptionInfo): self.parent_ref.shard_failed.remote(self.shard_name, error) def _finished_chunk(self, idx: int, chunk: ChunkMetadata): - if idx < self.metadata_writer.num_chunks: - logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") - error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") - self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) - raise error - - if self._expected_num_chunks is not None and idx >= self._expected_num_chunks: + if (idx < self.metadata_writer.num_chunks) or ( + self._expected_num_chunks is not None and idx >= self._expected_num_chunks + ): logger.error(f"Received chunk {idx} for {self.shard_name} but it's already finished") error = RuntimeError(f"Received chunk {idx} for {self.shard_name} but it's already finished") self.parent_ref.shard_failed.remote(self.shard_name, ser_exc_info(error)) @@ -961,6 +1036,8 @@ def _attempt_to_commit_chunks(self): chunks_committed = [] while len(self.uncommited_chunks) > 0 and self.uncommited_chunks[0][0] == self.metadata_writer.num_chunks: _, chunk = heapq.heappop(self.uncommited_chunks) + chunk_number = self.metadata_writer.num_chunks + logger.debug(f"Committing chunk {chunk.name} of shard {self.shard_name}. It is chunk {chunk_number}") self.metadata_writer.commit_chunk(chunk) chunks_committed.append(chunk) @@ -997,6 +1074,7 @@ def __init__(self, cache_dir: str, shard_name: str): self.chunk_writers: dict[int, _ChunkWriter] = {} # chunk index -> writer self.batch_counts: dict[int, int] = {} # chunk index -> number of batches written self.expected_totals: dict[int, int] = {} # chunk index -> expected num batches. + self.failed_chunks: dict[int, ExceptionInfo] = {} # chunk index -> error self.chunk_partial_batches: dict[ int, list[tuple[int, pa.RecordBatch]] ] = {} # chunk index -> heapq of (batch index, batch) @@ -1015,21 +1093,29 @@ def chunk_finished_reading(self, chunk_id, expected_num_batches) -> Optional[Chu return self._attempt_to_write_chunk_fragments(chunk_id) def chunk_failed(self, chunk_id, error: ExceptionInfo): + self.failed_chunks[chunk_id] = error if chunk_id in self.chunk_writers: self.chunk_writers[chunk_id].__exit__(*error.restore()) del self.chunk_writers[chunk_id] def _attempt_to_write_chunk_fragments(self, chunk_id) -> Optional[ChunkMetadata]: + if chunk_id in self.failed_chunks: + logger.error(f"Chunk {chunk_id} of shard {self.shard_name} already failed, not writing more") + raise self.failed_chunks[chunk_id].restore() if chunk_id in self.chunk_partial_batches: chunk_batches = self.chunk_partial_batches[chunk_id] - while len(chunk_batches) > 0 and chunk_batches[0][0] == self.batch_counts[chunk_id]: + while len(chunk_batches) > 0: + batch_id, batch = chunk_batches[0] + if batch_id != self.batch_counts[chunk_id]: + break + # we can write this batch batch_id, batch = heapq.heappop(chunk_batches) if chunk_id not in self.chunk_writers: - assert batch_id == 0 + assert batch_id == 0, f"Expected batch 0 but got {batch_id}" chunk_name = os.path.join(self.shard_name, f"chunk-{chunk_id}") writer = _ChunkWriter(self.cache_dir, chunk_name, batch.schema) writer.__enter__() @@ -1074,7 +1160,7 @@ def __init__( processor: BatchProcessor[T], rows_per_chunk: int, ): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.logger = pylogging.getLogger(f"{__name__}.{name}") self.broker_ref = broker_ref self.shard_status: Dict[str, _ShardStatus] = dict() @@ -1095,20 +1181,25 @@ def __init__( self._processor_actors = [] for shard_name in source.shard_names: + self._current_round_robin.append(shard_name) self.shard_status[shard_name] = _ShardStatus() num_shards = len(source.shard_names) + num_worker_groups = len(ray.nodes()) + num_shard_groups = max(min(num_worker_groups, num_shards), 1) - def priority_fn(shard_idx, chunk_idx): - return chunk_idx * num_shards + shard_idx - - num_shard_groups = max(min(len(ray.nodes()), num_shards), 1) + # if we have a bunch of caches to build with one shard, we don't want them all + # assigned to the same node, so we use an offset based on the hash of the name (for stability) + # in an attempt to spread them out + group_offset = int(hash(name) % num_worker_groups) shard_groups: list[list[str]] = [[] for _ in range(num_shard_groups)] for i, shard_name in enumerate(source.shard_names): - self._current_round_robin.append(shard_name) shard_groups[i % num_shard_groups].append(shard_name) + def priority_fn(shard_idx, chunk_idx): + return chunk_idx * num_shards + shard_idx + for group_id, shard_group in enumerate(shard_groups): writer = _GroupShardWriterWorker.remote(self_ref, cache_dir, shard_group) # type: ignore self._shard_writers.append(writer) @@ -1127,10 +1218,12 @@ def priority_fn(shard_idx, chunk_idx): processor_actor=processor_actor, batch_size=processor.batch_size, num_rows_per_chunk=rows_per_chunk, + group_id=group_id, ) # we want global names so that different tasks can coordinate priorities - priority_actor_name = f"priority_processor.{group_id}" + worker_to_assign = (group_id + group_offset) % num_worker_groups + priority_actor_name = f"priority_processor.{worker_to_assign}" reader_actor = PriorityProcessorActor.options( # type: ignore name=priority_actor_name, get_if_exists=True @@ -1138,17 +1231,6 @@ def priority_fn(shard_idx, chunk_idx): ray.get(reader_actor.add_work_group.remote(work_item)) - # reader = _alternating_shard_reader.remote( - # name, - # self_ref, - # writer, - # source, - # shard_group, - # priority_fn, - # processor_actor, - # processor.batch_size, - # rows_per_chunk, - # ) self._shard_readers.append(reader_actor) def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): @@ -1228,14 +1310,23 @@ def _attempt_to_flush_buffers(self): next_chunk = status.pop_chunk_to_send() if next_chunk is not None: # we can send a chunk from this shard - logger.debug(f"Sending chunk from {name}") self._current_round_robin.pop(0) self._current_round_robin.append(name) chunks_to_send.append(next_chunk) continue else: - logger.debug(f"Shard {name} has no chunks to send and is not known to be finished") # we can't send a chunk from this shard, so we can't send any additional chunks + if self.logger.level <= pylogging.DEBUG: + chunks_waiting = [ + f"{n2} ({len(s2.current_buffer)})" + for n2, s2 in self.shard_status.items() + if len(s2.current_buffer) > 0 + ] + msg = ( + f"Shard {name} has no chunks to send and is not known to be finished. We have this many queued" + f" chunks: {chunks_waiting}" + ) + self.logger.debug(msg) break if len(chunks_to_send) > 0: @@ -1257,7 +1348,7 @@ class ChunkCacheBroker: _finished_promise: asyncio.Future[None] def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchProcessor[T], rows_per_chunk: int): - pylogging.basicConfig(level=pylogging.INFO) + pylogging.basicConfig(level=pylogging.INFO, format=LOG_FORMAT) self.chunks = [] self._reader_promises = {} self._is_finished = False @@ -1269,6 +1360,9 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr # used to subscribe to metrics updates self._latest_metrics = InProgressCacheMetrics() self._metrics_condition = asyncio.Condition() + path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) + name = f"broker::{path_for_name}" + self.logger = pylogging.getLogger(f"{name}") # initialize writer task # first see if we need to do anything: check the ledger for is_finished @@ -1280,7 +1374,6 @@ def __init__(self, cache_dir: str, source: ShardedDataset[T], processor: BatchPr except FileNotFoundError: self_ref = ray.runtime_context.get_runtime_context().current_actor # only use the last two components of the name since it gets kind of long - path_for_name = os.path.join(*self._cache_dir.split("/")[-2:]) name = f"builder::{path_for_name}" self._builder_actor = ChunkCacheBuilder.remote(self_ref, self._cache_dir, name, self._source, self._processor, self._rows_per_chunk) # type: ignore @@ -1307,7 +1400,6 @@ async def get_chunk(self, chunk_idx: int) -> Optional[ChunkMetadata]: elif self._is_finished: return None else: - # we don't have this chunk yet, so we need to wait if chunk_idx not in self._reader_promises: self._reader_promises[chunk_idx] = asyncio.Future() return await self._reader_promises[chunk_idx] @@ -1322,7 +1414,9 @@ def _append_chunks(self, *chunks: ChunkMetadata): for chunk in chunks: self.chunks.append(chunk) chunk_idx = len(self.chunks) - 1 + self.logger.debug(f"Received chunk {chunk_idx}") if chunk_idx in self._reader_promises: + self.logger.debug(f"Resolving promise for chunk {chunk_idx}") self._reader_promises[chunk_idx].set_result(chunk) del self._reader_promises[chunk_idx] @@ -1359,7 +1453,9 @@ def _finalize(self): _serialize_json_and_commit(os.path.join(self._cache_dir, LEDGER_FILE_NAME), CacheLedger(self.chunks)) self._reader_promises = {} - self._builder_actor = None + # TODO: For some reason this crashes other actors with weird reference counting assertion errors. + # pretty sure it's a ray bug + # self._builder_actor = None self._finished_promise.set_result(None) # notify metrics subscribers @@ -1528,17 +1624,19 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N else: assert self._broker is not None time_in = time.time() + next_time = time_in # we want to also log if we're waiting for a long time, so we do this in a loop - while timeout is None or time.time() - time_in < timeout: + while timeout is None or next_time - time_in < timeout: current_timeout = 20.0 if timeout is not None: - current_timeout = min(current_timeout, timeout - (time.time() - time_in)) + current_timeout = min(current_timeout, timeout - (next_time - time_in)) try: chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout) except GetTimeoutError: - self.logger.warning(f"Waiting for chunk {mapped_index} for {int(time.time() - time_in)} seconds") + self.logger.warning(f"Waiting for chunk {mapped_index} for {int(next_time - time_in)} seconds") + next_time = time.time() current_timeout *= 2 - current_timeout = min(current_timeout, 80) + current_timeout = min(current_timeout, 100) continue if chunk is None: diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index 1ceae6366..0ec178e08 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -171,7 +171,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources([int(shard_name)])) + shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) else: shard = dataset diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 212318433..00e17eb58 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -11,12 +11,15 @@ import braceexpand import datasets +import equinox as eqx import fsspec import jax import numpy as np import pyarrow as pa +import regex from draccus import field from jaxtyping import PRNGKeyArray +from tokenizers import normalizers import haliax as hax from haliax import Axis @@ -91,29 +94,32 @@ def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": def __iter__(self) -> Iterator[LmExample]: key = self.key - for tokens in self.dataset: - with use_cpu_device(): - example = self._create_lm_example(tokens, key) - yield example + sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) + + with use_cpu_device(): - @functools.partial(jax.jit, static_argnums=(0)) - def _create_lm_example(self, tokens, key): - tokens = hax.named(tokens, self.QPos) + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _create_lm_example(tokens, key): + tokens = hax.named(tokens, self.QPos) - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 - assert self.key is not None - this_key, key = jax.random.split(key) - fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) - attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) - example = dataclasses.replace(example, attn_mask=attn_mask) + if self.fcm_prob > 0: + # masks for attention + # We support forgetful causal masking (FCM) which is a technique that improves training speed by + # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention + # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 + assert self.key is not None + this_key, key = jax.random.split(key) + fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) + attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) + example = dataclasses.replace(example, attn_mask=attn_mask) - return example + return example + + for tokens in self.dataset: + example = _create_lm_example(tokens, key) + yield example class TokenSeqDataset(ShardableDataset[np.ndarray]): @@ -306,13 +312,27 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): os.environ["TOKENIZERS_PARALLELISM"] = "true" +LONG_STRING_WORKAROUND = 100_000 + + +ws = regex.compile(r"\s") + + class BatchTokenizer(BatchProcessor[str]): """ A batch processor that tokenizes a batch of strings using a tokenizer. By default, this will append eos to the end of the string, even if the tokenizer doesn't. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, enforce_eos=True, override_resources=None): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + enforce_eos=True, + *, + batch_size=128, + override_resources=None, + _workaround_len=LONG_STRING_WORKAROUND, + ): _maybe_force_tokenizer_parallelism(tokenizer) self.tokenizer = tokenizer self.override_resources = override_resources @@ -326,17 +346,104 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, enforce_eos=True, overrid else: should_append_eos = False + self._batch_size = batch_size + self._need_to_add_eos = should_append_eos + self._workaround_len = _workaround_len def __call__(self, batch: Sequence[str]) -> BatchEncoding: + orig_lengths = [len(d) for d in batch] if self._need_to_add_eos: - encoding = self.tokenizer( - [d + " " + self.tokenizer.eos_token for d in batch], return_attention_mask=False, verbose=False - ) + batch = [d + " " + self.tokenizer.eos_token for d in batch] + + if self._needs_long_sequence_workaround: + # break any strings that are longer than 50K characters into smaller chunks + orig_batch = batch + batch = [] + needs_merge = [] + for i, d in enumerate(orig_batch): + needs_merge.append(False) + orig_len = orig_lengths[i] + while len(d) > self._workaround_len: + # we'd rather break strings at whitespace, so find the first whitespace + match = ws.search(d, self._workaround_len) + # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit + if match is None: + split = len(d) + else: + split = match.start() + + batch.append(d[:split]) + needs_merge.append(True) + + d = d[split:] + orig_len -= split + + batch.append(d) else: - encoding = self.tokenizer(batch, return_attention_mask=False, verbose=False) # type: ignore + needs_merge = [] + + encoding = self.tokenizer(batch, return_attention_mask=False, verbose=False) # type: ignore + + if needs_merge: + new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) + encoding = BatchEncoding(new_encoding) + return encoding + @staticmethod + def _merge_split_encodings(batch, encoding, needs_merge): + # merge the encodings back together + # we might need to merge multiple encodings together + # needs merge marks the first n-1 encodings that need to be merged for each document + new_encoding = {} + for k, v in encoding.items(): + if len(v) == 0: + continue + if isinstance(v[0], np.ndarray): + assert len(v) == len(batch) + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + v_out.append(np.concatenate(vs_to_merge)) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(np.concatenate(vs_to_merge)) + + new_encoding[k] = v_out + elif isinstance(v[0], list): + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + new_encoding[k] = v_out + else: + raise ValueError(f"Unknown type {type(v[0])}") + return new_encoding + + # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1449 + @cached_property + def _needs_long_sequence_workaround(self): + if isinstance(self.tokenizer, PreTrainedTokenizerFast): + normalizer = self.tokenizer.backend_tokenizer.normalizer + if normalizer is None: + return False + # if there's a "Replace" normalizer, then we need to do the workaround + # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it + return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) + else: + return False + @property def num_cpus(self) -> int: if self.override_resources is not None: @@ -353,7 +460,7 @@ def num_gpus(self) -> int: @property def batch_size(self) -> int: - return 1024 + return self._batch_size def concatenate_and_group_texts( diff --git a/src/levanter/distributed.py b/src/levanter/distributed.py index f4a86f8b9..eefb71fc4 100644 --- a/src/levanter/distributed.py +++ b/src/levanter/distributed.py @@ -224,33 +224,43 @@ def _munge_address_port(address: str): # this is no longer the case, so instead we need to check if we are the coordinator # and if so, start the head - if _is_this_machine(host): - logger.info(f"Starting ray head on port {ray_port}. We are process the coordinator {host}.") - logger.info(f"Starting ray with num_cpus set to {num_cpus}.") - ret = os.system( - f"ray start --head --port {ray_port} --num-cpus {num_cpus} --dashboard-host=0.0.0.0" - ) - if ret != 0: - raise RuntimeError(f"Failed to start ray head with exit code {ret}") - else: - logger.info(f"Successfully started ray head on port {ray_port}.") - - # install an atexit handler to kill the head when we exit - atexit.register(lambda: os.system("ray stop -g 10 --force")) - elif start_workers: - logger.info( - f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}." - ) - logger.info(f"Starting ray with num_cpus set to {num_cpus}.") - ret = os.system(f"ray start --address {address} --num-cpus {num_cpus}") - if ret != 0: - raise RuntimeError(f"Failed to start ray head with exit code {ret}") - else: - logger.info(f"Successfully started ray worker and connected to {address}.") + if _is_local_leader(): + if _is_this_machine(host): + logger.info(f"Starting ray head on port {ray_port}. We are process the coordinator {host}.") + logger.info(f"Starting ray head with num_cpus set to {num_cpus}.") + ret = os.system( + f"ray start --head --port {ray_port} --num-cpus {num_cpus} --dashboard-host=0.0.0.0" + ) + if ret != 0: + raise RuntimeError(f"Failed to start ray head with exit code {ret}") + else: + logger.info(f"Successfully started ray head on port {ray_port}.") + + # install an atexit handler to kill the head when we exit + atexit.register(lambda: os.system("ray stop -g 10 --force")) + elif start_workers: + logger.info( + f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}." + ) + logger.info(f"Starting ray worker with num_cpus set to {num_cpus}.") + ret = os.system(f"ray start --address {address} --num-cpus {num_cpus}") + if ret != 0: + raise RuntimeError(f"Failed to start ray head with exit code {ret}") + else: + logger.info(f"Successfully started ray worker and connected to {address}.") logger.info(f"ray.init(address={repr(address)}, namespace={repr(namespace)}, **{repr(kwargs)})") - # Ray has retry logic, so we don't need to retry here :fingers-crossed: - ray.init(address=address, namespace=namespace, **kwargs) + # Ray has retry logic, but it doesn't seem to work super well, so we retry manually + for i in range(0, 5): + try: + ray.init(address=address, namespace=namespace, **kwargs) + break + except Exception as e: + if i == 4: + raise e + else: + logger.warning(f"Failed to initialize ray with address {address}. Retrying...") + continue atexit.register(lambda: ray.shutdown()) _already_initialized = True @@ -318,6 +328,9 @@ def _is_this_machine(host): """ Checks if the given host identifies this machine. """ + if host == "localhost" or host == "0.0.0.0": + return True + try: # Get IP addresses of all interfaces machine_ips = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] @@ -330,3 +343,45 @@ def _is_this_machine(host): # Check if the host IP matches any of the machine IPs return any(host_ip == machine_ip for machine_ip in machine_ips) + + +def _remove_if_possible(path): + try: + os.remove(path) + except OSError: + pass + + +def _touch(file_path): + with open(file_path, "a"): + os.utime(file_path, None) + + +def _is_local_leader(): + import atexit + + import filelock + from jax.experimental.multihost_utils import broadcast_one_to_all + + if jax.process_count() == 1: + return True + + import random + + random_id = random.randint(0, 1000000) + random_id = broadcast_one_to_all(random_id) + + lock = filelock.FileLock(f"/tmp/levanter_local_process_zero_lock.{random_id}") + action_performed_file = f"/tmp/levanter_local_process_zero_action_performed.{random_id}" + + try: + with lock.acquire(timeout=0.1): + if not os.path.exists(action_performed_file): + _touch(action_performed_file) + return True # Action needs to be performed + else: + return False # Action already performed + atexit.register(_remove_if_possible, lock.lock_file) + atexit.register(_remove_if_possible, action_performed_file) + except filelock.Timeout: + return False diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 39258665a..db6f9508e 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -24,11 +24,13 @@ class ReductionType(enum.Enum): # TODO: add MAX? +# TODO: should we use a custom_jvp on microbatched? + # cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 def microbatched( fn: Callable[Args, R], Batch: Axis, - per_device_parallelism: int, + microbatch_size: int, accum_axis_mapping, compute_axis_mapping, patch_in_rng_key: Optional[str] = "key", @@ -39,6 +41,13 @@ def microbatched( Wraps a function that takes a batch and changes it to instead take microbatches and accumulate the results This function has to reduce the batch axis, so it can't be used for functions that need to keep the batch axis. + Can be used as a decorator with functools.partial, e.g.: + + >>> @functools.partial(microbatched, Batch=Batch, per_device_parallelism=4) + >>> def my_fn(x): + >>> return hax.mean(x + 1) + + Args: fn: a function to wrap Batch: the batch axis @@ -61,11 +70,14 @@ def microbatched( physical_axis_name = hax.partitioning.physical_axis_name(Batch, compute_axis_mapping) assert physical_axis_name is not None - if per_device_parallelism < 0: - raise ValueError(f"Bad value for {per_device_parallelism=}") + if microbatch_size <= 0: + raise ValueError(f"Bad value for {microbatch_size=}") - microbatch_size = data_axis_size * per_device_parallelism num_micro_steps = batch_size // microbatch_size + + if num_micro_steps == 1: + return fn + Microbatch = Batch.resize(microbatch_size) AccumStep = Axis("accum_step", num_micro_steps) assert num_micro_steps * microbatch_size == batch_size @@ -122,7 +134,7 @@ def _reshape(x): if not x.has_axis(Batch.name): return x x = x.unflatten_axis(Batch, (AccumStep, Microbatch)) - return hax.shard_with_axis_mapping(x, axis_mapping) + return hax.shard(x, axis_mapping) elif isinstance(x, jnp.ndarray): x = x.reshape((AccumStep.size, Microbatch.size) + x.shape[1:]) return with_sharding_constraint(x, PartitionSpec(None, ResourceAxis.DATA, *(None,) * (len(x.shape) - 2))) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 3c82e2fe2..b1b2bd3e8 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -49,8 +49,6 @@ class TrainLmConfig: def main(config: TrainLmConfig): - levanter.initialize(config) - tokenizer = config.data.the_tokenizer # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, @@ -79,6 +77,7 @@ def main(config: TrainLmConfig): else: converter = None + levanter.initialize(config) optimizer = config.optimizer.build(config.trainer.num_train_steps) # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp @@ -134,11 +133,7 @@ def main(config: TrainLmConfig): else: logger.info("No checkpoint found. Starting from scratch.") - levanter.tracker.log_summary( - { - "parameter_count": parameter_count(state.model), - } - ) + levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) if len(eval_datasets) == 0: logger.warning("No evaluation datasets provided.") @@ -147,7 +142,6 @@ def main(config: TrainLmConfig): eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) trainer.add_eval_hook(eval_dataset, name=name) - # Register hooks trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.hf_save_path is not None: full_save_path = os.path.join(config.hf_save_path, trainer.run_id) diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index bc89620f3..ef16a7238 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -50,6 +50,9 @@ def main(config: VizGpt2Config): EvalBatch, ) + # some axes we use outside the model proper + Pos = config.model.Pos + compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping diff --git a/src/levanter/models/flash_attention.py b/src/levanter/models/flash_attention.py index b1adea0b3..c1be0091e 100644 --- a/src/levanter/models/flash_attention.py +++ b/src/levanter/models/flash_attention.py @@ -304,8 +304,16 @@ def do_inner_block(state): dAttn_ij = p_ij * (dP_ij - D_i) dAttn_ij = dAttn_ij.astype(dQ_i.dtype) - dV_j = dV_j + hax.dot(QPos.name, p_ij, dO_i).astype(dV_j.dtype) - dK_j = dK_j + hax.dot(QPos.name, dAttn_ij, q_i).astype(dK_j.dtype) + dV_ji = hax.dot(QPos.name, p_ij, dO_i).astype(dV_j.dtype) + dK_ji = hax.dot(QPos.name, dAttn_ij, q_i).astype(dK_j.dtype) + + # GQA-specific: eliminate unnecessary axes (e.g. 'q_heads_per_group') + unnecessary_axes = hax.eliminate_axes(dV_ji.axes, v.axes) + dV_ji = hax.sum(dV_ji, unnecessary_axes) + dK_ji = hax.sum(dK_ji, unnecessary_axes) + + dV_j = dV_j + dV_ji + dK_j = dK_j + dK_ji dQ_i = dQ_i + hax.dot(KPos.name, dAttn_ij, k_j).astype(dQ.dtype) # dQ[i*block_size:(i+1)*block_size] = dQi diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 5822fab30..a6f27c7a5 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -312,40 +312,43 @@ class Gpt2Embeddings(StateDictSerializationMixin, eqx.Module): Vocab: Axis = eqx.static_field() config: Gpt2Config = eqx.static_field() - token_embeddings: NamedArray - position_embeddings: NamedArray + token_embeddings: hnn.Embedding + position_embeddings: hnn.Embedding dropout: hnn.Dropout @staticmethod def init(Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2Embeddings": k_wte, k_wpe, k_out = jrandom.split(key, 3) - token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) * config.initializer_range - position_embeddings = hax.random.normal(k_wpe, (config.Pos, config.Embed)) * (config.initializer_range / 2) + token_embeddings = hnn.Embedding.init( + Vocab, config.Embed, key=k_wte, initializer_range=config.initializer_range + ) + position_embeddings = hnn.Embedding.init( + config.Pos, config.Embed, key=k_wpe, initializer_range=config.initializer_range / 2 + ) dropout = hnn.Dropout(pdrop=config.embed_pdrop) return Gpt2Embeddings(Vocab, config, token_embeddings, position_embeddings, dropout) @named_call def embed(self, input_ids, *, key): - input_embeds = self.token_embeddings.take("vocab", input_ids) - position_embeds = self.position_embeddings - - input_len = input_ids.resolve_axis("position").size - x = input_embeds + position_embeds["position", hax.dslice(0, input_len)] + input_embeds = self.token_embeddings(input_ids) + input_Pos = input_ids.resolve_axis("position") + position_embeds = self.position_embeddings.embed(hax.arange(input_Pos)) + x = input_embeds + position_embeds x = self.dropout(x, key=key) return x def unembed(self, x: NamedArray): - return hax.dot("embed", x, self.token_embeddings) + return hax.dot("embed", x, self.token_embeddings.weight) def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return {"token_embeddings": "wte.weight", "position_embeddings": "wpe.weight"} + return {"token_embeddings": "wte", "position_embeddings": "wpe"} def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): - new_weights = hax.tree_util.resize_axis(self.token_embeddings, self.Vocab, new_size, key=key) - return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) + new_token_embeddings = self.token_embeddings.resize_embeddings(new_size, key=key) + return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_token_embeddings) class Gpt2LMHeadModel(eqx.Module, LmWithHfSerializationMixin[Gpt2Config]): diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py new file mode 100644 index 000000000..404c956a8 --- /dev/null +++ b/src/levanter/models/mistral.py @@ -0,0 +1,218 @@ +import dataclasses +from dataclasses import dataclass +from typing import Dict, Optional, Type, Union + +import equinox as eqx +import jax.random as jrandom + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, NamedArray +from haliax.jax_utils import maybe_rng_split + +from levanter.compat.hf_checkpoints import HFCheckpointConverter +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + unflatten_linear_layers, +) +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionMask +from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaTransformer +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.utils.py_utils import cached_classproperty + + +silence_transformer_nag() +from transformers import MistralConfig as HfMistralConfig # noqa: E402 +from transformers import PretrainedConfig as HfConfig # noqa: E402 + + +@LmConfig.register_subclass("mistral") +@dataclass(frozen=True) +class MistralConfig(LlamaConfig): + """Config for MistralModel + + Args: + seq_len (int, optional): maximum length of the input sequence. Defaults to 8192. + hidden_dim (int, optional): dimension of the hidden state. Defaults to 4096. + intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 14336. + num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. + num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. + num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer. + Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA. + Note that num_heads must be divisible by this number. Defaults to 8. + activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". + sliding_window (int, optional): window size of sliding window attention. Defaults to 4096. + """ + + seq_len: int = 8192 + hidden_dim: int = 4096 + intermediate_dim: int = 14336 + num_layers: int = 32 + num_heads: int = 32 + num_kv_heads: int = 8 + activation_function: str = "silu" + initializer_range: float = 0.02 + layer_norm_epsilon: float = 1e-6 + sliding_window: int = 4096 + + # Attention-related config + upcast_attn: bool = False + use_flash_attention: bool = False + flash_attention_block_size: Optional[int] = None + + gradient_checkpointing: bool = True + gradient_checkpointing_block_size: int = 5 + + use_bias: bool = False + rope_scaling: Optional[dict] = None + + # Axis + Pos = property(lambda self: Axis(name="position", size=self.seq_len)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) + Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + + @cached_classproperty + def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["MistralConfig"]: # type: ignore + return HFCheckpointConverter( + cls, # type: ignore + "mistralai/Mistral-7B-v0.1", + trust_remote_code=True, + tokenizer="mistralai/Mistral-7B-v0.1", + HfConfigClass=HfMistralConfig, + ) + + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + return MistralConfig( + seq_len=hf_config.max_position_embeddings, # this might be too big... + hidden_dim=hf_config.hidden_size, + intermediate_dim=hf_config.intermediate_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + num_kv_heads=hf_config.num_key_value_heads, + activation_function=hf_config.hidden_act, + initializer_range=hf_config.initializer_range, + layer_norm_epsilon=hf_config.rms_norm_eps, + sliding_window=hf_config.sliding_window, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfMistralConfig: + """Convert to HuggingFace's MistralConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfMistralConfig: HuggingFace's MistralConfig + """ + if config_overrides is None: + config_overrides = {} + + return HfMistralConfig( + max_position_embeddings=self.seq_len, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + hidden_act=self.activation_function, + initializer_range=self.initializer_range, + rms_norm_eps=self.layer_norm_epsilon, + sliding_window=self.sliding_window, + vocab_size=vocab_size, + **config_overrides, + ) + + @property + def model_type(cls) -> Type["MistralLMHeadModel"]: + return MistralLMHeadModel + + +class MistralLMHeadModel(eqx.Module, LmHeadModel[MistralConfig], StateDictSerializationMixin): + transformer: LlamaTransformer + embeddings: LlamaEmbedding + lm_head: hnn.Linear + + @property + def config(self): + return self.transformer.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + @classmethod + def init(cls, Vocab: Axis, config: MistralConfig, *, key) -> "MistralLMHeadModel": + k_t, k_emb = jrandom.split(key, 2) + transformer = LlamaTransformer.init(config, key=k_t) + embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + return MistralLMHeadModel(transformer, embeddings, lm_head) + + def __call__( + self, + input_ids: NamedArray, + attn_mask: Optional[Union[NamedArray, AttentionMask]] = None, + *, + key=None, + ) -> NamedArray: + """ + Args: + input_ids (NamedArray): [batch, position] + Indices of input sequence tokens in the vocabulary. + attn_mask (Union[NamedArray, AttentionMask], optional): [batch, position] + Mask to avoid performing attention on the padding token indices of the encoder input. + The attn_mask from training pipeline may be an AttentionMask object instead of NamedArray + """ + k_t, k_head = maybe_rng_split(key, 2) + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, key=k_t) + lm_logits = self.lm_head(x, key=k_head) + return lm_logits + + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[MistralConfig]": + new_Vocab = self.Vocab.resize(new_size) + k1, k2 = maybe_rng_split(key, 2) + new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"transformer": "model", "embeddings": None} + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + # unflatten the linear layers of HF state_dict to match the shape of MistralMlp + d = state_dict.copy() + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + ) + ) + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix=prefix) + + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) + ) + + state_dict.update(my_dict) + return state_dict diff --git a/src/levanter/optim/__init__.py b/src/levanter/optim/__init__.py index 319ddf84d..7dec2ebb4 100644 --- a/src/levanter/optim/__init__.py +++ b/src/levanter/optim/__init__.py @@ -1,15 +1,6 @@ from .config import AdamConfig, OptimizerConfig -from .second_order import ( - AnySecondOrderTransformation, - HessianUpdateFn, - SecondOrderTransformation, - chain_second_order, - inject_hyperparams, -) -from .sophia import ( +from .sophia import ( # SophiaGConfig,; SophiaGObjective, ScaleBySophiaState, - SophiaGConfig, - SophiaGObjective, SophiaHConfig, scale_by_sophia_g, scale_by_sophia_h, diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index 9df275c29..8895942a2 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -1,8 +1,7 @@ import abc import functools -import typing from dataclasses import dataclass -from typing import Any, NamedTuple, Optional, TypeVar, runtime_checkable +from typing import Any, NamedTuple, Optional, TypeVar import equinox as eqx import jax @@ -12,9 +11,8 @@ from jax.random import PRNGKey from jaxtyping import PRNGKeyArray -import levanter.tracker +# import levanter.tracker from levanter.optim.config import HessianOptConfig, OptimizerConfig -from levanter.optim.second_order import SecondOrderTransformation, chain_second_order, inject_hyperparams from levanter.optim.util import hvp, tree_gaussian_like from levanter.utils.jax_utils import parameter_count, tree_filter_like @@ -36,59 +34,92 @@ class ScaleBySophiaState(NamedTuple): hess_key: PRNGKey -@runtime_checkable -class SophiaGObjective(typing.Protocol): - """ - Class for objective functions that can be used with Sophia-G - - Sophia-G is a second order optimizer that uses the Gauss-Newton-Bartlett approximation to the Hessian - to compute the second order update. This requires the objective function be of the form loss(logits(x)) - where logits(x) is the activation of the model for the given example x. This is the case for most models - that are trained with "typical" losses. - """ - - def logits(self, parameters: M, example: Ex, *args, **kwargs) -> Any: - """ - Returns the logits/activations of the model for the given example, - or just sufficient statistics for the example for non-categorical models. - """ - ... - - def sample(self, logits, example: Ex, *, key: PRNGKey) -> Ex: - """ - Samples a new example with the same shape as the original example, but with - the "labels" replaced with some sampled values - """ - ... - - def loss(self, logits, example: Ex): - """ - Just computes the loss, e.g. cross entropy. - - Should return the mean loss over the batch, not the sum. - - TODO: should we reconsider this? - """ - ... - - def __call__(self, parameters: M, example: Ex, *args, **kwargs): - """ - Just a convenience method for invoking the objective for "normal" training w/o sophia-g - """ - logits = self.logits(parameters, example, *args, **kwargs) - return self.loss(logits, example) - - def num_data_points(self, example: Ex) -> int: - """ - Returns the number of data points in the example. This should take into account the loss mask - or any other masking that might be applied to the example. - - By default, we just return 1, and you can just pull the term into the hyperparams of Sophia if you want. - - Returns: - The number of data points in the example - """ - return 1 +# @runtime_checkable +# class SophiaGObjective(typing.Protocol): +# """ +# Class for objective functions that can be used with Sophia-G +# +# Sophia-G is a second order optimizer that uses the Gauss-Newton-Bartlett approximation to the Hessian +# to compute the second order update. This requires the objective function be of the form loss(logits(x)) +# where logits(x) is the activation of the model for the given example x. This is the case for most models +# that are trained with "typical" losses. +# """ +# +# def logits(self, parameters: M, *args, **kwargs) -> Any: +# """ +# Returns the logits/activations of the model for the given example, +# or just sufficient statistics for the example for non-categorical models. +# """ +# ... +# +# def sample(self, logits, *example, key: PRNGKey, **kwargs) -> Ex: +# """ +# Samples a new example with the same shape as the original example, but with +# the "labels" replaced with some sampled values +# """ +# ... +# +# def loss(self, logits, *example: Ex, **kwargs) -> jnp.ndarray: +# """ +# Just computes the loss, e.g. cross entropy. +# +# Should return the mean loss over the batch, not the sum. +# +# TODO: should we reconsider this? +# """ +# ... +# +# def __call__(self, parameters: M, *args, **kwargs) -> jnp.ndarray: +# """ +# Just a convenience method for invoking the objective for "normal" training w/o sophia-g +# """ +# logits = self.logits(parameters, *args, **kwargs) +# return self.loss(logits, *args, **kwargs) +# +# def num_data_points(self, example: Ex) -> int: +# """ +# Returns the number of data points in the example. This should take into account the loss mask +# or any other masking that might be applied to the example. +# +# By default, we just return 1, and you can just pull the term into the hyperparams of Sophia if you want. +# +# Returns: +# The number of data points in the example +# """ +# return 1 +# +# +# def apply_partial(self, *args, **kwargs) -> "SophiaGObjective": +# """ +# Returns a new objective that is a partial application of the current objective, used for +# passing in the data points. +# """ +# +# +# +# class PartialSophiaG(SophiaGObjective): +# def __init__(self, objective: SophiaGObjective, *args, **kwargs): +# self.objective = objective +# self.args = args +# self.kwargs = kwargs +# +# def logits(self, parameters: M, *args, **kwargs) -> Any: +# return self.objective.logits(parameters, *self.args, *args, **self.kwargs, **kwargs) +# +# def sample(self, logits, *example, key: PRNGKey, **kwargs) -> Ex: +# return self.objective.sample(logits, *self.args, *example, key=key, **self.kwargs, **kwargs) +# +# def loss(self, logits, *example: Ex, **kwargs) -> jnp.ndarray: +# return self.objective.loss(logits, *self.args, *example, **self.kwargs, **kwargs) +# +# def __call__(self, parameters: M, *args, **kwargs) -> jnp.ndarray: +# return self.objective(parameters, *self.args, *args, **self.kwargs, **kwargs) +# +# def num_data_points(self, example: Ex) -> int: +# return self.objective.num_data_points(*self.args, example, **self.kwargs) +# +# def apply_partial(self, *args, **kwargs) -> SophiaGObjective: +# return PartialSophiaG(self.objective, *self.args, *args, **self.kwargs, **kwargs) @dataclass @@ -115,7 +146,7 @@ def compute_hessian( raise NotImplementedError def build(self, num_train_steps: int): - def _optimizer(learning_rate, gamma) -> SecondOrderTransformation: + def _optimizer(learning_rate, gamma) -> optax.GradientTransformation: components = [] key = jax.random.PRNGKey(self.rng_seed) @@ -140,7 +171,7 @@ def _optimizer(learning_rate, gamma) -> SecondOrderTransformation: # - learning rate for descent components.append(optax.scale(-learning_rate)) - optimizer = chain_second_order(*components) + optimizer = optax.chain(*components) return optimizer @@ -149,18 +180,19 @@ def _optimizer(learning_rate, gamma) -> SecondOrderTransformation: constant_gamma_schedule = optax.constant_schedule(self.gamma) # type: ignore # gamma_schedule = optax.join_schedules([constant_gamma_schedule, gamma_decay_schedule], [num_train_steps // 2]) - return inject_hyperparams(_optimizer)( + return optax.inject_hyperparams(_optimizer)( learning_rate=self.lr_scheduler(num_train_steps), gamma=constant_gamma_schedule ) -@OptimizerConfig.register_subclass("sophia-g") -@dataclass -class SophiaGConfig(BaseSophiaConfig): - gamma: float = GAMMA_SOPHIA_G - - def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs): - return stochastic_diag_gauss_newton(fn, model, *batch, **batch_kwargs, hess_key=hess_key) +# @OptimizerConfig.register_subclass("sophia-g") +# @dataclass +# class SophiaGConfig(BaseSophiaConfig): +# gamma: float = GAMMA_SOPHIA_G +# +# def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs): +# return stochastic_diag_gauss_newton(fn, model, *batch, **batch_kwargs, hess_key=hess_key) +# @OptimizerConfig.register_subclass("sophia-h") @@ -183,7 +215,7 @@ def sophia_h( clip_threshold: Optional[float] = 1.0, update_interval: int = 10, key: PRNGKey, -) -> SecondOrderTransformation: +) -> optax.GradientTransformation: """Sophia-H: https://arxiv.org/pdf/2305.14342.pdf Algorithm 1&3""" components = [] @@ -194,7 +226,7 @@ def sophia_h( components.append(optax.scale(-lr)) - return chain_second_order(*components) + return optax.chain(*components) def scale_by_sophia_h( @@ -231,7 +263,7 @@ def sophia_g( clip_threshold: Optional[float] = 1.0, update_interval: int = 10, key: PRNGKey, -) -> SecondOrderTransformation: +) -> optax.GradientTransformation: """Sophia-G: https://arxiv.org/pdf/2305.14342.pdf Algorithm 2&3""" components = [] @@ -242,7 +274,7 @@ def sophia_g( components.append(optax.scale(-lr)) - return chain_second_order(*components) + return optax.chain(*components) def scale_by_sophia_g( @@ -278,7 +310,7 @@ def _sophia_gradient_transform( clip_threshold: Optional[float], initial_key: PRNGKeyArray, mu_dtype: Optional[Any] = None, -) -> SecondOrderTransformation: +) -> optax.GradientTransformation: mu_dtype = jax.canonicalize_dtype(mu_dtype) if mu_dtype is not None else None def init_fn(params): @@ -288,7 +320,7 @@ def init_fn(params): count=jnp.zeros([], jnp.int32), hessian_count=jnp.zeros([], jnp.int32), mu=mu, h=h, hess_key=initial_key ) - def update_fn(updates, state, params=None): + def update_fn(updates, state, params=None, *, obj_fn, **kwargs): mu = update_moment(updates, state.mu, b1, 1) # nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) mu_hat = bias_correction(mu, b1, state.count + 1) @@ -317,20 +349,22 @@ def update_fn(updates, state, params=None): stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates) # this doesn't work well on CPU, so skip if cpu - if jax.lib.xla_bridge.get_backend().platform != "cpu": - levanter.tracker.jit_log_metrics(stats, step=state.count) + # if jax.lib.xla_bridge.get_backend().platform != "cpu": + # levanter.tracker.jit_log_metrics(stats, step=state.count) if mu_dtype is not None: mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) - return updates, ScaleBySophiaState( + state = ScaleBySophiaState( count=state.count + 1, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key ) + state = update_hessian(state, params, obj_fn=obj_fn, **kwargs) + return updates, state - def update_hessian(state, fn, model, *batch, **batch_kwargs): + def update_hessian(state, params, *, obj_fn, **kwargs): def _do_update(): key, next_key = jax.random.split(state.hess_key) - new_hess = sophia_hess_fn(fn, model, *batch, hess_key=key, **batch_kwargs) + new_hess = sophia_hess_fn(obj_fn, params, hess_key=key, **kwargs) new_hess = tree_filter_like(state.h, new_hess) @@ -350,11 +384,11 @@ def _dont_update(): state.count, ) - return SecondOrderTransformation(init_fn, update_fn, update_hessian) + return optax.GradientTransformationExtraArgs(init_fn, update_fn) # use this for Sophia-G -def stochastic_diag_gauss_newton(fn: SophiaGObjective, model, example, *args, hess_key: PRNGKey, **kwargs): +def stochastic_diag_gauss_newton(fn, model, *args, hess_key: PRNGKey, **kwargs): """ Approximate the diagonal of the Hessian using an approximation to the Gauss Newton matrix. @@ -366,21 +400,22 @@ def stochastic_diag_gauss_newton(fn: SophiaGObjective, model, example, *args, he hess_key: key for sampling *args, **kwargs: passed to fn's logits """ - if not isinstance(fn, SophiaGObjective): - raise ValueError("objective must be a SophiaGObjective") + raise NotImplementedError("This is not implemented yet") + # if not isinstance(fn, SophiaGObjective): + # raise ValueError("objective must be a SophiaGObjective") # Step 3 - logits, model_backward = eqx.filter_vjp(lambda model: fn.logits(model, example, *args, **kwargs), model) + logits, model_backward = eqx.filter_vjp(lambda model: fn.logits(model, *args, **kwargs), model) # Step 4 - y_hat = fn.sample(logits, example, key=hess_key) + y_hat = fn.sample(logits, key=hess_key) # Step 5 grad_loss_logits = eqx.filter_grad(fn.loss)(logits, y_hat) pseudo_g = model_backward(grad_loss_logits)[0] # Step 6 - bs = fn.num_data_points(example) + bs = fn.num_data_points() h = jax.tree_util.tree_map(lambda x: x**2 * bs, pseudo_g) return h diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index 2ed2b1928..e3b6a1f71 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -50,7 +50,7 @@ def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commi try: if _global_tracker is None: raise RuntimeError("No global tracker set") - _global_tracker.log(metrics, step=step) + _global_tracker.log(metrics, step=step, commit=False) except Exception: logger.exception("Error logging metrics") diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 0d428e812..14aa98327 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -52,7 +52,6 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched from levanter.logging import capture_time -from levanter.optim import SecondOrderTransformation from levanter.tracker import TrackerConfig from levanter.types import ComputeLossFunction, FilterSpec, ModuleComputeLoss from levanter.utils import cloud_utils @@ -193,6 +192,8 @@ def __init__( if add_default_hooks: self._add_default_hooks() + self._cmanagers = [] + @cached_property def loss_fn(self): """ @@ -258,6 +259,7 @@ def EvalBatch(self): return self.config.EvalBatch def __enter__(self): + this_managers = [ levanter.current_tracker(self.tracker), self.device_mesh, @@ -514,9 +516,9 @@ def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: trainable_model = _partition_trainable_params(model, state.is_trainable)[0] updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) - # Sophia, e.g. - if isinstance(self.optimizer, SecondOrderTransformation): - opt_state = self.optimizer.update_hessian(opt_state, self.loss_fn, model, *batch, **batch_kwargs) + partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) + + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) model = eqx.apply_updates(model, updates) return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state) @@ -635,6 +637,10 @@ def TrainBatch(self): def EvalBatch(self): return Axis("batch", self.eval_batch_size) + @property + def microbatch_size(self): + return self.per_device_parallelism * self.data_axis_size + def __post_init__(self): if self.wandb is not None: warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index a181c8193..afc11c051 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -6,7 +6,7 @@ def logical_cpu_core_count(): """Returns the number of logical CPU cores available to the process.""" - num_cpus = os.getenv("SLURM_CPUS_PER_TASK", None) + num_cpus = os.getenv("SLURM_CPUS_ON_NODE", None) if num_cpus is not None: return int(num_cpus) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index e511cb11d..db54b2569 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -163,7 +163,6 @@ def test_checkpoint_simple(): initial_state, step=initial_state.step, checkpoint_path=tmpdir, - exist_ok=True, ) restored_state = load_checkpoint( rep_state, @@ -206,7 +205,7 @@ def loss_fn(model, data): assert_trees_not_close(state, rep_state) with tempfile.TemporaryDirectory() as tmpdir: - save_checkpoint(state, step=3, checkpoint_path=tmpdir, exist_ok=True) + save_checkpoint(state, step=3, checkpoint_path=tmpdir) restored_state = load_checkpoint(rep_state, checkpoint_path=tmpdir, discover_latest=False) assert_trees_all_close( diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 8d1f5aab0..467ac6cba 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -3,6 +3,7 @@ import equinox import jax.numpy as jnp import jax.random as jrandom +import pytest import haliax as hax import haliax.nn as hnn @@ -75,6 +76,40 @@ def d_attn(qkv, fn): assert jnp.allclose(hax_dv.array, fa_dv.array, atol=1e-4, rtol=1e-4) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_grad_group_query_attention(num_kv_heads): + Batch = hax.Axis("batch", 2) + KVHeads = hax.Axis("kv_heads", num_kv_heads) + QHeadsPerGroup = hax.Axis("q_heads_per_group", 4 // num_kv_heads) + Key = hax.Axis("Key", 8) + QPos = hax.Axis("QPos", BLOCK_SIZE * 2) + KPos = hax.Axis("KPos", BLOCK_SIZE * 2) + + mask = hax.nn.attention.causal_mask(QPos, KPos) + + q = hax.random.normal(jrandom.PRNGKey(0), (Batch, KVHeads, QHeadsPerGroup, QPos, Key)) + k = hax.random.normal(jrandom.PRNGKey(1), (Batch, KVHeads, KPos, Key)) + v = hax.random.normal(jrandom.PRNGKey(2), (Batch, KVHeads, KPos, Key)) + + @equinox.filter_value_and_grad + def d_attn(qkv, fn): + q, k, v = qkv + x_out = fn(KPos, Key, q, k, v, mask=mask) + return (x_out * x_out).sum().scalar() + + hax_val, (hax_dq, hax_dk, hax_dv) = d_attn((q, k, v), hnn.attention.dot_product_attention) + fa_val, (fa_dq, fa_dk, fa_dv) = d_attn((q, k, v), functools.partial(flash_attention, QPos, inference=True)) + + assert jnp.allclose(hax_val, fa_val, atol=1e-4, rtol=1e-4) + assert hax_dq.axes == fa_dq.axes + assert hax_dk.axes == fa_dk.axes + assert hax_dv.axes == fa_dv.axes + + assert jnp.allclose(hax_dq.array, fa_dq.array, atol=1e-4, rtol=1e-4) + assert jnp.allclose(hax_dk.array, fa_dk.array, atol=1e-4, rtol=1e-4) + assert jnp.allclose(hax_dv.array, fa_dv.array, atol=1e-4, rtol=1e-4) + + def test_fa_dropout_does_something(): Key = hax.Axis("Key", 8) QPos = hax.Axis("QPos", BLOCK_SIZE * 2) diff --git a/tests/test_llama.py b/tests/test_llama.py index 7224a3ac1..1ace5c63c 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,5 +1,6 @@ import tempfile +import equinox as eqx import jax import numpy as np import pytest @@ -193,7 +194,7 @@ def test_llama_decoder_layer(num_kv_heads): state = llama_decoder_layer.to_state_dict() state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} - hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000)) + hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000), layer_idx=0) hf_decoder_layer.load_state_dict(state, strict=True) x, mask = _get_random_inputs(llama_config) @@ -224,6 +225,25 @@ def test_llama_lm_head_model(num_kv_heads): assert out.array.shape == (Batch.size, Pos.size, Vocab.size) +@pytest.mark.parametrize("use_flash", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_llama_lm_head_model_bwd(use_flash, num_kv_heads): + llama_config = _get_llama_config(use_flash=use_flash, num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = llama_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = hax.nn.attention.causal_mask(Pos, llama_config.KeyPos) + + llama_model = LlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + + def f(llama_model, input_ids, mask): + out = llama_model(input_ids, mask) + return hax.sum(out).scalar() + + _, grads = eqx.filter_value_and_grad(f)(llama_model, input_ids, mask) + + @skip_if_no_torch @pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) def test_llama_roundtrip(num_kv_heads): @@ -297,6 +317,7 @@ def _get_llama_config(use_flash=False, num_kv_heads=4) -> LlamaConfig: rope_scaling=rope_scaling, gradient_checkpointing=False, # disable for tests so debugging is easier use_flash_attention=use_flash, + flash_attention_block_size=8 if use_flash else None, ) diff --git a/tests/test_mistral.py b/tests/test_mistral.py new file mode 100644 index 000000000..76630849f --- /dev/null +++ b/tests/test_mistral.py @@ -0,0 +1,165 @@ +import tempfile + +import equinox as eqx +import jax +import numpy as np +import pytest +import transformers +from jax import random + +import haliax as hax + +from levanter.models.mistral import MistralConfig, MistralLMHeadModel +from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs, skip_if_no_torch + + +@skip_if_no_torch +def test_mistral_config(): + # load HF config and convert to levanter config + hf_config = transformers.MistralConfig.from_pretrained("mistralai/Mistral-7B-v0.1") + mistral_config = MistralConfig.from_hf_config(hf_config) + + # convert back to HF config + config_overrides = { + "_name_or_path": hf_config._name_or_path, + "architectures": hf_config.architectures, + "torch_dtype": hf_config.torch_dtype, + } + new_hf_config = mistral_config.to_hf_config( + vocab_size=hf_config.vocab_size, + config_overrides=config_overrides, + ) + + # assert the content in new_hf_config is the same as hf_config + for k in new_hf_config.__dict__.keys(): + if k in ["_commit_hash", "transformers_version"]: + continue + assert getattr(new_hf_config, k) == getattr( + hf_config, k + ), f"{k} {getattr(new_hf_config, k)} != {getattr(hf_config, k)}" + + +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_mistral_lm_head_model(num_kv_heads): + mistral_config = _get_mistral_config(num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = mistral_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = hax.nn.attention.causal_mask(Pos, mistral_config.KeyPos) + + mistral_model = MistralLMHeadModel.init(Vocab=Vocab, config=mistral_config, key=random.PRNGKey(0)) + out = mistral_model(input_ids, mask) + assert out.array.shape == (Batch.size, Pos.size, Vocab.size) + + +@pytest.mark.parametrize("use_flash", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_mistral_lm_head_model_bwd(use_flash, num_kv_heads): + llama_config = _get_mistral_config(use_flash=use_flash, num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = llama_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = hax.nn.attention.causal_mask(Pos, llama_config.KeyPos) + + llama_model = MistralLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + + def f(llama_model, input_ids, mask): + out = llama_model(input_ids, mask) + return hax.sum(out).scalar() + + _, grads = eqx.filter_value_and_grad(f)(llama_model, input_ids, mask) + + +@skip_if_no_torch +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_mistral_roundtrip(num_kv_heads): + import torch + from transformers import AutoModelForCausalLM, MistralForCausalLM + + converter = MistralConfig.default_hf_checkpoint_converter + + config = MistralConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, + ) + Vocab = hax.Axis("vocab", 1000) + hf_config = config.to_hf_config(Vocab.size) + + # Make input and attn_mask + input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + attn_mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) + + torch.random.manual_seed(0) + + torch_model = MistralForCausalLM(hf_config) + torch_model.eval() + + torch_out = torch_model(input_torch) + torch_out = torch_out.logits[0].detach().cpu().numpy() + torch_out = jax.nn.softmax(torch_out, axis=-1) + + with tempfile.TemporaryDirectory() as tmpdir: + torch_model.save_pretrained(f"{tmpdir}/torch_model") + + model = converter.load_pretrained( + MistralLMHeadModel, f"{tmpdir}/torch_model", resize_vocab_to_match_tokenizer=False + ) + + def compute(input): + model_output = model(input, attn_mask=attn_mask) + return hax.nn.softmax(model_output, axis=model.Vocab) + + compute = jax.jit(compute) + jax_out = compute(input).array + + assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + + converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) + torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") + torch_model2.eval() + + torch_out2 = torch_model2(input_torch) + torch_out2 = torch_out2.logits[0].detach().cpu().numpy() + torch_out2 = jax.nn.softmax(torch_out2, axis=-1) + assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" + assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + + +def _get_mistral_config(use_flash=False, num_kv_heads=4) -> MistralConfig: + return MistralConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, # disable for tests so debugging is easier + use_flash_attention=use_flash, + flash_attention_block_size=8 if use_flash else None, + ) + + +@parameterize_with_configs("mistral*.yaml") +def test_mistral_configs(config_file): + from levanter.main.train_lm import TrainLmConfig + + config_class = TrainLmConfig + + check_load_config(config_class, config_file) + + +@pytest.mark.parametrize("num_kv_heads", [1, 2]) +def test_pass_different_length_seq(num_kv_heads): + config = MistralConfig( + seq_len=32, + hidden_dim=16, + intermediate_dim=32, + num_heads=2, + num_kv_heads=num_kv_heads, + ) + check_model_works_with_seqlen(MistralLMHeadModel, config, 16) diff --git a/tests/test_sophia.py b/tests/test_sophia.py index 7e759c330..1ca3a7265 100644 --- a/tests/test_sophia.py +++ b/tests/test_sophia.py @@ -1,3 +1,4 @@ +import functools import os import equinox as eqx @@ -15,9 +16,17 @@ def test_sophia_h(): model = nn.Linear(4, 4, use_bias=False, key=key) data = np.load(f"{os.path.dirname(__file__)}/data/hero_data.npy").astype("float32") optimizer = levanter.optim.sophia.sophia_h( - lr=1, b1=0, b2=0.99, gamma=2, weight_decay=0.0, clip_threshold=1, key=key + lr=1, + b1=0, + b2=0.99, + gamma=2, + weight_decay=0.0, + clip_threshold=1, + key=key, + update_interval=1, ) model = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), model) + zero_grad = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), model) opt_state = optimizer.init(model) @@ -25,10 +34,11 @@ def loss_fn(model, data): out = eqx.filter_vmap(model)(data) return jnp.mean(out**2) * 4 - jit_update = eqx.filter_jit(optimizer.update_hessian) + jit_update = eqx.filter_jit(optimizer.update) + obj_fn = functools.partial(loss_fn, data=data) for i in range(1000): - opt_state = jit_update(opt_state, loss_fn, model, data) + _, opt_state = jit_update(zero_grad, opt_state, params=model, obj_fn=obj_fn) # print('Test-estimated hessian: most coordinates should be approximately 2') # print('Estimated hessian:', opt_state[0].h.weight) @@ -37,7 +47,7 @@ def loss_fn(model, data): grad_loss_fn = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn)) loss, grad = grad_loss_fn(model, data) - model_updates, opt_state = optimizer.update(grad, opt_state) + model_updates, opt_state = optimizer.update(grad, opt_state, params=model, obj_fn=obj_fn) model = eqx.apply_updates(model, model_updates) # loss should be 15.74834156036377 @@ -49,7 +59,7 @@ def loss_fn(model, data): # print("Test-loss: loss should shrink by approximately 75% after each iteration") for i in range(10): loss, grad = grad_loss_fn(model, data) - model_updates, opt_state = optimizer.update(grad, opt_state) + model_updates, opt_state = optimizer.update(grad, opt_state, params=model, obj_fn=obj_fn) model = eqx.apply_updates(model, model_updates) # print('Step:', i , "Loss:", loss.item()) diff --git a/tests/test_text.py b/tests/test_text.py index a9d407b44..70b2d26a7 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,12 +1,14 @@ import tempfile import jax.numpy as jnp +from transformers import AutoTokenizer import haliax as hax -from levanter.data.text import LMDatasetConfig +from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.models.lm_model import LmExample from levanter.models.loss import next_token_loss +from test_utils import skip_if_hf_model_not_accessible def test_dont_blow_up_without_validation_set(): @@ -39,3 +41,29 @@ def test_lm_example_handles_ignore_id(): no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size + + +def test_merge_split_encodings(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + # make this very short for testing + + lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.""" + + short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3) + # force this + short_batch_tokenizer._needs_long_sequence_workaround = True + + batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000) + batch = [lorem] + + short_out = short_batch_tokenizer(batch) + reg_out = batch_tokenizer(batch) + + assert short_out == reg_out + + +@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") +def test_llama_tokenizer_needs_long_sequence_workaround(): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + batch_tokenizer = BatchTokenizer(tokenizer) + assert batch_tokenizer._needs_long_sequence_workaround diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 25d5e8fb0..71d117055 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -18,7 +18,11 @@ def setup_module(module): ray_designated_cores = max(1, logical_cpu_core_count()) - ray.init("local", num_cpus=ray_designated_cores) + try: + ray.init("local", num_cpus=ray_designated_cores) + except AssertionError: + # don't get upset if ray is already running + pass def teardown_module(module): diff --git a/tests/test_weight_decay_mask.py b/tests/test_weight_decay_mask.py index c47231116..cc94c5749 100644 --- a/tests/test_weight_decay_mask.py +++ b/tests/test_weight_decay_mask.py @@ -18,8 +18,8 @@ def apply_weight_decay(tree): nodes = [] # apply on embedding - nodes.append(tree.embeddings.token_embeddings.array) - nodes.append(tree.embeddings.position_embeddings.array) + nodes.append(tree.embeddings.token_embeddings.weight.array) + nodes.append(tree.embeddings.position_embeddings.weight.array) # apply on attention nodes.append(tree.transformer.blocks.stacked.attn.c_attn.weight.array) @@ -49,8 +49,8 @@ def apply_weight_decay(tree): "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight", - "token_embeddings", - "position_embeddings", + "token_embeddings.weight", + "position_embeddings.weight", ] ) regex_config = AdamConfig(