Skip to content

Commit

Permalink
Add QK LayerNorm to transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
zqevans committed Jul 10, 2024
1 parent 00d35fa commit 6706967
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions stable_audio_tools/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(
dim_context = None,
causal = False,
zero_init_output=True,
qk_norm = False,
qk_norm: Literal['l2', 'ln', 'none'] = 'none',
natten_kernel_size = None
):
super().__init__()
Expand All @@ -302,6 +302,10 @@ def __init__(

self.qk_norm = qk_norm

if self.qk_norm == "ln":
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)

# Using 1d neighborhood attention
self.natten_kernel_size = natten_kernel_size
if natten_kernel_size is not None:
Expand Down Expand Up @@ -416,9 +420,12 @@ def forward(
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

# Normalize q and k for cosine sim attention
if self.qk_norm:
if self.qk_norm == "l2":
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
elif self.qk_norm == "ln":
q = self.q_norm(q)
k = self.k_norm(k)

if rotary_pos_emb is not None and not has_context:
freqs, _ = rotary_pos_emb
Expand Down

0 comments on commit 6706967

Please sign in to comment.