From 66b2e2332b8279dc11d944fc77a32b2b58a023fa Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 13:59:18 +0000 Subject: [PATCH 1/8] Create cache for fused modules --- awq/modules/fused/cache.py | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 awq/modules/fused/cache.py diff --git a/awq/modules/fused/cache.py b/awq/modules/fused/cache.py new file mode 100644 index 00000000..c2f93a0a --- /dev/null +++ b/awq/modules/fused/cache.py @@ -0,0 +1,45 @@ +import torch + +class WindowedCache: + def __init__(self, cache_v_shape, cache_k_shape, device, attention_sinks=4): + """ + The window size is the same as the max_new_tokens. The window will + automatically roll once max_new_tokens is exceeded. + """ + self.attention_sinks = attention_sinks + + # [batch_size, n_kv_heads, max_seq_len, head_dim] + self.v = torch.zeros(cache_v_shape).to(device).half() + # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor] + self.k = torch.zeros(cache_k_shape).to(device).half() + + def get_kv(self, batch_size, start_pos, seqlen, head_dim): + xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous() + xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous() + xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous() + + return xv, xk + + def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): + self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store + self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store + + def roll_kv(self, roll_len, start_pos): + """ + For example, with roll_len=3 and [A,B,C,D,E] we get [D,E,F,G,H] + With sink=1, roll_len=3, and [A,B,C,D,E] we get [A,E,F,G,H] + """ + # Roll only the necessary part of the cache to the left + self.v[:, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.v[:, :, roll_len:, :] + self.k[:, :, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.k[:, :, :, roll_len:, :] + + # Zero out the new part + self.v[:, :, -roll_len:, :] = 0 + self.k[:, :, :, -roll_len:, :] = 0 + + return start_pos - roll_len + + def to(self, device): + self.k = self.k.to(device) + self.v = self.v.to(device) + \ No newline at end of file From e46703d815411aa930eaa7996b21c89c7e8533c8 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 13:59:41 +0000 Subject: [PATCH 2/8] Use new WindowedCache --- awq/modules/fused/attn.py | 148 +++++++++++++++++--------------------- 1 file changed, 66 insertions(+), 82 deletions(-) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index ce7b431f..de1d7b06 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -2,8 +2,8 @@ import math import torch import torch.nn as nn -import awq_inference_engine from torch.nn import functional as F +from awq.modules.fused.cache import WindowedCache try: import ft_inference_engine @@ -25,11 +25,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -): +def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): xq_ = torch.view_as_complex( xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous() ) @@ -65,6 +61,49 @@ def build_alibi_bias( slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) +def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): + if attention_shapes is not None: + attention_shapes = attention_shapes + + elif n_kv_heads == 0: + attention_shapes = { + # following fastertransformer definition + "cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,), + # 8: pack 8 fp16 in FT, if fp32 then use 4 + "cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,), + "xqkv_view": (-1, n_heads, head_dim), + "xq_slice": lambda xqkv: xqkv[:, :, 0], + "xk_slice": lambda xqkv: xqkv[:, :, 1], + "xv_slice": lambda xqkv: xqkv[:, :, 2], + "xq_view": (n_heads, head_dim), + "xk_view": (n_heads, head_dim), + "xv_view": (n_heads, head_dim), + "xk_reshape": (n_heads, head_dim // 8, 8), + "single_xq_view": (n_heads, head_dim), + "single_xk_view": (n_heads, head_dim), + "single_xv_view": (n_heads, head_dim) + } + + else: + attention_shapes = { + # following fastertransformer definition + "cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,), + # 8: pack 8 fp16 in FT, if fp32 then use 4 + "cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,), + "xqkv_view": (n_heads + n_kv_heads * 2, head_dim), + "xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads], + "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)], + "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :], + "xq_view": (n_heads, head_dim), + "xk_view": (n_kv_heads, head_dim), + "xv_view": (n_kv_heads, head_dim), + "xk_reshape": (n_kv_heads, head_dim // 8, 8), + "single_xq_view": (n_heads, head_dim), + "single_xk_view": (n_kv_heads, head_dim), + "single_xv_view": (n_kv_heads, head_dim) + } + + return attention_shapes class QuantAttentionFused(nn.Module): def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, @@ -81,9 +120,15 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max self.use_alibi = use_alibi self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.max_seq_len = max_seq_len - self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len) - self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() ) - self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() ) + + # attention shapes for self attention + self.attention_shapes = get_attention_shapes( + attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim + ) + # cache store that rolls cache + self.cache = WindowedCache( + self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev + ) if use_alibi: alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len) @@ -100,55 +145,7 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max self.alibi_slopes = None self.is_neox = True - def _get_attention_shapes(self, attention_shapes, max_seq_len): - if attention_shapes is not None: - attention_shapes = attention_shapes - - elif self.n_kv_heads == 0: - attention_shapes = { - # following fastertransformer definition - "cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,), - # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,), - "xqkv_view": (-1, self.n_heads, self.head_dim), - "xq_slice": lambda xqkv: xqkv[:, :, 0], - "xk_slice": lambda xqkv: xqkv[:, :, 1], - "xv_slice": lambda xqkv: xqkv[:, :, 2], - "xq_view": (self.n_heads, self.head_dim), - "xk_view": (self.n_heads, self.head_dim), - "xv_view": (self.n_heads, self.head_dim), - "xk_reshape": (self.n_heads, self.head_dim // 8, 8), - "single_xq_view": (self.n_heads, self.head_dim), - "single_xk_view": (self.n_heads, self.head_dim), - "single_xv_view": (self.n_heads, self.head_dim) - } - - else: - attention_shapes = { - # following fastertransformer definition - "cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,), - # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,), - "xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim), - "xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads], - "xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)], - "xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :], - "xq_view": (self.n_heads, self.head_dim), - "xk_view": (self.n_kv_heads, self.head_dim), - "xv_view": (self.n_kv_heads, self.head_dim), - "xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8), - "single_xq_view": (self.n_heads, self.head_dim), - "single_xk_view": (self.n_kv_heads, self.head_dim), - "single_xv_view": (self.n_kv_heads, self.head_dim) - } - - return attention_shapes - - def forward( - self, - hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None, - output_attentions=False, use_cache=False, *args, **kwargs - ): + def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): bsz, seqlen, _ = hidden_states.shape if bsz != self.cache_batch_size: raise RuntimeError( @@ -157,14 +154,8 @@ def forward( ) if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len: - # Roll cache to the left - roll_len = self.start_pos - self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2) - self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3) - # Zero out the new part - self.cache_v[:, :, -roll_len:, :] = 0 - self.cache_k[:, :, :, -roll_len:, :] = 0 - self.start_pos = 0 + excess_length = self.start_pos + seqlen - self.max_seq_len + self.start_pos = self.cache.roll_kv(excess_length, self.start_pos) xqkv = self.qkv_proj(hidden_states) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) @@ -181,8 +172,7 @@ def forward( if not self.use_alibi: xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen]) - self.cache_k = self.cache_k.to(xq) - self.cache_v = self.cache_v.to(xq) + self.cache.to(xq) values_store = xv.transpose(2, 1) keys_store = ( @@ -191,13 +181,10 @@ def forward( .contiguous() ) - self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store - self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store + self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen) if seqlen == 1: - xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous() - xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous() - xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous() + xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim) keys = xk values = xv @@ -229,8 +216,8 @@ def forward( xq, # query xk, # key xv, # value - self.cache_k, # key cache - self.cache_v, # value cache + self.cache.k, # key cache + self.cache.v, # value cache None, # length per sample self.alibi_slopes, # alibi slopes self.start_pos, # timestep @@ -241,11 +228,8 @@ def forward( attention_weight = attention_weight.reshape(bsz, 1, -1) attn_output = self.o_proj(attention_weight) - - if use_cache: - self.start_pos += seqlen - else: - self.start_pos = 0 + self.start_pos += seqlen - # past_key_value is replaced with cache_v, cache_k, returning None - return attn_output, attention_weight, None + # past_key_value is replaced with cache_v, cache_k, returning empty data + past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])] + return attn_output, attention_weight, past_key_value \ No newline at end of file From 204a3a12373c23fa652a7c2762ff65173d025b37 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 13:59:48 +0000 Subject: [PATCH 3/8] Update Mistral example --- examples/basic_generate.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/basic_generate.py b/examples/basic_generate.py index 1d76908c..89342524 100644 --- a/examples/basic_generate.py +++ b/examples/basic_generate.py @@ -4,7 +4,7 @@ quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ" # Load model -model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False, safetensors=True) +model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) @@ -16,8 +16,12 @@ {prompt}<|im_end|> <|im_start|>assistant""" +prompt = "You're standing on the surface of the Earth. "\ + "You walk one mile south, one mile west and one mile north. "\ + "You end up exactly where you started. Where are you?" + tokens = tokenizer( - prompt_template.format(prompt="Why is ice cream so good, yes so good?"), + prompt_template.format(prompt=prompt), return_tensors='pt' ).input_ids.cuda() @@ -26,4 +30,4 @@ tokens, streamer=streamer, max_new_tokens=512 -) +) \ No newline at end of file From 7a3d06d63f09c6d5ac0337dacce6211cac22842e Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 14:16:56 +0000 Subject: [PATCH 4/8] Update comment --- awq/modules/fused/cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/awq/modules/fused/cache.py b/awq/modules/fused/cache.py index c2f93a0a..57816ac0 100644 --- a/awq/modules/fused/cache.py +++ b/awq/modules/fused/cache.py @@ -26,8 +26,10 @@ def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): def roll_kv(self, roll_len, start_pos): """ - For example, with roll_len=3 and [A,B,C,D,E] we get [D,E,F,G,H] + With sink=0, roll_len=3, and [A,B,C,D,E] we get [D,E,F,G,H] With sink=1, roll_len=3, and [A,B,C,D,E] we get [A,E,F,G,H] + With sink=2, roll_len=3, and [A,B,C,D,E] we get [A,B,F,G,H] + With sink=3, roll_len=3, and [A,B,C,D,E] we get [A,B,C,G,H] """ # Roll only the necessary part of the cache to the left self.v[:, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.v[:, :, roll_len:, :] From 428504e423d9986ac48867a2b4a09c5a50bca6c5 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 14:29:55 +0000 Subject: [PATCH 5/8] Create ALiBi module --- awq/modules/fused/attn.py | 63 ++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index de1d7b06..f8290f40 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -37,29 +37,38 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) -def gen_slopes(n_heads, alibi_bias_max=8): - _n_heads = 2 ** math.ceil(math.log2(n_heads)) - m = torch.arange(1, _n_heads + 1, dtype=torch.float32) - m = m.mul(alibi_bias_max / _n_heads) - slopes = 1.0 / torch.pow(2, m) - if _n_heads != n_heads: - slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] - return slopes.view(1, n_heads, 1, 1) - - -def build_alibi_bias( - n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32 -): - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len) - if full: - alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view( - 1, 1, seq_len, 1 - ) - alibi_bias = alibi_bias.abs().mul(-1) - slopes = gen_slopes(n_heads, alibi_bias_max) - alibi_bias = alibi_bias * slopes - slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) - return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) +class ALiBi(nn.Module): + def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8): + super(ALiBi, self).__init__() + + # Initialize ALiBi slopes and bias + slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max) + self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False) + self.bias = nn.Parameter(bias.float().to(device), requires_grad=False) + + @staticmethod + def gen_slopes(n_heads, alibi_bias_max=8): + _n_heads = 2 ** math.ceil(math.log2(n_heads)) + m = torch.arange(1, _n_heads + 1, dtype=torch.float32) + m = m.mul(alibi_bias_max / _n_heads) + slopes = 1.0 / torch.pow(2, m) + + if _n_heads != n_heads: + slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads] + + return slopes.view(1, n_heads, 1, 1) + + @staticmethod + def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32): + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len) + slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max) + alibi_bias = alibi_bias * slopes + slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) + return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) + + def forward(self, scores, seqlen): + scores += self.bias[..., :seqlen] + return scores def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): if attention_shapes is not None: @@ -131,9 +140,7 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max ) if use_alibi: - alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len) - self.alibi_slopes = alibi_slopes.float().to(dev) - self.alibi_bias = alibi_bias.float().to(dev) + self.alibi = ALiBi(n_heads, max_seq_len, dev) self.rotary_dim = 0 self.is_neox = False else: @@ -199,7 +206,7 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if self.use_alibi: - scores += self.alibi_bias[..., :seqlen] + scores = self.alibi.forward(scores, seqlen) if attention_mask is not None: scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) @@ -219,7 +226,7 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar self.cache.k, # key cache self.cache.v, # value cache None, # length per sample - self.alibi_slopes, # alibi slopes + self.alibi.slopes, # alibi slopes self.start_pos, # timestep self.rotary_dim, # rotary embedding dimension 10000, # rotary embedding base From 306de68313d417abc11b1e28b52118166558da06 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 14:36:18 +0000 Subject: [PATCH 6/8] Move attention shapes to fused_utils --- awq/modules/fused/attn.py | 45 +-------------------------------------- awq/utils/fused_utils.py | 44 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 44 deletions(-) create mode 100644 awq/utils/fused_utils.py diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index f8290f40..6bb721d5 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -4,6 +4,7 @@ import torch.nn as nn from torch.nn import functional as F from awq.modules.fused.cache import WindowedCache +from awq.utils.fused_utils import get_attention_shapes try: import ft_inference_engine @@ -70,50 +71,6 @@ def forward(self, scores, seqlen): scores += self.bias[..., :seqlen] return scores -def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): - if attention_shapes is not None: - attention_shapes = attention_shapes - - elif n_kv_heads == 0: - attention_shapes = { - # following fastertransformer definition - "cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,), - # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,), - "xqkv_view": (-1, n_heads, head_dim), - "xq_slice": lambda xqkv: xqkv[:, :, 0], - "xk_slice": lambda xqkv: xqkv[:, :, 1], - "xv_slice": lambda xqkv: xqkv[:, :, 2], - "xq_view": (n_heads, head_dim), - "xk_view": (n_heads, head_dim), - "xv_view": (n_heads, head_dim), - "xk_reshape": (n_heads, head_dim // 8, 8), - "single_xq_view": (n_heads, head_dim), - "single_xk_view": (n_heads, head_dim), - "single_xv_view": (n_heads, head_dim) - } - - else: - attention_shapes = { - # following fastertransformer definition - "cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,), - # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,), - "xqkv_view": (n_heads + n_kv_heads * 2, head_dim), - "xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads], - "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)], - "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :], - "xq_view": (n_heads, head_dim), - "xk_view": (n_kv_heads, head_dim), - "xv_view": (n_kv_heads, head_dim), - "xk_reshape": (n_kv_heads, head_dim // 8, 8), - "single_xq_view": (n_heads, head_dim), - "single_xk_view": (n_kv_heads, head_dim), - "single_xv_view": (n_kv_heads, head_dim) - } - - return attention_shapes - class QuantAttentionFused(nn.Module): def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, use_alibi=False, attention_shapes=None): diff --git a/awq/utils/fused_utils.py b/awq/utils/fused_utils.py new file mode 100644 index 00000000..ff5e9ad0 --- /dev/null +++ b/awq/utils/fused_utils.py @@ -0,0 +1,44 @@ + +def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): + if attention_shapes is not None: + attention_shapes = attention_shapes + + elif n_kv_heads == 0: + attention_shapes = { + # following fastertransformer definition + "cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,), + # 8: pack 8 fp16 in FT, if fp32 then use 4 + "cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,), + "xqkv_view": (-1, n_heads, head_dim), + "xq_slice": lambda xqkv: xqkv[:, :, 0], + "xk_slice": lambda xqkv: xqkv[:, :, 1], + "xv_slice": lambda xqkv: xqkv[:, :, 2], + "xq_view": (n_heads, head_dim), + "xk_view": (n_heads, head_dim), + "xv_view": (n_heads, head_dim), + "xk_reshape": (n_heads, head_dim // 8, 8), + "single_xq_view": (n_heads, head_dim), + "single_xk_view": (n_heads, head_dim), + "single_xv_view": (n_heads, head_dim) + } + + else: + attention_shapes = { + # following fastertransformer definition + "cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,), + # 8: pack 8 fp16 in FT, if fp32 then use 4 + "cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,), + "xqkv_view": (n_heads + n_kv_heads * 2, head_dim), + "xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads], + "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)], + "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :], + "xq_view": (n_heads, head_dim), + "xk_view": (n_kv_heads, head_dim), + "xv_view": (n_kv_heads, head_dim), + "xk_reshape": (n_kv_heads, head_dim // 8, 8), + "single_xq_view": (n_heads, head_dim), + "single_xk_view": (n_kv_heads, head_dim), + "single_xv_view": (n_kv_heads, head_dim) + } + + return attention_shapes \ No newline at end of file From 69733d2c4122653c54266e2214d349db1ae50299 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 14:49:29 +0000 Subject: [PATCH 7/8] Create RoPE module --- awq/modules/fused/attn.py | 76 ++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 6bb721d5..03d5a40e 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -12,31 +12,45 @@ except: FT_INSTALLED = False -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - -def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): - xq_ = torch.view_as_complex( - xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous() - ) - xk_ = torch.view_as_complex( - xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous() - ) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device) - xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) +class RoPE(nn.Module): + def __init__(self, hidden_size, n_heads, max_seq_len, device): + super(RoPE, self).__init__() + + self.freqs_cis = nn.Parameter( + self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device), + requires_grad=False + ) + + @staticmethod + def precompute_freqs_cis(dim: int, end: int, theta=10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + @staticmethod + def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int): + xq_ = torch.view_as_complex( + xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous() + ) + xk_ = torch.view_as_complex( + xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous() + ) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device) + + xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) class ALiBi(nn.Module): def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8): @@ -101,12 +115,9 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max self.rotary_dim = 0 self.is_neox = False else: - self.freqs_cis = precompute_freqs_cis( - hidden_size // n_heads, - max_seq_len * 2, - ).to(dev) + self.alibi = None + self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev) self.rotary_dim = self.head_dim - self.alibi_slopes = None self.is_neox = True def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): @@ -134,7 +145,7 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"]) if not self.use_alibi: - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen]) + xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen) self.cache.to(xq) @@ -176,6 +187,7 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"]) xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"]) + alibi_slopes = self.alibi.slopes if self.alibi is not None else None attention_weight = ft_inference_engine.single_query_attention( xq, # query xk, # key @@ -183,7 +195,7 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar self.cache.k, # key cache self.cache.v, # value cache None, # length per sample - self.alibi.slopes, # alibi slopes + alibi_slopes, # alibi slopes self.start_pos, # timestep self.rotary_dim, # rotary embedding dimension 10000, # rotary embedding base From b13e2a859735530953b1df109a5cab471f46a852 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 6 Oct 2023 15:18:52 +0000 Subject: [PATCH 8/8] Remove attention sinks --- awq/modules/fused/cache.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/awq/modules/fused/cache.py b/awq/modules/fused/cache.py index 57816ac0..dff47423 100644 --- a/awq/modules/fused/cache.py +++ b/awq/modules/fused/cache.py @@ -1,13 +1,11 @@ import torch class WindowedCache: - def __init__(self, cache_v_shape, cache_k_shape, device, attention_sinks=4): + def __init__(self, cache_v_shape, cache_k_shape, device): """ The window size is the same as the max_new_tokens. The window will automatically roll once max_new_tokens is exceeded. """ - self.attention_sinks = attention_sinks - # [batch_size, n_kv_heads, max_seq_len, head_dim] self.v = torch.zeros(cache_v_shape).to(device).half() # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor] @@ -25,15 +23,9 @@ def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store def roll_kv(self, roll_len, start_pos): - """ - With sink=0, roll_len=3, and [A,B,C,D,E] we get [D,E,F,G,H] - With sink=1, roll_len=3, and [A,B,C,D,E] we get [A,E,F,G,H] - With sink=2, roll_len=3, and [A,B,C,D,E] we get [A,B,F,G,H] - With sink=3, roll_len=3, and [A,B,C,D,E] we get [A,B,C,G,H] - """ # Roll only the necessary part of the cache to the left - self.v[:, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.v[:, :, roll_len:, :] - self.k[:, :, :, self.attention_sinks:-roll_len+self.attention_sinks, :] = self.k[:, :, :, roll_len:, :] + self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :] + self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :] # Zero out the new part self.v[:, :, -roll_len:, :] = 0