Skip to content

Commit

Permalink
prettify the other part too
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Dec 2, 2024
1 parent 0c2d58d commit 8126e38
Showing 1 changed file with 33 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,43 +47,51 @@ def cache(page_pool):
return BasePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE)


# fmt: off
@pytest.mark.parametrize(
"tokens,expected_pages,test_name",
[
([], 0, "empty_token_list"),
(list(range(TEST_PAGE_SIZE // 2)), 1, "partial_page"),
(list(range(TEST_PAGE_SIZE)), 1, "exact_page"),
(list(range(TEST_PAGE_SIZE + 1)), 2, "just_over_one_page"),
(list(range(TEST_PAGE_SIZE * 2)), 2, "multiple_exact_pages"),
(list(range(TEST_PAGE_SIZE * 2 + 1)), 3, "multiple_pages_with_remainder"),
(list(range(TEST_PAGE_SIZE * 3)), 3, "three_exact_pages"),
(list(range(1)), 1, "single_token"),
(list(range(TEST_PAGE_SIZE - 1)), 1, "almost_full_page"),
],
"tokens,expected_pages,test_name",
[ # Tokens Pages Name
([], 0, "empty_token_list"),
(list(range(TEST_PAGE_SIZE // 2)), 1, "partial_page"),
(list(range(TEST_PAGE_SIZE)), 1, "exact_page"),
(list(range(TEST_PAGE_SIZE + 1)), 2, "just_over_one_page"),
(list(range(TEST_PAGE_SIZE * 2)), 2, "multiple_exact_pages"),
(list(range(TEST_PAGE_SIZE * 2 + 1)), 3, "multiple_pages_with_remainder"),
(list(range(TEST_PAGE_SIZE * 3)), 3, "three_exact_pages"),
(list(range(1)), 1, "single_token"),
(list(range(TEST_PAGE_SIZE - 1)), 1, "almost_full_page"),
],
)
# fmt: on
def test_allocation_sizes(cache, tokens, expected_pages, test_name):
allocation = cache.acquire_pages_for_tokens(tokens)
pages = allocation.pages
assert len(pages) == expected_pages, f"Failed for case: {test_name}"
allocation.release_pages()


# fmt: off
@pytest.mark.parametrize(
"num_workers,pages_per_worker,expect_failure",
[
(2, 1, False), # Basic concurrent access
(5, 1, False), # Higher concurrency, single page
(3, 2, False), # Multiple pages per worker
(2, 3, False), # More pages than workers, but within capacity
(TEST_POOL_CAPACITY, 1, False), # Max capacity single pages
(TEST_POOL_CAPACITY // 2, 2, False), # Max capacity multiple pages
(4, 3, True), # 12 pages needed, exceeds capacity
(TEST_POOL_CAPACITY + 1, 1, True), # More workers than capacity
(TEST_POOL_CAPACITY // 2, 3, True), # Exceeds capacity with multiple pages
],
"num_workers,pages_per_worker,expect_failure,case_name",
[ # Workers Pages Failure Case name
(2, 1, False, "basic_concurrent"), # Basic concurrent access
(5, 1, False, "high_concurrency"), # Higher concurrency, single page
(3, 2, False, "multi_page"), # Multiple pages per worker
(2, 3, False, "more_pages"), # More pages than workers, within capacity
(TEST_POOL_CAPACITY, 1, False, "max_capacity"), # Max capacity single pages
(TEST_POOL_CAPACITY // 2, 2, False, "max_capacity_multi"), # Max capacity multiple pages
(4, 3, True , "exceeds_total"), # 12 pages needed, exceeds capacity
(TEST_POOL_CAPACITY + 1, 1, True , "exceeds_workers"), # More workers than capacity
(TEST_POOL_CAPACITY // 2, 3, True , "exceeds_with_multi"), # Exceeds capacity with multiple pages
],
)
# fmt: on
def test_concurrent_page_allocation(
cache, num_workers, pages_per_worker, expect_failure
cache,
num_workers,
pages_per_worker,
expect_failure,
case_name,
):
allocated_pages = defaultdict(set)
errors = []
Expand Down

0 comments on commit 8126e38

Please sign in to comment.