Skip to content

Commit

Permalink
update parser in benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawnfz-Lenfeng committed Aug 7, 2024
1 parent 77f31a6 commit 181ad22
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
9 changes: 7 additions & 2 deletions benchmark/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
from dataclasses import dataclass, field
import time
from typing import List, Tuple
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -54,13 +54,15 @@ def __init__(
model_uid: str,
input_requests: List[Tuple[str, int, int]],
stream: bool,
api_key: Optional[str]=None,
):
self.api_url = api_url
self.model_uid = model_uid
self.input_requests = input_requests
self.outputs: List[RequestOutput] = []
self.benchmark_time = None
self.stream = stream
self.api_key = api_key

async def run(self):
await self.warm_up()
Expand Down Expand Up @@ -105,6 +107,8 @@ async def send_request(self, request: tuple, warming_up: bool = False):
}

headers = {"User-Agent": "Benchmark Client"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestOutput(prompt_len=prompt_len)
Expand Down Expand Up @@ -300,8 +304,9 @@ def __init__(
input_requests: List[Tuple[str, int, int]],
stream: bool,
concurrency: int,
api_key: Optional[str]=None,
):
super().__init__(api_url, model_uid, input_requests, stream)
super().__init__(api_url, model_uid, input_requests, stream, api_key)
self.concurrency = concurrency
self.left = len(input_requests)

Expand Down
22 changes: 18 additions & 4 deletions benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,32 @@ def __init__(
request_rate: float,
api_key: Optional[str] = None,
):
super().__init__(api_url, model_uid, input_requests, stream, concurrency, api_key)
super().__init__(
api_url,
model_uid,
input_requests,
stream,
concurrency,
api_key,
)
self.request_rate = request_rate
self.queue = asyncio.Queue(concurrency or 100)
self.left = len(input_requests)

async def _run(self):
tasks = []
for req in iter(self.input_requests):
await self.queue.put(req)

for _ in range(self.concurrency):
tasks.append(asyncio.create_task(self.worker()))

await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

async def warm_up(self, num_requests: int = 5):
logger.info(f"Enqueuing {len(self.input_requests)} requests.")
for req in iter(self.input_requests):
await self.queue.put(req)
await super().warm_up(num_requests)

async def worker(self):
"""
wait request dispatch by run(), and then send_request.
Expand Down Expand Up @@ -130,7 +141,10 @@ def main(args: argparse.Namespace):
"--prompt-len-limit", type=int, default=1024, help="Prompt length limitation."
)
parser.add_argument(
"--api-key", type=str, default=None, help="Authorization api key",
"--api-key",
type=str,
default=None,
help="Authorization api key",
)
parser.add_argument(
"--concurrency",
Expand Down

0 comments on commit 181ad22

Please sign in to comment.