diff --git a/.env.example b/.env.example index 1191bae6..2ef89567 100644 --- a/.env.example +++ b/.env.example @@ -76,9 +76,9 @@ BOT_API_PORT=6002 ################# Miscellaneous ################## -# Google Gemini -GEMINI_KEY= -GEMINI_MODEL=gemini-1.5-flash +# Groq +GROQ_KEY= +GROQ_MODEL=llama-3.3-70b-versatile # Bot list tokens TOPGG_TOKEN= diff --git a/classes/bot.py b/classes/bot.py index 76d2c262..14e5b4ed 100644 --- a/classes/bot.py +++ b/classes/bot.py @@ -8,13 +8,13 @@ import aiohttp import aioredis import asyncpg -import google.generativeai as genai import orjson from discord.ext import commands from discord.ext.commands.core import _CaseInsensitiveDict from discord.gateway import DiscordClientWebSocketResponse, DiscordWebSocket from discord.utils import parse_time +from groq import AsyncGroq from classes.http import HTTPClient from classes.misc import Session, Status @@ -215,6 +215,13 @@ async def on_http_request_end(self, _session, trace_config_ctx, params): } ) + async def ai_generate(self, text): + completion = await self.ai.chat.completions.create( + messages=[{"role": "user", "content": text}], + model=self.config.GROQ_MODEL, + ) + return completion.choices[0].message.content + async def start(self, worker=True): trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(self.on_http_request_start) @@ -257,9 +264,8 @@ async def start(self, worker=True): self.prom = Prometheus(self) await self.prom.start() - if self.config.GEMINI_KEY is not None: - genai.configure(api_key=self.config.GEMINI_KEY) - self.ai = genai.GenerativeModel(self.config.GEMINI_MODEL) + if self.config.GROQ_KEY is not None: + self.ai = AsyncGroq(api_key=self.config.GROQ_KEY) self._connection = State( id=self.id, diff --git a/cogs/core.py b/cogs/core.py index e16126e6..5832db83 100644 --- a/cogs/core.py +++ b/cogs/core.py @@ -58,6 +58,7 @@ async def aireply(self, ctx, *, instructions: str = None): data = await tools.get_data(self.bot, ctx.guild.id) history = await self.generate_history(ctx.channel) + truncated_history = "\n".join(history.splitlines()[-100:]) prompt = ( "You are a Discord moderator for a server. The following is the entire history of " "the conversation between staff and the user. Please fill in the suitable response " @@ -65,11 +66,11 @@ async def aireply(self, ctx, *, instructions: str = None): "as 'My response would be...'. Try to appear as supportive as possible.\nHere are " f"additional information you should consider (if any): {data[13]}\nHere are additional " f"instructions for your response (if any): {instructions}\n\nFull transcript: " - f"{history}.\n\nStaff response: " + f"{truncated_history}.\n\nStaff response: " ) try: - response = await self.bot.ai.generate_content_async(prompt) + response = await self.bot.ai_generate(prompt) except Exception: await ctx.send(ErrorEmbed("Failed to generate a response.")) return @@ -235,12 +236,12 @@ async def close_channel(self, ctx, reason, anon: bool = False): if self.bot.ai is not None and data[7] == 1: try: - summary = await self.bot.ai.generate_content_async( + truncated_history = "\n".join(history.splitlines()[-100:]) + summary = await self.bot.ai_generate( "The following is the entire history of the conversation between staff and " - "the user. Please summarise the entire interaction into 1 or 2 sentences, " - "with at most 20 words. Only give 1 response option. Do not output " - "additional text such as 'My response would be...'.\n\nFull transcript:\n" - + history + "the user. Please summarise the entire interaction into 1 or 2 sentences. " + "Only give 1 response option. Do not output additional text such as 'Here " + "is the summary...'.\n\nFull transcript:\n" + truncated_history ) embed.add_field("AI Summary", summary.text) except Exception: diff --git a/docker/.env.example b/docker/.env.example index cfad01a7..68f62255 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -26,6 +26,6 @@ BASE_URI=http://localhost:8000 ################# Miscellaneous ################## -# Google Gemini -GEMINI_KEY= -GEMINI_MODEL=gemini-1.5-flash +# Groq +GROQ_KEY= +GROQ_MODEL=llama-3.3-70b-versatile diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 9486cc9d..9c9e8f13 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -83,8 +83,8 @@ services: - BASE_URI=${BASE_URI} - BOT_API_HOST=0.0.0.0 - BOT_API_PORT=6002 - - GEMINI_KEY=${GEMINI_KEY} - - GEMINI_MODEL=${GEMINI_MODEL} + - GROQ_KEY=${GROQ_KEY} + - GROQ_MODEL=${GROQ_MODEL} - TOPGG_TOKEN= - DBOTS_TOKEN= - DBL_TOKEN= diff --git a/requirements.txt b/requirements.txt index 584bf2c0..9af64e7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ aioprometheus==23.3.0 aioredis==1.3.1 asyncpg==0.29.0 dateparser==1.2.0 -google-generativeai==0.8.3 +groq==0.13.0 orjson==3.9.13 psutil==5.9.8 python-dotenv==1.0.1