Skip to content

Commit

Permalink
switch to new gemini sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Dec 12, 2024
1 parent b91a9d5 commit 9f32609
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 31 deletions.
100 changes: 71 additions & 29 deletions defog_utils/utils_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,53 +339,95 @@ async def chat_together_async(

def chat_gemini(
messages: List[Dict[str, str]],
model: str = "gemini-1.5-pro",
model: str = "gemini-2.0-flash-exp",
max_completion_tokens: int = 8192,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
response_format=None,
seed: int = 0,
store=True,
metadata=None,
) -> Optional[LLMResponse]:
import google.generativeai as genai
from google import genai
from google.genai import types

genai.configure(api_key=os.environ["GEMINI_API_KEY"])
client = genai.Client(
api_key=os.getenv("GOOGLE_GENERATIVEAI_API_KEY"),
)
t = time.time()
generation_config = {
"temperature": temperature,
"max_output_tokens": max_completion_tokens,
"response_mime_type": "application/json" if json_mode else "text/plain",
"stop_sequences": stop,
# "seed": seed, # seed is not supported in the current version
}
if messages[0]["role"] == "system":
system_msg = messages[0]["content"]
messages = messages[1:]
else:
system_msg = None
final_msg = messages[-1]["content"]
messages = messages[:-1]
for msg in messages:
if msg["role"] != "user":
msg["role"] = "model"
client_gemini = genai.GenerativeModel(
model, generation_config=generation_config, system_instruction=system_msg

message = "\n".join([i["content"] for i in messages])

generation_config = types.GenerateContentConfig(
temperature=temperature,
system_instructions=system_msg,
max_output_tokens=max_completion_tokens,
stop_sequences=stop,
)
chat = client_gemini.start_chat(
history=messages,

response = client.models.generate_content(
model=model,
contents=message,
generation_config=generation_config,
)
response = chat.send_message(final_msg)
if len(response.candidates) == 0:
print("Empty response")
return None
if (
response.candidates[0].finish_reason.value != 1
): # 1 is the finish reason for STOP
print("Max tokens reached")
return None
content = response.text
return LLMResponse(
content=response.candidates[0].content.parts[0].text,
content=content,
time=round(time.time() - t, 3),
input_tokens=response.usage_metadata.prompt_token_count,
output_tokens=response.usage_metadata.candidates_token_count,
)


async def chat_gemini_async(
messages: List[Dict[str, str]],
model: str = "gemini-2.0-flash-exp",
max_completion_tokens: int = 8192,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
response_format=None,
seed: int = 0,
store=True,
metadata=None,
) -> Optional[LLMResponse]:
from google import genai
from google.genai import types

client = genai.AsyncClient(
api_key=os.getenv("GOOGLE_GENERATIVEAI_API_KEY"),
)
t = time.time()
if messages[0]["role"] == "system":
system_msg = messages[0]["content"]
messages = messages[1:]
else:
system_msg = None

message = "\n".join([i["content"] for i in messages])

generation_config = types.GenerateContentConfig(
temperature=temperature,
system_instructions=system_msg,
max_output_tokens=max_completion_tokens,
stop_sequences=stop,
)

response = await client.models.generate_content(
model=model,
contents=message,
generation_config=generation_config,
)
content = response.text
return LLMResponse(
content=content,
time=round(time.time() - t, 3),
input_tokens=response.usage_metadata.prompt_token_count,
output_tokens=response.usage_metadata.candidates_token_count,
)
3 changes: 2 additions & 1 deletion defog_utils/utils_multi_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
chat_openai,
chat_together,
chat_anthropic_async,
chat_gemini_async,
chat_openai_async,
chat_together_async,
)
Expand Down Expand Up @@ -39,7 +40,7 @@ def map_model_to_chat_fn_async(model: str) -> Callable:
if model.startswith("claude"):
return chat_anthropic_async
if model.startswith("gemini"):
raise ValueError("Gemini does not support async chat")
return chat_gemini_async
if model.startswith("gpt") or model.startswith("o1"):
return chat_openai_async
if (
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
anthropic==0.40.0
google-generativeai==0.8.3
google-genai==0.2.1
numpy
openai==1.57.2
psycopg2-binary==2.9.9
Expand Down

0 comments on commit 9f32609

Please sign in to comment.