From dded6444bc718386e2269c476a1e287e97cfcf81 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Thu, 26 Oct 2023 22:38:11 +0200 Subject: [PATCH] added standard type-safe openai calls --- gpt_condom/openai/chat_completion.py | 180 +++++++++++++++++++++++++-- gpt_condom/openai/views.py | 77 +++++++++++- tests/test_openai.py | 16 ++- 3 files changed, 257 insertions(+), 16 deletions(-) diff --git a/gpt_condom/openai/chat_completion.py b/gpt_condom/openai/chat_completion.py index 45a547c..eefa775 100644 --- a/gpt_condom/openai/chat_completion.py +++ b/gpt_condom/openai/chat_completion.py @@ -1,47 +1,201 @@ -from typing import Generic, TypeVar +from typing import Any, AsyncGenerator, Awaitable, Generic, Literal, TypeVar, overload import openai import tiktoken from ..message_collection_builder import EncodedMessage, MessageCollectionFactory from ..prompt_definition.prompt_template import PromptTemplate, _Output -from .views import OpenAIChatModel +from .views import ChatCompletionChunk, ChatCompletionResult, EncodedFunction, FunctionCallBehavior, OpenAIChatModel # Prompt = TypeVar("Prompt", bound=PromptTemplate) -# TODO: change to better name class OpenAIChatCompletion(openai.ChatCompletion): + @overload @classmethod async def acreate( + cls, + model: OpenAIChatModel, + messages: list[dict], + stream: Literal[True], + frequency_penalty: float | None = None, # [-2, 2] + function_call: FunctionCallBehavior | None = None, + functions: list[EncodedFunction] = [], + logit_bias: dict[int, float] | None = None, # [-100, 100] + stop: list[str] | None = None, + max_tokens: int = 1000, + n: int | None = None, + presence_penalty: float | None = None, # [-2, 2] + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> tuple[AsyncGenerator[ChatCompletionChunk, None]]: + ... + + @overload + @classmethod + async def acreate( + cls, + model: OpenAIChatModel, + messages: list[dict], + stream: Literal[False] = False, + frequency_penalty: float | None = None, # [-2, 2] + function_call: FunctionCallBehavior | None = None, + functions: list[EncodedFunction] = [], + logit_bias: dict[int, float] | None = None, # [-100, 100] + stop: list[str] | None = None, + max_tokens: int = 1000, + n: int | None = None, + presence_penalty: float | None = None, # [-2, 2] + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> ChatCompletionResult: + ... + + @classmethod + async def acreate( + cls, + model: OpenAIChatModel, + messages: list[dict], + stream: bool = False, + frequency_penalty: float | None = None, # [-2, 2] + function_call: FunctionCallBehavior | None = None, + functions: list[EncodedFunction] = [], + logit_bias: dict[int, float] | None = None, # [-100, 100] + stop: list[str] | None = None, + max_tokens: int = 1000, + n: int | None = None, + presence_penalty: float | None = None, # [-2, 2] + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> ChatCompletionResult | tuple[AsyncGenerator[ChatCompletionChunk, None]]: + kwargs = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "stream": stream, + } + + if frequency_penalty is not None: + kwargs["frequency_penalty"] = frequency_penalty + + if functions: + kwargs["functions"] = functions + + if function_call: + kwargs["function_call"] = function_call + + if logit_bias: + kwargs["logit_bias"] = logit_bias + + if stop: + kwargs["stop"] = stop + + if n is not None: + kwargs["n"] = n + + if presence_penalty is not None: + kwargs["presence_penalty"] = presence_penalty + + if temperature is not None: + kwargs["temperature"] = temperature + + if top_p is not None: + kwargs["top_p"] = top_p + + if user is not None: + kwargs["user"] = user + + if stream: + + async def _process_stream(raw_stream: AsyncGenerator[list[dict] | dict, None]) -> AsyncGenerator[ChatCompletionChunk, None]: + async for chunk in raw_stream: + if isinstance(chunk, dict): + chunk = [chunk] + + for c in chunk: + yield ChatCompletionChunk(**c) + + raw_stream: AsyncGenerator[list[dict] | dict, None] = await openai.ChatCompletion.acreate(**kwargs) # type: ignore + return (_process_stream(raw_stream),) # tuple used to have correct static type + + else: + raw_result: dict[str, Any] = await openai.ChatCompletion.acreate(**kwargs) # type: ignore + return ChatCompletionResult(**raw_result) + + @classmethod + async def generate_completion( + cls, + model: OpenAIChatModel, + messages: list[dict], + frequency_penalty: float | None = None, # [-2, 2] + function_call: FunctionCallBehavior | None = None, + functions: list[EncodedFunction] = [], + logit_bias: dict[int, float] | None = None, # [-100, 100] + stop: list[str] | None = None, + max_tokens: int = 1000, + n: int | None = None, + presence_penalty: float | None = None, # [-2, 2] + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> str: + result = await cls.acreate( + model=model, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + stop=stop, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + temperature=temperature, + top_p=top_p, + user=user, + ) + + return result["choices"][0]["message"].get("content") or "" + + @classmethod + async def generate_output( cls, model: OpenAIChatModel, prompt: PromptTemplate[_Output], max_output_tokens: int, max_input_tokens: int | None = None, - **kwargs, + frequency_penalty: float | None = None, # [-2, 2] + n: int | None = None, + presence_penalty: float | None = None, # [-2, 2] + temperature: float | None = None, + top_p: float | None = None, ) -> _Output: """ Calls OpenAI Chat API, generates assistant response, and fits it into the output class """ - kwargs["model"] = model - kwargs["stream"] = False - max_prompt_length = cls.max_tokens_of_model(model) - max_output_tokens if max_input_tokens: max_prompt_length = min(max_prompt_length, max_input_tokens) - kwargs["messages"] = prompt.generate_messages( + messages = prompt.generate_messages( token_limit=max_prompt_length, token_counter=lambda messages: cls.num_tokens_from_messages(messages, model=model) ) - kwargs["max_tokens"] = max_output_tokens - - result = await openai.ChatCompletion.acreate(**kwargs) - message = result["choices"][0]["message"] # type: ignore - completion = message.get("content", "") + completion = await cls.generate_completion( + model=model, + messages=messages, + max_tokens=max_output_tokens, + frequency_penalty=frequency_penalty, + n=n, + presence_penalty=presence_penalty, + temperature=temperature, + top_p=top_p, + ) return prompt.Output.parse_response(completion) diff --git a/gpt_condom/openai/views.py b/gpt_condom/openai/views.py index 9ff46ad..675b0b4 100644 --- a/gpt_condom/openai/views.py +++ b/gpt_condom/openai/views.py @@ -1,4 +1,5 @@ -from typing import Literal +from dataclasses import dataclass +from typing import Literal, TypedDict OpenAIChatModel = Literal[ "gpt-3.5-turbo", # 3.5 turbo @@ -13,3 +14,77 @@ "gpt-4-32k-0314", "gpt-4-32k-0613", ] + + +class FunctionCallForceBehavior(TypedDict): + name: str # function name + + +FunctionCallBehavior = Literal["auto", "none"] | FunctionCallForceBehavior + + +EncodedFunction = dict[str, "EncodedFunction | str | list[str] | list[EncodedFunction] | list[str] | None"] + + +# region - Outputs + + +ChatCompletionRole = Literal["function", "system", "user", "assistant"] + + +class FunctionCall(TypedDict): + name: str + arguments: str + + +class Message(TypedDict): + content: str | None + role: ChatCompletionRole + function_call: FunctionCall | None + + +class Choice(TypedDict): + finish_reason: Literal["stop", "lenght", "function_call", "content_filter"] + index: int + message: Message + + +class CompletionUsage(TypedDict): + completion_tokens: int + prompt_tokens: int + total_tokens: int + + +class ChatCompletionResult(TypedDict): + id: str + model: str + object: str + created: int + choices: list[Choice] + usage: CompletionUsage | None + + +# - Streaming + + +@dataclass +class ChatCompletionChunk: + @dataclass + class Choice: + class Delta: + content: str | None + function_call: FunctionCall | None + role: ChatCompletionRole | None + + delta: Delta + finish_reason: Literal["stop", "length", "function_call", "content_filter", None] + index: int + + id: str + model: str + choices: list[Choice] + created: int + object: str + + +# endregion diff --git a/tests/test_openai.py b/tests/test_openai.py index cf813a5..dc274bd 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -41,7 +41,19 @@ def test_token_counter(self): @pytest.fixture def mock_openai_completion(self, mocker): async def async_mock(*args, **kwargs): - return {"choices": [{"message": {"content": "TITLE: This is a test completion\nCOUNT: 09"}}]} + return { + "id": "test", + "model": "gpt-3.5-turbo", + "object": "x", + "created": 123, + "choices": [ + { + "finish_reason": "stop", + "index": 1, + "message": {"role": "assistant", "content": "TITLE: This is a test completion\nCOUNT: 09"}, + } + ], + } mocker.patch("gpt_condom.openai.chat_completion.openai.ChatCompletion.acreate", new=async_mock) @@ -58,7 +70,7 @@ class Output(BaseLLMResponse): title: str count: int - result = await OpenAIChatCompletion.acreate( + result = await OpenAIChatCompletion.generate_output( model="gpt-3.5-turbo", prompt=FullExamplePrompt(), max_output_tokens=100,