Skip to content

Commit

Permalink
Merge pull request #6 from alexeichhorn/feature/detailed-errors
Browse files Browse the repository at this point in the history
Inject details in LLMExceptions
  • Loading branch information
alexeichhorn authored May 2, 2024
2 parents 0cb4b5d + 75c7e7f commit e68df7a
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 30 deletions.
81 changes: 79 additions & 2 deletions tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from openai.types.chat.chat_completion_message import ChatCompletionMessage

from typegpt import BaseLLMResponse, LLMArrayOutput, LLMOutput, PromptTemplate
from typegpt.exceptions import LLMTokenLimitExceeded
from typegpt.exceptions import LLMOutputFieldWrongType, LLMTokenLimitExceeded
from typegpt.openai import AsyncTypeAzureOpenAI, AsyncTypeOpenAI, OpenAIChatModel, TypeAzureOpenAI, TypeOpenAI


Expand Down Expand Up @@ -411,13 +411,16 @@ class Output(BaseLLMResponse):

non_reducing_prompt_1000 = NonAutomaticReducingPrompt(1000)

with pytest.raises(LLMTokenLimitExceeded):
with pytest.raises(LLMTokenLimitExceeded) as exc:
result = await client.chat.completions.generate_output(
model="gpt-3.5-turbo-0613",
prompt=non_reducing_prompt_1000,
max_output_tokens=100,
)

assert exc.value.system_prompt == "This is a random system prompt"
assert exc.value.raw_completion is None

class ReducingTestPrompt(PromptTemplate):
def __init__(self, number: int):
self.lines = [f"This is line {i}" for i in range(number)]
Expand Down Expand Up @@ -490,3 +493,77 @@ class Output(BaseLLMResponse):

assert result.title == "This is a test completion"
assert result.count == 9

# region: - Exceptions

def test_exception_injection_sync(self, mock_openai_completion_sync):
class ExamplePrompt(PromptTemplate):
class Output(BaseLLMResponse):
title: int # wrong type
count: int

def system_prompt(self) -> str:
return "This is a random system prompt"

def user_prompt(self) -> str:
return "This is a random user prompt"

client = TypeOpenAI(api_key="mock")

with pytest.raises(LLMOutputFieldWrongType) as exc:
result = client.chat.completions.generate_output(
model="gpt-3.5-turbo-0613",
prompt=ExamplePrompt(),
output_type=ExamplePrompt.Output,
max_output_tokens=100,
)

assert exc.value.system_prompt and exc.value.system_prompt.startswith("This is a random system prompt") # + format instruction
assert exc.value.user_prompt == "This is a random user prompt"
assert exc.value.raw_completion == "TITLE: This is a test completion\nCOUNT: 09"

@pytest.mark.asyncio
async def test_exception_injection_async(self, mock_openai_completion):
class ExamplePrompt(PromptTemplate):
class Output(BaseLLMResponse):
title: int # wrong type
count: int

def system_prompt(self) -> str:
return "This is a random system prompt"

def user_prompt(self) -> str:
return "This is a random user prompt"

client = AsyncTypeOpenAI(api_key="mock")

with pytest.raises(LLMOutputFieldWrongType) as exc:
result = await client.chat.completions.generate_output(
model="gpt-3.5-turbo-0613",
prompt=ExamplePrompt(),
output_type=ExamplePrompt.Output,
max_output_tokens=100,
)

assert exc.value.system_prompt and exc.value.system_prompt.startswith("This is a random system prompt") # + format instruction
assert exc.value.user_prompt == "This is a random user prompt"
assert exc.value.raw_completion == "TITLE: This is a test completion\nCOUNT: 09"

# - Azure

azure_client = AsyncTypeAzureOpenAI(api_key="mock", azure_endpoint="mock", api_version="mock")

