Skip to content

Commit

Permalink
raise exception in case of error, instead of returning None
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Dec 13, 2024
1 parent e78ec14 commit bd4f742
Showing 1 changed file with 20 additions and 30 deletions.
50 changes: 20 additions & 30 deletions defog_utils/utils_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def chat_anthropic(
json_mode: bool = False,
response_format=None,
seed: int = 0,
) -> Optional[LLMResponse]:
) -> LLMResponse:
"""
Returns the response from the Anthropic API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
Note that anthropic doesn't have explicit json mode api constraints, nor does it have a seed parameter.
Expand All @@ -44,11 +44,9 @@ def chat_anthropic(
stop_sequences=stop,
)
if response.stop_reason == "max_tokens":
print("Max tokens reached")
return None
raise Exception("Max tokens reached")
if len(response.content) == 0:
print("Empty response")
return None
raise Exception("Max tokens reached")
return LLMResponse(
response.content[0].text,
round(time.time() - t, 3),
Expand All @@ -69,7 +67,7 @@ async def chat_anthropic_async(
store=True,
metadata=None,
timeout=100,
) -> Optional[LLMResponse]:
) -> LLMResponse:
"""
Returns the response from the Anthropic API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
Note that anthropic doesn't have explicit json mode api constraints, nor does it have a seed parameter.
Expand All @@ -92,11 +90,9 @@ async def chat_anthropic_async(
stop_sequences=stop,
)
if response.stop_reason == "max_tokens":
print("Max tokens reached")
return None
raise Exception("Max tokens reached")
if len(response.content) == 0:
print("Empty response")
return None
raise Exception("Max tokens reached")
return LLMResponse(
response.content[0].text,
round(time.time() - t, 3),
Expand All @@ -114,7 +110,7 @@ def chat_openai(
json_mode: bool = False,
response_format=None,
seed: int = 0,
) -> Optional[LLMResponse]:
) -> LLMResponse:
"""
Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
We use max_completion_tokens here, instead of using max_tokens. This is to support o1 models.
Expand Down Expand Up @@ -155,11 +151,9 @@ def chat_openai(
seed=seed,
)
if response.choices[0].finish_reason == "length":
print("Max tokens reached")
return None
raise Exception("Max tokens reached")
if len(response.choices) == 0:
print("Empty response")
return None
raise Exception("Max tokens reached")

if response_format and model not in ["o1-mini", "o1-preview", "o1"]:
content = response.choices[0].message.parsed
Expand Down Expand Up @@ -187,7 +181,7 @@ async def chat_openai_async(
store=True,
metadata=None,
timeout=100,
) -> Optional[LLMResponse]:
) -> LLMResponse:
"""
Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
We use max_completion_tokens here, instead of using max_tokens. This is to support o1 models.
Expand Down Expand Up @@ -241,10 +235,10 @@ async def chat_openai_async(

if response.choices[0].finish_reason == "length":
print("Max tokens reached")
return None
raise Exception("Max tokens reached")
if len(response.choices) == 0:
print("Empty response")
return None
raise Exception("No response")
return LLMResponse(
content,
round(time.time() - t, 3),
Expand All @@ -263,7 +257,7 @@ def chat_together(
json_mode: bool = False,
response_format=None,
seed: int = 0,
) -> Optional[LLMResponse]:
) -> LLMResponse:
"""
Returns the response from the Together API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
Together's max_tokens refers to the maximum completion tokens.
Expand All @@ -282,11 +276,9 @@ def chat_together(
seed=seed,
)
if response.choices[0].finish_reason == "length":
print("Max tokens reached")
return None
raise Exception("Max tokens reached")
if len(response.choices) == 0:
print("Empty response")
return None
raise Exception("Max tokens reached")
return LLMResponse(
response.choices[0].message.content,
round(time.time() - t, 3),
Expand All @@ -307,7 +299,7 @@ async def chat_together_async(
store=True,
metadata=None,
timeout=100,
) -> Optional[LLMResponse]:
) -> LLMResponse:
"""
Returns the response from the Together API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
Together's max_tokens refers to the maximum completion tokens.
Expand All @@ -326,11 +318,9 @@ async def chat_together_async(
seed=seed,
)
if response.choices[0].finish_reason == "length":
print("Max tokens reached")
return None
raise Exception("Max tokens reached")
if len(response.choices) == 0:
print("Empty response")
return None
raise Exception("Max tokens reached")
return LLMResponse(
response.choices[0].message.content,
round(time.time() - t, 3),
Expand All @@ -350,7 +340,7 @@ def chat_gemini(
seed: int = 0,
store=True,
metadata=None,
) -> Optional[LLMResponse]:
) -> LLMResponse:
from google import genai
from google.genai import types

Expand Down Expand Up @@ -409,7 +399,7 @@ async def chat_gemini_async(
store=True,
metadata=None,
timeout=100, # does not have timeout method
) -> Optional[LLMResponse]:
) -> LLMResponse:
from google import genai
from google.genai import types

Expand Down

0 comments on commit bd4f742

Please sign in to comment.