Skip to content

Commit

Permalink
Merge code from app and multirag
Browse files Browse the repository at this point in the history
Signed-off-by: Sanket <[email protected]>
  • Loading branch information
sanketsudake committed Sep 1, 2024
1 parent 15502c6 commit 8ae3284
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 96 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__pycache__/**
insf_venv/**
*.pyc
.env/*
94 changes: 50 additions & 44 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import uuid
import datasets
import tempfile

from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
Expand All @@ -19,6 +21,8 @@
from urllib3.exceptions import ProtocolError
from langchain.retrievers import ContextualCompressionRetriever
from transformers import AutoTokenizer
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
from langchain_community.document_loaders import PyPDFLoader

from tools import get_tools
from tei_rerank import TEIRerank
Expand All @@ -29,10 +33,10 @@
import yaml
from yaml.loader import SafeLoader

from langchain.globals import set_verbose, set_debug
# from langchain.globals import set_verbose, set_debug

set_verbose(True)
set_debug(True)
# set_verbose(True)
# set_debug(True)

st.set_page_config(layout="wide", page_title="InSightful")

Expand Down Expand Up @@ -129,7 +133,8 @@ def setup_huggingface_embeddings():

@st.cache_resource
def load_prompt_and_system_ins(
template_file_path="templates/prompt_template.tmpl", template=None
template_file_path: str = "templates/prompt_template.tmpl",
template: str | None = None,
):
# prompt = hub.pull("hwchase17/react-chat")
prompt = PromptTemplate.from_file(template_file_path)
Expand All @@ -149,10 +154,11 @@ def load_prompt_and_system_ins(
return prompt, system_instructions


class RAG:
def __init__(self, collection_name, db_client):
self.collection_name = collection_name
class RAG(object):
def __init__(self, llm: ChatOpenAI, db_client, embedding_function):
self.llm = llm
self.db_client = db_client
self.embedding_function = embedding_function

@retry(
retry=retry_if_exception_type(ProtocolError),
Expand Down Expand Up @@ -182,14 +188,14 @@ def chunk_doc(self, pages, chunk_size=512, chunk_overlap=30):
print("Document chunked")
return chunks

def insert_embeddings(self, chunks, chroma_embedding_function, batch_size=32):
def insert_embeddings(self, chunks, collection_name, batch_size=32):
print(
"Inserting embeddings into collection: {collection_name}".format(
collection_name=self.collection_name
collection_name=collection_name
)
)
collection = self.db_client.get_or_create_collection(
self.collection_name, embedding_function=chroma_embedding_function
collection_name, embedding_function=self.embedding_function
)
for i in range(0, len(chunks), batch_size):
batch = chunks[i : i + batch_size]
Expand Down Expand Up @@ -219,44 +225,39 @@ def get_retriever(self, vector_store, use_reranker=False):
return retriever

def query_docs(
self, model, question, vector_store, prompt, chat_history, use_reranker=False
self, question, vector_store, prompt, chat_history, use_reranker=False
):
retriever = self.get_retriever(vector_store, use_reranker)
pass_question = lambda input: input["question"]
rag_chain = (
RunnablePassthrough.assign(context=pass_question | retriever | format_docs)
| prompt
| model
| self.llm
| StrOutputParser()
)

return rag_chain.stream({"question": question, "chat_history": chat_history})

def load_pdf(self, doc):
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(doc.name)[1]
) as tmp:
tmp.write(doc.getvalue())
tmp_path = tmp.name
loader = PyPDFLoader(tmp_path)
documents = loader.load()
cleaned_pages = []
for doc in documents:
doc.page_content = clean_extra_whitespace(doc.page_content)
doc.page_content = group_broken_paragraphs(doc.page_content)
cleaned_pages.append(doc)
return cleaned_pages


def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)


def create_retriever(
name, description, client, chroma_embedding_function, embedding_svc, reranker=False
):
collection_name = "software-slacks"
rag = RAG(collection_name=collection_name, db_client=client)
pages = rag.load_documents("spencer/software_slacks", num_docs=100)
chunks = rag.chunk_doc(pages)
rag.insert_embeddings(chunks, chroma_embedding_function)
vector_store = Chroma(
embedding_function=embedding_svc,
collection_name=collection_name,
client=client,
)
retriever = rag.get_retriever(vector_store, use_reranker=reranker)

retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 10}
)
return create_retriever_tool(retriever, name, description)

@st.cache_resource
def setup_agent(_model, _prompt, _tools):
agent = create_react_agent(
Expand All @@ -280,17 +281,25 @@ def main():
model = setup_chat_endpoint()
embedder = setup_huggingface_embeddings()
use_reranker = os.getenv("USE_RERANKER", "False") == "True"

retriever_tool = create_retriever(
"slack_conversations_retriever",
"Useful for when you need to answer from Slack conversations.",
client,
chroma_embedding_function,
embedder,
reranker=use_reranker,
rag = RAG(llm=model, db_client=client, embedding_function=chroma_embedding_function)
collection_name = "software-slacks"
pages = rag.load_documents("spencer/software_slacks", num_docs=100)
chunks = rag.chunk_doc(pages)
rag.insert_embeddings(chunks, collection_name)
vector_store = Chroma(
embedding_function=embedder,
collection_name=collection_name,
client=client,
)
retriever = rag.get_retriever(vector_store, use_reranker=use_reranker)
_tools = get_tools()
_tools.append(retriever_tool)
_tools.append(
create_retriever_tool(
retriever,
"slack_conversations_retriever",
"Useful for when you need to answer from Slack conversations.",
)
)

agent_executor = setup_agent(model, prompt, _tools)

Expand Down Expand Up @@ -328,7 +337,4 @@ def main():


if __name__ == "__main__":
# authenticator = authenticate()
# if st.session_state['authentication_status']:
# authenticator.logout()
main()
75 changes: 24 additions & 51 deletions multi_tenant_rag.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import logging
import tempfile
import os
import yaml
from yaml.loader import SafeLoader
import streamlit as st
import streamlit_authenticator as stauth
from streamlit_authenticator.utilities import RegisterError
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores.chroma import Chroma
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
from langchain_chroma import Chroma

from tools import get_tools

from app import (
Expand All @@ -21,17 +20,19 @@
setup_agent,
)


SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"


logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
log: logging.Logger = logging.getLogger(__name__)


def configure_authenticator():
auth_config = os.getenv("AUTH_CONFIG_FILE_PATH", default=".streamlit/config.yaml")
print(f"auth_config: {auth_config}")
log.info(f"auth_config: {auth_config}")
with open(file=auth_config) as file:
config = yaml.load(file, Loader=SafeLoader)

Expand Down Expand Up @@ -67,49 +68,32 @@ def authenticate(op):
return authenticator


class MultiTenantRAG(RAG):
def __init__(self, user_id, collection_name, db_client):
self.user_id = user_id
super().__init__(collection_name, db_client)

def load_documents(self, doc):
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(doc.name)[1]
) as tmp:
tmp.write(doc.getvalue())
tmp_path = tmp.name
loader = PyPDFLoader(tmp_path)
documents = loader.load()
cleaned_pages = []
for doc in documents:
doc.page_content = clean_extra_whitespace(doc.page_content)
doc.page_content = group_broken_paragraphs(doc.page_content)
cleaned_pages.append(doc)
return cleaned_pages


def main():
authenticator = authenticate("login")
if st.session_state["authentication_status"]:
st.sidebar.text(f"Welcome {st.session_state['username']}")
authenticator.logout(location="sidebar")
user_id = st.session_state["username"]
if not user_id:
st.error("Please login to continue")
return

use_reranker = st.sidebar.toggle("Use reranker", False)
use_tools = st.sidebar.toggle("Use tools", False)
uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"])
question = st.chat_input("Chat with your docs or apis")

llm = setup_chat_endpoint()

embedding_svc = setup_huggingface_embeddings()

chroma_embeddings = hf_embedding_server()

user_id = st.session_state["username"]

client = setup_chroma_client()

# Set up prompt template
template = """
Based on the retrieved context, respond with an accurate answer.
Be concise and always provide accurate, specific, and relevant information.
"""

template_file_path = "templates/multi_tenant_rag_prompt_template.tmpl"
if use_tools:
template_file_path = "templates/multi_tenant_rag_prompt_template_tools.tmpl"
Expand All @@ -118,6 +102,7 @@ def main():
template_file_path=template_file_path,
template=template,
)
log.info(f"prompt: {prompt} system_instructions: {system_instructions}")

chat_history = st.session_state.get(
"chat_history", [{"role": SYSTEM, "content": system_instructions.content}]
Expand All @@ -127,38 +112,31 @@ def main():
with st.chat_message(message["role"]):
st.markdown(message["content"])

if not user_id:
st.error("Please login to continue")
return

collection = client.get_or_create_collection(
f"user-collection-{user_id}", embedding_function=chroma_embeddings
)

logger = logging.getLogger(__name__)
logger.info(
log.info(
f"user_id: {user_id} use_reranker: {use_reranker} use_tools: {use_tools} question: {question}"
)
rag = MultiTenantRAG(user_id, collection.name, client)
rag = RAG(llm=llm, db_client=client, embedding_function=chroma_embeddings)

if use_tools:
tools = get_tools()
agent_executor = setup_agent(llm, prompt, tools)

# prompt = hub.pull("rlm/rag-prompt")

vectorstore = Chroma(
embedding_function=embedding_svc,
collection_name=collection.name,
client=client,
)

if uploaded_file:
document = rag.load_documents(uploaded_file)
document = rag.load_pdf(uploaded_file)
chunks = rag.chunk_doc(document)
rag.insert_embeddings(
chunks=chunks,
chroma_embedding_function=chroma_embeddings,
collection_name=collection.name,
batch_size=32,
)

Expand All @@ -174,10 +152,9 @@ def main():
)["output"]
with st.chat_message(ASSISTANT):
st.write(answer)
logger.info(f"answer: {answer}")
log.info(f"answer: {answer}")
else:
answer = rag.query_docs(
model=llm,
question=question,
vector_store=vectorstore,
prompt=prompt,
Expand All @@ -186,16 +163,12 @@ def main():
)
with st.chat_message(ASSISTANT):
answer = st.write_stream(answer)
logger.info(f"answer: {answer}")
log.info(f"answer: {answer}")

chat_history.append({"role": USER, "content": question})
chat_history.append({"role": ASSISTANT, "content": answer})
st.session_state["chat_history"] = chat_history


if __name__ == "__main__":
authenticator = authenticate("login")
if st.session_state["authentication_status"]:
st.sidebar.text(f"Welcome {st.session_state['username']}")
authenticator.logout(location="sidebar")
main()
main()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
chromadb==0.5.3
datasets==2.20.0
langchain==0.2.12
langchain_chroma==0.1.2
langchain_chroma==0.1.3
langchain_community==0.2.11
langchain_core==0.2.28
langchain_huggingface==0.0.3
Expand Down

0 comments on commit 8ae3284

Please sign in to comment.