Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: weaviate vector database feature added for openai #106

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/docs/examples/openai-bot.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@ from typing import List
# Load your OpenAI API key
OpenAI.api_key = ""

# For using weaviate vector database
OpenAI.weaviate_host = "http://23.345.138.42:8080"
OpenAI.weaviate_data_class = "Documents"
# if using free OpenAI key set limit to avoid limit reach error
OpenAI.max_weaviate_res_length = 1000
# Prompt for GPT-3.5 Turbo
SYSTEM_PROMPT = """You are chatting with an AI. There are no specific prefixes for responses, so you can ask or talk about anything you like.
The AI will respond in a natural, conversational manner. Feel free to start the conversation with any question or topic, and let's have a
pleasant chat!
Use the information provided by vector database here : {vector_database_response}.
"""

@bot()
Expand Down
12 changes: 10 additions & 2 deletions examples/openai-bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@
# Load your OpenAI API key
OpenAI.api_key = ""

# optional if you want to use vector database
# currently textbase support weaviate
OpenAI.weaviate_host = "http://12.345.138.42:8080"
OpenAI.weaviate_data_class = "Documents"
# if using free OpenAI key set limit to avoid limit reach error
OpenAI.max_weaviate_res_length = 1000


# Prompt for GPT-3.5 Turbo
SYSTEM_PROMPT = """You are chatting with an AI. There are no specific prefixes for responses, so you can ask or talk about anything you like.
The AI will respond in a natural, conversational manner. Feel free to start the conversation with any question or topic, and let's have a
pleasant chat!
pleasant chat!. Use the information provided by vector database here : {vector_database_response}.
"""

@bot()
Expand Down Expand Up @@ -41,4 +49,4 @@ def on_message(message_history: List[Message], state: dict = None):
return {
"status_code": 200,
"response": response
}
}
4 changes: 3 additions & 1 deletion textbase/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .bot import bot
from .message import Message
from .message import Message
from .message import Content
from .vector_database import WeaviateClass
18 changes: 13 additions & 5 deletions textbase/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import time
import typing
import traceback

from textbase import Message
from textbase import Message, WeaviateClass

# Return list of values of content.
def get_contents(message: Message, data_type: str):
Expand All @@ -28,7 +27,10 @@ def extract_content_values(message: Message):

class OpenAI:
api_key = None

weaviate_host = None
weaviate_auth_key = None
weaviate_data_class = None
max_weaviate_res_length = None
@classmethod
def generate(
cls,
Expand All @@ -40,7 +42,6 @@ def generate(
):
assert cls.api_key is not None, "OpenAI API key is not set."
openai.api_key = cls.api_key

filtered_messages = []

for message in message_history:
Expand All @@ -49,6 +50,14 @@ def generate(
if contents:
filtered_messages.extend(contents)

weaviate_response = None
# if weaviate_host provided get response from weaviate
if cls.weaviate_host :
weaviate_response = WeaviateClass.search_in_weaviate(cls.api_key,cls.weaviate_host,cls.weaviate_auth_key,cls.weaviate_data_class,message_history[-1],cls.max_weaviate_res_length,"X-OpenAI-Api-Key")
# Todo: support for other vector database

# append the vector databases result in system prompt for better answers
system_prompt= system_prompt.format(vector_database_response = weaviate_response)
response = openai.ChatCompletion.create(
model=model,
messages=[
Expand All @@ -61,7 +70,6 @@ def generate(
temperature=temperature,
max_tokens=max_tokens,
)

return response["choices"][0]["message"]["content"]

class HuggingFace:
Expand Down
50 changes: 50 additions & 0 deletions textbase/vector_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import weaviate
from textbase import Message,Content
import json
from langchain.vectorstores import Weaviate
from langchain.embeddings.openai import OpenAIEmbeddings

class WeaviateClass:
@classmethod
def search_in_weaviate(
cls,
api_key: str,
host: str,
auth_key:str,
weaviate_data_class,
message_query:Message,
max_weaviate_res_length: int,
model_header_key: str,
):
weaviate_client = weaviate.Client(
url = host,
auth_client_secret=auth_key,
additional_headers = {
model_header_key: api_key,
}
)
# take out user query
user_query = message_query['content'][0]['value']
embeddings = OpenAIEmbeddings(
openai_api_key = api_key
)

weaviate_vectorstore = Weaviate(
weaviate_client,
weaviate_data_class,
"text",
embedding = embeddings
)
weaviate_response = weaviate_vectorstore.similarity_search(user_query)
messages = []
for response in weaviate_response:
messages.append(response.page_content)
response_string = json.dumps(messages)

# if token limit exceed error come user can configure max_weaviate_res_length to be considered for output
if max_weaviate_res_length and len(response_string)>max_weaviate_res_length :
response_string = response_string[:max_weaviate_res_length]
return response_string