diff --git a/README.md b/README.md index 27d541b..f43a782 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,14 @@ -# InSightful +# InSightful-rerank -The AI assistant for tech communities. +Enhancing the original InSightful with a reranker. ## Features -- [✔️] **Conversation Analysis**: InSightful can analyze and provide insights on the topics being discussed in a tech community. -- [✔️] **Community Health Analysis**: InSightful can analyze the engagement, sentiment, and more of a tech community. -- [✔️] **Search Stack Overflow**: InSightful can search Stack Overflow for relevant questions and answers. -- [✔️] **Browse The Web**: InSightful can browse the web for relevant information on community topics. +Everything InSightful can do but better. By utilizing and exploiting the methods of Advanced RAG using a reranker, we significantly improve the quality of retrieved context from the vector store. ## Overview of workflow -![RAG-FC](https://github.com/user-attachments/assets/456b8dfa-58c9-4894-b720-f662cffded2f) - +![RAG-FC-Rerank](https://github.com/user-attachments/assets/f56de040-05e8-4307-be70-16929a72bafb) ## Prerequisites diff --git a/app.py b/app.py index d674d8f..ad801c1 100644 --- a/app.py +++ b/app.py @@ -20,6 +20,11 @@ from chromadb.config import Settings from chromadb.utils.embedding_functions import HuggingFaceEmbeddingServer +from langchain.schema import Document +from langchain.retrievers import ContextualCompressionRetriever +from tei_rerank import TEIRerank + + st.set_page_config(layout="wide", page_title="InSightful") # Set up Chroma DB client @@ -29,7 +34,9 @@ def setup_chroma_client(): host=os.getenv("VECTORDB_HOST", "localhost"), port=os.getenv("VECTORDB_PORT", "8000"), ), - settings=Settings(allow_reset=True), + settings=Settings(allow_reset=True, + anonymized_telemetry=False) + ) return client @@ -157,16 +164,21 @@ def chunk_doc(self, pages, chunk_size=512, chunk_overlap=30): print("Document chunked") return chunks - def insert_embeddings(self, chunks, chroma_embedding_function, embedder): + def insert_embeddings(self, chunks, chroma_embedding_function, embedder, batch_size=32): collection = self.db_client.get_or_create_collection( self.collection_name, embedding_function=chroma_embedding_function ) - for chunk in chunks: + for i in range(0, len(chunks), batch_size): + batch = chunks[i:i + batch_size] + chunk_ids = [str(uuid.uuid1()) for _ in batch] + metadatas = [chunk.metadata for chunk in batch] + documents = [chunk.page_content for chunk in batch] + collection.add( - ids=[str(uuid.uuid1())], - metadatas=chunk.metadata, - documents=chunk.page_content, - ) + ids=chunk_ids, + metadatas=metadatas, + documents=documents + ) db = Chroma( embedding_function=embedder, collection_name=self.collection_name, @@ -191,15 +203,33 @@ def query_docs(self, model, question, vector_store, prompt): def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) -def create_retriever(name, model, description, client, chroma_embedding_function, embedder): +#def create_retriever(name, model, description, client, chroma_embedding_function, embedder): +# rag = RAG(llm=model, embeddings=embedder, collection_name="Slack", db_client=client) +# pages = rag.load_documents("spencer/software_slacks") +# chunks = rag.chunk_doc(pages) +# vector_store = rag.insert_embeddings(chunks, chroma_embedding_function, embedder) +# retriever = vector_store.as_retriever( +# search_type="similarity", search_kwargs={"k": 10} +# ) +# info_retriever = create_retriever_tool(retriever, name, description) +# return info_retriever + +def create_reranker_retriever(name, model, description, client, chroma_embedding_function, embedder): rag = RAG(llm=model, embeddings=embedder, collection_name="Slack", db_client=client) pages = rag.load_documents("spencer/software_slacks", num_docs=100) chunks = rag.chunk_doc(pages) vector_store = rag.insert_embeddings(chunks, chroma_embedding_function, embedder) + compressor = TEIRerank(url="http://{host}:{port}".format(host=os.getenv("RERANKER_HOST", "localhost"), + port=os.getenv("RERANKER_PORT", "8082")), + top_n=10, + batch_size=16) retriever = vector_store.as_retriever( - search_type="similarity", search_kwargs={"k": 10} + search_type="similarity", search_kwargs={"k": 100} + ) + compression_retriever = ContextualCompressionRetriever( + base_compressor=compressor, base_retriever=retriever ) - info_retriever = create_retriever_tool(retriever, name, description) + info_retriever = create_retriever_tool(compression_retriever, name, description) return info_retriever @st.cache_resource @@ -210,15 +240,24 @@ def setup_tools(_model, _client, _chroma_embedding_function, _embedder): web_search_tool = TavilySearchResults(max_results=10, handle_tool_error=True) - retriever = create_retriever( - name="slack_retriever", + #retriever = create_retriever( + # name="Slack conversations retriever", + # model=_model, + # description="Retrieves conversations from Slack for context.", + # client=_client, + # chroma_embedding_function=_chroma_embedding_function, + # embedder=_embedder, + #) + reranker_retriever = create_reranker_retriever( + name="slack_conversations_retriever", model=_model, - description="Retrieves conversations from Slack for context.", + description="Useful for when you need to answer from Slack conversations.", client=_client, chroma_embedding_function=_chroma_embedding_function, embedder=_embedder, ) - return [web_search_tool, stackexchange_tool, retriever] + + return [web_search_tool, stackexchange_tool, reranker_retriever] def setup_agent(model, prompt, client, chroma_embedding_function, embedder): tools = setup_tools(model, client, chroma_embedding_function, embedder) diff --git a/tei_rerank.py b/tei_rerank.py new file mode 100644 index 0000000..c65debf --- /dev/null +++ b/tei_rerank.py @@ -0,0 +1,68 @@ +from typing import Dict, Optional, Sequence, List +from langchain_core.callbacks.manager import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.pydantic_v1 import Extra +import requests + +DEFAULT_TOP_N = 3 +DEFAULT_BATCH_SIZE = 32 + +class TEIRerank(BaseDocumentCompressor): + """Document compressor using a custom rerank service.""" + + url: str + """URL of the custom rerank service.""" + top_n: int = DEFAULT_TOP_N + """Number of documents to return.""" + batch_size: int = DEFAULT_BATCH_SIZE + """Batch size to use for reranking.""" + + class Config: + """Configuration for this pydantic object.""" + extra = Extra.forbid + + def rerank(self, query: str, texts: List[str]) -> List[Dict]: + url = f"{self.url}/rerank" + print(f"URL: {url}") + request_body = {"query": query, "texts": texts, "truncate": True, "batch_size": self.batch_size} + print(f"Request Body: {request_body}") + response = requests.post(url, json=request_body) + print(f"Response Status Code: {response.status_code}") + if response.status_code != 200: + print(f"Response Content: {response.content}") + raise RuntimeError(f"Failed to rerank documents, detail: {response}") + print(f"Response JSON: {response.json()}") + return response.json() + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + print("compress_documents called") + if not documents: + print("No documents to compress") + return [] + + texts = [doc.page_content for doc in documents] + batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)] + all_results = [] + + for batch in batches: + results = self.rerank(query=query, texts=batch) + all_results.extend(results) + + # Sort results based on scores and select top_n + all_results = sorted(all_results, key=lambda x: x["score"], reverse=True)[:self.top_n] + + final_results = [] + for result in all_results: + index = int(result["index"]) + metadata = documents[index].metadata.copy() + metadata["relevance_score"] = result["score"] + final_results.append( + Document(page_content=documents[index].page_content, metadata=metadata) + ) + + return final_results \ No newline at end of file