Skip to content

Commit

Permalink
Merge pull request #272 from lakshith-403/LoRA
Browse files Browse the repository at this point in the history
LoRA minor fixes
  • Loading branch information
vpj authored Aug 24, 2024
2 parents 9e1b357 + 9485eec commit 789c31a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion labml_nn/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, in_features: int, out_features: int, bias: bool,
# Matrix $A \in \mathbb{R}^{r \times k}$
self.lora_a = nn.Parameter(torch.empty((r, in_features)))
# Matrix $B \in \mathbb{R}^{d \times r}$, we keep $A$ and $B$ transposed
self.lora_b = nn.Parameter(torch.empty((outfeatures, r)))
self.lora_b = nn.Parameter(torch.empty((out_features, r)))

with torch.no_grad():
# Initialize $A$ similar to a weight matrix in a normal linear layer
Expand Down
14 changes: 9 additions & 5 deletions labml_nn/lora/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def _load_pretrained_weights(self):

# Mapping (`hf: ours`) of decoder layers
for i in range(12):
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight'
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias'
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight'
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias'
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
Expand All @@ -110,7 +110,11 @@ def _load_pretrained_weights(self):
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)

# Load out model. We use `strict = False` because the state does not have LoRA weights
self.model.load_state_dict(new_state_dict, strict=False)
missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)

# make sure that only lora weights are not loaded
assert all('lora' in key for key in missing_keys)
assert not unexpected_keys

def initialize(self):
"""
Expand Down

0 comments on commit 789c31a

Please sign in to comment.