diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index be385ae67..5439cc38b 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -16,6 +16,7 @@ # TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 +from ..models.mixtral.mixtral import * def main(): @@ -52,7 +53,10 @@ def main(): llama_config = LlamaModelConfig(hp) llama_config.static_tables = False # Rely on the compiler for hoisting tables. llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" - model = PagedLlamaModelV1(dataset.root_theta, llama_config) + if llama_config.hp.expert_count: + model = PagedMixtralModelV1(dataset.root_theta, llama_config) + else: + model = PagedLlamaModelV1(dataset.root_theta, llama_config) def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): return { diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 17f2bc9ad..47d281565 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -17,6 +17,7 @@ from ..types import * # TODO: Should be using a base class with the protocol supported. +from ..models.mixtral.mixtral import * from ..models.llama.llama import * from ..utils.debugging import trace_tensor from ..utils.tokenizer import InferenceTokenizer, load_tokenizer @@ -236,7 +237,11 @@ def main(): activation_dtype=activation_dtype, attention_dtype=activation_dtype, ) - model = PagedLlamaModelV1(dataset.root_theta, config) + + if config.hp.expert_count: + model = PagedMixtralModelV1(dataset.root_theta, config) + else: + model = PagedLlamaModelV1(dataset.root_theta, config) if args.save_intermediates_path: from ..utils.patching import SaveModuleResultTensorsPatch diff --git a/sharktank/sharktank/examples/validate_direct_mixtral_model.py b/sharktank/sharktank/examples/validate_direct_mixtral_model.py new file mode 100644 index 000000000..e139365f8 --- /dev/null +++ b/sharktank/sharktank/examples/validate_direct_mixtral_model.py @@ -0,0 +1,144 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys + +import torch + +from sharktank.layers import * +from sharktank.types import * +from sharktank.models.mixtral.mixtral import * + + +def main(args: list[str]): + from ..utils import cli + + torch.no_grad().__enter__() + + parser = cli.create_parser() + cli.add_input_dataset_options(parser) + args = cli.parse(parser) + + dataset = cli.get_input_dataset(args) + hp = configs.LlamaHParams.from_gguf_props(dataset.properties) + llama_config = LlamaModelConfig(hp) + llama_config.kv_cache_type = "direct" + llama_config.activation_dtype = torch.float16 + model = PagedMixtralModelV1(dataset.root_theta, llama_config) + + # bs ("batch size") == 1 + cache_state = model.cache.allocate(bs=1) + + start_index = 0 + tokens = torch.tensor( + [ + [ + 1, + 1059, + 31871, + 1217, + 322, + 266, + 3682, + 6075, + 31902, + 13, + 31849, + 31871, + 0, + 0, + 0, + 0, + ] + + 48 * [0], + ] + ) + assert tokens.shape[1] % model.cache.block_seq_stride == 0 + seq_block_ids = torch.tensor( + [ + [127, 0, 0, 0], + ] + ) + + # Important: Do not use a sequence length of 0 for empty batch slots + # as it will cause softmax to nan due to a mask of all -inf. This then + # propagates and causes badness. + seq_lens = torch.tensor([12]) + + attention_mask = model.attention_mask( + model.input_mask(seq_lens, tokens.shape[1]), + ) + + print(f"Step {start_index}") + logits = model.prefill( + tokens, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + # TODO: Normalize the output of extract_tokens_from_logits into tensor [bs, 1]. + tokens = torch.tensor(model.extract_tokens_from_logits(logits, seq_lens)).unsqueeze( + 1 + ) + print(f" : tokens = {tokens}") + + # Decode a step. + print("Decoding...") + print(tokens.shape, tokens) + start_positions = torch.tensor([12]) + seq_lens = seq_lens + 1 + decode_attention_mask = model.decode_attention_mask( + model.input_mask( + seq_lens, + seq_block_ids.shape[1] * model.cache.block_seq_stride, + ), + ) + logits = model.decode( + tokens, + attention_mask=decode_attention_mask, + start_positions=start_positions, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + tokens = torch.tensor(model.extract_tokens_from_logits(logits, [1])).unsqueeze(1) + print(f" : tokens = {tokens}") + + def save_prefill_module(model): + from iree.compiler.extras.fx_importer import FxImporter + from iree.compiler.ir import AsmState + + importer = FxImporter() + + print("Generating FX graph") + + class InferenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.add_module("prefill", model) + + def forward(self, tokens, attention_mask, seq_block_ids, *cache_state): + return self.prefill.prefill( + tokens, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=list(cache_state), + ) + + infmod = InferenceModule() + prog = torch.export.export( + infmod, (tokens, attention_mask, seq_block_ids) + tuple(cache_state) + ) + + print(f"FX prog:", prog) + importer.import_program(prog, func_name="prefill") + output_file = "/tmp/prefill.mlirbc" + print("Saving to:", output_file) + with open(output_file, "wb") as f: + importer.module_op.write_bytecode(f) + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/sharktank/sharktank/examples/validate_mixtral_ref_model.py b/sharktank/sharktank/examples/validate_mixtral_ref_model.py new file mode 100644 index 000000000..bad6c4382 --- /dev/null +++ b/sharktank/sharktank/examples/validate_mixtral_ref_model.py @@ -0,0 +1,48 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys + +import torch + +from sharktank.layers import * +from sharktank.types import * +from sharktank.models.mixtral.mixtral_ref import * + + +def main(args: list[str]): + from ..utils import cli + + torch.no_grad().__enter__() + + parser = cli.create_parser() + cli.add_input_dataset_options(parser) + args = cli.parse(parser) + + dataset = cli.get_input_dataset(args) + hp = configs.LlamaHParams.from_gguf_props(dataset.properties) + ref_llama_config = RefLlamaModelConfig(hp) + ref_llama_config.activation_dtype = torch.float16 + model = DirectCacheMixtralModelV1(dataset.root_theta, ref_llama_config) + + kv_cache = model.create_cache(bs=1) + start_index = 0 + next_tokens = [1, 1059, 31871, 1217, 322, 266, 3682, 6075, 31902, 13, 31849, 31871] + print(f"Step {start_index}") + tokens = model.forward( + torch.tensor([next_tokens]), start_index=start_index, local_kv_cache=kv_cache + ) + print(f" : tokens = {tokens}") + + # Decode a step. + print("Decoding...") + print(tokens.shape, tokens) + decode_token = model.forward(tokens, start_index=12, local_kv_cache=kv_cache) + print(f" : decode tokens = {decode_token}") + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index 25a00c80f..181544763 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -12,5 +12,10 @@ from .norm import RMSNormLayer from .rotary_embedding import RotaryEmbeddingLayer from .token_embedding import TokenEmbeddingLayer +from .llama_attention_block import LlamaAttentionBlock +from .paged_llama_attention_block import PagedLlamaAttentionBlock +from .ffn_block import FFN +from .ffn_moe_block import FFNMOE +from .mixture_of_experts_block import SparseMoeBlock from . import configs diff --git a/sharktank/sharktank/layers/base.py b/sharktank/sharktank/layers/base.py index 90c976f25..11a21f885 100644 --- a/sharktank/sharktank/layers/base.py +++ b/sharktank/sharktank/layers/base.py @@ -16,11 +16,8 @@ from ..utils import debugging __all__ = [ - "LinearLayer", - "RotaryEmbeddingLayer", - "RMSNormLayer", + "BaseLayer", "ThetaLayer", - "TokenEmbedding", ] diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index a9aad8088..ab3a582e4 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -19,9 +19,7 @@ import torch -__all__ = [ - "LlamaHParams", -] +__all__ = ["LlamaHParams"] @dataclass @@ -36,14 +34,21 @@ class LlamaHParams: block_count: int feed_forward_length: int rope_dimension_count: int + rope_freq_base: float attention_head_count: int attn_head_dim: int attention_layer_norm_rms_epsilon: float attention_head_count_kv: int + expert_count: int + expert_used_count: int @staticmethod def from_gguf_props(p: dict[str, Any]): + default_expert_count = 0 + default_expert_used_count = 0 + default_rope_freq_base = 10000.0 attention_head_count = _int_prop(p, "llama.attention.head_count") + return LlamaHParams( context_length=_int_prop(p, "llama.context_length"), embedding_length=_int_prop(p, "llama.embedding_length"), @@ -58,6 +63,15 @@ def from_gguf_props(p: dict[str, Any]): attention_head_count_kv=_optional_int_prop( p, "llama.attention.head_count_kv", attention_head_count ), + rope_freq_base=_optional_float_prop( + p, "llama.rope.freq_base", default_rope_freq_base + ), + expert_count=_optional_int_prop( + p, "llama.expert_count", default_expert_count + ), + expert_used_count=_optional_int_prop( + p, "llama.expert_used_count", default_expert_used_count + ), ) @@ -79,10 +93,16 @@ def _int_prop(p: dict[str, Any], name: str) -> int: raise KeyError(f"Property '{name}' not found (among keys {p.keys()})") +def _optional_float_prop(p: dict[str, Any], name: str, default_value: float) -> float: + value = p.get(name, default_value) + try: + return float(value) + except ValueError as e: + raise ValueError(f"Property '{name}' expected to be a float and was not") from e + + def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int: - value = p[name] - if value is None: - return default_value + value = p.get(name, default_value) try: return int(value) except ValueError as e: diff --git a/sharktank/sharktank/layers/ffn_block.py b/sharktank/sharktank/layers/ffn_block.py new file mode 100644 index 000000000..020528f71 --- /dev/null +++ b/sharktank/sharktank/layers/ffn_block.py @@ -0,0 +1,38 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +import torch +import torch.nn.functional as F + +from .base import Theta, ThetaLayer +from .linear import LinearLayer + +__all__ = [ + "FFN", +] + + +class FFN(ThetaLayer): + def __init__( + self, + theta: Theta, + ): + super().__init__(theta) + + self.add_module("ffn_gate", LinearLayer(theta("ffn_gate"))) + self.add_module("ffn_up", LinearLayer(theta("ffn_up"))) + self.add_module("ffn_down", LinearLayer(theta("ffn_down"))) + + def forward( + self, + h: torch.Tensor, + ): + ffn_gate = F.silu(self.ffn_gate(h)) + ffn_up = self.ffn_up(h) + ffn_down = self.ffn_down(ffn_gate * ffn_up) + return ffn_down diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py new file mode 100644 index 000000000..d2f51d2d9 --- /dev/null +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -0,0 +1,84 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +import torch +import torch.nn.functional as F + +from .base import ThetaLayer +from .linear import LinearLayer +from ..types import Theta, DefaultPrimitiveTensor + +__all__ = [ + "FFNMOE", +] + + +class FFNMOE(ThetaLayer): + def __init__( + self, + theta: Theta, + expert_idx: Optional[int] = None, + ): + + super().__init__(theta) + + if theta.optional_tensor("ffn_gate_exps") is not None: + merged_tensor = theta.tensor("ffn_gate_exps", "weight") + + expert_tensor = extract_ffn_layer( + merged_tensor=merged_tensor, + layer_name="ffn_gate", + expert_idx=expert_idx, + ) + + self.add_module("ffn_gate", LinearLayer(Theta({"weight": expert_tensor}))) + + merged_tensor = theta.tensor("ffn_up_exps", "weight") + + expert_tensor = extract_ffn_layer( + merged_tensor=merged_tensor, layer_name="ffn_up", expert_idx=expert_idx + ) + + self.add_module("ffn_up", LinearLayer(Theta({"weight": expert_tensor}))) + + merged_tensor = theta.tensor("ffn_down_exps", "weight") + + expert_tensor = extract_ffn_layer( + merged_tensor=merged_tensor, + layer_name="ffn_down", + expert_idx=expert_idx, + ) + + self.add_module("ffn_down", LinearLayer(Theta({"weight": expert_tensor}))) + + else: + self.add_module("ffn_gate", LinearLayer(theta("ffn_gate", expert_idx))) + self.add_module("ffn_up", LinearLayer(theta("ffn_up", expert_idx))) + self.add_module("ffn_down", LinearLayer(theta("ffn_down", expert_idx))) + + def forward( + self, + h: torch.Tensor, + ): + ffn_gate = F.silu(self.ffn_gate(h)) + ffn_up = self.ffn_up(h) + ffn_down = self.ffn_down(ffn_gate * ffn_up) + return ffn_down + + +def extract_ffn_layer( + merged_tensor: DefaultPrimitiveTensor, layer_name: str, expert_idx: int +): + # fetches the block_idx from merged_tensor_name. e.g. blk.0.ffn_gate_exps.weight + expert_layer_name = ( + f"blk.{merged_tensor.name.split('.')[1]}.{layer_name}.{expert_idx}.weight" + ) + expert_tensor = DefaultPrimitiveTensor( + name=expert_layer_name, data=merged_tensor.as_torch()[expert_idx] + ) + return expert_tensor diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py new file mode 100644 index 000000000..7be8c7366 --- /dev/null +++ b/sharktank/sharktank/layers/llama_attention_block.py @@ -0,0 +1,133 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +import math + +import torch +import torch.nn.functional as F + +from .base import Theta, ThetaLayer +from .linear import LinearLayer +from .norm import RMSNormLayer +from .rotary_embedding import RotaryEmbeddingLayer + +__all__ = [ + "LlamaAttentionBlock", +] + + +class LlamaAttentionBlock(ThetaLayer): + """Implements a self attention layer in the style of Llama.""" + + def __init__( + self, + theta: Theta, + *, + head_count: int, + head_dim: int, + head_count_kv: int, + embedding: RotaryEmbeddingLayer, + rms_epsilon: float, + ): + super().__init__(theta) + self.add_module( + "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) + ) + self.add_module("attn_q", LinearLayer(theta("attn_q"))) + self.add_module("attn_k", LinearLayer(theta("attn_k"))) + self.add_module("attn_v", LinearLayer(theta("attn_v"))) + self.add_module("attn_output", LinearLayer(theta("attn_output"))) + + self.embedding = embedding + self.head_count = head_count + self.head_dim = head_dim + self.head_count_kv = head_count_kv + + def forward( + self, + h: torch.Tensor, + *, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + start_index: int, + attention_mask: Optional[torch.Tensor] = None, + ): + x = self.attn_norm(h) + + bs, q_len, feature_dim = x.shape + kv_seq_len = start_index + q_len + assert feature_dim == self.head_count * self.head_dim + + xq = self.attn_q(x) + xk = self.attn_k(x) + xv = self.attn_v(x) + + xq = xq.view(bs, q_len, self.head_count, self.head_dim) + xk = xk.view(bs, q_len, self.head_count_kv, self.head_dim) + xv = xv.view(bs, q_len, self.head_count_kv, self.head_dim) + + # Fast path to start_index based embedding lookup if available. + # Falls back to a slower position based index lookup. + if start_index is not None: + xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index) + else: + xq, xk = embedding.apply_batched_mask( + xq=xq, xk=xk, mask=embedding_batch_mask + ) + + # Expand kv heads for GQA. + gqa_n_rep = self.head_count // self.head_count_kv + assert gqa_n_rep > 0 + if gqa_n_rep > 1: + + def repeat_kv(x: torch.Tensor) -> torch.Tensor: + bs, slen, n_kv_heads, head_dim = x.shape + return ( + x.unsqueeze(-2) + .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim) + ) + + xk = repeat_kv(xk) + xv = repeat_kv(xv) + + # Update our positions in the cache. + cache_k[:bs, start_index:kv_seq_len] = xk + cache_v[:bs, start_index:kv_seq_len] = xv + + # Derive keys/values from the entirety of the available sequence. + keys = cache_k[:bs, :kv_seq_len] + values = cache_v[:bs, :kv_seq_len] + + # Tranpose into [bs, heads, sl, dim] + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + + # Flash attention. + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + + # Apply attention mask. + if attention_mask is not None: + expected_mask_shape = (bs, 1, q_len, kv_seq_len) + assert ( + attention_mask.shape == expected_mask_shape + ), f"Attention mask should be of size {expected_mask_shape}, but is {attention_mask.shape}" + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq) + attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim) + attn_output = attn_output.transpose(1, 2).reshape(bs, q_len, -1) + + # Project. + attn_output = self.attn_output(attn_output) + + # Remainder of the block. + h = h + attn_output + + return h diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py new file mode 100644 index 000000000..5f6e592f9 --- /dev/null +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -0,0 +1,115 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import Theta, ThetaLayer +from .linear import LinearLayer +from .norm import RMSNormLayer +from .ffn_moe_block import FFNMOE + +__all__ = [ + "SparseMoeBlock", +] + + +class SparseMoeBlock(ThetaLayer): + """ + This implementation considers MoE operations as block-sparse + operations to support imbalanced token assignments to experts. + This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens + (or reduced performance). + """ + + def __init__( + self, + theta: Theta, + expert_count: int, + expert_used_count: int, + rms_epsilon: float, + ): + super().__init__(theta) + + # Add router gate + self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) + + # Add FFN norm + self.add_module( + "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) + ) + + # Add FFN output norm + self.add_module( + "layer_output_norm", + RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), + ) + + # Add expert_count x FFN + self.experts = nn.ModuleList( + [FFNMOE(theta, expert_idx=i) for i in range(expert_count)] + ) + + self.expert_count = expert_count + self.expert_used_count = expert_used_count + + def forward( + self, + h: torch.Tensor, + ): + ffn_input = self.ffn_norm(h) + batch_size, sequence_length, feature_dim = ffn_input.shape + ffn_input = ffn_input.view(-1, feature_dim) + + # For each token, the router calculates the router weights for all experts + # router_logits: (batch_size * sequence_length, expert_count) + router_logits = self.ffn_gate_inp(ffn_input) + router_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + # Select top k experts from router weights + router_weights, top_k_experts = torch.topk( + router_weights, self.expert_used_count, dim=-1 + ) + router_weights /= router_weights.sum(dim=-1, keepdim=True) + router_weights = router_weights.to(ffn_input.dtype) + + moe_output = torch.zeros( + (batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype + ) + + # Create an expert mask by one hot encoding the selected top k experts + # used to index which expert is to be invoked for each token + # expert_mask: (expert_count, expert_used_count, sequence_length) + expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute( + 2, 1, 0 + ) + + # Iterate over all experts in the model + for expert_idx in range(self.expert_count): + expert_layer = self.experts[expert_idx] + top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx]) + + # Given the hidden states, index the tokens assigned to this expert + # and calculate the current expert's hidden state and weigh the + # output expert hidden states by the router weights, based on the + # appropriate tokens + current_expert_tokens = ffn_input[None, token_idx] + + current_expert = ( + expert_layer(current_expert_tokens) + * router_weights[token_idx, top_k_expert_idx, None] + ) + + current_expert = current_expert.reshape(-1, feature_dim) + + moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) + moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) + + moe_output = self.layer_output_norm(moe_output) + return h + moe_output diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py new file mode 100644 index 000000000..1e5224e89 --- /dev/null +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -0,0 +1,261 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +import math + +import torch +import torch.nn.functional as F + +from .base import Theta, ThetaLayer +from .linear import LinearLayer +from .norm import RMSNormLayer +from .rotary_embedding import RotaryEmbeddingLayer +from .kv_cache import PagedKVCache + +__all__ = [ + "PagedLlamaAttentionBlock", +] + + +class PagedLlamaAttentionBlock(ThetaLayer): + """Implements a self attention layer in the style of Llama using a + paged cache.""" + + def __init__( + self, + theta: Theta, + *, + block_index: int, + cache: PagedKVCache, + head_count: int, + head_dim: int, + head_count_kv: int, + rms_epsilon: float, + use_hf: bool = False, + ): + super().__init__(theta) + self.add_module( + "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) + ) + self.add_module("attn_q", LinearLayer(theta("attn_q"))) + self.add_module("attn_k", LinearLayer(theta("attn_k"))) + self.add_module("attn_v", LinearLayer(theta("attn_v"))) + self.add_module("attn_output", LinearLayer(theta("attn_output"))) + + self.block_index = block_index + self.cache = cache + self.head_count = head_count + self.head_dim = head_dim + self.head_count_kv = head_count_kv + self.use_hf = use_hf + + def forward( + self, + h: torch.Tensor, + *, + embedding: RotaryEmbeddingLayer, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + start_index: Optional[int] = None, + start_positions: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + embedding_batch_mask: Optional[torch.Tensor] = None, + cache_state: list[torch.Tensor] = None, + xk_temp: Optional[torch.Tensor] = None, + xv_temp: Optional[torch.Tensor] = None, + ): + assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None) + + x = self.attn_norm(h) + + bs, batch_seq_len, feature_dim = x.shape + assert feature_dim == self.head_count * self.head_dim + + xq = self.attn_q(x) + xk = self.attn_k(x) + xv = self.attn_v(x) + + xq = xq.view(bs, batch_seq_len, self.head_count, self.head_dim) + xk = xk.view(bs, batch_seq_len, self.head_count_kv, self.head_dim) + xv = xv.view(bs, batch_seq_len, self.head_count_kv, self.head_dim) + + # Fast path to start_index based embedding lookup if available. + # Falls back to a slower position based index lookup. + if start_index is not None: + xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index) + else: + xq, xk = embedding.apply_batched_mask( + xq=xq, xk=xk, mask=embedding_batch_mask + ) + + # Full sequence length. + kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride + + if self.cache.is_paged: + xk, xv = self.transact_cache_paged( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + elif self.cache.is_direct: + xk, xv = self.transact_cache_direct( + xk_cache_update=xk, + xv_cache_update=xv, + start_positions=start_positions, + kv_seq_len=kv_seq_len, + cache_state=cache_state, + ) + else: + raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}") + + # Expand kv heads for GQA. + gqa_n_rep = self.head_count // self.head_count_kv + assert gqa_n_rep > 0 + if gqa_n_rep > 1: + + def repeat_kv(x: torch.Tensor) -> torch.Tensor: + bs, slen, n_kv_heads, head_dim = x.shape + return ( + x.unsqueeze(-2) + .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim) + ) + + xk = repeat_kv(xk) + xv = repeat_kv(xv) + + # Transpose into [bs, heads, sl, dim] + xq = xq.transpose(1, 2) + keys = xk.transpose(1, 2) + values = xv.transpose(1, 2) + + # Flash attention. + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + self.assert_not_nan(attn_weights) + + # Apply attention mask. + self.trace_tensor("attn_weights", attn_weights, values=False) + if attention_mask is not None: + # self.trace_tensor("attn_mask", attention_mask) + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq) + attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim) + attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1) + + # Project. + attn_output = self.attn_output(attn_output) + + # Remainder of the block. + h = h + attn_output + + return h + + def transact_cache_direct( + self, + *, + cache_state: list[torch.Tensor], + xk_cache_update: torch.Tensor, + xv_cache_update: torch.Tensor, + kv_seq_len: int, + start_positions: Optional[torch.Tensor] = None, + ): + bs, batch_seq_len, _, _ = xk_cache_update.shape + cache_k = cache_state[self.block_index * 2] + cache_v = cache_state[self.block_index * 2 + 1] + + if start_positions is None: + # Prefill. Write the entire cache. + cache_k[:, :batch_seq_len] = xk_cache_update + cache_v[:, :batch_seq_len] = xv_cache_update + return xk_cache_update, xv_cache_update + else: + # Decode. Write a single timestep. + # TODO: This needs to be reworked with index ops. + assert xk_cache_update.shape[1] == 1 + assert xv_cache_update.shape[1] == 1 + max_start_pos = 0 + for row_index in range(bs): + row_start_pos = start_positions[row_index].item() + max_start_pos = max(row_start_pos, max_start_pos) + cache_k[row_index, row_start_pos] = xk_cache_update[row_index, 0] + cache_v[row_index, row_start_pos] = xv_cache_update[row_index, 0] + return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len] + + def transact_cache_paged( + self, + *, + xk_cache_update: torch.Tensor, + xv_cache_update: torch.Tensor, + cache_state: list[torch.Tensor], + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + kv_seq_len: int, + start_positions: Optional[torch.Tensor] = None, + xk_temp: Optional[torch.Tensor] = None, + xv_temp: Optional[torch.Tensor] = None, + ): + cache = self.cache.paged + # Manage the cache. + if start_positions is None: + # Prefill: Write the entire cache. + cache.write( + cache_state, + cache_partitions=[xk_cache_update, xv_cache_update], + transformer_block_index=self.block_index, + page_ids=seq_block_ids, + ) + return xk_cache_update, xv_cache_update + else: + # Decode at ragged start positions. + # We need to initialize/read the K/V from the cache for the whole + # sequence. Note that at this point, it is possible to fork and + # use a memory efficient attention kernel that can do indirect + # reads, skipping this materialization. This path is taken for + # a decode step. + assert xk_temp is not None and xv_temp is not None + assert xk_cache_update.shape[1] == 1 + assert xv_cache_update.shape[1] == 1 + assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride + + # Write our one updated cache row into the cache. + cache.write_timestep( + cache_state, + cache_partitions=[ + xk_cache_update, + xv_cache_update, + ], + transformer_block_index=self.block_index, + seq_positions=start_positions + 1, + page_ids=seq_block_ids, + ) + + # Restore from the cache. + cache.read( + cache_state, + read_into_partitions=[ + xk_temp[:, 0:kv_seq_len, ...], + xv_temp[:, 0:kv_seq_len, ...], + ], + transformer_block_index=self.block_index, + page_ids=seq_block_ids, + ) + + # For computation, we create a subview of the xk/xv tensors to have + # a sequence length covering the blocked size. This must include + # the newly added row (the caller is responsible for ensuring that + # every block has at least one row left). We'll compute on this + # ragged view and use an appropriate mask. + xk = xk_temp[:, 0:kv_seq_len, ...] + xv = xv_temp[:, 0:kv_seq_len, ...] + return xk, xv diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 2e47ca5c7..a5f0eed09 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -19,6 +19,7 @@ def __init__( *, rope_dimension_count: int, max_seqlen: int, + rope_freq_base: Optional[float] = None, device: Optional[torch.device] = None, use_hf: bool = False, static_tables: bool = True, @@ -31,6 +32,7 @@ def __init__( self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf + self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 if static_tables: self.register_buffer( "static_rotary_embed_table", self._create_rotary_embed_table() @@ -182,12 +184,11 @@ def apply_batched_mask( def _create_rotary_embed_table( self, - theta_value: float = 10000.0, ): dim = self.rope_dimension_count max_seqlen = self.max_seqlen freqs = 1.0 / ( - theta_value + self.rope_freq_base ** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim) ) t = torch.arange(max_seqlen, device=freqs.device) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 8fc66191a..8266872ad 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -137,6 +137,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): "attention_embedding", RotaryEmbeddingLayer( rope_dimension_count=hp.rope_dimension_count, + rope_freq_base=hp.rope_freq_base, max_seqlen=hp.context_length, device=self.device, use_hf=self.use_hf, @@ -153,7 +154,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.attn_blocks = nn.ModuleList( [ - PagedLlamaAttentionBlock( + AttentionFFNBlock( theta("blk", n), block_index=n, cache=self.cache, @@ -281,7 +282,7 @@ def decode( ################################################################################ -class PagedLlamaAttentionBlock(ThetaLayer): +class AttentionFFNBlock(ThetaLayer): """Implements a self attention layer in the style of Llama using a paged cache.""" @@ -298,28 +299,28 @@ def __init__( use_hf: bool = False, ): super().__init__(theta) - self.add_module( - "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) + "attn", + PagedLlamaAttentionBlock( + theta=theta, + block_index=block_index, + cache=cache, + head_count=head_count, + head_dim=head_dim, + head_count_kv=head_count_kv, + rms_epsilon=rms_epsilon, + use_hf=use_hf, + ), + ) + self.add_module( + "ffn", + FFN( + theta=theta, + ), ) - self.add_module("attn_q", LinearLayer(theta("attn_q"))) - self.add_module("attn_k", LinearLayer(theta("attn_k"))) - self.add_module("attn_v", LinearLayer(theta("attn_v"))) - self.add_module("attn_output", LinearLayer(theta("attn_output"))) self.add_module( "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) ) - self.add_module("ffn_gate", LinearLayer(theta("ffn_gate"))) - self.add_module("ffn_up", LinearLayer(theta("ffn_up"))) - self.add_module("ffn_down", LinearLayer(theta("ffn_down"))) - - self.block_index = block_index - self.cache = cache - assert isinstance(head_count, int) - self.head_count = head_count - self.head_dim = head_dim - self.head_count_kv = head_count_kv - self.use_hf = use_hf def forward( self, @@ -336,200 +337,21 @@ def forward( xk_temp: Optional[torch.Tensor] = None, xv_temp: Optional[torch.Tensor] = None, ): - assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None) - - x = self.attn_norm(h) - - bs, batch_seq_len, feature_dim = x.shape - assert feature_dim == self.head_count * self.head_dim - - xq = self.attn_q(x) - xk = self.attn_k(x) - xv = self.attn_v(x) - - xq = xq.view(bs, batch_seq_len, self.head_count, self.head_dim) - xk = xk.view(bs, batch_seq_len, self.head_count_kv, self.head_dim) - xv = xv.view(bs, batch_seq_len, self.head_count_kv, self.head_dim) - - # Fast path to start_index based embedding lookup if available. - # Falls back to a slower position based index lookup. - if start_index is not None: - xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index) - else: - xq, xk = embedding.apply_batched_mask( - xq=xq, xk=xk, mask=embedding_batch_mask - ) - - # Full sequence length. - kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride - - if self.cache.is_paged: - xk, xv = self.transact_cache_paged( - xk_cache_update=xk, - xv_cache_update=xv, - seq_block_ids=seq_block_ids, - kv_seq_len=kv_seq_len, - start_positions=start_positions, - cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - elif self.cache.is_direct: - xk, xv = self.transact_cache_direct( - xk_cache_update=xk, - xv_cache_update=xv, - start_positions=start_positions, - kv_seq_len=kv_seq_len, - cache_state=cache_state, - ) - else: - raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}") - - # Expand kv heads for GQA. - gqa_n_rep = self.head_count // self.head_count_kv - assert gqa_n_rep > 0 - if gqa_n_rep > 1: - - def repeat_kv(x: torch.Tensor) -> torch.Tensor: - bs, slen, n_kv_heads, head_dim = x.shape - return ( - x.unsqueeze(-2) - .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim) - ) - - xk = repeat_kv(xk) - xv = repeat_kv(xv) - - # Transpose into [bs, heads, sl, dim] - xq = xq.transpose(1, 2) - keys = xk.transpose(1, 2) - values = xv.transpose(1, 2) - - # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) - self.assert_not_nan(attn_weights) - - # Apply attention mask. - self.trace_tensor("attn_weights", attn_weights, values=False) - if attention_mask is not None: - # self.trace_tensor("attn_mask", attention_mask) - attn_weights = attn_weights + attention_mask - - attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq) - attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim) - attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1) - - # Project. - attn_output = self.attn_output(attn_output) - - # Remainder of the block. - h = h + attn_output - + h = self.attn( + h, + embedding=embedding, + seq_block_ids=seq_block_ids, + start_index=start_index, + start_positions=start_positions, + attention_mask=attention_mask, + embedding_batch_mask=embedding_batch_mask, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) # Feed forward network. ffn_input = self.ffn_norm(h) - ffn_gate = F.silu(self.ffn_gate(ffn_input)) - ffn_up = self.ffn_up(ffn_input) - ffn_down = self.ffn_down(ffn_gate * ffn_up) + ffn_down = self.ffn(ffn_input) final_output = h + ffn_down return final_output - - def transact_cache_direct( - self, - *, - cache_state: list[torch.Tensor], - xk_cache_update: torch.Tensor, - xv_cache_update: torch.Tensor, - kv_seq_len: int, - start_positions: Optional[torch.Tensor] = None, - ): - bs, batch_seq_len, _, _ = xk_cache_update.shape - cache_k = cache_state[self.block_index * 2] - cache_v = cache_state[self.block_index * 2 + 1] - - if start_positions is None: - # Prefill. Write the entire cache. - cache_k[:, :batch_seq_len] = xk_cache_update - cache_v[:, :batch_seq_len] = xv_cache_update - return xk_cache_update, xv_cache_update - else: - # Decode. Write a single timestep. - # TODO: This needs to be reworked with index ops. - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - max_start_pos = 0 - for row_index in range(bs): - row_start_pos = start_positions[row_index].item() - max_start_pos = max(row_start_pos, max_start_pos) - cache_k[row_index, row_start_pos] = xk_cache_update[row_index, 0] - cache_v[row_index, row_start_pos] = xv_cache_update[row_index, 0] - return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len] - - def transact_cache_paged( - self, - *, - xk_cache_update: torch.Tensor, - xv_cache_update: torch.Tensor, - cache_state: list[torch.Tensor], - # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, - kv_seq_len: int, - start_positions: Optional[torch.Tensor] = None, - xk_temp: Optional[torch.Tensor] = None, - xv_temp: Optional[torch.Tensor] = None, - ): - cache = self.cache.paged - # Manage the cache. - if start_positions is None: - # Prefill: Write the entire cache. - cache.write( - cache_state, - cache_partitions=[xk_cache_update, xv_cache_update], - transformer_block_index=self.block_index, - page_ids=seq_block_ids, - ) - return xk_cache_update, xv_cache_update - else: - # Decode at ragged start positions. - # We need to initialize/read the K/V from the cache for the whole - # sequence. Note that at this point, it is possible to fork and - # use a memory efficient attention kernel that can do indirect - # reads, skipping this materialization. This path is taken for - # a decode step. - assert xk_temp is not None and xv_temp is not None - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride - - # Write our one updated cache row into the cache. - cache.write_timestep( - cache_state, - cache_partitions=[ - xk_cache_update, - xv_cache_update, - ], - transformer_block_index=self.block_index, - seq_positions=start_positions, - page_ids=seq_block_ids, - ) - - # Restore from the cache. - cache.read( - cache_state, - read_into_partitions=[ - xk_temp[:, 0:kv_seq_len, ...], - xv_temp[:, 0:kv_seq_len, ...], - ], - transformer_block_index=self.block_index, - page_ids=seq_block_ids, - ) - - # For computation, we create a subview of the xk/xv tensors to have - # a sequence length covering the blocked size. This must include - # the newly added row (the caller is responsible for ensuring that - # every block has at least one row left). We'll compute on this - # ragged view and use an appropriate mask. - xk = xk_temp[:, 0:kv_seq_len, ...] - xv = xv_temp[:, 0:kv_seq_len, ...] - return xk, xv diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py index c635565ca..74ed9e8e0 100644 --- a/sharktank/sharktank/models/llama/llama_ref.py +++ b/sharktank/sharktank/models/llama/llama_ref.py @@ -57,6 +57,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): "attention_embedding", RotaryEmbeddingLayer( rope_dimension_count=hp.rope_dimension_count, + rope_freq_base=hp.rope_freq_base, max_seqlen=hp.context_length, ), ) diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py index e1f428e93..b63fd5d07 100644 --- a/sharktank/sharktank/models/llama/testing.py +++ b/sharktank/sharktank/models/llama/testing.py @@ -14,7 +14,7 @@ # Range of torch.rand() is [0,1) # Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values -def make_rand_torch(shape, dtype): +def make_rand_torch(shape, dtype=torch.float32): return torch.rand(shape, dtype=dtype) * 2 - 1 @@ -54,3 +54,28 @@ def make_attention_block_theta( ), } ) + + +def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta: + return Theta( + { + "blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor( + data=make_rand_torch((feature_dim, ffn_dim)) + ), + "blk.0.ffn_norm.weight": DefaultPrimitiveTensor( + data=make_rand_torch((ffn_dim)) + ), + "blk.0.layer_output_norm.weight": DefaultPrimitiveTensor( + data=make_rand_torch((ffn_dim)) + ), + "blk.0.ffn_gate_exps.weight": DefaultPrimitiveTensor( + data=make_rand_torch((8, feature_dim * num_experts, ffn_dim)) + ), + "blk.0.ffn_up_exps.weight": DefaultPrimitiveTensor( + data=make_rand_torch((8, feature_dim * num_experts, ffn_dim)) + ), + "blk.0.ffn_down_exps.weight": DefaultPrimitiveTensor( + data=make_rand_torch((8, ffn_dim, feature_dim * num_experts)) + ), + } + ) diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py new file mode 100644 index 000000000..5a179e5b9 --- /dev/null +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -0,0 +1,280 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +from dataclasses import dataclass + +import torch +import torch.nn as nn + + +from ...layers import * +from ...types import Theta + +torch.set_printoptions(profile="full") + +__all__ = [ + "LlamaModelConfig", + "PagedMixtralModelV1", +] + +################################################################################ +# Config +################################################################################ + + +@dataclass +class LlamaModelConfig: + hp: configs.LlamaHParams + + # Block sequence stride for a paged KV cache. This must divide evenly + # into the context length. + block_seq_stride: int = 16 + + # Either "paged" or "direct". + kv_cache_type: str = "paged" + + # The device on which to place intermediate state. + device: Optional[torch.device] = None + + # Dtype to use for general FP activations not otherwise configured. + activation_dtype: torch.dtype = torch.float16 + + # Dtype to use for attention. + attention_dtype: torch.dtype = torch.float16 + + def create_kv_cache(self) -> BaseKVCache: + hp = self.hp + if self.kv_cache_type == "direct": + return DirectKVCache( + block_seq_stride=self.block_seq_stride, + transformer_block_count=hp.block_count, + attn_head_count=hp.attention_head_count_kv, + attn_head_dim=hp.attn_head_dim, + seq_length=hp.context_length, + device=self.device, + dtype=self.attention_dtype, + ) + elif self.kv_cache_type == "paged": + return PagedKVCache( + transformer_block_count=hp.block_count, + attn_head_count=hp.attention_head_count_kv, + attn_head_dim=hp.attn_head_dim, + cache_partition_count=2, # One for each of K/V. + block_seq_stride=self.block_seq_stride, + device=self.device, + dtype=self.attention_dtype, + ) + else: + raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}") + + +################################################################################ +# Models +################################################################################ + + +class PagedMixtralModelV1(BaseCausalLMModel): + """MixtralModel with a paged KV cache and supporting variable sequence + length batched inference. + + As both the caching and batching setup is complicated, this model variant + is modular, intending to be instantiated and used in an overall assembly + vs trying to providing one-stop methods that do everything. + + The inference procedure is typically: + + 1. Initialize the PagedKVCache state tensors. + 2. Generate an input mask given a vector of sequence lengths. + 3. Generate an attention mask from the input mask. + 4. Allocate a block mapping table. + 5. Invoke prefill() with a batch of sequences. + 6. Extract tokens from batched logits. + 7. Iteratively invoke decode() for as long as there are sequences needing + to be serviced. + + Various samplers and schedulers can be interleaved throughout. + """ + + def __init__(self, theta: Theta, config: LlamaModelConfig): + hp = config.hp + super().__init__( + theta, + context_length=config.hp.context_length, + device=config.device, + activation_dtype=config.activation_dtype, + attention_dtype=config.attention_dtype, + ) + self.config = config + self.hp = hp + self.cache = config.create_kv_cache() + self.activation_dtype = config.activation_dtype + self.add_module( + "token_embedding", + TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), + ) + self.add_module( + "attention_embedding", + RotaryEmbeddingLayer( + rope_dimension_count=hp.rope_dimension_count, + rope_freq_base=hp.rope_freq_base, + max_seqlen=hp.context_length, + device=self.device, + ), + ) + self.add_module( + "output_norm", + RMSNormLayer( + theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon + ), + ) + self.add_module("output_lm_head", LinearLayer(theta("output"))) + + self.attn_blocks = nn.ModuleList() + + for n in range(hp.block_count): + self.attn_blocks.append( + PagedLlamaAttentionBlock( + theta("blk", n), + block_index=n, + cache=self.cache, + head_count=hp.attention_head_count, + head_dim=hp.attn_head_dim, + head_count_kv=hp.attention_head_count_kv, + rms_epsilon=hp.attention_layer_norm_rms_epsilon, + ) + ) + self.attn_blocks.append( + SparseMoeBlock( + theta("blk", n), + expert_count=hp.expert_count, + expert_used_count=hp.expert_used_count, + rms_epsilon=hp.attention_layer_norm_rms_epsilon, + ) + ) + + def prefill( + self, + # [bs, batch_seq_len] + tokens: torch.Tensor, + *, + # [1, 1, batch_seq_len, batch_seq_len] + attention_mask: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[torch.Tensor], + ): + self._assert_device(tokens) + self._assert_device(attention_mask, dtype=self.activation_dtype) + self._assert_device(seq_block_ids) + self._assert_device(*cache_state, dtype=self.activation_dtype) + h = self.token_embedding(tokens) + self.trace_tensor("mixtral.token_embedding", h) + + # Iterate over attention blocks. + for block_idx, block in enumerate(self.attn_blocks): + if block_idx == 0: + self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) + + if block.__class__.__name__ == "PagedLlamaAttentionBlock": + h = block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + elif block.__class__.__name__ == "SparseMoeBlock": + h = block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + + h = self.output_norm(h) + logits = self.output_lm_head(h) + return logits + + def decode( + self, + # [bs, 1] + tokens: torch.Tensor, + *, + # [bs, 1, 1, batch_seq_len] + attention_mask: torch.Tensor, + # [bs] of starting positions + start_positions: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[torch.Tensor], + ): + self._assert_device(tokens) + self._assert_device(attention_mask, dtype=self.activation_dtype) + self._assert_device(start_positions) + self._assert_device(*cache_state, dtype=self.activation_dtype) + bs, _ = tokens.shape + # Precompute a position based mask for computing rope embeddings + # as it is the same for all blocks. + embedding_batch_mask = self.attention_embedding.compute_batch_mask( + start_positions, batch_seq_len=1 + ) + self.trace_tensor("mixtral.embedding_batch_mask", embedding_batch_mask) + + # Allocate per-block temporary K/V tensors. These temporaries hold + # one block's K/V state for the maximum context length. + xk_temp = torch.empty( + [ + bs, + self.context_length, + self.hp.attention_head_count_kv, + self.hp.attn_head_dim, + ], + dtype=self.config.activation_dtype, + device=self.device, + ) + xv_temp = torch.empty( + [ + bs, + self.context_length, + self.hp.attention_head_count_kv, + self.hp.attn_head_dim, + ], + dtype=self.config.activation_dtype, + device=self.device, + ) + + h = self.token_embedding(tokens) + self.trace_tensor("mixtral.token_embedding", h) + + # Iterate over attention blocks. + for block_idx, block in enumerate(self.attn_blocks): + if block_idx == 0: + self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) + + if block.__class__.__name__ == "PagedLlamaAttentionBlock": + h = block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + elif block.__class__.__name__ == "SparseMoeBlock": + h = block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + + h = self.output_norm(h) + logits = self.output_lm_head(h) + return logits diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py new file mode 100644 index 000000000..392f60a25 --- /dev/null +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -0,0 +1,164 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ...layers import * +from ...types import Theta + +__all__ = [ + "RefLlamaModelConfig", + "DirectCacheMixtralModelV1", +] + + +################################################################################ +# Config +################################################################################ + + +@dataclass +class RefLlamaModelConfig: + hp: configs.LlamaHParams + + # Dtype to use for general FP activations not otherwise configured. + activation_dtype: torch.dtype = torch.float16 + + +################################################################################ +# Models +################################################################################ + + +class DirectCacheMixtralModelV1(ThetaLayer): + """Simple Mixtral Model with a direct lookup KV cache for batch-1 inference.""" + + def __init__(self, theta: Theta, config: RefLlamaModelConfig): + super().__init__(theta) + hp = config.hp + self.config = config + self.hp = hp + self.activation_dtype = config.activation_dtype + self.add_module( + "token_embedding", + TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), + ) + self.add_module( + "attention_embedding", + RotaryEmbeddingLayer( + rope_dimension_count=hp.rope_dimension_count, + rope_freq_base=hp.rope_freq_base, + max_seqlen=hp.context_length, + ), + ) + self.add_module( + "output_norm", + RMSNormLayer( + theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon + ), + ) + self.add_module("output_lm_head", LinearLayer(theta("output"))) + + self.attn_blocks = nn.ModuleList() + + for n in range(hp.block_count): + self.attn_blocks.append( + LlamaAttentionBlock( + theta("blk", n), + embedding=self.attention_embedding, + head_count=hp.attention_head_count, + head_dim=hp.rope_dimension_count, + head_count_kv=hp.attention_head_count_kv, + rms_epsilon=hp.attention_layer_norm_rms_epsilon, + ) + ) + self.attn_blocks.append( + SparseMoeBlock( + theta("blk", n), + expert_count=hp.expert_count, + expert_used_count=hp.expert_used_count, + rms_epsilon=hp.attention_layer_norm_rms_epsilon, + ) + ) + + def create_cache(self, bs: int) -> list[torch.Tensor]: + return [ + torch.empty( + ( + bs, + self.hp.context_length, + self.hp.attention_head_count, + self.hp.rope_dimension_count, + ), + dtype=self.activation_dtype, + ) + for _ in range(self.hp.block_count * 2) + ] + + def forward( + self, + tokens: torch.Tensor, + start_index: int, + *, + return_logits: bool = False, + local_kv_cache: list[torch.Tensor], + ): + bs, sl = tokens.shape + h = self.token_embedding(tokens) + dtype = h.dtype + self.trace_tensor("mixtral.token_embedding", h) + + # Compute attention mask. + attention_mask = None + if sl > 1: + # Use the smallest value like HF as opposed to -inf like original. + # A little bit easier for some systems. + attention_mask = torch.full( + (1, 1, sl, sl), torch.finfo(dtype).min, dtype=dtype + ) + attention_mask = torch.triu( + attention_mask, diagonal=start_index + 1 + ).type_as(h) + + # block_count = (total no. of blocks)/2, excluding MoE blocks + block_count = len(self.attn_blocks) // 2 + # print('local_kv_cache, #attn_blocks', len(local_kv_cache), block_count) + # Iterate over attention + MoE blocks. + for block_idx, block in enumerate(self.attn_blocks): + # print("block_idx, block", block_idx, block) + if block_idx == 0: + self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) + + if block.__class__.__name__ == "LlamaAttentionBlock": + attn_block_idx = block_idx // 2 + block_cache_k = local_kv_cache[attn_block_idx] + block_cache_v = local_kv_cache[block_count + attn_block_idx] + h = block( + h, + cache_k=block_cache_k, + cache_v=block_cache_v, + start_index=start_index, + attention_mask=attention_mask, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + elif block.__class__.__name__ == "SparseMoeBlock": + h = block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + + h = self.output_norm(h) + logits = self.output_lm_head(h) + + if return_logits: + return h + else: + last_step = logits[:, -1, :] + token = torch.argmax(last_step, keepdim=True, dim=1) + return token.to(tokens.dtype) diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py index cf10c856f..bb5eb254d 100644 --- a/sharktank/tests/models/llama/attention_test.py +++ b/sharktank/tests/models/llama/attention_test.py @@ -11,7 +11,7 @@ from sharktank.models.llama.testing import * from sharktank.layers.rotary_embedding import RotaryEmbeddingLayer -from sharktank.models.llama.llama import PagedLlamaAttentionBlock, PagedKVCache +from sharktank.models.llama.llama import AttentionFFNBlock, PagedKVCache from sharktank import ops from transformers.models.llama.modeling_llama import ( @@ -36,6 +36,7 @@ def test(self): block_seq_stride = 1 rms_epsilon = 0.01 rope_dimension_count = 100 + rope_freq_base = 10000.0 max_seq_len = 2048 attention_block_theta = make_attention_block_theta( feature_dim=head_count * head_dim, ffn_dim=ffn_dim, dtype=torch.float32 @@ -49,7 +50,7 @@ def test(self): device="cpu", dtype=torch.float32, ) - attention_block = PagedLlamaAttentionBlock( + attention_block = AttentionFFNBlock( theta=attention_block_theta, block_index=block_index, cache=paged_kv_cache, @@ -61,6 +62,7 @@ def test(self): ) attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=rope_dimension_count, + rope_freq_base=rope_freq_base, max_seqlen=max_seq_len, device="cpu", use_hf=True, diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py new file mode 100644 index 000000000..e04ca11fd --- /dev/null +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -0,0 +1,35 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from typing import List + +import torch +from shark_turbine.aot import * +from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch +from sharktank.layers.mixture_of_experts_block import SparseMoeBlock +from sharktank import ops + + +class SparseMoeBlockTest(unittest.TestCase): + @unittest.skip("Skip test until grok implementation") + def test(self): + model = SparseMoeBlock( + theta=make_moe_block_theta()("blk.0"), + expert_count=8, + expert_used_count=2, + rms_epsilon=1e-5, + ) + fxb = FxProgramsBuilder(model) + input = make_rand_torch((2, 16, 6144)) + + @fxb.export_program(name="moe_block", args=(input,)) + def _(model, input: torch.Tensor) -> torch.Tensor: + return model(input) + + +if __name__ == "__main__": + unittest.main()