From 8cd3f850f0a7864d9e004483fcf4bbea08d3e21a Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Mon, 2 Dec 2024 14:14:27 -0500 Subject: [PATCH] Implement PageAllocation as a handle into a PagedAttentionCache, allowing publishing and releasing an allocation via handle rather than cache (#608) Deinitialization looks wonky for now. Will test extensively to get deinit right once I merge #600 Closes #607 --- .../shortfin_apps/llm/components/__init__.py | 0 .../llm/components/kvcache/__init__.py | 0 .../kvcache/base_attention_cache.py | 86 +++++++--- .../shortfin_apps/llm/components/messages.py | 37 ++-- .../shortfin_apps/llm/components/service.py | 59 ++++--- .../kvcache/base_attention_cache_test.py | 159 ++++++++++++++++++ 6 files changed, 266 insertions(+), 75 deletions(-) create mode 100644 shortfin/python/shortfin_apps/llm/components/__init__.py create mode 100644 shortfin/python/shortfin_apps/llm/components/kvcache/__init__.py create mode 100644 shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py diff --git a/shortfin/python/shortfin_apps/llm/components/__init__.py b/shortfin/python/shortfin_apps/llm/components/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/__init__.py b/shortfin/python/shortfin_apps/llm/components/kvcache/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 0007000bc..73134903c 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -8,9 +8,66 @@ Base class for kv caches. """ -from typing import List +from typing import List, Iterable, Protocol from .page_pool import PageInfo import math +from abc import ABC, abstractmethod +from .page_pool import PagePool + +# logging +import logging + +logger = logging.getLogger(__name__) + +# exception for when cache allocation failed +class CacheAllocationFailure(Exception): + pass + + +class PageAllocation(ABC): + """Abstract base class for page allocations in the cache.""" + + @property + @abstractmethod + def pages(self) -> List[PageInfo]: + """Returns the list of pages that were allocated.""" + pass + + @abstractmethod + def publish_pages(self, up_to_page_index) -> None: + """Makes pages[0:up_to_page_index] available to other requests.""" + pass + + @abstractmethod + def release_pages(self) -> None: + """Releases the allocation's reference to pages.""" + pass + + +class BasePageAttentionCacheAllocation(PageAllocation): + """Represents a page allocation in the cache.""" + + def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"): + self._pages = tuple(pages) + self._cache = cache + self._is_released = False + + @property + def pages(self) -> List[PageInfo]: + return list(self._pages) + + def publish_pages(self, up_to_page_index) -> None: + pass + + def release_pages(self) -> None: + if self._is_released: + logger.warning("Releasing already-released allocation") + return + self._cache.page_pool.release_pages(self._pages) + self._is_released = True + + def __rerp__(self) -> str: + return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})" class BasePagedAttentionCache: @@ -33,13 +90,13 @@ class BasePagedAttentionCache: - Reference counting prevents eviction of in-use pages """ - def __init__(self, page_pool, tokens_per_page): + def __init__(self, page_pool: PagePool, tokens_per_page: int): self.page_pool = page_pool self.tokens_per_page = tokens_per_page def acquire_pages_for_tokens( self, tokens: List[int], extra_token_slots: int = 1 - ) -> tuple[list[PageInfo], int]: + ) -> PageAllocation: """ Given a list of tokens, return a list of pages and a start position to continue generation from. @@ -57,24 +114,7 @@ def acquire_pages_for_tokens( pages_needed = math.ceil(token_count / self.tokens_per_page) pages = self.page_pool.acquire_free_pages(pages_needed) - n_cached_tokens = 0 - - return pages, n_cached_tokens - - def publish_pages(self, tokens, pages) -> None: - """ - Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. - - Associates the tokens with the pages, and mark them as done writing. - - It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. - """ - - pass # the base implementation doesn't cache unfinished requests. + if pages is None: + raise CacheAllocationFailure() - def release_pages(self, tokens, pages): - """ - Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. - """ - # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release - self.page_pool.release_pages(pages) + return BasePageAttentionCacheAllocation(pages, cache=self) diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index c3e6fe34b..c03900782 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -9,7 +9,7 @@ import shortfin as sf import shortfin.array as sfnp -from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.base_attention_cache import BasePagedAttentionCache, PageAllocation from .kvcache.page_pool import PageInfo @@ -43,7 +43,7 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]): # Cache pages that have been locked for this request. self._cache: BasePagedAttentionCache | None = None - self.locked_pages: list[PageInfo] | None = None + self.allocation: PageAllocation | None = None def reset(self, phase: InferencePhase): """Resets all per request state in preparation for an subsequent execution.""" @@ -52,35 +52,22 @@ def reset(self, phase: InferencePhase): self.return_all_logits = False self.return_host_array = True self.result_logits = None + self.allocation.release_pages() + self.allocation = None def cache_page_indices(self, max_len: int) -> list[int]: - if not self.locked_pages: + if not self.allocation: return [] - indices = [p.index for p in self.locked_pages] - if len(indices) > max_len: - return indices[0:max_len] + indices = [p.index for p in self.allocation.pages[:max_len]] return indices + def publish_allocated_pages(self, up_to_page_index: int): + assert self.allocation + self.allocation.publish_pages(up_to_page_index) + def free_cache_pages(self): - cache = self._cache - if cache: - pages = self.locked_pages - self._cache = None - self.locked_pages = None - cache.release_pages(self.input_token_ids, pages) - - def lock_initial_cache_pages( - self, cache: BasePagedAttentionCache, pages: list[PageInfo] - ): - assert not self._cache - self._cache = cache - self.locked_pages = pages - - def lock_new_cache_pages( - self, cache: BasePagedAttentionCache, pages: list[PageInfo] - ): - assert self._cache is cache - self.locked_pages.extend(pages) + if self.allocation: + self.allocation.release_pages() class StrobeMessage(sf.Message): diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 8d3cc1424..2f942aec7 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -11,8 +11,12 @@ import shortfin as sf import shortfin.array as sfnp -from .kvcache.base_attention_cache import BasePagedAttentionCache -from .kvcache.page_pool import PagePoolConfig, PagePool +from .kvcache.base_attention_cache import ( + BasePagedAttentionCache, + CacheAllocationFailure, + PageAllocation, +) +from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage @@ -229,16 +233,17 @@ def board_prefills(self, cache: BasePagedAttentionCache): len(prefill_request.input_token_ids) / self.page_seq_stride ) # allocate kv cache pages - pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( - prefill_request.input_token_ids, - extra_token_slots=0, # prefill needs no extra kvcache slots to write to - ) - if pages is None: + try: + allocation = cache.acquire_pages_for_tokens( + prefill_request.input_token_ids, + extra_token_slots=0, # prefill needs no extra kvcache slots to write to + ) + except CacheAllocationFailure: logger.debug("Cannot fulfill request for %d pages", needed_pages) continue - else: - logger.debug("Allocated %d cache pages to request", len(pages)) - prefill_request.lock_initial_cache_pages(cache, pages) + logger.debug(f"Successfully acquired allocation: {allocation}") + prefill_request.free_cache_pages() + prefill_request.allocation = allocation # Can flight this request. exec_process.exec_requests.append(prefill_request) @@ -266,26 +271,20 @@ def board_decodes(self, cache: BasePagedAttentionCache): if len(exec_process.exec_requests) >= self.ideal_batch_size: break incoming_token_count = len(decode_request.input_token_ids) - needed_pages = math.ceil( - (decode_request.start_position + incoming_token_count) - / self.page_seq_stride - ) - if needed_pages > len(decode_request.locked_pages): - # allocate kv cache pages - pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + + try: + allocation = cache.acquire_pages_for_tokens( decode_request.input_token_ids, extra_token_slots=1, # need 1 extra slot to write result. ) - if pages is None: - logger.debug( - "Cannot fulfill decode request for %d pages", needed_pages - ) - continue - else: - logger.debug( - "Allocated %d cache pages to decode request", len(pages) - ) - decode_request.lock_new_cache_pages(cache, pages) + except CacheAllocationFailure: + logger.debug( + "Cannot fulfill request for %d tokens", + len(decode_request.input_token_ids), + ) + + decode_request.free_cache_pages() + decode_request.allocation = allocation # Can flight this request. exec_process.exec_requests.append(decode_request) @@ -438,6 +437,12 @@ async def run(self): # Invoke. Logits are of shape [bs, bsl, d]. (logits,) = await fn(*args, fiber=self.fiber) + # publish cache pages + for r in self.exec_requests: + total_tokens = r.start_position + len(r.input_token_ids) + number_of_complete_pages = total_tokens // seq_stride + r.publish_allocated_pages(number_of_complete_pages) + # Return results. for i in range(req_count): req = self.exec_requests[i] diff --git a/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py new file mode 100644 index 000000000..113da6912 --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/base_attention_cache_test.py @@ -0,0 +1,159 @@ +import pytest +import threading +import queue +import random +import time +from collections import defaultdict +from unittest.mock import Mock +from dataclasses import dataclass +from typing import List, Optional, Set + +from shortfin_apps.llm.components.kvcache.base_attention_cache import ( + BasePagedAttentionCache, + BasePageAttentionCacheAllocation, + CacheAllocationFailure, +) +from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PageInfo + +TEST_PAGE_SIZE = 16 +TEST_POOL_CAPACITY = 10 + + +class MockPagePool(PagePool): + def __init__(self, total_pages: int): + self._queue = queue.Queue() + for i in range(total_pages): + page = PageInfo(index=i, pool=self, token_offset=0, token_count=0) + self._queue.put(page) + + def acquire_free_pages(self, count: int) -> List[PageInfo]: + try: + return [self._queue.get_nowait() for _ in range(count)] + except queue.Empty: + return None + + def release_pages(self, pages): + for page in pages: + self._queue.put(page) + + +@pytest.fixture +def page_pool(): + return MockPagePool(total_pages=TEST_POOL_CAPACITY) + + +@pytest.fixture +def cache(page_pool): + return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE) + + +# fmt: off +@pytest.mark.parametrize( + "tokens,expected_pages,case_name", + [ # Tokens Pages Case Name + ([], 0, "empty_token_list"), + (list(range(TEST_PAGE_SIZE // 2)), 1, "partial_page"), + (list(range(TEST_PAGE_SIZE)), 1, "exact_page"), + (list(range(TEST_PAGE_SIZE + 1)), 2, "just_over_one_page"), + (list(range(TEST_PAGE_SIZE * 2)), 2, "multiple_exact_pages"), + (list(range(TEST_PAGE_SIZE * 2 + 1)), 3, "multiple_pages_with_remainder"), + (list(range(TEST_PAGE_SIZE * 3)), 3, "three_exact_pages"), + (list(range(1)), 1, "single_token"), + (list(range(TEST_PAGE_SIZE - 1)), 1, "almost_full_page"), + ], +) +# fmt: on +def test_allocation_sizes(cache, tokens, expected_pages, case_name): + allocation = cache.acquire_pages_for_tokens(tokens) + pages = allocation.pages + assert len(pages) == expected_pages, f"Failed for case: {case_name}" + allocation.release_pages() + + +# fmt: off +@pytest.mark.parametrize( + "num_workers,pages_per_worker,expect_failure,case_name", + [ # Workers Pages Failure Case name + (2, 1, False, "basic_concurrent"), # Basic concurrent access + (5, 1, False, "high_concurrency"), # Higher concurrency, single page + (3, 2, False, "multi_page"), # Multiple pages per worker + (2, 3, False, "more_pages"), # More pages than workers, within capacity + (TEST_POOL_CAPACITY, 1, False, "max_capacity"), # Max capacity single pages + (TEST_POOL_CAPACITY // 2, 2, False, "max_capacity_multi"), # Max capacity multiple pages + (4, 3, True , "exceeds_total"), # 12 pages needed, exceeds capacity + (TEST_POOL_CAPACITY + 1, 1, True , "exceeds_workers"), # More workers than capacity + (TEST_POOL_CAPACITY // 2, 3, True , "exceeds_with_multi"), # Exceeds capacity with multiple pages + ], +) +# fmt: on +def test_concurrent_page_allocation( + cache, + num_workers, + pages_per_worker, + expect_failure, + case_name, +): + allocated_pages = defaultdict(set) + errors = [] + allocations = [] + + def worker(worker_id: int): + try: + tokens = list(range(TEST_PAGE_SIZE * pages_per_worker)) + allocation = cache.acquire_pages_for_tokens(tokens) + allocations.append(allocation) + allocated_pages[worker_id] = {page.index for page in allocation.pages} + time.sleep(random.uniform(0.001, 0.01)) + except CacheAllocationFailure as e: + errors.append(e) + except Exception as e: + pytest.fail(f"Unexpected error: {e}") + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(num_workers)] + + for t in threads: + t.start() + for t in threads: + t.join() + + if expect_failure: + assert len(errors) > 0, "Expected at least one CacheAllocationFailure" + else: + assert not errors, f"Workers encountered errors: {errors}" + for worker_id, pages in allocated_pages.items(): + assert ( + len(pages) == pages_per_worker + ), f"Worker {worker_id} got {len(pages)} pages, expected {pages_per_worker}" + + all_pages = set() + for pages in allocated_pages.values(): + assert not ( + pages & all_pages + ), f"Found duplicate page allocation: {pages & all_pages}" + all_pages.update(pages) + + for allocation in allocations: + allocation.release_pages() + + +@pytest.mark.parametrize( + "total_pages_needed", + [ + TEST_POOL_CAPACITY + 1, # Just over capacity + TEST_POOL_CAPACITY * 2, # Double capacity + ], +) +def test_allocation_failure_when_exhausted(cache, total_pages_needed): + successful_allocations = [] + + try: + tokens = list(range(TEST_PAGE_SIZE * total_pages_needed)) + allocation = cache.acquire_pages_for_tokens(tokens) + successful_allocations.append(allocation) + except CacheAllocationFailure as e: + pass + else: + pytest.fail("Expected CacheAllocationFailure was not raised") + finally: + for alloc in successful_allocations: + alloc.release_pages()