Skip to content

Commit

Permalink
FEAT: support guided decoding for vllm async engine (#2391)
Browse files Browse the repository at this point in the history
Signed-off-by: wxiwnd <[email protected]>
  • Loading branch information
wxiwnd authored Nov 28, 2024
1 parent 0d4cb9c commit 23f09f9
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 4 deletions.
24 changes: 22 additions & 2 deletions xinference/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@
ChatCompletionStreamOptionsParam,
)
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from openai.types.shared_params.response_format_json_object import (
ResponseFormatJSONObject,
)
from openai.types.shared_params.response_format_text import ResponseFormatText

OpenAIChatCompletionStreamOptionsParam = create_model_from_typeddict(
ChatCompletionStreamOptionsParam
Expand All @@ -70,6 +74,23 @@
)


class JSONSchema(BaseModel):
name: str
description: Optional[str] = None
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
strict: Optional[bool] = None


class ResponseFormatJSONSchema(BaseModel):
json_schema: JSONSchema
type: Literal["json_schema"]


ResponseFormat = Union[
ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema
]


class CreateChatCompletionOpenAI(BaseModel):
"""
Comes from source code: https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py
Expand All @@ -84,8 +105,7 @@ class CreateChatCompletionOpenAI(BaseModel):
n: Optional[int]
parallel_tool_calls: Optional[bool]
presence_penalty: Optional[float]
# we do not support this
# response_format: ResponseFormat
response_format: Optional[ResponseFormat]
seed: Optional[int]
service_tier: Optional[Literal["auto", "default"]]
stop: Union[Optional[str], List[str]]
Expand Down
28 changes: 28 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,9 @@ async def create_completion(self, request: Request) -> Response:
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
kwargs = body.dict(exclude_unset=True, exclude=exclude)

# guided_decoding params
kwargs.update(self.extract_guided_params(raw_body=raw_body))

# TODO: Decide if this default value override is necessary #1061
if body.max_tokens is None:
kwargs["max_tokens"] = max_tokens_field.default
Expand Down Expand Up @@ -1971,9 +1974,13 @@ async def create_chat_completion(self, request: Request) -> Response:
"logit_bias_type",
"user",
}

raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
kwargs = body.dict(exclude_unset=True, exclude=exclude)

# guided_decoding params
kwargs.update(self.extract_guided_params(raw_body=raw_body))

# TODO: Decide if this default value override is necessary #1061
if body.max_tokens is None:
kwargs["max_tokens"] = max_tokens_field.default
Expand Down Expand Up @@ -2336,6 +2343,27 @@ async def abort_cluster(self) -> JSONResponse:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

@staticmethod
def extract_guided_params(raw_body: dict) -> dict:
kwargs = {}
if raw_body.get("guided_json") is not None:
kwargs["guided_json"] = raw_body.get("guided_json")
if raw_body.get("guided_regex") is not None:
kwargs["guided_regex"] = raw_body.get("guided_regex")
if raw_body.get("guided_choice") is not None:
kwargs["guided_choice"] = raw_body.get("guided_choice")
if raw_body.get("guided_grammar") is not None:
kwargs["guided_grammar"] = raw_body.get("guided_grammar")
if raw_body.get("guided_json_object") is not None:
kwargs["guided_json_object"] = raw_body.get("guided_json_object")
if raw_body.get("guided_decoding_backend") is not None:
kwargs["guided_decoding_backend"] = raw_body.get("guided_decoding_backend")
if raw_body.get("guided_whitespace_pattern") is not None:
kwargs["guided_whitespace_pattern"] = raw_body.get(
"guided_whitespace_pattern"
)
return kwargs


def run(
supervisor_address: str,
Expand Down
85 changes: 83 additions & 2 deletions xinference/model/llm/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class VLLMModelConfig(TypedDict, total=False):
quantization: Optional[str]
max_model_len: Optional[int]
limit_mm_per_prompt: Optional[Dict[str, int]]
guided_decoding_backend: Optional[str]


class VLLMGenerateConfig(TypedDict, total=False):
Expand All @@ -85,6 +86,14 @@ class VLLMGenerateConfig(TypedDict, total=False):
stop: Optional[Union[str, List[str]]]
stream: bool # non-sampling param, should not be passed to the engine.
stream_options: Optional[Union[dict, None]]
response_format: Optional[dict]
guided_json: Optional[Union[str, dict]]
guided_regex: Optional[str]
guided_choice: Optional[List[str]]
guided_grammar: Optional[str]
guided_json_object: Optional[bool]
guided_decoding_backend: Optional[str]
guided_whitespace_pattern: Optional[str]


try:
Expand Down Expand Up @@ -314,6 +323,7 @@ def _sanitize_model_config(
model_config.setdefault("max_num_seqs", 256)
model_config.setdefault("quantization", None)
model_config.setdefault("max_model_len", None)
model_config.setdefault("guided_decoding_backend", "outlines")

return model_config

Expand All @@ -325,6 +335,22 @@ def _sanitize_generate_config(
generate_config = {}

sanitized = VLLMGenerateConfig()

response_format = generate_config.pop("response_format", None)
guided_decoding_backend = generate_config.get("guided_decoding_backend", None)
guided_json_object = None
guided_json = None

if response_format is not None:
if response_format.get("type") == "json_object":
guided_json_object = True
elif response_format.get("type") == "json_schema":
json_schema = response_format.get("json_schema")
assert json_schema is not None
guided_json = json_schema.get("json_schema")
if guided_decoding_backend is None:
guided_decoding_backend = "outlines"

sanitized.setdefault("lora_name", generate_config.get("lora_name", None))
sanitized.setdefault("n", generate_config.get("n", 1))
sanitized.setdefault("best_of", generate_config.get("best_of", None))
Expand All @@ -346,6 +372,28 @@ def _sanitize_generate_config(
sanitized.setdefault(
"stream_options", generate_config.get("stream_options", None)
)
sanitized.setdefault(
"guided_json", generate_config.get("guided_json", guided_json)
)
sanitized.setdefault("guided_regex", generate_config.get("guided_regex", None))
sanitized.setdefault(
"guided_choice", generate_config.get("guided_choice", None)
)
sanitized.setdefault(
"guided_grammar", generate_config.get("guided_grammar", None)
)
sanitized.setdefault(
"guided_whitespace_pattern",
generate_config.get("guided_whitespace_pattern", None),
)
sanitized.setdefault(
"guided_json_object",
generate_config.get("guided_json_object", guided_json_object),
)
sanitized.setdefault(
"guided_decoding_backend",
generate_config.get("guided_decoding_backend", guided_decoding_backend),
)

return sanitized

Expand Down Expand Up @@ -483,13 +531,46 @@ async def async_generate(
if isinstance(stream_options, dict)
else False
)
sampling_params = SamplingParams(**sanitized_generate_config)

if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
# guided decoding only available for vllm >= 0.6.3
from vllm.sampling_params import GuidedDecodingParams

guided_options = GuidedDecodingParams.from_optional(
json=sanitized_generate_config.pop("guided_json", None),
regex=sanitized_generate_config.pop("guided_regex", None),
choice=sanitized_generate_config.pop("guided_choice", None),
grammar=sanitized_generate_config.pop("guided_grammar", None),
json_object=sanitized_generate_config.pop("guided_json_object", None),
backend=sanitized_generate_config.pop("guided_decoding_backend", None),
whitespace_pattern=sanitized_generate_config.pop(
"guided_whitespace_pattern", None
),
)

sampling_params = SamplingParams(
guided_decoding=guided_options, **sanitized_generate_config
)
else:
# ignore generate configs
sanitized_generate_config.pop("guided_json", None)
sanitized_generate_config.pop("guided_regex", None)
sanitized_generate_config.pop("guided_choice", None)
sanitized_generate_config.pop("guided_grammar", None)
sanitized_generate_config.pop("guided_json_object", None)
sanitized_generate_config.pop("guided_decoding_backend", None)
sanitized_generate_config.pop("guided_whitespace_pattern", None)
sampling_params = SamplingParams(**sanitized_generate_config)

if not request_id:
request_id = str(uuid.uuid1())

assert self._engine is not None
results_generator = self._engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
prompt,
sampling_params,
request_id,
lora_request,
)

async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
Expand Down

0 comments on commit 23f09f9

Please sign in to comment.