Skip to content

Commit

Permalink
Support structured output (#3732)
Browse files Browse the repository at this point in the history
* Support structured output

* use ruff format

* add type checking for cookbook

* add the notebook to index.md

* fix the type error

* pass response_format explicitly

* remove casting

* ensure type are correct

* seperate response_format arg

* fix type and resolve pyright errors

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
lordlinus and ekzhu authored Oct 13, 2024
1 parent 43ccc81 commit a106229
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ local-llms-ollama-litellm
instrumenting
topic-subscription-scenarios
azure-container-code-executor
```
structured-output-agent
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Strcutured output using GPT-4o models\n",
"\n",
"This cookbook demonstrates how to obtain structured output using GPT-4o models. The OpenAI beta client SDK provides a parse helper that allows you to use your own Pydantic model, eliminating the need to define a JSON schema. This approach is recommended for supported models.\n",
"\n",
"Currently, this feature is supported for:\n",
"\n",
"- gpt-4o-mini on OpenAI\n",
"- gpt-4o-2024-08-06 on OpenAI\n",
"- gpt-4o-2024-08-06 on Azure"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a simple message type that carries explanation and output for a Math problem"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel\n",
"\n",
"\n",
"class MathReasoning(BaseModel):\n",
" class Step(BaseModel):\n",
" explanation: str\n",
" output: str\n",
"\n",
" steps: list[Step]\n",
" final_answer: str"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Set the environment variable\n",
"os.environ[\"AZURE_OPENAI_ENDPOINT\"] = \"https://YOUR_ENDPOINT_DETAILS.openai.azure.com/\"\n",
"os.environ[\"AZURE_OPENAI_API_KEY\"] = \"YOUR_API_KEY\"\n",
"os.environ[\"AZURE_OPENAI_DEPLOYMENT_NAME\"] = \"gpt-4o-2024-08-06\"\n",
"os.environ[\"AZURE_OPENAI_API_VERSION\"] = \"2024-08-01-preview\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from typing import Optional\n",
"\n",
"from autogen_core.components.models import AzureOpenAIChatCompletionClient, UserMessage\n",
"\n",
"\n",
"# Function to get environment variable and ensure it is not None\n",
"def get_env_variable(name: str) -> str:\n",
" value = os.getenv(name)\n",
" if value is None:\n",
" raise ValueError(f\"Environment variable {name} is not set\")\n",
" return value\n",
"\n",
"\n",
"# Create the client with type-checked environment variables\n",
"client = AzureOpenAIChatCompletionClient(\n",
" model=get_env_variable(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),\n",
" api_version=get_env_variable(\"AZURE_OPENAI_API_VERSION\"),\n",
" azure_endpoint=get_env_variable(\"AZURE_OPENAI_ENDPOINT\"),\n",
" api_key=get_env_variable(\"AZURE_OPENAI_API_KEY\"),\n",
" model_capabilities={\n",
" \"vision\": False,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'steps': [{'explanation': 'Start by aligning the numbers vertically.', 'output': '\\n 16\\n+ 32'}, {'explanation': 'Add the units digits: 6 + 2 = 8.', 'output': '\\n 16\\n+ 32\\n 8'}, {'explanation': 'Add the tens digits: 1 + 3 = 4.', 'output': '\\n 16\\n+ 32\\n 48'}], 'final_answer': '48'}\n"
]
},
{
"data": {
"text/plain": [
"MathReasoning(steps=[Step(explanation='Start by aligning the numbers vertically.', output='\\n 16\\n+ 32'), Step(explanation='Add the units digits: 6 + 2 = 8.', output='\\n 16\\n+ 32\\n 8'), Step(explanation='Add the tens digits: 1 + 3 = 4.', output='\\n 16\\n+ 32\\n 48')], final_answer='48')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Define the user message\n",
"messages = [\n",
" UserMessage(content=\"What is 16 + 32?\", source=\"user\"),\n",
"]\n",
"\n",
"# Call the create method on the client, passing the messages and additional arguments\n",
"# The extra_create_args dictionary includes the response format as MathReasoning model we defined above\n",
"# Providing the response format and pydantic model will use the new parse method from beta SDK\n",
"response = await client.create(messages=messages, extra_create_args={\"response_format\": MathReasoning})\n",
"\n",
"# Ensure the response content is a valid JSON string before loading it\n",
"response_content: Optional[str] = response.content if isinstance(response.content, str) else None\n",
"if response_content is None:\n",
" raise ValueError(\"Response content is not a valid JSON string\")\n",
"\n",
"# Print the response content after loading it as JSON\n",
"print(json.loads(response_content))\n",
"\n",
"# Validate the response content with the MathReasoning model\n",
"MathReasoning.model_validate(json.loads(response_content))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import re
import warnings
from asyncio import Task
from typing import (
Any,
AsyncGenerator,
Expand All @@ -14,13 +15,15 @@
Optional,
Sequence,
Set,
Type,
Union,
cast,
)

import tiktoken
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
Expand All @@ -31,9 +34,13 @@
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
ParsedChatCompletion,
ParsedChoice,
completion_create_params,
)
from openai.types.chat.chat_completion import Choice
from openai.types.shared_params import FunctionDefinition, FunctionParameters
from pydantic import BaseModel
from typing_extensions import Unpack

from ...application.logging import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
Expand Down Expand Up @@ -279,10 +286,10 @@ def convert_tools(
type="function",
function=FunctionDefinition(
name=tool_schema["name"],
description=tool_schema["description"] if "description" in tool_schema else "",
parameters=cast(FunctionParameters, tool_schema["parameters"])
if "parameters" in tool_schema
else {},
description=(tool_schema["description"] if "description" in tool_schema else ""),
parameters=(
cast(FunctionParameters, tool_schema["parameters"]) if "parameters" in tool_schema else {}
),
),
)
)
Expand Down Expand Up @@ -365,6 +372,24 @@ async def create(
create_args = self._create_args.copy()
create_args.update(extra_create_args)

# Declare use_beta_client
use_beta_client: bool = False
response_format_value: Optional[Type[BaseModel]] = None

if "response_format" in create_args:
value = create_args["response_format"]
# If value is a Pydantic model class, use the beta client
if isinstance(value, type) and issubclass(value, BaseModel):
response_format_value = value
use_beta_client = True
else:
# response_format_value is not a Pydantic model class
use_beta_client = False
response_format_value = None

# Remove 'response_format' from create_args to prevent passing it twice
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}

# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
if self.capabilities["vision"] is False:
Expand All @@ -390,24 +415,69 @@ async def create(

if self.capabilities["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling")

future: Union[Task[ParsedChatCompletion[BaseModel]], Task[ChatCompletion]]
if len(tools) > 0:
converted_tools = convert_tools(tools)
future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=oai_messages,
stream=False,
tools=converted_tools,
**create_args,
if use_beta_client:
# Pass response_format_value if it's not None
if response_format_value is not None:
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=oai_messages,
tools=converted_tools,
response_format=response_format_value,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=oai_messages,
tools=converted_tools,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=oai_messages,
stream=False,
tools=converted_tools,
**create_args,
)
)
)
else:
future = asyncio.ensure_future(
self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
)
if use_beta_client:
if response_format_value is not None:
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=oai_messages,
response_format=response_format_value,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=oai_messages,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=oai_messages,
stream=False,
**create_args,
)
)

if cancellation_token is not None:
cancellation_token.link_future(future)
result = await future
result: Union[ParsedChatCompletion[BaseModel], ChatCompletion] = await future
if use_beta_client:
result = cast(ParsedChatCompletion[Any], result)

if result.usage is not None:
logger.info(
LLMCallEvent(
Expand All @@ -430,7 +500,7 @@ async def create(
)

# Limited to a single choice currently.
choice = result.choices[0]
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0]
if choice.finish_reason == "function_call":
raise ValueError("Function calls are not supported in this context")

Expand Down

0 comments on commit a106229

Please sign in to comment.