From 181ad2290531f80f2d8b489ed3bf56383c14fa2b Mon Sep 17 00:00:00 2001 From: Dawnfz-Lenfeng <2912706234@qq.com> Date: Wed, 7 Aug 2024 22:41:54 +0800 Subject: [PATCH] update parser in benchmark --- benchmark/benchmark_runner.py | 9 +++++++-- benchmark/benchmark_serving.py | 22 ++++++++++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/benchmark/benchmark_runner.py b/benchmark/benchmark_runner.py index 4b1ad039f4..b4d5620abf 100644 --- a/benchmark/benchmark_runner.py +++ b/benchmark/benchmark_runner.py @@ -20,7 +20,7 @@ import logging from dataclasses import dataclass, field import time -from typing import List, Tuple +from typing import List, Optional, Tuple import numpy as np @@ -54,6 +54,7 @@ def __init__( model_uid: str, input_requests: List[Tuple[str, int, int]], stream: bool, + api_key: Optional[str]=None, ): self.api_url = api_url self.model_uid = model_uid @@ -61,6 +62,7 @@ def __init__( self.outputs: List[RequestOutput] = [] self.benchmark_time = None self.stream = stream + self.api_key = api_key async def run(self): await self.warm_up() @@ -105,6 +107,8 @@ async def send_request(self, request: tuple, warming_up: bool = False): } headers = {"User-Agent": "Benchmark Client"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: output = RequestOutput(prompt_len=prompt_len) @@ -300,8 +304,9 @@ def __init__( input_requests: List[Tuple[str, int, int]], stream: bool, concurrency: int, + api_key: Optional[str]=None, ): - super().__init__(api_url, model_uid, input_requests, stream) + super().__init__(api_url, model_uid, input_requests, stream, api_key) self.concurrency = concurrency self.left = len(input_requests) diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index 11ee393b30..5db4d066f6 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -39,21 +39,32 @@ def __init__( request_rate: float, api_key: Optional[str] = None, ): - super().__init__(api_url, model_uid, input_requests, stream, concurrency, api_key) + super().__init__( + api_url, + model_uid, + input_requests, + stream, + concurrency, + api_key, + ) self.request_rate = request_rate self.queue = asyncio.Queue(concurrency or 100) self.left = len(input_requests) async def _run(self): tasks = [] - for req in iter(self.input_requests): - await self.queue.put(req) for _ in range(self.concurrency): tasks.append(asyncio.create_task(self.worker())) await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + async def warm_up(self, num_requests: int = 5): + logger.info(f"Enqueuing {len(self.input_requests)} requests.") + for req in iter(self.input_requests): + await self.queue.put(req) + await super().warm_up(num_requests) + async def worker(self): """ wait request dispatch by run(), and then send_request. @@ -130,7 +141,10 @@ def main(args: argparse.Namespace): "--prompt-len-limit", type=int, default=1024, help="Prompt length limitation." ) parser.add_argument( - "--api-key", type=str, default=None, help="Authorization api key", + "--api-key", + type=str, + default=None, + help="Authorization api key", ) parser.add_argument( "--concurrency",