Skip to content

Commit

Permalink
Address PR comments from Sam Sucik
Browse files Browse the repository at this point in the history
- address PR comments

Co-authored-by: Sam Sucik <[email protected]>
  • Loading branch information
gkaretka and samsucik committed Oct 31, 2024
1 parent d48fad0 commit dd05258
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 80 deletions.
4 changes: 2 additions & 2 deletions prompterator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class StructuredOutputImplementation(Enum):


@dataclasses.dataclass
class StructuredOutputData:
class StructuredOutputConfig:
enabled: bool
schema: str
method: StructuredOutputImplementation
Expand All @@ -34,7 +34,7 @@ class StructuredOutputData:
@dataclasses.dataclass
class ModelInputs:
inputs: Dict[str, Any]
structured_output_data: StructuredOutputData
structured_output_data: StructuredOutputConfig


class ModelProperties(BaseModel):
Expand Down
17 changes: 9 additions & 8 deletions prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def set_up_dynamic_session_state_vars():

def run_prompt(progress_ui_area):
progress_bar = progress_ui_area.progress(0, text="generating texts")

system_prompt_template = st.session_state.system_prompt
user_prompt_template = st.session_state.user_prompt

Expand All @@ -126,7 +127,7 @@ def run_prompt(progress_ui_area):
st.session_state.selected_structured_output_method
)

structured_output_params = c.StructuredOutputData(
structured_output_params = c.StructuredOutputConfig(
structured_output_enabled, prompt_json_schema, selected_structured_output_method
)

Expand Down Expand Up @@ -427,13 +428,14 @@ def set_up_ui_generation():
label_visibility="collapsed",
key=c.PROMPT_COMMENT_KEY,
)

selected_model: c.ModelProperties = st.session_state.model
available_structured_output_settings = selected_model.supports_structured_output
available_structured_output_settings = (
selected_model.supported_structured_output_implementations
)

# Allow structured outputs only if the model allows other implementation
# than NONE, other implementations currently include FUNCTION_CALLING
# and RESPONSE_FORMAT. Models by default do not require this parameter to be set.
# than NONE. Models by default do not require this parameter to be set.
structured_output_available = (
len(set(available_structured_output_settings) - {c.StructuredOutputImplementation.NONE})
> 0
Expand Down Expand Up @@ -468,12 +470,11 @@ def set_up_ui_generation():
disabled=not model_supports_user_prompt,
)

if structured_output_available and structured_output_enabled:
if structured_output_enabled:
json_input = st.container()
json_input.text_area(
label="JSON Schema",
placeholder="Your JSON schema goes here",
value=c.DEFAULT_JSON_SCHEMA,
placeholder=c.DEFAULT_JSON_SCHEMA,
key="prompt_json_schema",
height=c.PROMPT_TEXT_AREA_HEIGHT,
)
Expand Down
43 changes: 43 additions & 0 deletions prompterator/model_specific_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json


def build_function_calling_tooling(json_schema: str):
"""
@param json_schema: contains desired output schema in proper Json Schema format
@return: (tools, function name) where
- tools is list of tools (single function in this case) callable by OpenAI model
in function calling mode.
- function name is the name of the desired function to be called
"""
schema = json.loads(json_schema)
function = schema.copy()
function_name = function.pop("title")
description = (
function.pop("description")
if function.get("description", None) is not None
else function_name
)
tools = [
{
"type": "function",
"function": {
"name": function_name,
"description": description,
"parameters": function,
},
}
]

return tools, function_name


def build_response_format(json_schema: str):
"""
@param json_schema: contains desired output schema in proper Json Schema format
@return: dict with desired response format directly usable with OpenAI API
"""
json_schema = json.loads(json_schema)
schema = {"name": json_schema.pop("title"), "schema": json_schema, "strict": True}
response_format = {"type": "json_schema", "json_schema": schema}

return response_format
84 changes: 45 additions & 39 deletions prompterator/models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import time

