diff --git a/bot/on_message/bots/openai_bot.py b/bot/on_message/bots/openai_bot.py index e9c7003..39f1363 100644 --- a/bot/on_message/bots/openai_bot.py +++ b/bot/on_message/bots/openai_bot.py @@ -1,9 +1,10 @@ import contextlib from dataclasses import dataclass +import json import os from typing import Any, Optional import openai -# from bot.setup.bots import WeezerpediaAPI +from bot.on_message.bots.weezerpedia import WeezerpediaAPI from rich import print import random @@ -21,7 +22,7 @@ class PromptParams: class OpenAIBot: - def __init__(self, long_name, short_name, openai_sessions, weezerpedia_api): + def __init__(self, long_name: str, short_name: str, openai_sessions: list, weezerpedia_api: WeezerpediaAPI): self.long_name = long_name self.short_name = short_name self.openai_sessions = openai_sessions @@ -117,37 +118,50 @@ async def build_ai_response(self, message, system: str, adjective: str, num_mess reply = reply.replace("!", ".") return reply.strip() - def should_query_weezerpedia_api(self, last_three_messages): - decision_prompt = { - "role": "system", - "content": ( - f"The user has asked: '{last_three_messages}'. " - "If the question is asking for specific or detailed information that is not in your internal knowledge, " - "especially related to Weezerpedia, you **must** query the Weezerpedia API to provide accurate information. " - "Always prefer querying the API for detailed questions about Weezer. " - "If a query is needed, respond with 'API NEEDED:'. Otherwise, respond 'NO API NEEDED'." - ) - } - + def _get_response_or_weezerpedia_function_call_results(self, new_content: list[dict[str, str]], function_call: bool) -> str: try: - # Ask GPT to make the decision based on the new message - decision_response = openai.chat.completions.create( + completion = openai.chat.completions.create( temperature=0.7, max_tokens=100, model="gpt-4o", - messages=[decision_prompt], - ) + messages=new_content, + functions = [ + { + "name": "fetch_weezerpedia_data", + "description": "Queries Weezerpedia API for detailed information about Weezer-related topics. " \ + "Only call this if the most recent messages warrant it, and you have not already responded on a query.", + "parameters": { + "type": "object", + "properties": { + "query_term": { + "type": "string", + "description": "The specific Weezer-related topic to look up in Weezerpedia." + } + }, + "required": ["query_term"] + } + } + ], + function_call="auto" if function_call else "none" + ) + + response_text = completion.choices[0].message.content + choice = completion.choices[0].message + + if choice.function_call and function_call: + arguments = choice.function_call.arguments + + function_args = json.loads(arguments) + query_term = function_args.get("query_term") + + if query_term: + response_text = self.weezerpedia_api.get_search_result_knowledge(query_term, True)[0] - decision_text = decision_response.choices[0].message.content.strip( - ) - print(f"API decision: {decision_text}") - return decision_text except openai.APIError as e: - print(f"An error occurred during API decision: {e}") - return None + response_text = f"An error occurred: {e}" except Exception as e: - print(f"An error occurred: {e}") - return None + response_text = f"An error occurred: {e}" + return response_text def fetch_openai_completion(self, prompt_params: PromptParams, num_messages_lookback: int): system_message = {"role": "system", @@ -182,12 +196,6 @@ def fetch_openai_completion(self, prompt_params: PromptParams, num_messages_look # Replace the channel messages with the cleaned up content self.openai_sessions[prompt_params.channel_id] = new_content - # Add any context from Weezerpedia API if needed - # if weezerpedia_context := self.get_weezerpedia_context( - # prompt_params.user_prompt, messages_in_this_channel - # ): - # new_content.append(weezerpedia_context) - # Append the user's message to the session new_content.append( {"role": "user", "content": prompt_params.user_prompt}) @@ -200,56 +208,15 @@ def fetch_openai_completion(self, prompt_params: PromptParams, num_messages_look new_content = new_content[-num_messages_lookback:] new_content = [system_message] + new_content - try: - completion = openai.chat.completions.create( - temperature=1.0, - max_tokens=500, - model="gpt-4o", - messages=new_content, - ) - - response_text = completion.choices[0].message.content + response_text = self._get_response_or_weezerpedia_function_call_results(new_content, True) + function_call_content = [{"role": "user", "content": f"Incorporate the following Weezerpedia entry into your response, to the extent it is relevant. \n {response_text}"}] + response_text = self._get_response_or_weezerpedia_function_call_results(new_content + function_call_content, False) + new_content.append( + {"role": "assistant", "content": response_text} + ) - new_content.append( - {"role": "assistant", "content": response_text} - ) - except openai.APIError as e: - response_text = f"An error occurred: {e}" - except Exception as e: - response_text = f"An error occurred: {e}" return response_text - def get_weezerpedia_context(self, incoming_message_text, messages_in_this_channel) -> dict: - - # prepend the last 1 or 2 messages in this channel to the incoming message (if they exist) - if len(messages_in_this_channel) > 1: - last_message = messages_in_this_channel[-1]["content"] - incoming_message_text = f"{last_message}\n{incoming_message_text}" - if len(messages_in_this_channel) > 2: - penultimate_message = messages_in_this_channel[-2]["content"] - incoming_message_text = f"{penultimate_message}\n{incoming_message_text}" - - decision_text = self.should_query_weezerpedia_api(incoming_message_text - ) - - weezerpedia_context = None - if decision_text and decision_text.startswith("API NEEDED"): - - query_term = decision_text.split("API NEEDED:")[1].strip() - - # print(self.weezerpedia_api) - # print(self.weezerpedia_api.get_search_result_knowledge) - # print(self.weezerpedia_api.base_url) - - if wiki_content := self.weezerpedia_api.get_search_result_knowledge( - search_query=query_term - ): - weezerpedia_context = { - "role": "system", "content": f"API result for '{query_term}': {wiki_content}" - } - - return weezerpedia_context - def append_any_images(self, attachment_urls: list[str], content: list[dict[str, Any]]): for url in attachment_urls: if any(ext in url for ext in ['.jpg', '.jpeg', '.png', '.gif']): diff --git a/bot/on_message/bots/weezerpedia.py b/bot/on_message/bots/weezerpedia.py index 380d506..d8f408d 100644 --- a/bot/on_message/bots/weezerpedia.py +++ b/bot/on_message/bots/weezerpedia.py @@ -120,7 +120,7 @@ def preprocess_query(self, query): return ' '.join(key_terms) # Main method to get the knowledge to be used as context for GPT - def get_search_result_knowledge(self, search_query="Songs from the Black Hole"): + def get_search_result_knowledge(self, search_query, remove_urls): logging.info(f"Original query: {search_query}") # Step 1: Preprocess the query @@ -166,7 +166,7 @@ def get_search_result_knowledge(self, search_query="Songs from the Black Hole"): infobox = InfoboxGenerator(full_content, self) img_file = infobox.generate_infobox() - md_content = wiki_to_markdown(full_content) + md_content = wiki_to_markdown(full_content, remove_urls) if len(md_content) > 2000: md_content = md_content[:2000] diff --git a/bot/scripts/wiki_to_markdown.py b/bot/scripts/wiki_to_markdown.py index 4e3b04f..aaffa2b 100644 --- a/bot/scripts/wiki_to_markdown.py +++ b/bot/scripts/wiki_to_markdown.py @@ -2,7 +2,7 @@ import urllib.parse -def wiki_to_markdown(text, wiki_url_prefix='https://www.weezerpedia.com/wiki/'): +def wiki_to_markdown(text, remove_urls, wiki_url_prefix='https://www.weezerpedia.com/wiki/'): # Remove everything from "See also" or "References" downward text = re.split(r"==See also==", text)[0] @@ -50,6 +50,10 @@ def wiki_to_markdown(text, wiki_url_prefix='https://www.weezerpedia.com/wiki/'): # Trim any leading or trailing whitespaces text = text.strip() + if remove_urls: + url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' + text = re.sub(url_pattern, '', text) + return text diff --git a/bot/slash_commands/commands.py b/bot/slash_commands/commands.py index b2e4601..605d257 100644 --- a/bot/slash_commands/commands.py +++ b/bot/slash_commands/commands.py @@ -34,7 +34,7 @@ async def predicate(interaction: discord.Interaction) -> bool: async def weezerpedia(interaction: discord.Interaction, search_term: str): # Your function to query the Weezerpedia API await interaction.response.defer() - result, img_file = weezerpedia_api.get_search_result_knowledge(search_term) + result, img_file = weezerpedia_api.get_search_result_knowledge(search_term, False) if result is None: await interaction.followup.send("No results found.") else: diff --git a/tests/test_weezerpedia.py b/tests/test_weezerpedia.py index 361163c..306b8f4 100644 --- a/tests/test_weezerpedia.py +++ b/tests/test_weezerpedia.py @@ -11,7 +11,7 @@ def test(): api = WeezerpediaAPI() knowledge, img = api.get_search_result_knowledge( - "bokkus") + "bokkus", False) print(knowledge) diff --git a/tests/unit_tests/test_wiki_to_markdown.py b/tests/unit_tests/test_wiki_to_markdown.py index f30b90f..ac509c3 100644 --- a/tests/unit_tests/test_wiki_to_markdown.py +++ b/tests/unit_tests/test_wiki_to_markdown.py @@ -2,7 +2,7 @@ def test_internal_wiki_link_only_label(): - assert wiki_to_markdown('[[Ecce Homo]]') == '[Ecce Homo](https://www.weezerpedia.com/wiki/Ecce%20Homo)' + assert wiki_to_markdown('[[Ecce Homo]]', False) == '[Ecce Homo](https://www.weezerpedia.com/wiki/Ecce%20Homo)' def test_internal_wiki_link_with_label_and_path(): - assert wiki_to_markdown('[[Blue Album|the Blue Album]]') == '[the Blue Album](https://www.weezerpedia.com/wiki/Blue%20Album)' + assert wiki_to_markdown('[[Blue Album|the Blue Album]]', False) == '[the Blue Album](https://www.weezerpedia.com/wiki/Blue%20Album)'