From 946e0311b5446c9653b0b2802ca61d77ff683291 Mon Sep 17 00:00:00 2001 From: Vaibhav Adlakha <32997732+vaibhavad@users.noreply.github.com> Date: Tue, 9 Apr 2024 12:46:15 -0400 Subject: [PATCH] Update tutorial.md --- docs/_pages/tutorial.md | 50 ++++++++++++----------------------------- 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/docs/_pages/tutorial.md b/docs/_pages/tutorial.md index 7bc0218..e2b85b2 100644 --- a/docs/_pages/tutorial.md +++ b/docs/_pages/tutorial.md @@ -1,54 +1,38 @@ ---- -title: "LLM2Vec Tutorial: Steps for transforming any decoder-only model into a text encoder" -permalink: /tutorial/ ---- +# LLM2Vec Tutorial: Steps for transforming any decoder-only model into a text encoder LLM2Vec consists of 3 simple steps to transform decoder-only LLMs into text encoders: 1) enabling bidirectional attention, 2) training with masked next token prediction, and 3) unsupervised contrastive learning. The model can be further fine-tuned with supervised data. Here, we provide a tutorial on how to use the LlaMA models. This tutorial will focus on the first two steps. After these steps, the model can be trained for unsupervised or supervised contrastive learning like any other encoder model. +In this tutorial, we will transform LlaMA models into text encoders, however, transforming Mistral will require similar steps. We will focus on modifying the flash attention implementation as it requires the least changes in the codebase, and the implementation is consistent across models and transformers versions. Our tutorial is based on transformers version 4.39.3. + ## 1) Enabling Bidirectional Attention -- add a conceptual figure here +TODO:add a conceptual figure here - A decoder-only causal LLM consists of multiple decoder layers, each of which has a self-attention mechanism. We start bottoms-up by first modifying the attention mechanism to be bidirectional. -In order to be able to use the bidirectional attentions with all sorts of attentions, we need to create new LLaMA attention classes: +HuggingFace implements three attention mechanisms for Llama and Mistral models - Eager, SDPA, and Flash Attention. Here, we only modify the flash attention implementation. In order to be able to use the bidirectional attention, we need to create new LLaMA flash attention class: ```python -class ModifiedLlamaAttention(LlamaAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False # Initially `True` in transformers implementation - - class ModifiedLlamaFlashAttention2(LlamaFlashAttention2): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_causal = False # Initially `True` in transformers implementation - -class ModifiedLlamaSdpaAttention(LlamaSdpaAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False # Initially `True` in transformers implementation - LLAMA_ATTENTION_CLASSES = { - "eager": ModifiedLlamaAttention, # Initially, `LlamaAttention' + "eager": LlamaAttention, "flash_attention_2": ModifiedLlamaFlashAttention2, # Initially, `LlamaFlashAttention2' - "sdpa": ModifiedLlamaSdpaAttention, # Initially, `LlamaSdpaAttention' + "sdpa": LlamaSdpaAttention, } ``` -For now, we have changed all sorts of attention classes to non-causal (i.e., bidirectional). Next, we need to modify the decoder layer to use these new attention classes. the `__init__` function is directly copied from the `transformers` implementation of `LlamaDecoderLayer`. As `LLAMA_ATTENTION_CLASSES` point to the new attention classes, the decoder layer will use bidirectional attentions. +We have changed flash attention to be non-causal (i.e., bidirectional). Next, we need to modify the decoder layer to use this new attention classes. the `__init__` function is directly copied from the `transformers` implementation of `LlamaDecoderLayer`. As `flash_attention_2` in `LLAMA_ATTENTION_CLASSES` points to the new flash attention class, the decoder layer will use bidirectional attention when initialized with `flash_attention_2`. ```python class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): - nn.Module.__init__(self) + nn.Module.__init__(self) # Initially, super().__init__() self.hidden_size = config.hidden_size self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) @@ -57,7 +41,7 @@ class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) ``` -Finally, we need to modify the model class to use the new decoder layer. We create a new model class `LlamaBiModel` that inherits from `LlamaModel` and uses the new `ModifiedLlamaDecoderLayer` in its `__init__` function. Everything else remains the same as the original implementation of `LlamaModel`. +Finally, we need to modify the main model class to use the new decoder layer. We create a new model class `LlamaBiModel` that inherits from `LlamaModel` and uses the new `ModifiedLlamaDecoderLayer` in its `__init__` function. Everything else remains the same as the original implementation of `LlamaModel`. ```python class LlamaBiModel(LlamaModel): ``` @@ -66,7 +50,7 @@ We first have to use the `ModifiedLlamaDecoderLayer` in our `LlamaBiModel` class ```python class LlamaBiModel(LlamaModel): def __init__(self, config): - LlamaPreTrainedModel.__init__(self, config) + LlamaPreTrainedModel.__init__(self, config) # Initially, super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -74,19 +58,13 @@ class LlamaBiModel(LlamaModel): self.layers = nn.ModuleList( [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] # Initially, `LlamaDecoderLayer(config, layer_idx)` ) - self._use_sdpa = config._attn_implementation == "sdpa" - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False + self.post_init() ``` - -- talk about `from .attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_attention_mask` and the LlamaBiModel forward function. - - -This is not sufficient, as transformers models use specific attention mask generation functions, `_prepare_4d_attention_mask_for_sdpa` and `_prepare_4d_attention_mask`, in the `forward` call of the `LlamaModel`. We now want to manipulate these function... - + +That's it! We have successfully created a bidirectional LLaMA model. We can now use this model for training with masked next token prediction. ## 2) Masked Next Token Prediction (MNTP)