Skip to content

Commit

Permalink
add asynchronous chat function for llm calls
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Nov 15, 2024
1 parent ce17720 commit 7f058a5
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 0 deletions.
137 changes: 137 additions & 0 deletions defog_utils/utils_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
import asyncio


@dataclass
Expand Down Expand Up @@ -57,6 +58,50 @@ def chat_anthropic(
)


async def chat_anthropic_async(
messages: List[Dict[str, str]],
model: str = "claude-3-5-sonnet-20241022",
max_completion_tokens: int = 8192,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
seed: int = 0,
) -> Optional[LLMResponse]:
"""
Returns the response from the Anthropic API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
Note that anthropic doesn't have explicit json mode api constraints, nor does it have a seed parameter.
"""
from anthropic import AsyncAnthropic

client_anthropic = AsyncAnthropic()
t = time.time()
if len(messages) >= 1 and messages[0].get("role") == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]
else:
sys_msg = ""
response = await client_anthropic.messages.create(
system=sys_msg,
messages=messages,
model=model,
max_tokens=max_completion_tokens,
temperature=temperature,
stop_sequences=stop,
)
if response.stop_reason == "max_tokens":
print("Max tokens reached")
return None
if len(response.content) == 0:
print("Empty response")
return None
return LLMResponse(
response.content[0].text,
round(time.time() - t, 3),
response.usage.input_tokens,
response.usage.output_tokens,
)


def chat_openai(
messages: List[Dict[str, str]],
model: str = "gpt-4o",
Expand Down Expand Up @@ -109,6 +154,58 @@ def chat_openai(
)


async def chat_openai_async(
messages: List[Dict[str, str]],
model: str = "gpt-4o",
max_completion_tokens: int = 16384,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
seed: int = 0,
) -> Optional[LLMResponse]:
"""
Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
We use max_completion_tokens here, instead of using max_tokens. This is to support o1 models.
"""
from openai import AsyncOpenAI

client_openai = AsyncOpenAI()
t = time.time()
if model in ["o1-mini", "o1-preview", "o1"]:
if messages[0].get("role") == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]
messages[0]["content"] = sys_msg + messages[0]["content"]
response = await client_openai.chat.completions.create(
messages=messages,
model=model,
max_completion_tokens=max_completion_tokens,
)
else:
response = await client_openai.chat.completions.create(
messages=messages,
model=model,
max_completion_tokens=max_completion_tokens,
temperature=temperature,
stop=stop,
response_format={"type": "json_object"} if json_mode else None,
seed=seed,
)
if response.choices[0].finish_reason == "length":
print("Max tokens reached")
return None
if len(response.choices) == 0:
print("Empty response")
return None
return LLMResponse(
response.choices[0].message.content,
round(time.time() - t, 3),
response.usage.prompt_tokens,
response.usage.completion_tokens,
response.usage.completion_tokens_details,
)


def chat_together(
messages: List[Dict[str, str]],
model: str = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
Expand Down Expand Up @@ -149,6 +246,46 @@ def chat_together(
)


async def chat_together_async(
messages: List[Dict[str, str]],
model: str = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_completion_tokens: int = 4096,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
seed: int = 0,
) -> Optional[LLMResponse]:
"""
Returns the response from the Together API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
Together's max_tokens refers to the maximum completion tokens.
Together doesn't have explicit json mode api constraints.
"""
from together import AsyncTogether

client_together = AsyncTogether()
t = time.time()
response = await client_together.chat.completions.create(
messages=messages,
model=model,
max_tokens=max_completion_tokens,
temperature=temperature,
stop=stop,
seed=seed,
)
if response.choices[0].finish_reason == "length":
print("Max tokens reached")
return None
if len(response.choices) == 0:
print("Empty response")
return None
return LLMResponse(
response.choices[0].message.content,
round(time.time() - t, 3),
response.usage.prompt_tokens,
response.usage.completion_tokens,
)


def chat_gemini(
messages: List[Dict[str, str]],
model: str = "gemini-1.5-pro",
Expand Down
45 changes: 45 additions & 0 deletions defog_utils/utils_multi_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
chat_gemini,
chat_openai,
chat_together,
chat_anthropic_async,
chat_openai_async,
chat_together_async,
)


Expand All @@ -29,6 +32,48 @@ def map_model_to_chat_fn(model: str) -> Callable:
raise ValueError(f"Unknown model: {model}")


def map_model_to_chat_fn_async(model: str) -> Callable:
"""
Returns the appropriate chat function based on the model.
"""
if model.startswith("claude"):
return chat_anthropic_async
if model.startswith("gemini"):
return ValueError("Gemini does not support async chat")
if model.startswith("gpt") or model in ["o1", "o1-mini", "o1-preview"]:
return chat_openai_async
if (
model.startswith("meta-llama")
or model.startswith("mistralai")
or model.startswith("Qwen")
):
return chat_together_async
raise ValueError(f"Unknown model: {model}")


async def chat_async(
model,
messages,
max_completion_tokens=4096,
temperature=0.0,
stop=[],
json_mode=False,
seed=0,
) -> LLMResponse:
"""
Returns the response from the LLM API for a single model that is passed in.
"""
llm_function = map_model_to_chat_fn_async(model)
return await llm_function(
messages=messages,
max_completion_tokens=max_completion_tokens,
temperature=temperature,
stop=stop,
json_mode=json_mode,
seed=seed,
)


def chat(
models,
messages,
Expand Down

0 comments on commit 7f058a5

Please sign in to comment.