Skip to content

Commit

Permalink
Merge pull request #266 from lakshith-403/LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj authored Jul 31, 2024
2 parents 89a3ae8 + bc32b50 commit 957ade6
Show file tree
Hide file tree
Showing 5 changed files with 554 additions and 0 deletions.
130 changes: 130 additions & 0 deletions labml_nn/transformers/LoRA/GPT2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from labml_nn.transformers.LoRA import Linear, Embedding

tokenizer = AutoTokenizer.from_pretrained("gpt2")

config = {
"layer_norm_epsilon": 1e-05,
"n_embd": 768,
"n_head": 12,
"n_layer": 12,
"n_positions": 1024,
"vocab_size": 50257,
"device": "cuda"
}


class FFN(nn.Module):
def __init__(self, dim):
super().__init__()
self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True)
self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True)
self.act = nn.functional.gelu

def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states


class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.embed_dim = config['n_embd']
self.num_heads = config['n_head']
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim

self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True)
self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True)

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def forward(self, hidden_states):
batch_size, seq_length, _ = hidden_states.size()

query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=True, # for the triangular mask
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)

attn_output = self.c_proj(attn_output)

return attn_output


class Block(nn.Module):
def __init__(self):
super().__init__()
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
self.attn = MultiHeadAttention()
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
self.ffn = FFN(config['n_embd'] * 4)

def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.pre_norm(hidden_states)

attn_output = self.attn(hidden_states)

hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.post_norm(hidden_states)
feed_forward_output = self.ffn(hidden_states)
hidden_states = feed_forward_output + residual

return hidden_states


class GPTModel(nn.Module):
def __init__(self):
super().__init__()

self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32)
self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32)

self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])

self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])

self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False)

def forward(self, input_ids):
batch_size, input_shape = input_ids.size()

token_embeddings = self.token_embedding(input_ids) # B T C
position_ids = torch.arange(input_shape, device=config['device']) # T C
position_embeddings = self.position_embedding(position_ids) # B T C

hidden_states = token_embeddings + position_embeddings

for block in self.blocks:
hidden_states = block(hidden_states)

hidden_states = self.final_norm(hidden_states)

logits = self.lm_head(hidden_states)

return logits
68 changes: 68 additions & 0 deletions labml_nn/transformers/LoRA/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import torch.nn as nn


class Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
r: int,
alpha: int = None):
if alpha is None:
alpha = r
super().__init__()
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
self.weight.requires_grad = False

if bias:
self.bias = nn.Parameter(torch.empty(out_features))
self.bias.requires_grad = False
else:
self.bias = None

self.scaling = alpha / r
self.lora_a = nn.Parameter(torch.empty((in_features, r)))
self.lora_b = nn.Parameter(torch.empty((r, out_features)))

with torch.no_grad():
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
nn.init.zeros_(self.lora_b)

def forward(self, x: torch.Tensor):
result = nn.functional.linear(x, self.weight, bias=self.bias)

result += (x @ self.lora_a @ self.lora_b) * self.scaling

return result


class Embedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int,
alpha: int = None,
):
if alpha is None:
alpha = r
super().__init__()

self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
self.weight.requires_grad = False

self.scaling = alpha / r
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))

with torch.no_grad():
nn.init.normal_(self.lora_a)
nn.init.zeros_(self.lora_b)

def forward(self, x: torch.Tensor):
result = nn.functional.embedding(x, self.weight)
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling

return result
97 changes: 97 additions & 0 deletions labml_nn/transformers/LoRA/experiment.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
{
"cells": [
{
"metadata": {},
"cell_type": "code",
"source": [
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
"import torch"
],
"id": "cffa3ec341b4905a",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
],
"id": "c2b0b7e18394ea9e",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true
},
"source": [
"model = GPTModel()\n",
"\n",
"state_dict = torch.load('transformed.pth')\n",
"\n",
"missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n",
"if missing_keys:\n",
" print(f\"Missing keys: {missing_keys}\")\n",
"if unexpected_keys:\n",
" print(f\"Unexpected keys: {unexpected_keys}\")"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"prompt = \"hello how are you\"\n",
"tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
"tokenized['input_ids'] = tokenized['input_ids'].to('cuda')\n",
"model = model.to('cuda')\n",
"\n",
"with torch.no_grad():\n",
" model.eval()\n",
" res = model(tokenized['input_ids'])\n",
"\n",
"output_ids = torch.argmax(res, dim=-1)\n",
"for id in output_ids[0]:\n",
" print(tokenizer.decode(id))"
],
"id": "f4f7826ec3729b66",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "c12776360008a974",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
44 changes: 44 additions & 0 deletions labml_nn/transformers/LoRA/load_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")

state_dict = model.state_dict()

mapping = {
'transformer.wte.weight': 'token_embedding.weight',
'transformer.wpe.weight': 'position_embedding.weight',
'transformer.ln_f.weight': 'final_norm.weight',
'transformer.ln_f.bias': 'final_norm.bias',
'lm_head.weight': 'lm_head.weight'
}

for i in range(12):
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'

new_state_dict = {}
for old_key, new_key in mapping.items():
if old_key in state_dict:
new_state_dict[new_key] = state_dict[old_key]

# transpose weight matrices of convo 1d layers to use linear layers instead
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])

for layer in convo_layers:
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)

torch.save(new_state_dict, 'transformed.pth')
Loading

0 comments on commit 957ade6

Please sign in to comment.