Skip to content

Commit

Permalink
added standard type-safe openai calls
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeichhorn committed Oct 26, 2023
1 parent bd2981d commit dded644
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 16 deletions.
180 changes: 167 additions & 13 deletions gpt_condom/openai/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,201 @@
from typing import Generic, TypeVar
from typing import Any, AsyncGenerator, Awaitable, Generic, Literal, TypeVar, overload

import openai
import tiktoken

from ..message_collection_builder import EncodedMessage, MessageCollectionFactory
from ..prompt_definition.prompt_template import PromptTemplate, _Output
from .views import OpenAIChatModel
from .views import ChatCompletionChunk, ChatCompletionResult, EncodedFunction, FunctionCallBehavior, OpenAIChatModel

# Prompt = TypeVar("Prompt", bound=PromptTemplate)


# TODO: change to better name
class OpenAIChatCompletion(openai.ChatCompletion):
@overload
@classmethod
async def acreate(
cls,
model: OpenAIChatModel,
messages: list[dict],
stream: Literal[True],
frequency_penalty: float | None = None, # [-2, 2]
function_call: FunctionCallBehavior | None = None,
functions: list[EncodedFunction] = [],
logit_bias: dict[int, float] | None = None, # [-100, 100]
stop: list[str] | None = None,
max_tokens: int = 1000,
n: int | None = None,
presence_penalty: float | None = None, # [-2, 2]
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
) -> tuple[AsyncGenerator[ChatCompletionChunk, None]]:
...

@overload
@classmethod
async def acreate(
cls,
model: OpenAIChatModel,
messages: list[dict],
stream: Literal[False] = False,
frequency_penalty: float | None = None, # [-2, 2]
function_call: FunctionCallBehavior | None = None,
functions: list[EncodedFunction] = [],
logit_bias: dict[int, float] | None = None, # [-100, 100]
stop: list[str] | None = None,
max_tokens: int = 1000,
n: int | None = None,
presence_penalty: float | None = None, # [-2, 2]
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
) -> ChatCompletionResult:
...

@classmethod
async def acreate(
cls,
model: OpenAIChatModel,
messages: list[dict],
stream: bool = False,
frequency_penalty: float | None = None, # [-2, 2]
function_call: FunctionCallBehavior | None = None,
functions: list[EncodedFunction] = [],
logit_bias: dict[int, float] | None = None, # [-100, 100]
stop: list[str] | None = None,
max_tokens: int = 1000,
n: int | None = None,
presence_penalty: float | None = None, # [-2, 2]
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
) -> ChatCompletionResult | tuple[AsyncGenerator[ChatCompletionChunk, None]]:
kwargs = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"stream": stream,
}

if frequency_penalty is not None:
kwargs["frequency_penalty"] = frequency_penalty

if functions:
kwargs["functions"] = functions

if function_call:
kwargs["function_call"] = function_call

if logit_bias:
kwargs["logit_bias"] = logit_bias

if stop:
kwargs["stop"] = stop

if n is not None:
kwargs["n"] = n

if presence_penalty is not None:
kwargs["presence_penalty"] = presence_penalty

if temperature is not None:
kwargs["temperature"] = temperature

if top_p is not None:
kwargs["top_p"] = top_p

if user is not None:
kwargs["user"] = user

if stream:

async def _process_stream(raw_stream: AsyncGenerator[list[dict] | dict, None]) -> AsyncGenerator[ChatCompletionChunk, None]:
async for chunk in raw_stream:
if isinstance(chunk, dict):
chunk = [chunk]

for c in chunk:
yield ChatCompletionChunk(**c)

raw_stream: AsyncGenerator[list[dict] | dict, None] = await openai.ChatCompletion.acreate(**kwargs) # type: ignore
return (_process_stream(raw_stream),) # tuple used to have correct static type

else:
raw_result: dict[str, Any] = await openai.ChatCompletion.acreate(**kwargs) # type: ignore
return ChatCompletionResult(**raw_result)

@classmethod
async def generate_completion(
cls,
model: OpenAIChatModel,
messages: list[dict],
frequency_penalty: float | None = None, # [-2, 2]
function_call: FunctionCallBehavior | None = None,
functions: list[EncodedFunction] = [],
logit_bias: dict[int, float] | None = None, # [-100, 100]
stop: list[str] | None = None,
max_tokens: int = 1000,
n: int | None = None,
presence_penalty: float | None = None, # [-2, 2]
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
) -> str:
result = await cls.acreate(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
stop=stop,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
user=user,
)

