diff --git a/config/llama2_nano.yaml b/config/llama2_nano.yaml index 6b6f8d93f..58415022e 100644 --- a/config/llama2_nano.yaml +++ b/config/llama2_nano.yaml @@ -9,6 +9,7 @@ model: type: llama hidden_dim: 32 num_heads: 4 + num_kv_heads: 4 num_layers: 2 trainer: tracker: diff --git a/docs/figures/finetune_func_cm_full_weight.png b/docs/figures/finetune_func_cm_full_weight.png new file mode 100644 index 000000000..9e04504e6 Binary files /dev/null and b/docs/figures/finetune_func_cm_full_weight.png differ diff --git a/docs/figures/finetune_func_cm_lora.png b/docs/figures/finetune_func_cm_lora.png new file mode 100644 index 000000000..753c479d0 Binary files /dev/null and b/docs/figures/finetune_func_cm_lora.png differ diff --git a/docs/tutorials/Fine-Tuning-Semantic-Parsing.md b/docs/tutorials/Fine-Tuning-Semantic-Parsing.md new file mode 100644 index 000000000..c704867e4 --- /dev/null +++ b/docs/tutorials/Fine-Tuning-Semantic-Parsing.md @@ -0,0 +1,427 @@ +# Fine-Tuning for Semantic Parsing + +Semantic parsing is a process that transforms a natural language sentence into a logical form. +This logical form represents the sentence's meaning in a way that computer programs can utilize to carry out tasks, respond to questions, or follow commands. + +For example, through semantic parsing, a chatbot can convert a question posed in natural language into a precise SQL query, which then retrieves the desired information from a database. +Similarly, a virtual assistant can interpret a user's spoken request and, by employing semantic parsing, translate it into a JSON object that triggers specific actions. +By bridging the gap between human language and machine-readable formats, semantic parsing empowers AI systems to perform tasks with both accuracy and autonomy. + +In this post, we'll guide you through fine-tuning a [Llama2 model](https://ai.meta.com/llama/) for semantic parsing with Levanter. + +## Example Task +Our example task is based on the [GEM ViGGO](https://huggingface.co/datasets/GEM/viggo) dataset. + +It translates conversational English queries on the topic of video games into a structural format. +Each example features a plain English input and a structured output that adheres to predefined rules for function naming conventions and attributes. + +The dataset has 5,103 training examples and 714 examples for evaluation. +Although smaller than [the Alpaca dataset](../Fine-Tuning.md#overview-of-alpaca), it provides a decent amount of data to effectively adapt the model to the task. + +Below are some examples from the dataset: + +In this example below, the user expresses an opinion on a game and describes the game with a list of attributes. +The chatbot should capture the intention (`give_opinion`), as well as the attributes that describe the game (name, release year, rating, has_multiplayer). + +``` +Query: I had fun playing Age of Empires II: The Age of Kings. I enjoy a lot of multiplayer games from 1999. +Expected Response: give_opinion(name[Age of Empires II: The Age of Kings], release_year[1999], rating[good], has_multiplayer[yes]) +``` + +In this example below, the query describes a game in a very detailed manner. +The chatbot should tell apart the intention of `inform` from `give_opinion` in the example above, and parse all the corresponding attributes. It is a typical example of "Semantic Parsing" and it is not easy to capture all the attributes with field and value correctly. + +``` +Query: BioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. +Expected Response: inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes]) +``` + +This is an example of a user asking for an action. The chatbot should capture the intention and attributes correctly, so that it can generate a backend call to perform the action. + +``` +Query: Is it the PC game, The Elder Scrolls Online, which was developed by ZeniMax Online Studios? +Expected Response: confirm(name[The Elder Scrolls Online], developer[ZeniMax Online Studios], platforms[PC]) +``` + +### Quick Test with Llama2-7B Chat Model + +If we test with the Llama2-7B chat model on these examples (see the full prompt in [Appendix A](#appendix-a-the-prompt-used-in-this-task)), +we can see it does not learn the task well: it struggles to generate the correct function names and hallucinates quite a few attributes that are not mentioned in the query; it also produces outputs with incorrect formats (`\n` before attribute names, `[E (`, etc). + +``` +Query: I had fun playing Age of Empires II: The Age of Kings. I enjoy a lot of multiplayer games from 1999. +Target: give_opinion(name[Age of Empires II: The Age of Kings], release_year[1999], rating[good], has_multiplayer[yes]) +Model: inform(name[Age of Empires II: The Age of Kings], release_year[1999], esrb[E (for Everyone)], genres[strategy], platforms[PC], has_multiplayer[yes]) + +Query: BioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. +Target: inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes]) +Model: inform(name[BioShock], release_year[2007], esrb[Rating Pending], genres[role-playing, action-adventure, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_mac_release[yes], has_linux_release[no]) + +Query: Is it the PC game, The Elder Scrolls Online, which was developed by ZeniMax Online Studios? +Target: confirm(name[The Elder Scrolls Online], developer[ZeniMax Online Studios], platforms[PC]) +Model: inform(name[The Elder Scrolls Online], release_year[2014],\nesrb[M (for Mature)], genres[action, adventure, role-playing],\nplatforms[PC], available_on_steam[yes], has_linux_release[no],\nhas_mac_release[yes]) +``` + +This highlights the need for fine-tuning: to enhance the model's understanding on this task and producing the intended output. + +## Fine-tuning with Levanter + +### Step 1: Prepare the Dataset + +Let's begin by preparing the dataset. + +Since our dataset is already in a clean, tabular format, minimal preprocessing effort is required. +Our main tasks are to rename the columns and convert the data into the JSONL format, which is compatible with Levanter. +For detailed instructions, refer to the documentation [Training on Your Data](../Training-On-Your-Data.md#dataset-preparation). + +Below is a code snippet for dataset preparation: + +```python +import json +import datasets + + +# load the dataset +train_dataset = datasets.load_dataset("GEM/viggo", split="train") + +# rename the columns +train_dataset = train_dataset.map( + lambda example: { + "instruction_field": PROMPT, + "input": example["target"], + "output": example["meaning_representation"], + } +) + +# save the dataset in JSONL format +with open("train.jsonl", "w") as f: + for example in train_dataset: + json.dump(example, f) + f.write("\n") +``` + +The `PROMPT` provides the model with instructions to enhance its understanding of the task at hand. +In our example, the prompt details the potential function names and attributes, aiding the model in generating the correct output. +We provide the full prompt in [Appendix A](#appendix-a-the-prompt-used-in-this-task). +While helpful, including a prompt is optional for fine-tuning. + +### Step 2: Fine-tune the Model + +Now, let's proceed to fine-tune the model with Levanter. + +For this example, we've explored both comprehensive full-weight fine-tuning, similar to the approach used in Alpaca, and the more resource-efficient LoRA fine-tuning. Detailed descriptions of both methods are available in the documentation: [Fine-Tuning](../Fine-Tuning.md) and [LoRA](../LoRA.md). + +Here's a brief comparison of the two: + +- **Full-weight fine-tuning**: it fine-tunes the entire model weights to better follow the instruction and examples in the training dataset. It is able to leverage the entire model capacity, but it is expensive and prone to overfitting. +- **LoRA fine-tuning**: it adapts the model to the task by adding a small number of parameters (0.1% to 1%) to the model, and train only those parameters. The new parameters are sufficient to capture the task-specific patterns and enable the model to generate the desired output. After training, we merge the new parameters into the original model to be used for inference. It is much more efficient than full-weight fine-tuning, and it is less prone to overfitting. + +Levanter provides good support for both methods. Therefore, we can easily try both methods and compare their results. + +#### Full-weight Fine-tuning + +We start with full-weight fine-tuning. Below is our configuration. Noteably: + +- The base model is `meta-llama/Llama-2-7b-hf`. It is set as the default value, so we don't need to specify it explicitly. +- Batch size: We set the batch size to 128, which is the maximum batch size that can fit into a single TPUv3-8. +- Learning rate: We compared the results with 3 epochs vs 2 epochs, and found that 2 epochs is sufficient to achieve the best results, while 3 epochs leads to overfitting. + +```yaml +data: "gs://levanter-data/fine-tuning/gem-viggo/GEM_viggo_train.jsonl" +data_cache_dir: "gs://levanter-data/tokenized/GEM_viggo_llama/" +trainer: + wandb: + project: "levanter" + tags: ["viggo", "llama", "full-weight"] + mp: p=f32,c=bfloat16 + train_batch_size: 128 + num_train_steps: 80 # 5103 examples / 128 batch size * 2 epochs + tensor_parallel_axes: ["mlp", "heads"] +optimizer: + learning_rate: 2E-5 +``` + +The detailed instruction to run the training job can be found in the [Fine-Tuning documentation](../Fine-Tuning.md). +Here is the command to run the training job on TPU: + +```bash +gcloud compute tpus tpu-vm ssh finetune-32 --zone us-east1-d --worker=all \ +--command="WANDB_API_KEY=${YOUR WANDB TOKEN HERE} \ +HUGGING_FACE_HUB_TOKEN=${YOUR HF TOKEN HERE} \ +bash levanter/infra/run.sh python \ +levanter/examples/alpaca/alpaca.py \ + --config_path gs:// \ + --hf_save_path gs://" +``` + +Given the small dataset and high efficiency of Levanter, the entire training job completed quickly in only 21 min on a single TPUv3-8. + +#### LoRA Fine-tuning + +Below is our configuration for LoRA fine-tuning. Note that it is very similar to the full-weight fine-tuning configuration, except for a few differences: + +- We added the `lora` section to specify the LoRA parameters. All of the parameters are set to the default values. +- We increased the number of steps by 1 more epoch. LoRA fine-tuning uses less parameters, so it regularizes better and we can train for more steps. +- We increased the learning rate to 3e-4, but we did not do very thorough hyperparameter tuning. We expect there might be a better learning rate. +- We found weight decay at 0.1 leads to better results than no weight decay, so we set it at 0.1. + +```yaml +data: "gs://levanter-data/fine-tuning/gem-viggo/GEM_viggo_train.jsonl" +data_cache_dir: "gs://levanter-data/tokenized/GEM_viggo_lora/" +trainer: + wandb: + project: "levanter" + tags: ["viggo", "llama", "lora"] + + mp: p=f32,c=bfloat16 + train_batch_size: 128 + num_train_steps: 120 # 5103 examples / 128 batch size * 3 epochs + tensor_parallel_axes: ["mlp", "heads"] +optimizer: + learning_rate: 3e-4 + weight_decay: 0.1 +lora: + r: 8 # rank of LoRA transform + alpha: 8.0 # scaling factor for LoRA transform + dropout: 0.0 # dropout probability for LoRA layers +``` + +Here is the command to run the training job on TPU: + +```bash +gcloud compute tpus tpu-vm ssh finetune-32 --zone us-east1-d --worker=all \ +--command="WANDB_API_KEY=${YOUR WANDB TOKEN HERE} \ +HUGGING_FACE_HUB_TOKEN=${YOUR HF TOKEN HERE} \ +bash levanter/infra/run.sh python \ +levanter/examples/alpaca-lora/alpaca_lora.py \ + --config_path gs:// \ + --hf_save_path $GS_BUCKET/llama2/ \ + --merged_hf_save_path gs://levanter-checkpoints/llama2/llama2_7b_viggo_lora" +``` + +Note that with `--merged_hf_save_path`, it will merge the trained LoRA parameters into the original model and save the new model. +To save the LoRA adaptors as separate weight file, use `--hf_save_path` instead. + +## Evaluation + +### Metrics +How do we accurately evaluate a model's performance in semantic parsing tasks? +Character-level accuracy falls short as it doesn't account for variations in the order of attributes and does not distinguish between function names and attributes. +Instead, we assess the model's ability to interpret instructions and parse semantic meaning from input queries by measuring more specific accuracies: + +- Function Name Accuracy: This metric confirms whether the extracted function name matches the expected one. +- Attribute Set Accuracy: This checks if the model identifies the correct set of attributes, regardless of their order. +- Attribute Value Accuracy: This evaluates the proportion of attributes for which the model has accurately predicted the corresponding values. +- Overall Accuracy: This is the simple average of the three metrics above. This is used as an aggregate metric to compare the overall performance of the model. + +Together, these metrics provide a comprehensive picture of the model's effectiveness in this task. + +The code snippet below shows how we extract the function name and attributes from the model's response and evaluate each accuracy metric. + +```python +def extract_function_and_attributes(response): + # Remove extra spaces and normalize the response + response = response.strip().lower() + # Extract the function name using regex + function_match = re.match(r"(\w+)\(", response) + function_name = function_match.group(1) if function_match else None + # Extract attributes and their values using regex + attributes = re.findall(r"(\w+)\[([^]]*)\]", response) + return function_name, dict(attributes) + + +def evaluate_chatbot_response(chatbot_response, labeled_response): + # Preprocess and extract data from responses + chatbot_function, chatbot_attributes = extract_function_and_attributes( + chatbot_response + ) + labeled_function, labeled_attributes = extract_function_and_attributes( + labeled_response + ) + + # Function Name Accuracy + function_name_accuracy = int(chatbot_function == labeled_function) + + # Attribute Set Accuracy + attribute_set_accuracy = int( + set(chatbot_attributes.keys()) == set(labeled_attributes.keys()) + ) + + # Attribute Value Accuracy + correct_values = sum( + chatbot_attributes.get(attr, None) == value + for attr, value in labeled_attributes.items() + ) + attribute_value_accuracy = ( + correct_values / len(labeled_attributes) if labeled_attributes else 1 + ) + + # Composite Metric (simple average for this example) + composite_score = ( + function_name_accuracy + attribute_set_accuracy + attribute_value_accuracy + ) / 3 + + return { + "function_name_accuracy": function_name_accuracy, + "attribute_set_accuracy": attribute_set_accuracy, + "attribute_value_accuracy": attribute_value_accuracy, + "composite_score": composite_score, + } +``` + +### Results + +We evaluated the fine-tuned models on a hold-out evaluation set of 714 examples and computed the metrics described above. +The results are shown in the table below. + +| Model \ Metric | Function Name Accuracy | Attribute Set Accuracy | Attribute Value Accuracy | Overall Accuracy | +|---------------------------|------------------------|------------------------|--------------------------|------------------| +| Llama2-7B Chat | 0.014 | 0.010 | 0.524 | 0.183 | +| Full-weight Fine-tuning | 0.577 | 0.822 | 0.942 | 0.780 | +| LoRA Fine-tuning | 0.517 | 0.845 | 0.881 | 0.748 | + + +There are a few highlights from the results: + +- The baseline Llama2-7B Chat model's performance is remarkably low at 0.183 overall accuracy. This is consistent with our earlier observation that it does not really understand how to perform the task. +- Fine-tuning methods, both full-weight and LoRA, substantially enhance the model's accuracy, achieving 0.780 and 0.748, respectively. Notably, LoRA fine-tuning, while training with less than 0.1% parameters, approaches the metrics of full-weight fine-tuning and achieves a better `Attribute Set Accuracy`. This highlights the efficiency of LoRA. +- Full-weight fine-tuning outperforms LoRA fine-tuning in the `Function Name Accuracy` and the `Attribute Value Accuracy` metrics. This shows the advantage of training the entire model weights on the task. Though further hyperparameter tuning might allow LoRA to close this gap. +- The higher accuracy in attribute set and value suggests that the attributes are more contextually driven and thus easier for the model to predict. In contrast, correctly identifying function names appears to be more challenging, indicating a need for deeper understanding of the task and reasoning capability. + +#### Confusion among Function Names + +To expand on the last point, we went deep into predictions of function name. +The confusion matrices below illustrate how the the full-weight fine-tuning and LoRA fine-tuning models perform in this area. + +![Confusion Matrix of Full-weight Fine-Tuning](../figures/finetune_func_cm_full_weight.png) +![Confusion Matrix of LoRA Fine-Tuning](../figures/finetune_func_cm_lora.png) + +- A notable area of confusion is within the `confirm` column, where both models frequently misclassify examples from the `inform`, `recommend`, and `suggest` categories as `confirm`. This may stem from these function names being too similar in the dataset, leading to ambiguity that the models struggle to differentiate. +- The full-weight fine-tuning model (depicted on the left) tends to make more errors in the "empty" or "other" columns, indicating instances where the model either fails to predict a function name or predicts one that doesn't exist within the predefined set. +- The LoRA fine-tuning model (shown on the right) exhibits more confusion between the predefined function classes; overall, it has fewer correct predictions along the diagonal of the matrix. + +To enhance the model's ability to accurately predict function names, we should consider consolidating function names that are closely related to reduce ambiguity. Additionally, further refinement of the hyperparameters could yield improvements. + +#### Examples +Here we are going to show some examples and compare predictions of different models. + +Below shows three examples of three different functions, both fine-tuning methods generate the exact same output as the target and avoid hallucinating attributes, which is seen in the Llama2-7B Chat model. + +``` +Query: Is it the PC game, The Elder Scrolls Online, which was developed by ZeniMax Online Studios? +Target: confirm(name[The Elder Scrolls Online], developer[ZeniMax Online Studios], platforms[PC]) +Llama2-7B Chat: inform(name[The Elder Scrolls Online], release_year[2014],\nesrb[M (for Mature)], genres[action, adventure, role-playing],\nplatforms[PC], available_on_steam[yes], has_linux_release[no],\nhas_mac_release[yes]) +Full fine-tune: confirm(name[The Elder Scrolls Online], developer[ZeniMax Online Studios], platforms[PC]) +LoRA fine-tune: confirm(name[The Elder Scrolls Online], developer[ZeniMax Online Studios], platforms[PC]) +``` + +``` +Query: A puzzle-shooter platformer game is Portal 2 from 2011. You can play it multiplayer on PlayStation and your PC, and it is available on Steam. +Target: inform(name[Portal 2], release_year[2011], genres[platformer, puzzle, shooter], has_multiplayer[yes], platforms[PlayStation, PC], available_on_steam[yes]) +Llama2-7B Chat: inform(name[Portal 2], release_year[2011], esrb[E 10+ (for Everyone 10 and Older)], genres[puzzle-shooter, platformer], platforms[PlayStation, PC], available_on_steam[yes]) +Full fine-tune: inform(name[Portal 2], release_year[2011], genres[platformer, puzzle, shooter], has_multiplayer[yes], platforms[PlayStation, PC], available_on_steam[yes]) +LoRA fine-tune: inform(name[Portal 2], release_year[2011], genres[platformer, puzzle, shooter], has_multiplayer[yes], platforms[PlayStation, PC], available_on_steam[yes]) +``` + +``` +Query: What is it about first person games on the PC that causes you to find them just so-so? +Target: request_explanation(rating[average], player_perspective[first person], platforms[PC]) +Llama2-7B Chat: inform(name[The Last of Us], release_year[2013], esrb[M (for Mature)], genres[action-adventure], platforms[PlayStation], available_on_steam[yes], has_mac_release[no]) +Full fine-tune: request_explanation(rating[average], player_perspective[first person], platforms[PC]) +LoRA fine-tune: request_explanation(rating[average], player_perspective[first person], platforms[PC]) +``` + +In the following example, both fine-tuning methods generate precisely the same attribute set as the target; the full-weight fine-tuning model correctly identifies the function name (`give_opinion`) while the LoRA fine-tuning model mispredicts it as `confirm`. + +``` +Query: I had fun playing Age of Empires II: The Age of Kings. I enjoy a lot of multiplayer games from 1999. +Target: give_opinion(name[Age of Empires II: The Age of Kings], release_year[1999], rating[good], has_multiplayer[yes]) +Llama2-7B Chat: inform(name[Age of Empires II: The Age of Kings], release_year[1999], esrb[E (for Everyone)], genres[strategy], platforms[PC], has_multiplayer[yes]) +Full fine-tune: give_opinion(name[Age of Empires II: The Age of Kings], release_year[1999], rating[good], has_multiplayer[yes]) +LoRA fine-tune: confirm(name[Age of Empires II: The Age of Kings], release_year[1999], has_multiplayer[yes]) +``` + +## Summary + +In this post, we showcase the process of fine-tuning a Llama2 model using Levanter for the task of semantic parsing. + +Semantic parsing is a critical step in translating user queries into machine-readable, structural format, which is essential for programmatic interactions with AI systems. +We chose the GEM ViGGO dataset to demonstrate this task. The out-of-box Llama2-7B Chat model struggles to follow the instruction and hallucinates at predicting attributes. By applying fine-tuning, both full-weight and LoRA, we significantly improved the model's ability to perform on this task. + +## Appendices + +### Appendix A: The Prompt Used in This Task + +``` +You are a helpful chatbot to assist users to parse queries in natural language into structural format. + +Given a target sentence construct the underlying meaning representation +of the input sentence as a single function with attributes and attribute +values. This function should describe the target string accurately and the +function must be one of the following ['inform', 'request', 'give_opinion', +'confirm', 'verify_attribute', 'suggest', 'request_explanation', +'recommend', 'request_attribute'] . + +The attributes must be one of the following: +['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating', +'genres', 'player_perspective', 'has_multiplayer', 'platforms', +'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] +The order your list the attributes within the function must follow the +order listed above. For example the 'name' attribute must always come +before the 'exp_release_date' attribute, and so forth. + +For each attribute, fill in the corresponding value of the attribute +within brackets. A couple of examples are below. Note: you are to output +the string after "Output: ". Do not include "Output: " in your answer. + +Example 1) +Sentence: Dirt: Showdown from 2012 is a sport racing game for the +PlayStation, Xbox, PC rated E 10+ (for Everyone 10 and Older). +It's not available on Steam, Linux, or Mac. +Output: inform(name[Dirt: Showdown], release_year[2012], +esrb[E 10+ (for Everyone 10 and Older)], genres[driving/racing, sport], +platforms[PlayStation, Xbox, PC], available_on_steam[no], +has_linux_release[no], has_mac_release[no]) + +Example 2) +Sentence: Were there even any terrible games in 2014? +Output: request(release_year[2014], specifier[terrible]) + +Example 3) +Sentence: Adventure games that combine platforming and puzzles +can be frustrating to play, but the side view perspective is +perfect for them. That's why I enjoyed playing Little Nightmares. +Output: give_opinion(name[Little Nightmares], rating[good], +genres[adventure, platformer, puzzle], player_perspective[side view]) + +Example 4) +Sentence: Since we're on the subject of games developed by Telltale +Games, I'm wondering, have you played The Wolf Among Us? +Output: recommend(name[The Wolf Among Us], developer[Telltale Games]) + +Example 5) +Sentence: Layers of Fear, the indie first person point-and-click adventure game? +Output: confirm(name[Layers of Fear], genres[adventure, indie, +point-and-click], player_perspective[first person]) + +Example 6) +Sentence: I bet you like it when you can play games on Steam, like +Worms: Reloaded, right? +Output: suggest(name[Worms: Reloaded], available_on_steam[yes]) + +Example 7) +Sentence: I recall you saying that you really enjoyed The Legend +of Zelda: Ocarina of Time. Are you typically a big fan of games +on Nintendo rated E (for Everyone)? +Output: verify_attribute(name[The Legend of Zelda: Ocarina of Time], +esrb[E (for Everyone)], rating[excellent], platforms[Nintendo]) + +Example 8) +Sentence: So what is it about the games that were released in 2005 +that you find so excellent? +Output: request_explanation(release_year[2005], rating[excellent]) + +Example 9) +Sentence: Do you think Mac is a better gaming platform than others? +Output: request_attribute(has_mac_release[]) +``` diff --git a/mkdocs.yml b/mkdocs.yml index a386310ac..28fdb9849 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -94,6 +94,7 @@ nav: - 'Tutorials': - "Fine-Tuning.md" - "LoRA.md" + - "tutorials/Fine-Tuning-Semantic-Parsing.md" - "Hardware-Agnostic-Training.md" - 'Developer Guide': - 'dev/Port-Models.md' diff --git a/src/levanter/lora.py b/src/levanter/lora.py index eb511a80d..3e0dee750 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -500,12 +500,10 @@ def to_hf_config(config: LoraConfig, base_model_name_or_path: Optional[str] = No return { "base_model_name_or_path": base_model_name_or_path, "bias": "none", # TODO: support bias - "enable_lora": None, "fan_in_fan_out": False, # TODO: support fan_in_fan_out "inference_mode": True, # TODO: support inference_mode "lora_alpha": config.alpha, "lora_dropout": 0.00, # TODO: support dropout - "merge_weights": False, "modules_to_save": None, # TODO: support modules_to_save? "peft_type": "LORA", "r": config.r, diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index e96f23c8a..17f6d04cb 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -47,6 +47,9 @@ class LlamaConfig(HFCompatConfig): intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008. num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. + num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer. + Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA. + Note that num_heads must be divisible by this number. Defaults to 32. activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ @@ -56,6 +59,7 @@ class LlamaConfig(HFCompatConfig): intermediate_dim: int = 11008 num_layers: int = 32 num_heads: int = 32 + num_kv_heads: int = 32 activation_function: str = "silu" initializer_range: float = 0.02 layer_norm_epsilon: float = 1e-5 @@ -76,10 +80,16 @@ class LlamaConfig(HFCompatConfig): KeyPos = property(lambda self: self.Pos.alias("key_position")) Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + def __post_init__(self): + assert ( + self.num_heads % self.num_kv_heads == 0 + ), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." + @cached_classproperty def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore return HFCheckpointConverter( @@ -98,6 +108,7 @@ def from_hf_config(cls, hf_config: HfConfig): intermediate_dim=hf_config.intermediate_size, num_layers=hf_config.num_hidden_layers, num_heads=hf_config.num_attention_heads, + num_kv_heads=hf_config.num_key_value_heads, activation_function=hf_config.hidden_act, initializer_range=hf_config.initializer_range, layer_norm_epsilon=hf_config.rms_norm_eps, @@ -123,6 +134,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) intermediate_size=self.intermediate_dim, num_hidden_layers=self.num_layers, num_attention_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, hidden_act=self.activation_function, initializer_range=self.initializer_range, rms_norm_eps=self.layer_norm_epsilon, @@ -264,10 +276,14 @@ class LlamaAttention(StateDictSerializationMixin, eqx.Module): def init(config: LlamaConfig, *, key) -> "LlamaAttention": use_bias = config.use_bias Embed = config.Embed + QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads) + k_q, k_k, k_v, k_o = jrandom.split(key, 4) - q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias) - k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias) - v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias) + q_proj = hnn.Linear.init( + In=Embed, Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize), key=k_q, use_bias=use_bias + ) + k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias) + v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias) o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias) rotary_emb = LlamaRotaryEmbedding(config.HeadSize, config.Pos) return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @@ -277,9 +293,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na key_q, key_k, key_v, key_o = maybe_rng_split(key, 4) # reorder heads and position for better training throughput - q = self.q_proj(x, key=key_q).rearrange((..., "heads", "position", "head_size")) - k = self.k_proj(x, key=key_k).rearrange((..., "heads", "position", "head_size")) - v = self.v_proj(x, key=key_v).rearrange((..., "heads", "position", "head_size")) + q = self.q_proj(x, key=key_q).rearrange((..., "kv_heads", "q_heads_per_group", "position", "head_size")) + k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size")) + v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size")) cos, sin = self.rotary_emb(seq_len=x.axis_size("position")) @@ -305,6 +321,8 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na flash_block_size=c.flash_attention_block_size, ) + attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads") + if self.config.upcast_attn: attn_output = attn_output.astype(x.dtype) @@ -574,7 +592,7 @@ def _rotate_half(x: NamedArray) -> NamedArray: def _apply_rotary_pos_emb( - q: NamedArray, # [batch, position, heads, head_size] + q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size] k: NamedArray, # [batch, position, kv_heads, head_size] cos: NamedArray, # [position, head_size] sin: NamedArray, # [position, head_size] diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index 68a02e09e..ad8708c6a 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -1,12 +1,17 @@ import abc +import re import warnings from dataclasses import dataclass from typing import Optional import draccus +import equinox as eqx +import jax import optax from jax import numpy as jnp +from levanter.utils.jax_utils import leaf_key_paths + @dataclass class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): @@ -20,6 +25,9 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): cooldown: float = 0.0 """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" lr_schedule: str = "cosine" # constant, cosine, linear + weight_decay_modules: Optional[list[str] | str] = None + """A regex or a list of strings to identify where to mask weight. + For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`""" @classmethod def default_choice_name(cls) -> Optional[str]: @@ -29,6 +37,28 @@ def default_choice_name(cls) -> Optional[str]: def build(self, num_train_steps: int): raise NotImplementedError + def build_weight_decay_mask(self): + if self.weight_decay_modules is None: + return None + else: + # mask based on regex or module path + def _apply_on(x, key_path): + if isinstance(self.weight_decay_modules, str): + compiled_regex = re.compile(self.weight_decay_modules) + return compiled_regex.match(key_path) is not None + else: + return any(key_path.__contains__(target) for target in self.weight_decay_modules) + + def mask_fn(model): + return jax.tree_util.tree_map( + _apply_on, + model, + leaf_key_paths(model, is_leaf=eqx.is_array), + is_leaf=eqx.is_array, + ) + + return mask_fn + def lr_scheduler(self, num_train_steps): warmup_steps = self._convert_warmup(num_train_steps) cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) @@ -118,8 +148,7 @@ def _optimizer(learning_rate): components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) if self.weight_decay > 0: - # TODO: add weight decay masking?? - components.append(optax.add_decayed_weights(self.weight_decay)) + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) # - learning rate for descent components.append(optax.scale(-learning_rate)) diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index 6a26c1253..9df275c29 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -135,7 +135,7 @@ def _optimizer(learning_rate, gamma) -> SecondOrderTransformation: # Algorithm 3, step 11 (Note, this comes after clipping b/c it's not supposed to be clipped) # In the paper, it comes as a prior step, but doesn't get clipped if self.weight_decay > 0: - components.append(optax.add_decayed_weights(self.weight_decay)) + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) # - learning rate for descent components.append(optax.scale(-learning_rate)) diff --git a/tests/test_llama.py b/tests/test_llama.py index 15a5ab452..7224a3ac1 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -131,11 +131,12 @@ def named_array_to_tensor(named_array): @skip_if_no_torch @pytest.mark.parametrize("use_flash", [True, False]) -def test_llama_attention(use_flash): +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_llama_attention(use_flash, num_kv_heads): import torch from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention - config = _get_llama_config(use_flash=use_flash) + config = _get_llama_config(use_flash=use_flash, num_kv_heads=num_kv_heads) attention = LlamaAttention.init(config=config, key=random.PRNGKey(0)) @@ -181,11 +182,12 @@ def test_llama_rms_norm(): @skip_if_no_torch -def test_llama_decoder_layer(): +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_llama_decoder_layer(num_kv_heads): import torch from transformers.models.llama.modeling_llama import LlamaDecoderLayer as HFLlamaDecoderLayer - llama_config = _get_llama_config() + llama_config = _get_llama_config(num_kv_heads=num_kv_heads) key = random.PRNGKey(0) llama_decoder_layer = LlamaDecoderLayer.init(config=llama_config, key=key) @@ -208,8 +210,9 @@ def test_llama_decoder_layer(): ).all(), f"{hf_out[0]} != {out}" -def test_llama_lm_head_model(): - llama_config = _get_llama_config() +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_llama_lm_head_model(num_kv_heads): + llama_config = _get_llama_config(num_kv_heads=num_kv_heads) Batch = hax.Axis("batch", 2) Vocab = hax.Axis("vocab", 1000) Pos = llama_config.Pos @@ -222,7 +225,8 @@ def test_llama_lm_head_model(): @skip_if_no_torch -def test_llama_roundtrip(): +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_llama_roundtrip(num_kv_heads): import torch from transformers import AutoModelForCausalLM, LlamaForCausalLM @@ -232,6 +236,7 @@ def test_llama_roundtrip(): seq_len=128, hidden_dim=16, num_heads=4, + num_kv_heads=num_kv_heads, gradient_checkpointing=False, ) Vocab = hax.Axis("vocab", 1000) @@ -279,7 +284,7 @@ def compute(input): assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" -def _get_llama_config(use_flash=False) -> LlamaConfig: +def _get_llama_config(use_flash=False, num_kv_heads=4) -> LlamaConfig: rope_scaling = { "type": "linear", "factor": 2.0, @@ -288,6 +293,7 @@ def _get_llama_config(use_flash=False) -> LlamaConfig: seq_len=128, hidden_dim=16, num_heads=4, + num_kv_heads=num_kv_heads, rope_scaling=rope_scaling, gradient_checkpointing=False, # disable for tests so debugging is easier use_flash_attention=use_flash, @@ -312,11 +318,13 @@ def test_llama_configs(config_file): check_load_config(config_class, config_file) -def test_pass_different_length_seq(): +@pytest.mark.parametrize("num_kv_heads", [1, 2]) +def test_pass_different_length_seq(num_kv_heads): config = LlamaConfig( seq_len=32, hidden_dim=16, intermediate_dim=32, num_heads=2, + num_kv_heads=num_kv_heads, ) check_model_works_with_seqlen(LlamaLMHeadModel, config, 16) diff --git a/tests/test_weight_decay_mask.py b/tests/test_weight_decay_mask.py new file mode 100644 index 000000000..c47231116 --- /dev/null +++ b/tests/test_weight_decay_mask.py @@ -0,0 +1,67 @@ +import equinox as eqx +import jax +import jax.random as jrandom + +import haliax as hax + +from levanter.models.gpt2 import Gpt2Config +from levanter.optim import AdamConfig + + +def test_weight_decay_masking(): + def tree_at_mask(params): + # let's mask all leaves as False + params = jax.tree_util.tree_map(lambda _: False, params) + + def apply_weight_decay(tree): + # there is no weight decay performed in LayerNorms and bias + nodes = [] + + # apply on embedding + nodes.append(tree.embeddings.token_embeddings.array) + nodes.append(tree.embeddings.position_embeddings.array) + + # apply on attention + nodes.append(tree.transformer.blocks.stacked.attn.c_attn.weight.array) + nodes.append(tree.transformer.blocks.stacked.attn.c_proj.weight.array) + + # apply on MLP + nodes.append(tree.transformer.blocks.stacked.mlp.c_fc.weight.array) + nodes.append(tree.transformer.blocks.stacked.mlp.c_proj.weight.array) + + return nodes + + # apply weight decay when necessary + params = eqx.tree_at( + where=apply_weight_decay, + pytree=params, + replace_fn=lambda _: True, + ) + + return params + + gpt_config = Gpt2Config() + Vocab = hax.Axis("vocab", 100) + model = gpt_config.build(Vocab, key=jrandom.PRNGKey(0)) + string_list_config = AdamConfig( + weight_decay_modules=[ + "attn.c_attn.weight", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_proj.weight", + "token_embeddings", + "position_embeddings", + ] + ) + regex_config = AdamConfig( + weight_decay_modules=r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings", + ) + # masking using `equinox.tree_at` + true_mask = tree_at_mask(model) + # masking using list of module path + list_string_mask = string_list_config.build_weight_decay_mask()(model) + + regex_mask = regex_config.build_weight_decay_mask()(model) + + assert eqx.tree_equal(list_string_mask, true_mask) + assert eqx.tree_equal(regex_mask, true_mask)