Skip to content

Commit

Permalink
initial PageAllocation handle implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Nov 26, 2024
1 parent ddc3091 commit 00e608d
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,91 @@
Base class for kv caches.
"""

from typing import List
from typing import List, Iterable, Protocol
from .page_pool import PageInfo
import math
from abc import ABC, abstractmethod
from .page_pool import PagePool

# logging
import logging

logger = logging.getLogger(__name__)

# exception for when cache allocation failed
class CacheAllocationFailure(Exception):
pass


class PageAllocation(ABC):
"""
Abstract base class for page allocations in the cache.
Subclasses only need to implement the core allocation methods.
"""

@abstractmethod
def get_page_list(self) -> List[PageInfo]:
"""Returns the list of pages that were allocated."""
pass

@abstractmethod
def publish_pages(self, up_to_page_index) -> None:
"""
Makes self.get_page_list()[0:up_to_page_index] available to other requests after writing is complete.
Associates tokens with pages and marks them as ready for reading.
"""
pass

@abstractmethod
def release_pages(self) -> None:
"""
Releases the allocation's reference to pages.
Pages become eligible for eviction when their reference count reaches zero.
"""
pass


class BasePageAttentionCacheAllocation(PageAllocation):
"""
Represents a page allocation in the cache, implementing the PageAllocation protocol.
"""

def __init__(self, pages: Iterable[PageInfo], cache: "BasePagedAttentionCache"):
# this should only be called by the associated attention cache &
self._pages = tuple(pages)
self._cache = cache
self._is_released = False

def get_page_list(self) -> List[PageInfo]:
return list(self._pages) # return a list, as expected by service.py

def publish_pages(self, up_to_page_index) -> None:
"""
Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests.
Associates the tokens with the pages, and mark them as done writing.
It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)].
This should be called when the request has finished writing to the pages.
"""
pass # the base implementation doesn't cache unfinished requests.

def release_pages(self) -> None:
"""
Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction.
This should be called when the request has finished reading from the pages.
"""
# in the base implementation, the pages can be owned by 1 request max, so they can be instantly release
if self._is_released:
logger.warning("Releasing already-released allocation")
return
self._cache.page_pool.release_pages(self._pages)
self._is_released = True

def __repr__(self):
return f"BasePageAttentionCacheAllocation(pages={self._pages}, cache={self._cache})"


class BasePagedAttentionCache:
Expand All @@ -33,13 +115,13 @@ class BasePagedAttentionCache:
- Reference counting prevents eviction of in-use pages
"""

def __init__(self, page_pool, tokens_per_page):
def __init__(self, page_pool: PagePool, tokens_per_page: int):
self.page_pool = page_pool
self.tokens_per_page = tokens_per_page

