Skip to content

Commit

Permalink
Add gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
wongjingping committed Oct 30, 2024
1 parent f1cce2c commit 959cdc9
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 19 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}

50 changes: 48 additions & 2 deletions defog_utils/utils_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from dataclasses import dataclass
import os
import time
from typing import Optional, List, Dict
from dataclasses import dataclass
from typing import Dict, List, Optional

import google.generativeai as genai
from anthropic import Anthropic
from openai import OpenAI
from together import Together

client_anthropic = Anthropic()
client_openai = OpenAI()
client_together = Together()
genai.configure(api_key=os.environ["GEMINI_API_KEY"])


@dataclass
Expand Down Expand Up @@ -130,3 +134,45 @@ def chat_together(
response.usage.prompt_tokens,
response.usage.completion_tokens,
)


def chat_gemini(
messages: List[Dict[str, str]],
model: str = "gemini-1.5-pro",
max_tokens: int = 8192,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
seed: int = 0,
) -> Optional[LLMResponse]:
t = time.time()
generation_config = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"response_mime_type": "application/json" if json_mode else "text/plain",
"stop_sequences": stop,
# "seed": seed, # seed is not supported in the current version
}
if messages[0]["role"] == "system":
system_msg = messages[0]["content"]
messages = messages[1:]
final_msg = messages[-1]["content"]
messages = messages[:-1]
for msg in messages:
if msg["role"] != "user":
msg["role"] = "model"
client_gemini = genai.GenerativeModel(model, generation_config=generation_config, system_instruction=system_msg)
chat = client_gemini.start_chat(
history=messages,
)
response = chat.send_message(final_msg)
if len(response.candidates) == 0:
print("Empty response")
return None
print(response.candidates[0].finish_reason)
return LLMResponse(
content=response.candidates[0].content.parts[0].text,
time=round(time.time() - t, 3),
input_tokens=response.usage_metadata.prompt_token_count,
output_tokens=response.usage_metadata.candidates_token_count,
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
anthropic==0.37.1
google-generativeai==0.8.3
numpy
openai==1.52.2
psycopg2-binary==2.9.9
Expand Down
51 changes: 34 additions & 17 deletions tests/test_utils_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..defog_utils.utils_llm import (
LLMResponse,
chat_anthropic,
chat_gemini,
chat_openai,
chat_together,
)
Expand Down Expand Up @@ -50,21 +51,17 @@
]

acceptable_sql = [
"SELECT COUNT(*) FROM orders",
"SELECT COUNT(order_id) FROM orders",
"select count(*) from orders",
"select count(order_id) from orders",
"select count(*) as total_orders from orders",
"select count(order_id) as total_orders from orders",
]

acceptable_sql_from_json = set(
[
"SELECT COUNT(order_id) as total_orders FROM orders;",
"SELECT COUNT(*) AS total_orders FROM orders;",
"SELECT COUNT(order_id) FROM orders;",
"SELECT COUNT(*) FROM orders;",
]
)
class TestChatClients(unittest.TestCase):

def check_sql(self, sql: str):
self.assertIn(sql.strip(";\n").lower(), acceptable_sql)

class TestChatClients(unittest.TestCase):
def test_chat_anthropic(self):
response = chat_anthropic(
messages,
Expand All @@ -75,15 +72,15 @@ def test_chat_anthropic(self):
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 90) # 90 input tokens
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10

def test_chat_openai(self):
response = chat_openai(messages, model="gpt-4o-mini", stop=[";"], seed=0)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 83)
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10

Expand All @@ -96,10 +93,18 @@ def test_chat_together(self):
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 108)
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10

def test_chat_gemini(self):
response = chat_gemini(messages, model="gemini-1.5-flash", stop=[";"], seed=0)
print(response)
self.assertIsInstance(response, LLMResponse)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 86)
self.assertTrue(response.output_tokens < 10)

def test_chat_json_anthropic(self):
response = chat_anthropic(
messages_json,
Expand All @@ -111,7 +116,7 @@ def test_chat_json_anthropic(self):
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)
Expand All @@ -123,7 +128,7 @@ def test_chat_json_openai(self):
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)
Expand All @@ -139,7 +144,19 @@ def test_chat_json_together(self):
self.assertIsInstance(response, LLMResponse)
raw_output = response.content
resp_dict = json.loads(raw_output)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

def test_chat_json_gemini(self):
response = chat_gemini(
messages_json, model="gemini-1.5-flash", seed=0, json_mode=True
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

0 comments on commit 959cdc9

Please sign in to comment.