with pytest.raises(LLMOutputFieldWrongType) as exc2:
result = await azure_client.chat.completions.generate_output(
model="gpt-3.5-turbo-0613",
prompt=ExamplePrompt(),
output_type=ExamplePrompt.Output,
max_output_tokens=100,
)

assert exc.value.system_prompt and exc.value.system_prompt.startswith("This is a random system prompt") # + format instruction
assert exc2.value.user_prompt == "This is a random user prompt"
assert exc2.value.raw_completion == "TITLE: This is a test completion\nCOUNT: 09"


# endregion: - Exceptions
18 changes: 15 additions & 3 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,25 @@ def test_parse_multiline_multiple_output(self):
TEXT: L1
"""

with pytest.raises(LLMOutputFieldMissing):
with pytest.raises(LLMOutputFieldMissing) as exc1:
self.MultilineMultipleTestOutput.parse_response(completion_output_4)

assert exc1.value.system_prompt is None
assert exc1.value.user_prompt is None
assert exc1.value.raw_completion == completion_output_4

completion_output = """
TEXT: L1
VALUE: 8xz
"""

with pytest.raises(LLMOutputFieldWrongType):
with pytest.raises(LLMOutputFieldWrongType) as exc2:
self.MultilineMultipleTestOutput.parse_response(completion_output)

assert exc2.value.system_prompt is None
assert exc2.value.user_prompt is None
assert exc2.value.raw_completion == completion_output

# endregion
# region - 4

Expand Down Expand Up @@ -182,9 +190,13 @@ def test_parse_multiline_array_output(self):
APPLE 1: L1
"""

with pytest.raises(LLMOutputFieldInvalidLength):
with pytest.raises(LLMOutputFieldInvalidLength) as exc:
self.MultilineArrayTestOutput.parse_response(completion_output_2)

assert exc.value.system_prompt is None
assert exc.value.user_prompt is None
assert exc.value.raw_completion == completion_output_2

