From b13e2a859735530953b1df109a5cab471f46a852 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 15:18:52 +0000 Subject: [PATCH] Remove attention sinks --- awq/modules/fused/cache.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/awq/modules/fused/cache.py b/awq/modules/fused/cache.py index 57816ac0..dff47423 100644 --- a/awq/modules/fused/cache.py +++ b/awq/modules/fused/cache.py @@ -1,13 +1,11 @@ import torch class WindowedCache: - def __init__(self, cache_v_shape, cache_k_shape, device, attention_sinks=4): + def __init__(self, cache_v_shape, cache_k_shape, device): """ The window size is the same as the max_new_tokens. The window will automatically roll once max_new_tokens is exceeded. """ - self.attention_sinks = attention_sinks - # [batch_size, n_kv_heads, max_seq_len, head_dim] self.v = torch.zeros(cache_v_shape).to(device).half() # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor] @@ -25,15 +23,9 @@ def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store def roll_kv(self, roll_len, start_pos): - """ - With sink=0, roll_len=3, and [A,B,C,D,E] we get [D,E,F,G,H] - With sink=1, roll_len=3, and [A,B,C,D,E] we get [A,E,F,G,H] - With sink=2, roll_len=3, and [A,B,C,D,E] we get [A,B,F,G,H] - With sink=3, roll_len=3, and [A,B,C,D,E] we get [A,B,C,G,H] - """ # Roll only the necessary part of the cache to the left - self.v[:, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.v[:, :, roll_len:, :] - self.k[:, :, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.k[:, :, :, roll_len:, :] + self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :] + self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :] # Zero out the new part self.v[:, :, -roll_len:, :] = 0