import openai
from prompterator.model_specific_utils import build_function_calling_tooling, build_response_format
from openai import AzureOpenAI, OpenAI

from prompterator.utils import build_function_calling_tooling, build_response_format

logger = logging.getLogger(__name__)

from prompterator.constants import ( # isort:skip
CONFIGURABLE_MODEL_PARAMETER_PROPERTIES,
ModelProperties,
PrompteratorLLM,
StructuredOutputImplementation as soi,
StructuredOutputData,
StructuredOutputConfig,
)


Expand Down Expand Up @@ -89,41 +88,48 @@ def __init__(self):

super().__init__()

@staticmethod
def enrich_model_params_of_function_calling(structured_output_data, model_params):
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
model_params["tools"], function_name = build_function_calling_tooling(
structured_output_data.schema
)
model_params["tool_choice"] = {
"type": "function",
"function": {"name": function_name},
}
if structured_output_data.method == soi.RESPONSE_FORMAT:
model_params["response_format"] = build_response_format(
structured_output_data.schema
)
return model_params

@staticmethod
def process_response(structured_output_data, response_data):
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
response_text = response_data.choices[0].message.tool_calls[0].function.arguments
elif structured_output_data.method == soi.RESPONSE_FORMAT:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
return response_text

def call(self, idx, input, **kwargs):
structured_output_data: StructuredOutputData = kwargs["structured_output"]
structured_output_data: StructuredOutputConfig = kwargs["structured_output"]
model_params = kwargs["model_params"]
try:
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
model_params["tools"], function_name = build_function_calling_tooling(
structured_output_data.schema
)
model_params["tool_choice"] = {
"type": "function",
"function": {"name": function_name},
}
if structured_output_data.method == soi.RESPONSE_FORMAT:
model_params["response_format"] = build_response_format(
structured_output_data.schema
)

model_params = ChatGPTMixin.enrich_model_params_of_function_calling(
structured_output_data, model_params
)
response_data = self.client.chat.completions.create(
model=self.specific_model_name or self.name, messages=input, **model_params
)

response_text = None
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
response_text = (
response_data.choices[0].message.tool_calls[0].function.arguments
)
elif structured_output_data.method == soi.RESPONSE_FORMAT:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content

response_text = ChatGPTMixin.process_response(structured_output_data, response_data)
return {"response": response_text, "data": response_data, "idx": idx}
except openai.RateLimitError as e:
logger.error(
Expand Down Expand Up @@ -160,7 +166,7 @@ class GPT4o(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=1,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
Expand All @@ -176,7 +182,7 @@ class GPT4oAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=6,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
Expand All @@ -194,7 +200,7 @@ class GPT4oMini(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=2,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
Expand All @@ -210,7 +216,7 @@ class GPT4oMiniAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=7,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
Expand All @@ -228,7 +234,7 @@ class GPT35Turbo(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=3,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
],
Expand All @@ -243,7 +249,7 @@ class GPT35TurboAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=8,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
],
Expand All @@ -260,7 +266,7 @@ class GPT4(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=4,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
],
Expand All @@ -275,7 +281,7 @@ class GPT4Azure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=9,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
],
Expand Down
31 changes: 0 additions & 31 deletions prompterator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,37 +382,6 @@ def format_traceback_for_markdown(text):
return re.sub(r"\n", "\n\n", text)


def build_function_calling_tooling(json_schema):
schema = json.loads(json_schema)
function = schema.copy()
function_name = function.pop("title")
description = (
function.pop("description")
if function.get("description", None) is not None
else function_name
)
tools = [
{
"type": "function",
"function": {
"name": function_name,
"description": description,
"parameters": function,
},
}
]

return tools, function_name


def build_response_format(json_schema):
json_schema = json.loads(json_schema)
schema = {"name": json_schema.pop("title"), "schema": json_schema, "strict": True}
response_format = {"type": "json_schema", "json_schema": schema}

return response_format


def validate_json(text):
try:
json.loads(text)
Expand Down

0 comments on commit dd05258

Please sign in to comment.