Skip to content

Commit

Permalink
TriePagedAttentionCache (#632)
Browse files Browse the repository at this point in the history
feat: Add TriePagedAttentionCache with initial implementation

Added TriePagedAttentionCache as an optional prefix sharing algorithm,
selectable via:
`config["paged_kv_cache"]["prefix_sharing_algorithm"] = "trie"`

Current Status:
- Basic implementation and unit tests complete
- Integration test cases for both Base and Trie implementations, with
trie implementation xfailed due to pending cache allocation improvements
- BasePagedAttentionCache remains the default

Next Steps:
To achieve full functionality, we need to support cache re-allocations
to extend the associated tokens & pages.
  • Loading branch information
renxida authored Dec 4, 2024
1 parent ddbabba commit de4d2fe
Show file tree
Hide file tree
Showing 10 changed files with 943 additions and 40 deletions.
7 changes: 6 additions & 1 deletion app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def model_test_dir(request, tmp_path_factory):
tokenizer_id = request.param["tokenizer_id"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]

tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test")
hf_home = os.environ.get("HF_HOME", None)
Expand Down Expand Up @@ -83,7 +84,11 @@ def model_test_dir(request, tmp_path_factory):
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 26,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
"paged_kv_cache": {
"block_seq_stride": 16,
"device_block_count": 256,
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
logger.info(f"Saving edited config to: {edited_config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
Expand Down
16 changes: 14 additions & 2 deletions app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,28 @@ def do_generate(prompt, port):
@pytest.mark.parametrize(
"model_test_dir,llm_server",
[
(
pytest.param(
{
"repo_id": "SlyEcho/open_llama_3b_v2_gguf",
"model_file": "open-llama-3b-v2-f16.gguf",
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
"prefix_sharing_algorithm": "trie",
},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
)
),
pytest.param(
{
"repo_id": "SlyEcho/open_llama_3b_v2_gguf",
"model_file": "open-llama-3b-v2-f16.gguf",
"tokenizer_id": "openlm-research/open_llama_3b_v2",
"settings": CPU_SETTINGS,
"batch_sizes": [1, 4],
"prefix_sharing_algorithm": "none",
},
{"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
),
],
indirect=True,
)
Expand Down
2 changes: 2 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class PagedKVCacheParams:
# Size of the cache on each device.
device_block_count: int

prefix_sharing_algorithm: str = "none" # currently supporting none and trie


@dataclass_json(undefined=Undefined.RAISE)
@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,28 @@ def pages(self) -> List[PageInfo]:
pass

@abstractmethod
def publish_pages(self, up_to_page_index) -> None:
"""Makes pages[0:up_to_page_index] available to other requests."""
def publish_pages_for_tokens(
self, tokens, *, publish_incomplete_page=False
) -> None:
"""
Makes pages available to other requests. For details, reference the derived class in trie_attention_cache.py.
"""
pass

@abstractmethod
def release_pages(self) -> None:
"""Releases the allocation's reference to pages."""
pass

@abstractmethod
def extend_allocation(self, tokens, *, extra_token_slots=0) -> None:
"""
Extends the allocation to include additional tokens. For details, reference the derived class in trie_attention_cache.py.
"""
pass


class BasePageAttentionCacheAllocation(PageAllocation):
class BasePagedAttentionCacheAllocation(PageAllocation):
"""Represents a page allocation in the cache."""

def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"):
Expand All @@ -56,18 +67,33 @@ def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"):
def pages(self) -> List[PageInfo]:
return list(self._pages)

def publish_pages(self, up_to_page_index) -> None:
def publish_pages_for_tokens(
self, tokens, *, publish_incomplete_page=False
) -> None:
pass

def release_pages(self) -> None:
if self._is_released:
logger.warning("Releasing already-released allocation")
return
self._cache.page_pool.release_pages(self._pages)
self._cache.page_pool.free_pages(self._pages)
self._is_released = True

def extend_allocation(self, tokens, *, extra_token_slots=0) -> None:
# assert old tokens are a prefix of incoming tokens
# if we don't have enough pages to hold the tokens, we need to allocate more pages
token_count = len(tokens) + extra_token_slots
pages_needed = math.ceil(token_count / self._cache.tokens_per_page)
if pages_needed > len(self._pages):
new_pages = self._cache.page_pool.acquire_free_pages(
pages_needed - len(self._pages)
)
if new_pages is None:
raise CacheAllocationFailure()
self._pages += tuple(new_pages)

def __rerp__(self) -> str:
return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})"
return f"BasePagedAttentionCacheAllocation(pages={self._pages}, cache={self._cache})"


class BasePagedAttentionCache:
Expand Down Expand Up @@ -117,4 +143,4 @@ def acquire_pages_for_tokens(
if pages is None:
raise CacheAllocationFailure()

return BasePageAttentionCacheAllocation(pages, cache=self)
return BasePagedAttentionCacheAllocation(pages, cache=self)
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig
for i in range(self.config.alloc_page_count)
]

self.attn_page_free = list(self.attn_page_entries)
self.available_pages = list(self.attn_page_entries)

# Initialize a page table on each device.
page_table_shape = [
Expand All @@ -108,14 +108,14 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig

def acquire_free_pages(self, count: int) -> list[PageInfo] | None:
with self._lock:
available = len(self.attn_page_free)
available = len(self.available_pages)
if count > available:
return None
return [self.attn_page_free.pop() for _ in range(count)]
return [self.available_pages.pop() for _ in range(count)]

def release_pages(self, pages: list[PageInfo]):
def free_pages(self, pages: list[PageInfo]):
with self._lock:
self.attn_page_free.extend(pages)
self.available_pages.extend(pages)

def copy_page(self, src_page: PageInfo) -> PageInfo:
"""
Expand Down Expand Up @@ -148,7 +148,7 @@ def copy_page(self, src_page: PageInfo) -> PageInfo:

def __repr__(self):
# No need to lock for repr (list is internally synchronized).
free_pages = len(self.attn_page_free)
free_pages = len(self.available_pages)
total_pages = len(self.attn_page_entries)
return (
f"PagePool({total_pages - free_pages}/{total_pages} pages in use: "
Expand Down
Loading

0 comments on commit de4d2fe

Please sign in to comment.