From 23f09f981cd33103ff714285f5d13acb9f69b953 Mon Sep 17 00:00:00 2001 From: wxiwnd <40122078+wxiwnd@users.noreply.github.com> Date: Thu, 28 Nov 2024 20:02:00 +0800 Subject: [PATCH] FEAT: support guided decoding for vllm async engine (#2391) Signed-off-by: wxiwnd --- xinference/_compat.py | 24 ++++++++- xinference/api/restful_api.py | 28 ++++++++++ xinference/model/llm/vllm/core.py | 85 ++++++++++++++++++++++++++++++- 3 files changed, 133 insertions(+), 4 deletions(-) diff --git a/xinference/_compat.py b/xinference/_compat.py index fbb3572a59..1a781bc7ef 100644 --- a/xinference/_compat.py +++ b/xinference/_compat.py @@ -60,6 +60,10 @@ ChatCompletionStreamOptionsParam, ) from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from openai.types.shared_params.response_format_json_object import ( + ResponseFormatJSONObject, +) +from openai.types.shared_params.response_format_text import ResponseFormatText OpenAIChatCompletionStreamOptionsParam = create_model_from_typeddict( ChatCompletionStreamOptionsParam @@ -70,6 +74,23 @@ ) +class JSONSchema(BaseModel): + name: str + description: Optional[str] = None + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + strict: Optional[bool] = None + + +class ResponseFormatJSONSchema(BaseModel): + json_schema: JSONSchema + type: Literal["json_schema"] + + +ResponseFormat = Union[ + ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema +] + + class CreateChatCompletionOpenAI(BaseModel): """ Comes from source code: https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py @@ -84,8 +105,7 @@ class CreateChatCompletionOpenAI(BaseModel): n: Optional[int] parallel_tool_calls: Optional[bool] presence_penalty: Optional[float] - # we do not support this - # response_format: ResponseFormat + response_format: Optional[ResponseFormat] seed: Optional[int] service_tier: Optional[Literal["auto", "default"]] stop: Union[Optional[str], List[str]] diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index cfc191db55..c1189a7911 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -1229,6 +1229,9 @@ async def create_completion(self, request: Request) -> Response: raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude} kwargs = body.dict(exclude_unset=True, exclude=exclude) + # guided_decoding params + kwargs.update(self.extract_guided_params(raw_body=raw_body)) + # TODO: Decide if this default value override is necessary #1061 if body.max_tokens is None: kwargs["max_tokens"] = max_tokens_field.default @@ -1971,9 +1974,13 @@ async def create_chat_completion(self, request: Request) -> Response: "logit_bias_type", "user", } + raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude} kwargs = body.dict(exclude_unset=True, exclude=exclude) + # guided_decoding params + kwargs.update(self.extract_guided_params(raw_body=raw_body)) + # TODO: Decide if this default value override is necessary #1061 if body.max_tokens is None: kwargs["max_tokens"] = max_tokens_field.default @@ -2336,6 +2343,27 @@ async def abort_cluster(self) -> JSONResponse: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + @staticmethod + def extract_guided_params(raw_body: dict) -> dict: + kwargs = {} + if raw_body.get("guided_json") is not None: + kwargs["guided_json"] = raw_body.get("guided_json") + if raw_body.get("guided_regex") is not None: + kwargs["guided_regex"] = raw_body.get("guided_regex") + if raw_body.get("guided_choice") is not None: + kwargs["guided_choice"] = raw_body.get("guided_choice") + if raw_body.get("guided_grammar") is not None: + kwargs["guided_grammar"] = raw_body.get("guided_grammar") + if raw_body.get("guided_json_object") is not None: + kwargs["guided_json_object"] = raw_body.get("guided_json_object") + if raw_body.get("guided_decoding_backend") is not None: + kwargs["guided_decoding_backend"] = raw_body.get("guided_decoding_backend") + if raw_body.get("guided_whitespace_pattern") is not None: + kwargs["guided_whitespace_pattern"] = raw_body.get( + "guided_whitespace_pattern" + ) + return kwargs + def run( supervisor_address: str, diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 20e7a16c53..89e14ae496 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -69,6 +69,7 @@ class VLLMModelConfig(TypedDict, total=False): quantization: Optional[str] max_model_len: Optional[int] limit_mm_per_prompt: Optional[Dict[str, int]] + guided_decoding_backend: Optional[str] class VLLMGenerateConfig(TypedDict, total=False): @@ -85,6 +86,14 @@ class VLLMGenerateConfig(TypedDict, total=False): stop: Optional[Union[str, List[str]]] stream: bool # non-sampling param, should not be passed to the engine. stream_options: Optional[Union[dict, None]] + response_format: Optional[dict] + guided_json: Optional[Union[str, dict]] + guided_regex: Optional[str] + guided_choice: Optional[List[str]] + guided_grammar: Optional[str] + guided_json_object: Optional[bool] + guided_decoding_backend: Optional[str] + guided_whitespace_pattern: Optional[str] try: @@ -314,6 +323,7 @@ def _sanitize_model_config( model_config.setdefault("max_num_seqs", 256) model_config.setdefault("quantization", None) model_config.setdefault("max_model_len", None) + model_config.setdefault("guided_decoding_backend", "outlines") return model_config @@ -325,6 +335,22 @@ def _sanitize_generate_config( generate_config = {} sanitized = VLLMGenerateConfig() + + response_format = generate_config.pop("response_format", None) + guided_decoding_backend = generate_config.get("guided_decoding_backend", None) + guided_json_object = None + guided_json = None + + if response_format is not None: + if response_format.get("type") == "json_object": + guided_json_object = True + elif response_format.get("type") == "json_schema": + json_schema = response_format.get("json_schema") + assert json_schema is not None + guided_json = json_schema.get("json_schema") + if guided_decoding_backend is None: + guided_decoding_backend = "outlines" + sanitized.setdefault("lora_name", generate_config.get("lora_name", None)) sanitized.setdefault("n", generate_config.get("n", 1)) sanitized.setdefault("best_of", generate_config.get("best_of", None)) @@ -346,6 +372,28 @@ def _sanitize_generate_config( sanitized.setdefault( "stream_options", generate_config.get("stream_options", None) ) + sanitized.setdefault( + "guided_json", generate_config.get("guided_json", guided_json) + ) + sanitized.setdefault("guided_regex", generate_config.get("guided_regex", None)) + sanitized.setdefault( + "guided_choice", generate_config.get("guided_choice", None) + ) + sanitized.setdefault( + "guided_grammar", generate_config.get("guided_grammar", None) + ) + sanitized.setdefault( + "guided_whitespace_pattern", + generate_config.get("guided_whitespace_pattern", None), + ) + sanitized.setdefault( + "guided_json_object", + generate_config.get("guided_json_object", guided_json_object), + ) + sanitized.setdefault( + "guided_decoding_backend", + generate_config.get("guided_decoding_backend", guided_decoding_backend), + ) return sanitized @@ -483,13 +531,46 @@ async def async_generate( if isinstance(stream_options, dict) else False ) - sampling_params = SamplingParams(**sanitized_generate_config) + + if VLLM_INSTALLED and vllm.__version__ >= "0.6.3": + # guided decoding only available for vllm >= 0.6.3 + from vllm.sampling_params import GuidedDecodingParams + + guided_options = GuidedDecodingParams.from_optional( + json=sanitized_generate_config.pop("guided_json", None), + regex=sanitized_generate_config.pop("guided_regex", None), + choice=sanitized_generate_config.pop("guided_choice", None), + grammar=sanitized_generate_config.pop("guided_grammar", None), + json_object=sanitized_generate_config.pop("guided_json_object", None), + backend=sanitized_generate_config.pop("guided_decoding_backend", None), + whitespace_pattern=sanitized_generate_config.pop( + "guided_whitespace_pattern", None + ), + ) + + sampling_params = SamplingParams( + guided_decoding=guided_options, **sanitized_generate_config + ) + else: + # ignore generate configs + sanitized_generate_config.pop("guided_json", None) + sanitized_generate_config.pop("guided_regex", None) + sanitized_generate_config.pop("guided_choice", None) + sanitized_generate_config.pop("guided_grammar", None) + sanitized_generate_config.pop("guided_json_object", None) + sanitized_generate_config.pop("guided_decoding_backend", None) + sanitized_generate_config.pop("guided_whitespace_pattern", None) + sampling_params = SamplingParams(**sanitized_generate_config) + if not request_id: request_id = str(uuid.uuid1()) assert self._engine is not None results_generator = self._engine.generate( - prompt, sampling_params, request_id, lora_request=lora_request + prompt, + sampling_params, + request_id, + lora_request, ) async def stream_results() -> AsyncGenerator[CompletionChunk, None]: