diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py index 27bfddfa2..1a2633b0e 100644 --- a/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py @@ -68,35 +68,35 @@ def write_config(request, pre_process_model): batch_sizes = request.param["batch_sizes"] prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"] - logger.info("Writing config file..." + start_log_group("Writing config file")) - + # Construct the new config filename config_path = ( pre_process_model / f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json" ) - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 131072, - "attn_head_count": 8, - "attn_head_dim": 128, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 32, - "paged_kv_cache": { - "block_seq_stride": 16, - "device_block_count": 256, - "prefix_sharing_algorithm": prefix_sharing_algorithm, - }, - } + # Read the base config file + base_config_path = pre_process_model / "config.json" + with open(base_config_path, "r") as f: + config = json.load(f) + + # Override specific fields + config.update( + { + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "paged_kv_cache": { + **config.get( + "paged_kv_cache", {} + ), # Preserve other paged_kv_cache settings + "prefix_sharing_algorithm": prefix_sharing_algorithm, + }, + } + ) logger.info(f"Saving edited config to: {config_path}\n") logger.info(f"Config: {json.dumps(config, indent=2)}") with open(config_path, "w") as f: json.dump(config, f) - - logger.info("Config file successfully written" + end_log_group()) yield config_path diff --git a/app_tests/integration_tests/llm/sglang/conftest.py b/app_tests/integration_tests/llm/sglang/conftest.py index 8543708da..cc79fc365 100644 --- a/app_tests/integration_tests/llm/sglang/conftest.py +++ b/app_tests/integration_tests/llm/sglang/conftest.py @@ -64,21 +64,6 @@ def pre_process_model(request, tmp_path_factory): device_settings, ) - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 131072, - "attn_head_count": 8, - "attn_head_dim": 128, - "prefill_batch_sizes": [1, 4], - "decode_batch_sizes": [1, 4], - "transformer_block_count": 32, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - config_path = tmp_dir / "config.json" - with open(config_path, "w") as f: - json.dump(config, f) - return tmp_dir diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 42e541506..55c9e8bdc 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -87,26 +87,30 @@ def write_config(request, model_test_dir): batch_sizes = request.param["batch_sizes"] prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"] + # Construct the new config filename config_path = ( model_test_dir / f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json" ) - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 26, - "paged_kv_cache": { - "block_seq_stride": 16, - "device_block_count": 256, - "prefix_sharing_algorithm": prefix_sharing_algorithm, - }, - } + # Read the base config file + base_config_path = model_test_dir / "config.json" + with open(base_config_path, "r") as f: + config = json.load(f) + + # Override specific fields + config.update( + { + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "paged_kv_cache": { + **config.get( + "paged_kv_cache", {} + ), # Preserve other paged_kv_cache settings + "prefix_sharing_algorithm": prefix_sharing_algorithm, + }, + } + ) logger.info(f"Saving edited config to: {config_path}\n") logger.info(f"Config: {json.dumps(config, indent=2)}") with open(config_path, "w") as f: diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 6dd9785c3..900c1a9ae 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -7,6 +7,7 @@ """Export support for the PagedLLMV1 protocol of models.""" import json +from typing import Any, Dict import torch from iree.turbine.aot import * @@ -86,17 +87,29 @@ def main(): else: model = PagedLlamaModelV1(dataset.root_theta, llama_config) - def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): + def generate_params_json( + hp: LlamaHParams, prefill_bs: list[int], decode_bs: list[int] + ) -> Dict[str, Any]: + """ + Generate config.json for shortfin. + + + For shortfin, we only write attention_head_count_kv because that's all shortfin needs. + Note that this is different from hp.attn_head_count when grouped attention shares kvcache between heads. + """ return { "module_name": "module", "module_abi_version": 1, "max_seq_len": hp.context_length, - "attn_head_count": hp.attention_head_count, "attn_head_dim": hp.attn_head_dim, "prefill_batch_sizes": prefill_bs, "decode_batch_sizes": decode_bs, "transformer_block_count": hp.block_count, - "block_seq_stride": llama_config.block_seq_stride, + "paged_kv_cache": { + "attention_head_count_kv": hp.attention_head_count_kv, + "block_seq_stride": llama_config.block_seq_stride, + "device_block_count": 256, # so that this makes its way into the config file & can be edited. + }, } # Unrolling cache updates by batch row makes dynamo sad without an diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 7caed5d07..8fefa0a12 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -11,20 +11,20 @@ In a typical transformer model, the KV cache is organized similar to (mapped to our parameter names below): k = tensor.empty(transformer_block_count, batch_size, seq, - attn_head_count, attn_head_dim) + attn_head_count_kv, attn_head_dim) v = ... For context, a popular model has parameters of: attn_dtype_size = 2 # (fp16) max_seq_len = 2048 transformer_block_count = 32 - attn_head_count = 32 + attn_head_count_kv = 32 attn_head_dim = 128 # (dim / head_count) If paging, then we primarily care about the organization of a single block, where a block represents a single position in the sequence for a single item in the batch. Therefore, it will be organized like: - block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) + block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim) In this scenario, we declare that one block holds the KV cache for all transformer block layers because it reduces the accounting. As such, for the above example, @@ -80,10 +80,15 @@ def _decode_dtype(name: str) -> sfnp.DType: class PagedKVCacheParams: """Parameters for the paged KV cache.""" - # Position stride per attention block + # Tokens per page. block_seq_stride: int + # Number of attention heads per block. This can be different from the model's + # attention head count due to sharing. + attention_head_count_kv: int + # Size of the cache on each device. + # Default: 256 device_block_count: int prefix_sharing_algorithm: str = "none" # currently supporting none and trie @@ -92,19 +97,23 @@ class PagedKVCacheParams: @dataclass_json(undefined=Undefined.RAISE) @dataclass class ModelParams: - """Parameters for a specific compiled model, sufficient to do cache planning and - invocations.""" + """ + Parameters for a specific compiled model, sufficient to do cache planning and + invocations. + + Compatibility should be maintained with function generate_params_json in + + sharktank/sharktank/examples/export_paged_llm_v1.py + """ # Maximum length of a sequence including prompt and output. max_seq_len: int - # Number of transformer blocks. + # Number of transformer layers (aka attention blocks / transformer blocks). transformer_block_count: int - # Number of attention heads per block. - attn_head_count: int - - # Dimensionality of each attention head + # Dimensionality of each attention head. This is the dimensionality of the + # key and value vectors. AKA rope_dimension_count from the GGUF props. attn_head_dim: int # Batch sizes that the prefill stage is compiled for. These are expected to be @@ -159,7 +168,7 @@ def paged_kv_unit_size_elements(self) -> int: size = 1 size *= self.transformer_block_count size *= 2 # K and V cache line - size *= self.attn_head_count + size *= self.paged_kv_cache.attention_head_count_kv size *= self.attn_head_dim return size