def acquire_pages_for_tokens(
self, tokens: List[int], extra_token_slots: int = 1
) -> tuple[list[PageInfo], int]:
) -> PageAllocation:
"""
Given a list of tokens, return a list of pages and a start position to continue generation from.
Expand All @@ -57,24 +139,9 @@ def acquire_pages_for_tokens(
pages_needed = math.ceil(token_count / self.tokens_per_page)
pages = self.page_pool.acquire_free_pages(pages_needed)

n_cached_tokens = 0

return pages, n_cached_tokens

def publish_pages(self, tokens, pages) -> None:
"""
Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests.
Associates the tokens with the pages, and mark them as done writing.
It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)].
"""
if pages is None:
raise CacheAllocationFailure()

pass # the base implementation doesn't cache unfinished requests.
n_cached_tokens = 0

def release_pages(self, tokens, pages):
"""
Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction.
"""
# in the base implementation, the pages can be owned by 1 request max, so they can be instantly release
self.page_pool.release_pages(pages)
return BasePageAttentionCacheAllocation(pages, cache=self)
39 changes: 13 additions & 26 deletions shortfin/python/shortfin_apps/llm/components/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shortfin as sf
import shortfin.array as sfnp

from .kvcache.base_attention_cache import BasePagedAttentionCache
from .kvcache.base_attention_cache import BasePagedAttentionCache, PageAllocation
from .kvcache.page_pool import PageInfo


Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]):

# Cache pages that have been locked for this request.
self._cache: BasePagedAttentionCache | None = None
self.locked_pages: list[PageInfo] | None = None
self.allocation: PageAllocation | None = None

def reset(self, phase: InferencePhase):
"""Resets all per request state in preparation for an subsequent execution."""
Expand All @@ -52,35 +52,22 @@ def reset(self, phase: InferencePhase):
self.return_all_logits = False
self.return_host_array = True
self.result_logits = None
self.allocation.release_pages()
self.allocation = None

def cache_page_indices(self, max_len: int) -> list[int]:
if not self.locked_pages:
if not self.allocation:
return []
indices = [p.index for p in self.locked_pages]
if len(indices) > max_len:
return indices[0:max_len]
return indices
indices = [p.index for p in self.allocation.get_page_list()]
return indices[:max_len]

def publish_allocated_pages(self, up_to_page_index: int):
assert self.allocation
self.allocation.publish_pages(up_to_page_index)

def free_cache_pages(self):
cache = self._cache
if cache:
pages = self.locked_pages
self._cache = None
self.locked_pages = None
cache.release_pages(self.input_token_ids, pages)

def lock_initial_cache_pages(
self, cache: BasePagedAttentionCache, pages: list[PageInfo]
):
assert not self._cache
self._cache = cache
self.locked_pages = pages

def lock_new_cache_pages(
self, cache: BasePagedAttentionCache, pages: list[PageInfo]
):
assert self._cache is cache
self.locked_pages.extend(pages)
if self.allocation:
self.allocation.release_pages()


class StrobeMessage(sf.Message):
Expand Down
59 changes: 32 additions & 27 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
import shortfin as sf
import shortfin.array as sfnp

from .kvcache.base_attention_cache import BasePagedAttentionCache
from .kvcache.page_pool import PagePoolConfig, PagePool
from .kvcache.base_attention_cache import (
BasePagedAttentionCache,
CacheAllocationFailure,
PageAllocation,
)
from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo
from .config_struct import ModelParams
from .manager import SystemManager
from .messages import InferenceExecRequest, InferencePhase, StrobeMessage
Expand Down Expand Up @@ -229,16 +233,17 @@ def board_prefills(self, cache: BasePagedAttentionCache):
len(prefill_request.input_token_ids) / self.page_seq_stride
)
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(
prefill_request.input_token_ids,
extra_token_slots=0, # prefill needs no extra kvcache slots to write to
)
if pages is None:
try:
allocation = cache.acquire_pages_for_tokens(
prefill_request.input_token_ids,
extra_token_slots=0, # prefill needs no extra kvcache slots to write to
)
except CacheAllocationFailure:
logger.debug("Cannot fulfill request for %d pages", needed_pages)
continue
else:
logger.debug("Allocated %d cache pages to request", len(pages))
prefill_request.lock_initial_cache_pages(cache, pages)
logger.debug(f"Successfully acquired allocation: {allocation}")
prefill_request.free_cache_pages()
prefill_request.allocation = allocation

# Can flight this request.
exec_process.exec_requests.append(prefill_request)
Expand Down Expand Up @@ -266,26 +271,20 @@ def board_decodes(self, cache: BasePagedAttentionCache):
if len(exec_process.exec_requests) >= self.ideal_batch_size:
break
incoming_token_count = len(decode_request.input_token_ids)
needed_pages = math.ceil(
(decode_request.start_position + incoming_token_count)
/ self.page_seq_stride
)
if needed_pages > len(decode_request.locked_pages):
# allocate kv cache pages
pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens(

try:
allocation = cache.acquire_pages_for_tokens(
decode_request.input_token_ids,
extra_token_slots=1, # need 1 extra slot to write result.
)
if pages is None:
logger.debug(
"Cannot fulfill decode request for %d pages", needed_pages
)
continue
else:
logger.debug(
"Allocated %d cache pages to decode request", len(pages)
)
decode_request.lock_new_cache_pages(cache, pages)
except CacheAllocationFailure:
logger.debug(
"Cannot fulfill request for %d tokens",
len(decode_request.input_token_ids),
)

decode_request.free_cache_pages()
decode_request.allocation = allocation

# Can flight this request.
exec_process.exec_requests.append(decode_request)
Expand Down Expand Up @@ -438,6 +437,12 @@ async def run(self):
# Invoke. Logits are of shape [bs, bsl, d].
(logits,) = await fn(*args, fiber=self.fiber)

# publish cache pages
for r in self.exec_requests:
total_tokens = r.start_position + len(r.input_token_ids)
number_of_complete_pages = total_tokens // seq_stride
r.publish_allocated_pages(number_of_complete_pages)

# Return results.
for i in range(req_count):
req = self.exec_requests[i]
Expand Down

0 comments on commit 00e608d

Please sign in to comment.