diff --git a/docs/docs/examples/openai-bot.md b/docs/docs/examples/openai-bot.md index a7df7b7c..b981eff3 100644 --- a/docs/docs/examples/openai-bot.md +++ b/docs/docs/examples/openai-bot.md @@ -14,10 +14,16 @@ from typing import List # Load your OpenAI API key OpenAI.api_key = "" +# For using weaviate vector database +OpenAI.weaviate_host = "http://23.345.138.42:8080" +OpenAI.weaviate_data_class = "Documents" +# if using free OpenAI key set limit to avoid limit reach error +OpenAI.max_weaviate_res_length = 1000 # Prompt for GPT-3.5 Turbo SYSTEM_PROMPT = """You are chatting with an AI. There are no specific prefixes for responses, so you can ask or talk about anything you like. The AI will respond in a natural, conversational manner. Feel free to start the conversation with any question or topic, and let's have a pleasant chat! +Use the information provided by vector database here : {vector_database_response}. """ @bot() diff --git a/examples/openai-bot/main.py b/examples/openai-bot/main.py index ebc25837..fd62c7a1 100644 --- a/examples/openai-bot/main.py +++ b/examples/openai-bot/main.py @@ -5,10 +5,18 @@ # Load your OpenAI API key OpenAI.api_key = "" +# optional if you want to use vector database +# currently textbase support weaviate +OpenAI.weaviate_host = "http://12.345.138.42:8080" +OpenAI.weaviate_data_class = "Documents" +# if using free OpenAI key set limit to avoid limit reach error +OpenAI.max_weaviate_res_length = 1000 + + # Prompt for GPT-3.5 Turbo SYSTEM_PROMPT = """You are chatting with an AI. There are no specific prefixes for responses, so you can ask or talk about anything you like. The AI will respond in a natural, conversational manner. Feel free to start the conversation with any question or topic, and let's have a -pleasant chat! +pleasant chat!. Use the information provided by vector database here : {vector_database_response}. """ @bot() @@ -41,4 +49,4 @@ def on_message(message_history: List[Message], state: dict = None): return { "status_code": 200, "response": response - } \ No newline at end of file + } diff --git a/textbase/__init__.py b/textbase/__init__.py index a19f84ae..5e923b6a 100644 --- a/textbase/__init__.py +++ b/textbase/__init__.py @@ -1,2 +1,4 @@ from .bot import bot -from .message import Message \ No newline at end of file +from .message import Message +from .message import Content +from .vector_database import WeaviateClass \ No newline at end of file diff --git a/textbase/models.py b/textbase/models.py index 814ed533..e3184736 100644 --- a/textbase/models.py +++ b/textbase/models.py @@ -4,8 +4,7 @@ import time import typing import traceback - -from textbase import Message +from textbase import Message, WeaviateClass # Return list of values of content. def get_contents(message: Message, data_type: str): @@ -28,7 +27,10 @@ def extract_content_values(message: Message): class OpenAI: api_key = None - + weaviate_host = None + weaviate_auth_key = None + weaviate_data_class = None + max_weaviate_res_length = None @classmethod def generate( cls, @@ -40,7 +42,6 @@ def generate( ): assert cls.api_key is not None, "OpenAI API key is not set." openai.api_key = cls.api_key - filtered_messages = [] for message in message_history: @@ -49,6 +50,14 @@ def generate( if contents: filtered_messages.extend(contents) + weaviate_response = None + # if weaviate_host provided get response from weaviate + if cls.weaviate_host : + weaviate_response = WeaviateClass.search_in_weaviate(cls.api_key,cls.weaviate_host,cls.weaviate_auth_key,cls.weaviate_data_class,message_history[-1],cls.max_weaviate_res_length,"X-OpenAI-Api-Key") + # Todo: support for other vector database + + # append the vector databases result in system prompt for better answers + system_prompt= system_prompt.format(vector_database_response = weaviate_response) response = openai.ChatCompletion.create( model=model, messages=[ @@ -61,7 +70,6 @@ def generate( temperature=temperature, max_tokens=max_tokens, ) - return response["choices"][0]["message"]["content"] class HuggingFace: diff --git a/textbase/vector_database.py b/textbase/vector_database.py new file mode 100644 index 00000000..2b6c32c8 --- /dev/null +++ b/textbase/vector_database.py @@ -0,0 +1,50 @@ +import weaviate +from textbase import Message,Content +import json +from langchain.vectorstores import Weaviate +from langchain.embeddings.openai import OpenAIEmbeddings + +class WeaviateClass: + @classmethod + def search_in_weaviate( + cls, + api_key: str, + host: str, + auth_key:str, + weaviate_data_class, + message_query:Message, + max_weaviate_res_length: int, + model_header_key: str, + ): + weaviate_client = weaviate.Client( + url = host, + auth_client_secret=auth_key, + additional_headers = { + model_header_key: api_key, + } + ) + # take out user query + user_query = message_query['content'][0]['value'] + embeddings = OpenAIEmbeddings( + openai_api_key = api_key + ) + + weaviate_vectorstore = Weaviate( + weaviate_client, + weaviate_data_class, + "text", + embedding = embeddings + ) + weaviate_response = weaviate_vectorstore.similarity_search(user_query) + messages = [] + for response in weaviate_response: + messages.append(response.page_content) + response_string = json.dumps(messages) + + # if token limit exceed error come user can configure max_weaviate_res_length to be considered for output + if max_weaviate_res_length and len(response_string)>max_weaviate_res_length : + response_string = response_string[:max_weaviate_res_length] + return response_string + + + \ No newline at end of file