return result["choices"][0]["message"].get("content") or ""

@classmethod
async def generate_output(
cls,
model: OpenAIChatModel,
prompt: PromptTemplate[_Output],
max_output_tokens: int,
max_input_tokens: int | None = None,
**kwargs,
frequency_penalty: float | None = None, # [-2, 2]
n: int | None = None,
presence_penalty: float | None = None, # [-2, 2]
temperature: float | None = None,
top_p: float | None = None,
) -> _Output:
"""
Calls OpenAI Chat API, generates assistant response, and fits it into the output class
"""

kwargs["model"] = model
kwargs["stream"] = False

max_prompt_length = cls.max_tokens_of_model(model) - max_output_tokens

if max_input_tokens:
max_prompt_length = min(max_prompt_length, max_input_tokens)

kwargs["messages"] = prompt.generate_messages(
messages = prompt.generate_messages(
token_limit=max_prompt_length, token_counter=lambda messages: cls.num_tokens_from_messages(messages, model=model)
)
kwargs["max_tokens"] = max_output_tokens

result = await openai.ChatCompletion.acreate(**kwargs)

message = result["choices"][0]["message"] # type: ignore
completion = message.get("content", "")
completion = await cls.generate_completion(
model=model,
messages=messages,
max_tokens=max_output_tokens,
frequency_penalty=frequency_penalty,
n=n,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
)

return prompt.Output.parse_response(completion)

Expand Down
77 changes: 76 additions & 1 deletion gpt_condom/openai/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Literal
from dataclasses import dataclass
from typing import Literal, TypedDict

OpenAIChatModel = Literal[
"gpt-3.5-turbo", # 3.5 turbo
Expand All @@ -13,3 +14,77 @@
"gpt-4-32k-0314",
"gpt-4-32k-0613",
]


class FunctionCallForceBehavior(TypedDict):
name: str # function name


FunctionCallBehavior = Literal["auto", "none"] | FunctionCallForceBehavior


EncodedFunction = dict[str, "EncodedFunction | str | list[str] | list[EncodedFunction] | list[str] | None"]


# region - Outputs


ChatCompletionRole = Literal["function", "system", "user", "assistant"]


class FunctionCall(TypedDict):
name: str
arguments: str


class Message(TypedDict):
content: str | None
role: ChatCompletionRole
function_call: FunctionCall | None


class Choice(TypedDict):
finish_reason: Literal["stop", "lenght", "function_call", "content_filter"]
index: int
message: Message


class CompletionUsage(TypedDict):
completion_tokens: int
prompt_tokens: int
total_tokens: int


class ChatCompletionResult(TypedDict):
id: str
model: str
object: str
created: int
choices: list[Choice]
usage: CompletionUsage | None


# - Streaming


@dataclass
class ChatCompletionChunk:
@dataclass
class Choice:
class Delta:
content: str | None
function_call: FunctionCall | None
role: ChatCompletionRole | None

delta: Delta
finish_reason: Literal["stop", "length", "function_call", "content_filter", None]
index: int

id: str
model: str
choices: list[Choice]
created: int
object: str


# endregion
16 changes: 14 additions & 2 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,19 @@ def test_token_counter(self):
@pytest.fixture
def mock_openai_completion(self, mocker):
async def async_mock(*args, **kwargs):
return {"choices": [{"message": {"content": "TITLE: This is a test completion\nCOUNT: 09"}}]}
return {
"id": "test",
"model": "gpt-3.5-turbo",
"object": "x",
"created": 123,
"choices": [
{
"finish_reason": "stop",
"index": 1,
"message": {"role": "assistant", "content": "TITLE: This is a test completion\nCOUNT: 09"},
}
],
}

mocker.patch("gpt_condom.openai.chat_completion.openai.ChatCompletion.acreate", new=async_mock)

Expand All @@ -58,7 +70,7 @@ class Output(BaseLLMResponse):
title: str
count: int

result = await OpenAIChatCompletion.acreate(
result = await OpenAIChatCompletion.generate_output(
model="gpt-3.5-turbo",
prompt=FullExamplePrompt(),
max_output_tokens=100,
Expand Down

0 comments on commit dded644

Please sign in to comment.