From 926d48381915d85f8ae591830ad4332fb4768716 Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 13:59:32 -0800 Subject: [PATCH] rename cache_type to prefix_sharing_algorithm --- app_tests/integration_tests/llm/shortfin/conftest.py | 6 +++++- .../python/shortfin_apps/llm/components/config_struct.py | 2 +- shortfin/python/shortfin_apps/llm/components/service.py | 6 +++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 0d40119c7..9724485cd 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -83,7 +83,11 @@ def model_test_dir(request, tmp_path_factory): "prefill_batch_sizes": batch_sizes, "decode_batch_sizes": batch_sizes, "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + "paged_kv_cache": { + "block_seq_stride": 16, + "device_block_count": 256, + "prefix_sharing_algorithm": "none", + }, } logger.info(f"Saving edited config to: {edited_config_path}\n") logger.info(f"Config: {json.dumps(config, indent=2)}") diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index eb93f017e..7caed5d07 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -86,7 +86,7 @@ class PagedKVCacheParams: # Size of the cache on each device. device_block_count: int - cache_type: str = "base" # currently supporting base and trie + prefix_sharing_algorithm: str = "none" # currently supporting none and trie @dataclass_json(undefined=Undefined.RAISE) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 8b9c39c68..ed1be03db 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -68,19 +68,19 @@ def __init__( page_pool = PagePool( devices=self.main_fiber.devices_dict.values(), config=page_pool_config ) - if model_params.paged_kv_cache.cache_type == "trie": + if model_params.paged_kv_cache.prefix_sharing_algorithm == "trie": self.page_cache = TriePagedAttentionCache( page_pool=page_pool, tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) - elif model_params.paged_kv_cache.cache_type == "base": + elif model_params.paged_kv_cache.prefix_sharing_algorithm == "none": self.page_cache = BasePagedAttentionCache( page_pool=page_pool, tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) else: raise ValueError( - f"Unknown model_params.paged_kv_cache.cache_type {model_params.paged_kv_cache.cache_type}. Currently only supporting 'trie' and 'base'." + f"Unknown model_params.paged_kv_cache.prefix_sharing_algorithm {model_params.paged_kv_cache.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'." ) self.program_isolation = PROG_ISOLATIONS[program_isolation]