completion_output_3 = """
APPLE 1: L1
APPLE 2: L2
Expand Down
8 changes: 6 additions & 2 deletions typegpt/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

from .exceptions import LLMOutputFieldInvalidLength, LLMOutputFieldMissing, LLMOutputFieldWrongType
from .exceptions import LLMException, LLMOutputFieldInvalidLength, LLMOutputFieldMissing, LLMOutputFieldWrongType
from .fields import ClassPlaceholder, LLMArrayElementOutputInfo, LLMArrayOutputInfo, LLMFieldInfo, LLMOutputInfo
from .meta import LLMArrayElementMeta, LLMBaseMeta
from .parser import Parser
Expand Down Expand Up @@ -163,7 +163,11 @@ def _set_raw_completion(self, completion: str):

@classmethod
def parse_response(cls: type[_Self], response: str) -> _Self:
return Parser(cls).parse(response)
try:
return Parser(cls).parse(response)
except LLMException as e:
e.raw_completion = response
raise e


# -
Expand Down
22 changes: 11 additions & 11 deletions typegpt/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
class LLMException(Exception):
...

def __init__(self, message: str, system_prompt: str | None = None, user_prompt: str | None = None, raw_completion: str | None = None):
super().__init__(message)
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.raw_completion = raw_completion

class LLMTokenLimitExceeded(LLMException):
...

class LLMTokenLimitExceeded(LLMException): ...

class LLMParseException(LLMException):
...

class LLMParseException(LLMException): ...

class LLMOutputFieldMissing(LLMParseException):
...

class LLMOutputFieldMissing(LLMParseException): ...

class LLMOutputFieldWrongType(LLMParseException):
...

class LLMOutputFieldWrongType(LLMParseException): ...

class LLMOutputFieldInvalidLength(LLMParseException):
...

class LLMOutputFieldInvalidLength(LLMParseException): ...
6 changes: 5 additions & 1 deletion typegpt/message_collection_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def generate_messages(self, token_limit: int):
self.prompt = prompt # update the prompt if successful
return generated_messages

raise LLMTokenLimitExceeded(f"Prompt can't be reduced to fit within the token limit ({token_limit})")
raise LLMTokenLimitExceeded(
f"Prompt can't be reduced to fit within the token limit ({token_limit})",
system_prompt=prompt.system_prompt(),
user_prompt=prompt.user_prompt(),
)


if TYPE_CHECKING:
Expand Down
13 changes: 8 additions & 5 deletions typegpt/openai/_async/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

from ...base import BaseLLMResponse
from ...exceptions import LLMParseException
from ...exceptions import LLMException, LLMParseException
from ...prompt_definition.prompt_template import PromptTemplate
from ...utils.internal_types import _UseDefault, _UseDefaultType
from ..base_chat_completion import BaseChatCompletions
Expand Down Expand Up @@ -107,8 +107,7 @@ async def generate_output(
top_p: float | NotGiven = NOT_GIVEN,
timeout: float | None | NotGiven = NOT_GIVEN,
retry_on_parse_error: int = 0,
) -> _Output:
...
) -> _Output: ...

@overload
async def generate_output(
Expand All @@ -126,8 +125,7 @@ async def generate_output(
top_p: float | NotGiven = NOT_GIVEN,
timeout: float | None | NotGiven = NOT_GIVEN,
retry_on_parse_error: int = 0,
) -> BaseLLMResponse:
...
) -> BaseLLMResponse: ...

async def generate_output(
self,
Expand Down Expand Up @@ -208,4 +206,9 @@ async def generate_output(
retry_on_parse_error=retry_on_parse_error - 1,
)
else:
self._inject_exception_details(e, messages, completion)
raise e

except LLMException as e:
self._inject_exception_details(e, messages, completion)
raise e
13 changes: 8 additions & 5 deletions typegpt/openai/_sync/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

from ...base import BaseLLMResponse
from ...exceptions import LLMParseException
from ...exceptions import LLMException, LLMParseException
from ...prompt_definition.prompt_template import PromptTemplate
from ...utils.internal_types import _UseDefault, _UseDefaultType
from ..base_chat_completion import BaseChatCompletions
Expand Down Expand Up @@ -106,8 +106,7 @@ def generate_output(
top_p: float | NotGiven = NOT_GIVEN,
timeout: float | None | NotGiven = NOT_GIVEN,
retry_on_parse_error: int = 0,
) -> _Output:
...
) -> _Output: ...

@overload
def generate_output(
Expand All @@ -125,8 +124,7 @@ def generate_output(
top_p: float | NotGiven = NOT_GIVEN,
timeout: float | None | NotGiven = NOT_GIVEN,
retry_on_parse_error: int = 0,
) -> BaseLLMResponse:
...
) -> BaseLLMResponse: ...

def generate_output(
self,
Expand Down Expand Up @@ -207,4 +205,9 @@ def generate_output(
retry_on_parse_error=retry_on_parse_error - 1,
)
else:
self._inject_exception_details(e, messages, completion)
raise e

except LLMException as e:
self._inject_exception_details(e, messages, completion)
raise e
15 changes: 14 additions & 1 deletion typegpt/openai/base_chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import tiktoken

from typegpt.exceptions import LLMException

from ..message_collection_builder import EncodedMessage
from .views import OpenAIChatModel
import tiktoken


class BaseChatCompletions:
Expand Down Expand Up @@ -79,3 +82,13 @@ def num_tokens_from_messages(cls, messages: list[EncodedMessage], model: OpenAIC
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

# - Exception Handling

def _inject_exception_details(self, e: LLMException, messages: list[EncodedMessage], raw_completion: str):
system_prompt = next((m["content"] for m in messages if m["role"] == "system"), None)
user_prompt = next((m["content"] for m in messages if m["role"] == "user"), None)

e.system_prompt = system_prompt
e.user_prompt = user_prompt
e.raw_completion = raw_completion

0 comments on commit e68df7a

Please sign in to comment.