Skip to content

Commit

Permalink
Remove attention sinks
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen committed Oct 6, 2023
1 parent 69733d2 commit b13e2a8
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions awq/modules/fused/cache.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch

class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, device, attention_sinks=4):
def __init__(self, cache_v_shape, cache_k_shape, device):
"""
The window size is the same as the max_new_tokens. The window will
automatically roll once max_new_tokens is exceeded.
"""
self.attention_sinks = attention_sinks

# [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
Expand All @@ -25,15 +23,9 @@ def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store

def roll_kv(self, roll_len, start_pos):
"""
With sink=0, roll_len=3, and [A,B,C,D,E] we get [D,E,F,G,H]
With sink=1, roll_len=3, and [A,B,C,D,E] we get [A,E,F,G,H]
With sink=2, roll_len=3, and [A,B,C,D,E] we get [A,B,F,G,H]
With sink=3, roll_len=3, and [A,B,C,D,E] we get [A,B,C,G,H]
"""
# Roll only the necessary part of the cache to the left
self.v[:, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.v[:, :, roll_len:, :]
self.k[:, :, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.k[:, :, :, roll_len:, :]
self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :]
self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :]

# Zero out the new part
self.v[:, :, -roll_len:, :] = 0
Expand Down

0 comments on commit b13e2a8

Please sign in to comment.