Skip to content

Commit

Permalink
Update tutorial.md
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhavad authored Apr 9, 2024
1 parent e4247f1 commit 946e031
Showing 1 changed file with 14 additions and 36 deletions.
50 changes: 14 additions & 36 deletions docs/_pages/tutorial.md
Original file line number Diff line number Diff line change
@@ -1,54 +1,38 @@
---
title: "LLM2Vec Tutorial: Steps for transforming any decoder-only model into a text encoder"
permalink: /tutorial/
---
# LLM2Vec Tutorial: Steps for transforming any decoder-only model into a text encoder

LLM2Vec consists of 3 simple steps to transform decoder-only LLMs into text encoders: 1) enabling bidirectional attention, 2) training with masked next token prediction, and 3) unsupervised contrastive learning. The model can be further fine-tuned with supervised data. Here, we provide a tutorial on how to use the LlaMA models.

This tutorial will focus on the first two steps. After these steps, the model can be trained for unsupervised or supervised contrastive learning like any other encoder model.

In this tutorial, we will transform LlaMA models into text encoders, however, transforming Mistral will require similar steps. We will focus on modifying the flash attention implementation as it requires the least changes in the codebase, and the implementation is consistent across models and transformers versions. Our tutorial is based on transformers version 4.39.3.

## 1) Enabling Bidirectional Attention

- add a conceptual figure here
TODO:add a conceptual figure here

<!-- Will work for both Llama and Mistral -->
<!-- mention which transformer version is used for this -->

A decoder-only causal LLM consists of multiple decoder layers, each of which has a self-attention mechanism. We start bottoms-up by first modifying the attention mechanism to be bidirectional.

In order to be able to use the bidirectional attentions with all sorts of attentions, we need to create new LLaMA attention classes:
HuggingFace implements three attention mechanisms for Llama and Mistral models - Eager, SDPA, and Flash Attention. Here, we only modify the flash attention implementation. In order to be able to use the bidirectional attention, we need to create new LLaMA flash attention class:
```python
class ModifiedLlamaAttention(LlamaAttention):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False # Initially `True` in transformers implementation


class ModifiedLlamaFlashAttention2(LlamaFlashAttention2):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False # Initially `True` in transformers implementation


class ModifiedLlamaSdpaAttention(LlamaSdpaAttention):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False # Initially `True` in transformers implementation

LLAMA_ATTENTION_CLASSES = {
"eager": ModifiedLlamaAttention, # Initially, `LlamaAttention'
"eager": LlamaAttention,
"flash_attention_2": ModifiedLlamaFlashAttention2, # Initially, `LlamaFlashAttention2'
"sdpa": ModifiedLlamaSdpaAttention, # Initially, `LlamaSdpaAttention'
"sdpa": LlamaSdpaAttention,
}
```
For now, we have changed all sorts of attention classes to non-causal (i.e., bidirectional). Next, we need to modify the decoder layer to use these new attention classes. the `__init__` function is directly copied from the `transformers` implementation of `LlamaDecoderLayer`. As `LLAMA_ATTENTION_CLASSES` point to the new attention classes, the decoder layer will use bidirectional attentions.
We have changed flash attention to be non-causal (i.e., bidirectional). Next, we need to modify the decoder layer to use this new attention classes. the `__init__` function is directly copied from the `transformers` implementation of `LlamaDecoderLayer`. As `flash_attention_2` in `LLAMA_ATTENTION_CLASSES` points to the new flash attention class, the decoder layer will use bidirectional attention when initialized with `flash_attention_2`.
```python
class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
nn.Module.__init__(self)
nn.Module.__init__(self) # Initially, super().__init__()
self.hidden_size = config.hidden_size

self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
Expand All @@ -57,7 +41,7 @@ class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
```
Finally, we need to modify the model class to use the new decoder layer. We create a new model class `LlamaBiModel` that inherits from `LlamaModel` and uses the new `ModifiedLlamaDecoderLayer` in its `__init__` function. Everything else remains the same as the original implementation of `LlamaModel`.
Finally, we need to modify the main model class to use the new decoder layer. We create a new model class `LlamaBiModel` that inherits from `LlamaModel` and uses the new `ModifiedLlamaDecoderLayer` in its `__init__` function. Everything else remains the same as the original implementation of `LlamaModel`.
```python
class LlamaBiModel(LlamaModel):
```
Expand All @@ -66,27 +50,21 @@ We first have to use the `ModifiedLlamaDecoderLayer` in our `LlamaBiModel` class
```python
class LlamaBiModel(LlamaModel):
def __init__(self, config):
LlamaPreTrainedModel.__init__(self, config)
LlamaPreTrainedModel.__init__(self, config) # Initially, super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] # Initially, `LlamaDecoderLayer(config, layer_idx)`
)
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False

self.post_init()
```

- talk about `from .attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_attention_mask` and the LlamaBiModel forward function.
<!-- Llama has moved to a different -->

This is not sufficient, as transformers models use specific attention mask generation functions, `_prepare_4d_attention_mask_for_sdpa` and `_prepare_4d_attention_mask`, in the `forward` call of the `LlamaModel`. We now want to manipulate these function...
<!-- an example to verify output? -->
<!-- attach full file -->
That's it! We have successfully created a bidirectional LLaMA model. We can now use this model for training with masked next token prediction.


## 2) Masked Next Token Prediction (MNTP)
Expand Down

0 comments on commit 946e031

Please sign in to comment.