Skip to content

Commit

Permalink
avoid magic numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Dec 2, 2024
1 parent 8f3f049 commit abf2cb2
Showing 1 changed file with 15 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
)
from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PageInfo

TEST_PAGE_SIZE = 16
TEST_POOL_CAPACITY = 10


class MockPagePool(PagePool):
def __init__(self, total_pages: int = 100):
def __init__(self, total_pages: int):
self._queue = queue.Queue()

for i in range(total_pages):
Expand All @@ -36,33 +39,33 @@ def release_pages(self, pages):

@pytest.fixture
def page_pool():
return MockPagePool(total_pages=10)
return MockPagePool(total_pages=TEST_POOL_CAPACITY)


@pytest.fixture
def cache(page_pool):
return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=16)
return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE)


@pytest.fixture
def page_pool():
return MockPagePool(total_pages=10)
return MockPagePool(total_pages=TEST_POOL_CAPACITY)


@pytest.fixture
def cache(page_pool):
"""Create cache with 16 tokens per page"""
return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=16)
"""Create cache with TEST_PAGE_SIZE tokens per page"""
return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE)


def test_allocation_sizes(cache):
test_cases = [
([], 0), # Empty token list
(list(range(8)), 1), # Partial page
(list(range(16)), 1), # Exact page
(list(range(17)), 2), # Just over one page
(list(range(32)), 2), # Multiple exact pages
(list(range(33)), 3), # Multiple pages with remainder
(list(range(TEST_PAGE_SIZE // 2)), 1), # Partial page
(list(range(TEST_PAGE_SIZE)), 1), # Exact page
(list(range(TEST_PAGE_SIZE + 1)), 2), # Just over one page
(list(range(TEST_PAGE_SIZE * 2)), 2), # Multiple exact pages
(list(range(TEST_PAGE_SIZE * 2 + 1)), 3), # Multiple pages with remainder
]

for tokens, expected_pages in test_cases:
Expand All @@ -74,7 +77,7 @@ def test_allocation_sizes(cache):

def test_concurrent_access(cache):
def worker(results: List):
allocation = cache.acquire_pages_for_tokens(list(range(16)))
allocation = cache.acquire_pages_for_tokens(list(range(TEST_PAGE_SIZE)))
results.append(len(allocation.pages))
allocation.release_pages()

Expand Down

0 comments on commit abf2cb2

Please sign in to comment.