diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 52e9b5dc53..1201ca58bf 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -105,5 +105,8 @@ jobs: run: | pytest --timeout=1500 \ -W ignore::PendingDeprecationWarning \ - --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference + --cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/tests/test_client.py + pytest --timeout=1500 \ + -W ignore::PendingDeprecationWarning \ + --cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/tests/test_client.py xinference working-directory: . diff --git a/setup.cfg b/setup.cfg index c3ed24b27e..74e2148c70 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,6 @@ install_requires = pydantic<2 fastapi uvicorn - sse_starlette huggingface-hub>=0.14.1,<1.0 typing_extensions fsspec diff --git a/xinference/client.py b/xinference/client.py index a69579f52d..04af40dcd8 100644 --- a/xinference/client.py +++ b/xinference/client.py @@ -182,14 +182,14 @@ def chat( def streaming_response_iterator( - response_lines: Iterator[bytes], + response_chunk: Iterator[bytes], ) -> Iterator["CompletionChunk"]: """ Create an Iterator to handle the streaming type of generation. Parameters ---------- - response_lines: Iterator[bytes] + response_chunk: Iterator[bytes] Generated lines by the Model Generator. Returns @@ -199,24 +199,25 @@ def streaming_response_iterator( """ - for line in response_lines: - line = line.strip() - if line.startswith(b"data:"): - data = json.loads(line.decode("utf-8").replace("data: ", "", 1)) - yield data + for chunk in response_chunk: + content = json.loads(chunk.decode("utf-8")) + error = content.get("error", None) + if error is not None: + raise Exception(str(error)) + yield content # Duplicate code due to type hint issues def chat_streaming_response_iterator( - response_lines: Iterator[bytes], + response_chunk: Iterator[bytes], ) -> Iterator["ChatCompletionChunk"]: """ Create an Iterator to handle the streaming type of generation. Parameters ---------- - response_lines: Iterator[bytes] - Generated lines by the Model Generator. + response_chunk: Iterator[bytes] + Generated chunk by the Model Generator. Returns ------- @@ -225,11 +226,12 @@ def chat_streaming_response_iterator( """ - for line in response_lines: - line = line.strip() - if line.startswith(b"data:"): - data = json.loads(line.decode("utf-8").replace("data: ", "", 1)) - yield data + for chunk in response_chunk: + content = json.loads(chunk.decode("utf-8")) + error = content.get("error", None) + if error is not None: + raise Exception(str(error)) + yield content class RESTfulModelHandle: @@ -327,7 +329,7 @@ def generate( ) if stream: - return streaming_response_iterator(response.iter_lines()) + return streaming_response_iterator(response.iter_content(chunk_size=None)) response_data = response.json() return response_data @@ -405,7 +407,9 @@ def chat( ) if stream: - return chat_streaming_response_iterator(response.iter_lines()) + return chat_streaming_response_iterator( + response.iter_content(chunk_size=None) + ) response_data = response.json() return response_data @@ -469,7 +473,9 @@ def chat( ) if stream: - return chat_streaming_response_iterator(response.iter_lines()) + return chat_streaming_response_iterator( + response.iter_content(chunk_size=None) + ) response_data = response.json() return response_data diff --git a/xinference/core/model.py b/xinference/core/model.py index 165e916820..7d9d0b37b9 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -118,6 +118,8 @@ def load(self): async def _wrap_generator(self, ret: Any): if inspect.isgenerator(ret) or inspect.isasyncgen(ret): + if self._lock is not None and self._generators: + raise Exception("Parallel generation is not supported by ggml.") generator_uid = str(uuid.uuid1()) self._generators[generator_uid] = ret diff --git a/xinference/core/restful_api.py b/xinference/core/restful_api.py index 9d2e55fdfa..a2569de5c0 100644 --- a/xinference/core/restful_api.py +++ b/xinference/core/restful_api.py @@ -19,19 +19,15 @@ import sys import threading import warnings -from functools import partial from typing import Any, Dict, List, Literal, Optional, Union -import anyio import gradio as gr import xoscar as xo -from anyio.streams.memory import MemoryObjectSendStream from fastapi import APIRouter, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field -from sse_starlette.sse import EventSourceResponse from starlette.responses import RedirectResponse from typing_extensions import NotRequired, TypedDict from uvicorn import Config, Server @@ -517,32 +513,17 @@ async def create_completion(self, request: Request, body: CreateCompletionReques raise HTTPException(status_code=500, detail=str(e)) if body.stream: - # create a pair of memory object streams - send_chan, recv_chan = anyio.create_memory_object_stream(10) - - async def event_publisher(inner_send_chan: MemoryObjectSendStream): - async with inner_send_chan: - try: - iterator = await model.generate(body.prompt, kwargs) - async for chunk in iterator: - await inner_send_chan.send(dict(data=json.dumps(chunk))) - if await request.is_disconnected(): - raise anyio.get_cancelled_exc_class()() - except anyio.get_cancelled_exc_class() as e: - logger.warning("disconnected") - with anyio.move_on_after(1, shield=True): - logger.warning( - f"Disconnected from client (via refresh/close) {request.client}" - ) - await inner_send_chan.send(dict(closing=True)) - raise e - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - return EventSourceResponse( - recv_chan, data_sender_callable=partial(event_publisher, send_chan) - ) + async def stream_results(): + try: + iterator = await model.generate(body.prompt, kwargs) + async for item in iterator: + yield json.dumps(item) + except Exception as ex: + logger.exception("Completion stream got an error: %s", ex) + yield json.dumps({"error": str(ex)}) + + return StreamingResponse(stream_results()) else: try: return await model.generate(body.prompt, kwargs) @@ -640,37 +621,22 @@ async def create_chat_completion( ) if body.stream: - # create a pair of memory object streams - send_chan, recv_chan = anyio.create_memory_object_stream(10) - - async def event_publisher(inner_send_chan: MemoryObjectSendStream): - async with inner_send_chan: - try: - if is_chatglm_ggml: - iterator = await model.chat(prompt, chat_history, kwargs) - else: - iterator = await model.chat( - prompt, system_prompt, chat_history, kwargs - ) - async for chunk in iterator: - await inner_send_chan.send(dict(data=json.dumps(chunk))) - if await request.is_disconnected(): - raise anyio.get_cancelled_exc_class()() - except anyio.get_cancelled_exc_class() as e: - logger.warning("disconnected") - with anyio.move_on_after(1, shield=True): - logger.warning( - f"Disconnected from client (via refresh/close) {request.client}" - ) - await inner_send_chan.send(dict(closing=True)) - raise e - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - return EventSourceResponse( - recv_chan, data_sender_callable=partial(event_publisher, send_chan) - ) + async def stream_results(): + try: + if is_chatglm_ggml: + iterator = await model.chat(prompt, chat_history, kwargs) + else: + iterator = await model.chat( + prompt, system_prompt, chat_history, kwargs + ) + async for item in iterator: + yield json.dumps(item) + except Exception as ex: + logger.exception("Chat completion stream got an error: %s", ex) + yield json.dumps({"error": str(ex)}) + + return StreamingResponse(stream_results()) else: try: if is_chatglm_ggml: diff --git a/xinference/tests/test_client.py b/xinference/tests/test_client.py index 885d75e100..2ef6593b7d 100644 --- a/xinference/tests/test_client.py +++ b/xinference/tests/test_client.py @@ -232,12 +232,33 @@ def test_RESTful_client(setup): completion = model.chat("What is the capital of France?") assert "content" in completion["choices"][0]["message"] - streaming_response = model.chat( - prompt="What is the capital of France?", generate_config={"stream": True} - ) + def _check_stream(): + streaming_response = model.chat( + prompt="What is the capital of France?", + generate_config={"stream": True, "max_tokens": 5}, + ) + for chunk in streaming_response: + assert "content" or "role" in chunk["choices"][0]["delta"] + + _check_stream() + + results = [] + with ThreadPoolExecutor() as executor: + for _ in range(2): + r = executor.submit(_check_stream) + results.append(r) + # Parallel generation is not supported by ggml. + error_count = 0 + for r in results: + try: + r.result() + except Exception as ex: + assert "Parallel generation" in str(ex) + error_count += 1 + assert error_count == 1 - for chunk in streaming_response: - assert "content" or "role" in chunk["choices"][0]["delta"] + # After iteration finish, we can iterate again. + _check_stream() client.terminate_model(model_uid=model_uid) assert len(client.list_models()) == 0 @@ -251,8 +272,9 @@ def test_RESTful_client(setup): assert len(client.list_models()) == 1 # Test concurrent chat is OK. + model = client.get_model(model_uid=model_uid) + def _check(stream=False): - model = client.get_model(model_uid=model_uid) completion = model.generate( "AI is going to", generate_config={"stream": stream, "max_tokens": 5} ) @@ -266,12 +288,18 @@ def _check(stream=False): for stream in [True, False]: results = [] + error_count = 0 with ThreadPoolExecutor() as executor: for _ in range(3): r = executor.submit(_check, stream=stream) results.append(r) for r in results: - r.result() + try: + r.result() + except Exception as ex: + assert "Parallel generation" in str(ex) + error_count += 1 + assert error_count == (2 if stream else 0) client.terminate_model(model_uid=model_uid) assert len(client.list_models()) == 0