Skip to content

Commit

Permalink
feat: support response_format
Browse files Browse the repository at this point in the history
Signed-off-by: wxiwnd <[email protected]>
  • Loading branch information
wxiwnd committed Oct 22, 2024
1 parent b2a255a commit 4d9e044
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
7 changes: 5 additions & 2 deletions xinference/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ChatCompletionStreamOptionsParam,
)
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from openai.types.chat.completion_create_params import ResponseFormat

OpenAIChatCompletionStreamOptionsParam = create_model_from_typeddict(
ChatCompletionStreamOptionsParam
Expand All @@ -84,8 +85,10 @@ class CreateChatCompletionOpenAI(BaseModel):
n: Optional[int]
parallel_tool_calls: Optional[bool]
presence_penalty: Optional[float]
# we do not support this
# response_format: ResponseFormat
# FIXME schema replica error in Pydantic
# source: ResponseFormatJSONSchema in ResponseFormat
# use alias
_response_format: Optional[ResponseFormat] = Field(alias="response_format")
seed: Optional[int]
service_tier: Optional[Literal["auto", "default"]]
stop: Union[Optional[str], List[str]]
Expand Down
1 change: 1 addition & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,7 @@ async def create_chat_completion(self, request: Request) -> Response:
"logit_bias",
"logit_bias_type",
"user",
"response_format", # only support completion
}

raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
Expand Down
29 changes: 28 additions & 1 deletion xinference/model/llm/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class VLLMGenerateConfig(TypedDict, total=False):
stop: Optional[Union[str, List[str]]]
stream: bool # non-sampling param, should not be passed to the engine.
stream_options: Optional[Union[dict, None]]
response_format: Optional[dict]
guided_json: Optional[Union[str, dict]]
guided_regex: Optional[str]
guided_choice: Optional[List[str]]
Expand Down Expand Up @@ -333,6 +334,22 @@ def _sanitize_generate_config(
generate_config = {}

sanitized = VLLMGenerateConfig()

response_format = generate_config.pop("response_format", None)
guided_decoding_backend = generate_config.get("guided_decoding_backend", None)
guided_json_object = None
guided_json = None

if response_format is not None:
if response_format.type == "json_object":
guided_json_object = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
guided_json = json_schema.json_schema
if guided_decoding_backend is None:
guided_decoding_backend = "outlines"

sanitized.setdefault("lora_name", generate_config.get("lora_name", None))
sanitized.setdefault("n", generate_config.get("n", 1))
sanitized.setdefault("best_of", generate_config.get("best_of", None))
Expand All @@ -354,7 +371,9 @@ def _sanitize_generate_config(
sanitized.setdefault(
"stream_options", generate_config.get("stream_options", None)
)
sanitized.setdefault("guided_json", generate_config.get("guided_json", None))
sanitized.setdefault(
"guided_json", generate_config.get("guided_json", guided_json)
)
sanitized.setdefault("guided_regex", generate_config.get("guided_regex", None))
sanitized.setdefault(
"guided_choice", generate_config.get("guided_choice", None)
Expand All @@ -366,6 +385,14 @@ def _sanitize_generate_config(
"guided_whitespace_pattern",
generate_config.get("guided_whitespace_pattern", None),
)
sanitized.setdefault(
"guided_json_object",
generate_config.get("guided_json_object", guided_json_object),
)
sanitized.setdefault(
"guided_decoding_backend",
generate_config.get("guided_decoding_backend", guided_decoding_backend),
)

return sanitized

Expand Down

0 comments on commit 4d9e044

Please sign in to comment.