-
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #266 from lakshith-403/LoRA
- Loading branch information
Showing
5 changed files
with
554 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import torch | ||
import torch.nn as nn | ||
from transformers import AutoTokenizer | ||
from labml_nn.transformers.LoRA import Linear, Embedding | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
|
||
config = { | ||
"layer_norm_epsilon": 1e-05, | ||
"n_embd": 768, | ||
"n_head": 12, | ||
"n_layer": 12, | ||
"n_positions": 1024, | ||
"vocab_size": 50257, | ||
"device": "cuda" | ||
} | ||
|
||
|
||
class FFN(nn.Module): | ||
def __init__(self, dim): | ||
super().__init__() | ||
self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True) | ||
self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True) | ||
self.act = nn.functional.gelu | ||
|
||
def forward(self, hidden_states): | ||
hidden_states = self.c_fc(hidden_states) | ||
hidden_states = self.act(hidden_states) | ||
hidden_states = self.c_proj(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class MultiHeadAttention(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.embed_dim = config['n_embd'] | ||
self.num_heads = config['n_head'] | ||
self.head_dim = self.embed_dim // self.num_heads | ||
self.split_size = self.embed_dim | ||
|
||
self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True) | ||
self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True) | ||
|
||
def _split_heads(self, tensor, num_heads, attn_head_size): | ||
""" | ||
Splits hidden_size dim into attn_head_size and num_heads | ||
""" | ||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) | ||
tensor = tensor.view(new_shape) | ||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) | ||
|
||
def forward(self, hidden_states): | ||
batch_size, seq_length, _ = hidden_states.size() | ||
|
||
query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2) | ||
|
||
query = self._split_heads(query, self.num_heads, self.head_dim) | ||
key = self._split_heads(key, self.num_heads, self.head_dim) | ||
value = self._split_heads(value, self.num_heads, self.head_dim) | ||
|
||
attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
query, | ||
key, | ||
value, | ||
attn_mask=None, | ||
dropout_p=0.0, | ||
is_causal=True, # for the triangular mask | ||
) | ||
|
||
attn_output = attn_output.transpose(1, 2).contiguous() | ||
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim) | ||
|
||
attn_output = self.c_proj(attn_output) | ||
|
||
return attn_output | ||
|
||
|
||
class Block(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon']) | ||
self.attn = MultiHeadAttention() | ||
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon']) | ||
self.ffn = FFN(config['n_embd'] * 4) | ||
|
||
def forward(self, hidden_states): | ||
residual = hidden_states | ||
hidden_states = self.pre_norm(hidden_states) | ||
|
||
attn_output = self.attn(hidden_states) | ||
|
||
hidden_states = attn_output + residual | ||
residual = hidden_states | ||
hidden_states = self.post_norm(hidden_states) | ||
feed_forward_output = self.ffn(hidden_states) | ||
hidden_states = feed_forward_output + residual | ||
|
||
return hidden_states | ||
|
||
|
||
class GPTModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32) | ||
self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32) | ||
|
||
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])]) | ||
|
||
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon']) | ||
|
||
self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False) | ||
|
||
def forward(self, input_ids): | ||
batch_size, input_shape = input_ids.size() | ||
|
||
token_embeddings = self.token_embedding(input_ids) # B T C | ||
position_ids = torch.arange(input_shape, device=config['device']) # T C | ||
position_embeddings = self.position_embedding(position_ids) # B T C | ||
|
||
hidden_states = token_embeddings + position_embeddings | ||
|
||
for block in self.blocks: | ||
hidden_states = block(hidden_states) | ||
|
||
hidden_states = self.final_norm(hidden_states) | ||
|
||
logits = self.lm_head(hidden_states) | ||
|
||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Linear(nn.Module): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
bias: bool, | ||
r: int, | ||
alpha: int = None): | ||
if alpha is None: | ||
alpha = r | ||
super().__init__() | ||
self.weight = nn.Parameter(torch.empty((out_features, in_features))) | ||
self.weight.requires_grad = False | ||
|
||
if bias: | ||
self.bias = nn.Parameter(torch.empty(out_features)) | ||
self.bias.requires_grad = False | ||
else: | ||
self.bias = None | ||
|
||
self.scaling = alpha / r | ||
self.lora_a = nn.Parameter(torch.empty((in_features, r))) | ||
self.lora_b = nn.Parameter(torch.empty((r, out_features))) | ||
|
||
with torch.no_grad(): | ||
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5) | ||
nn.init.zeros_(self.lora_b) | ||
|
||
def forward(self, x: torch.Tensor): | ||
result = nn.functional.linear(x, self.weight, bias=self.bias) | ||
|
||
result += (x @ self.lora_a @ self.lora_b) * self.scaling | ||
|
||
return result | ||
|
||
|
||
class Embedding(nn.Module): | ||
def __init__( | ||
self, | ||
num_embeddings: int, | ||
embedding_dim: int, | ||
r: int, | ||
alpha: int = None, | ||
): | ||
if alpha is None: | ||
alpha = r | ||
super().__init__() | ||
|
||
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim))) | ||
self.weight.requires_grad = False | ||
|
||
self.scaling = alpha / r | ||
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r))) | ||
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim))) | ||
|
||
with torch.no_grad(): | ||
nn.init.normal_(self.lora_a) | ||
nn.init.zeros_(self.lora_b) | ||
|
||
def forward(self, x: torch.Tensor): | ||
result = nn.functional.embedding(x, self.weight) | ||
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n", | ||
"import torch" | ||
], | ||
"id": "cffa3ec341b4905a", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"from transformers import AutoTokenizer\n", | ||
"\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")" | ||
], | ||
"id": "c2b0b7e18394ea9e", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "initial_id", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"source": [ | ||
"model = GPTModel()\n", | ||
"\n", | ||
"state_dict = torch.load('transformed.pth')\n", | ||
"\n", | ||
"missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n", | ||
"if missing_keys:\n", | ||
" print(f\"Missing keys: {missing_keys}\")\n", | ||
"if unexpected_keys:\n", | ||
" print(f\"Unexpected keys: {unexpected_keys}\")" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"prompt = \"hello how are you\"\n", | ||
"tokenized = tokenizer(prompt, return_tensors=\"pt\")\n", | ||
"tokenized['input_ids'] = tokenized['input_ids'].to('cuda')\n", | ||
"model = model.to('cuda')\n", | ||
"\n", | ||
"with torch.no_grad():\n", | ||
" model.eval()\n", | ||
" res = model(tokenized['input_ids'])\n", | ||
"\n", | ||
"output_ids = torch.argmax(res, dim=-1)\n", | ||
"for id in output_ids[0]:\n", | ||
" print(tokenizer.decode(id))" | ||
], | ||
"id": "f4f7826ec3729b66", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": "", | ||
"id": "c12776360008a974", | ||
"outputs": [], | ||
"execution_count": null | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from transformers import AutoModelForCausalLM | ||
|
||
model = AutoModelForCausalLM.from_pretrained("gpt2") | ||
|
||
state_dict = model.state_dict() | ||
|
||
mapping = { | ||
'transformer.wte.weight': 'token_embedding.weight', | ||
'transformer.wpe.weight': 'position_embedding.weight', | ||
'transformer.ln_f.weight': 'final_norm.weight', | ||
'transformer.ln_f.bias': 'final_norm.bias', | ||
'lm_head.weight': 'lm_head.weight' | ||
} | ||
|
||
for i in range(12): | ||
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight' | ||
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias' | ||
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight' | ||
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias' | ||
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight' | ||
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias' | ||
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight' | ||
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias' | ||
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight' | ||
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias' | ||
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight' | ||
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias' | ||
|
||
new_state_dict = {} | ||
for old_key, new_key in mapping.items(): | ||
if old_key in state_dict: | ||
new_state_dict[new_key] = state_dict[old_key] | ||
|
||
# transpose weight matrices of convo 1d layers to use linear layers instead | ||
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] + | ||
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] + | ||
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] + | ||
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)]) | ||
|
||
for layer in convo_layers: | ||
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1) | ||
|
||
torch.save(new_state_dict, 'transformed.pth') |
Oops, something went wrong.