Skip to content

Commit

Permalink
ENH: Safe iterate stream of ggml model (#449)
Browse files Browse the repository at this point in the history
Co-authored-by: Uranus <[email protected]>
Co-authored-by: aresnow <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2023
1 parent 15cd024 commit 5f2078b
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 87 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,8 @@ jobs:
run: |
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/tests/test_client.py
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/tests/test_client.py xinference
working-directory: .
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ install_requires =
pydantic<2
fastapi
uvicorn
sse_starlette
huggingface-hub>=0.14.1,<1.0
typing_extensions
fsspec
Expand Down
42 changes: 24 additions & 18 deletions xinference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ def chat(


def streaming_response_iterator(
response_lines: Iterator[bytes],
response_chunk: Iterator[bytes],
) -> Iterator["CompletionChunk"]:
"""
Create an Iterator to handle the streaming type of generation.
Parameters
----------
response_lines: Iterator[bytes]
response_chunk: Iterator[bytes]
Generated lines by the Model Generator.
Returns
Expand All @@ -199,24 +199,25 @@ def streaming_response_iterator(
"""

for line in response_lines:
line = line.strip()
if line.startswith(b"data:"):
data = json.loads(line.decode("utf-8").replace("data: ", "", 1))
yield data
for chunk in response_chunk:
content = json.loads(chunk.decode("utf-8"))
error = content.get("error", None)
if error is not None:
raise Exception(str(error))
yield content


# Duplicate code due to type hint issues
def chat_streaming_response_iterator(
response_lines: Iterator[bytes],
response_chunk: Iterator[bytes],
) -> Iterator["ChatCompletionChunk"]:
"""
Create an Iterator to handle the streaming type of generation.
Parameters
----------
response_lines: Iterator[bytes]
Generated lines by the Model Generator.
response_chunk: Iterator[bytes]
Generated chunk by the Model Generator.
Returns
-------
Expand All @@ -225,11 +226,12 @@ def chat_streaming_response_iterator(
"""

for line in response_lines:
line = line.strip()
if line.startswith(b"data:"):
data = json.loads(line.decode("utf-8").replace("data: ", "", 1))
yield data
for chunk in response_chunk:
content = json.loads(chunk.decode("utf-8"))
error = content.get("error", None)
if error is not None:
raise Exception(str(error))
yield content


class RESTfulModelHandle:
Expand Down Expand Up @@ -327,7 +329,7 @@ def generate(
)

if stream:
return streaming_response_iterator(response.iter_lines())
return streaming_response_iterator(response.iter_content(chunk_size=None))

response_data = response.json()
return response_data
Expand Down Expand Up @@ -405,7 +407,9 @@ def chat(
)

if stream:
return chat_streaming_response_iterator(response.iter_lines())
return chat_streaming_response_iterator(
response.iter_content(chunk_size=None)
)

response_data = response.json()
return response_data
Expand Down Expand Up @@ -469,7 +473,9 @@ def chat(
)

if stream:
return chat_streaming_response_iterator(response.iter_lines())
return chat_streaming_response_iterator(
response.iter_content(chunk_size=None)
)

response_data = response.json()
return response_data
Expand Down
2 changes: 2 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def load(self):

async def _wrap_generator(self, ret: Any):
if inspect.isgenerator(ret) or inspect.isasyncgen(ret):
if self._lock is not None and self._generators:
raise Exception("Parallel generation is not supported by ggml.")
generator_uid = str(uuid.uuid1())
self._generators[generator_uid] = ret

Expand Down
86 changes: 26 additions & 60 deletions xinference/core/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,15 @@
import sys
import threading
import warnings
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Union

import anyio
import gradio as gr
import xoscar as xo
from anyio.streams.memory import MemoryObjectSendStream
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from starlette.responses import RedirectResponse
from typing_extensions import NotRequired, TypedDict
from uvicorn import Config, Server
Expand Down Expand Up @@ -517,32 +513,17 @@ async def create_completion(self, request: Request, body: CreateCompletionReques
raise HTTPException(status_code=500, detail=str(e))

if body.stream:
# create a pair of memory object streams
send_chan, recv_chan = anyio.create_memory_object_stream(10)

async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
try:
iterator = await model.generate(body.prompt, kwargs)
async for chunk in iterator:
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
except anyio.get_cancelled_exc_class() as e:
logger.warning("disconnected")
with anyio.move_on_after(1, shield=True):
logger.warning(
f"Disconnected from client (via refresh/close) {request.client}"
)
await inner_send_chan.send(dict(closing=True))
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
)

async def stream_results():
try:
iterator = await model.generate(body.prompt, kwargs)
async for item in iterator:
yield json.dumps(item)
except Exception as ex:
logger.exception("Completion stream got an error: %s", ex)
yield json.dumps({"error": str(ex)})

return StreamingResponse(stream_results())
else:
try:
return await model.generate(body.prompt, kwargs)
Expand Down Expand Up @@ -640,37 +621,22 @@ async def create_chat_completion(
)

if body.stream:
# create a pair of memory object streams
send_chan, recv_chan = anyio.create_memory_object_stream(10)

async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
try:
if is_chatglm_ggml:
iterator = await model.chat(prompt, chat_history, kwargs)
else:
iterator = await model.chat(
prompt, system_prompt, chat_history, kwargs
)
async for chunk in iterator:
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
except anyio.get_cancelled_exc_class() as e:
logger.warning("disconnected")
with anyio.move_on_after(1, shield=True):
logger.warning(
f"Disconnected from client (via refresh/close) {request.client}"
)
await inner_send_chan.send(dict(closing=True))
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
)

async def stream_results():
try:
if is_chatglm_ggml:
iterator = await model.chat(prompt, chat_history, kwargs)
else:
iterator = await model.chat(
prompt, system_prompt, chat_history, kwargs
)
async for item in iterator:
yield json.dumps(item)
except Exception as ex:
logger.exception("Chat completion stream got an error: %s", ex)
yield json.dumps({"error": str(ex)})

return StreamingResponse(stream_results())
else:
try:
if is_chatglm_ggml:
Expand Down
42 changes: 35 additions & 7 deletions xinference/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,33 @@ def test_RESTful_client(setup):
completion = model.chat("What is the capital of France?")
assert "content" in completion["choices"][0]["message"]

streaming_response = model.chat(
prompt="What is the capital of France?", generate_config={"stream": True}
)
def _check_stream():
streaming_response = model.chat(
prompt="What is the capital of France?",
generate_config={"stream": True, "max_tokens": 5},
)
for chunk in streaming_response:
assert "content" or "role" in chunk["choices"][0]["delta"]

_check_stream()

results = []
with ThreadPoolExecutor() as executor:
for _ in range(2):
r = executor.submit(_check_stream)
results.append(r)
# Parallel generation is not supported by ggml.
error_count = 0
for r in results:
try:
r.result()
except Exception as ex:
assert "Parallel generation" in str(ex)
error_count += 1
assert error_count == 1

for chunk in streaming_response:
assert "content" or "role" in chunk["choices"][0]["delta"]
# After iteration finish, we can iterate again.
_check_stream()

client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0
Expand All @@ -251,8 +272,9 @@ def test_RESTful_client(setup):
assert len(client.list_models()) == 1

# Test concurrent chat is OK.
model = client.get_model(model_uid=model_uid)

def _check(stream=False):
model = client.get_model(model_uid=model_uid)
completion = model.generate(
"AI is going to", generate_config={"stream": stream, "max_tokens": 5}
)
Expand All @@ -266,12 +288,18 @@ def _check(stream=False):

for stream in [True, False]:
results = []
error_count = 0
with ThreadPoolExecutor() as executor:
for _ in range(3):
r = executor.submit(_check, stream=stream)
results.append(r)
for r in results:
r.result()
try:
r.result()
except Exception as ex:
assert "Parallel generation" in str(ex)
error_count += 1
assert error_count == (2 if stream else 0)

client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0
Expand Down

0 comments on commit 5f2078b

Please sign in to comment.