Skip to content

Commit

Permalink
BUG: request_limits does not work with streaming interfaces (#2571)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengjieLi28 authored Nov 25, 2024
1 parent f2b22bb commit ddbe211
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
14 changes: 13 additions & 1 deletion xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,8 @@ async def stream_results():
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
yield dict(data=json.dumps({"error": str(ex)}))
return
finally:
await model.decrease_serve_count()

return EventSourceResponse(stream_results())
else:
Expand Down Expand Up @@ -1540,8 +1542,16 @@ async def create_speech(
**parsed_kwargs,
)
if body.stream:

async def stream_results():
try:
async for item in out:
yield item
finally:
await model.decrease_serve_count()

return EventSourceResponse(
media_type="application/octet-stream", content=out
media_type="application/octet-stream", content=stream_results()
)
else:
return Response(media_type="application/octet-stream", content=out)
Expand Down Expand Up @@ -2072,6 +2082,8 @@ async def stream_results():
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
yield dict(data=json.dumps({"error": str(ex)}))
return
finally:
await model.decrease_serve_count()

return EventSourceResponse(stream_results())
else:
Expand Down
34 changes: 22 additions & 12 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,26 @@ async def wrapped_func(self, *args, **kwargs):
logger.debug(
f"Request {fn.__name__}, current serve request count: {self._serve_count}, request limit: {self._request_limits} for the model {self.model_uid()}"
)
if self._request_limits is not None:
if 1 + self._serve_count <= self._request_limits:
self._serve_count += 1
else:
raise RuntimeError(
f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
)
if 1 + self._serve_count <= self._request_limits:
self._serve_count += 1
else:
raise RuntimeError(
f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
)
ret = None
try:
ret = await fn(self, *args, **kwargs)
finally:
if self._request_limits is not None:
if ret is not None and (
inspect.isasyncgen(ret) or inspect.isgenerator(ret)
):
# stream case, let client call model_ref to decrease self._serve_count
pass
else:
self._serve_count -= 1
logger.debug(
f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
)
logger.debug(
f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
)
return ret

return wrapped_func
Expand Down Expand Up @@ -215,7 +220,9 @@ def __init__(
self._model_description = (
model_description.to_dict() if model_description else {}
)
self._request_limits = request_limits
self._request_limits = (
float("inf") if request_limits is None else request_limits
)
self._pending_requests: asyncio.Queue = asyncio.Queue()
self._handle_pending_requests_task = None
self._lock = (
Expand Down Expand Up @@ -268,6 +275,9 @@ async def __post_create__(self):
def __repr__(self) -> str:
return f"ModelActor({self._replica_model_uid})"

def decrease_serve_count(self):
self._serve_count -= 1

async def _record_completion_metrics(
self, duration, completion_tokens, prompt_tokens
):
Expand Down

0 comments on commit ddbe211

Please sign in to comment.