Skip to content

Commit

Permalink
TEI Reranker adapter and app.py implementation
Browse files Browse the repository at this point in the history
Cleaned up code

Update README.md

Update README.md

Revert "Reranker fork"

Update README.md

Refined prompt and instruction

TEI Reranker adapter and app.py implementation

Cleaned up code

chore: Update system message template in app.py
  • Loading branch information
AIWithShrey authored and sanketsudake committed Jul 30, 2024
1 parent fc0a2c8 commit c79b590
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 53 deletions.
12 changes: 4 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
# InSightful
# InSightful-rerank

The AI assistant for tech communities.
Enhancing the original InSightful with a reranker.

## Features

- [✔️] **Conversation Analysis**: InSightful can analyze and provide insights on the topics being discussed in a tech community.
- [✔️] **Community Health Analysis**: InSightful can analyze the engagement, sentiment, and more of a tech community.
- [✔️] **Search Stack Overflow**: InSightful can search Stack Overflow for relevant questions and answers.
- [✔️] **Browse The Web**: InSightful can browse the web for relevant information on community topics.
Everything InSightful can do but better. By utilizing and exploiting the methods of Advanced RAG using a reranker, we significantly improve the quality of retrieved context from the vector store.

## Overview of workflow

![RAG-FC](https://github.com/user-attachments/assets/456b8dfa-58c9-4894-b720-f662cffded2f)

![RAG-FC-Rerank](https://github.com/user-attachments/assets/f56de040-05e8-4307-be70-16929a72bafb)

## Prerequisites

Expand Down
145 changes: 100 additions & 45 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
from langchain_core.messages import SystemMessage
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_chroma import Chroma
import chromadb
from chromadb.config import Settings
from chromadb.utils.embedding_functions import HuggingFaceEmbeddingServer

from langchain.schema import Document
from langchain.retrievers import ContextualCompressionRetriever
from tei_rerank import TEIRerank


st.set_page_config(layout="wide", page_title="InSightful")

# Set up Chroma DB client
Expand All @@ -28,7 +34,9 @@ def setup_chroma_client():
host=os.getenv("VECTORDB_HOST", "localhost"),
port=os.getenv("VECTORDB_PORT", "8000"),
),
settings=Settings(allow_reset=True),
settings=Settings(allow_reset=True,
anonymized_telemetry=False)

)
return client

Expand Down Expand Up @@ -67,42 +75,58 @@ def setup_huggingface_embeddings():
return embedder

def load_prompt_and_system_ins():
prompt = hub.pull("hwchase17/react-chat")
# Set up prompt template
template = """
You are InSightful, a virtual assistant designed to help users with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model,
you are able to generate human-like text based on the input you receive, allowing you to engage in
natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
#prompt = hub.pull("hwchase17/react-chat")
prompt = PromptTemplate.from_template("""
InSightful is a bot developed by InfraCloud Technologies.
Always provide accurate and informative responses to a wide range of questions.
InSightful is used to assist technical communities online on platforms such as Slack, Reddit and Discord.
You can assess the health of a conversation from the engagement and understand the sentiment of the conversations on Slack.
InSightful can answer questions from conversations amongst community members and can also search StackOverflow for technical questions.
You can assess if people are generally interested or disinterested in the conversation, and you can also determine if the conversation is positive or negative.
InSightful can also conduct its own research on the web to find answers to questions.
You do not answer questions about personal information, such as social security numbers,
credit card numbers, or other sensitive information. You also do not provide medical, legal, or financial advice.
InSightful is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. InSightful is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
InSightful is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, InSightful is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics.
Overall, InSightful is a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, InSightful is here to assist.
TOOLS:
------
You will not respond to any questions that are inappropriate or offensive. You are friendly, helpful,
and you are here to assist users with any questions they may have.
InSightful has access to the following tools:
Keep your answers clear and concise, and provide as much information as possible to help users understand the topic.
{tools}
Use your best judgement and only use any tool if you absolutely need to.
To use a tool, please use the following format:
Tools provide you with more context and up-to-date information. Use them to your advantage.
```
Thought: Do I need to use a tool? Yes
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
```
For the tavily_search_results_json tool, make sure the Action Input is a string derived from the new input.
When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
Use the tools for any current information. Make sure to use the tools to verify your answers as well.
```
Thought: Do I need to use a tool? No
Final Answer: [your response here]
```
Thought: {thought}
Action: {action}
Action Input: {action_input}
Observation: {observation}
Begin!
If you are ready with an answer use the format:
Thought: Do I have to use a tool? No
Final Answer: {observation}
Previous conversation history:
{chat_history}
New input: {input}
{agent_scratchpad}
""")

# Set up prompt template
template = """
Based on the retrieved context, respond with an accurate answer. Use the provided tools to support your response.
"""

system_instructions = SystemMessage(
Expand All @@ -119,9 +143,9 @@ def __init__(self, llm, embeddings, collection_name, db_client):
self.collection_name = collection_name
self.db_client = db_client

def load_documents(self, doc):
def load_documents(self, doc, num_docs=250):
documents = []
for data in datasets.load_dataset(doc, split="train[:500]").to_list():
for data in datasets.load_dataset(doc, split=f"train[:{num_docs}]").to_list():
documents.append(
Document(
page_content=data["text"],
Expand All @@ -140,16 +164,21 @@ def chunk_doc(self, pages, chunk_size=512, chunk_overlap=30):
print("Document chunked")
return chunks

def insert_embeddings(self, chunks, chroma_embedding_function, embedder):
def insert_embeddings(self, chunks, chroma_embedding_function, embedder, batch_size=32):
collection = self.db_client.get_or_create_collection(
self.collection_name, embedding_function=chroma_embedding_function
)
for chunk in chunks:
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
chunk_ids = [str(uuid.uuid1()) for _ in batch]
metadatas = [chunk.metadata for chunk in batch]
documents = [chunk.page_content for chunk in batch]

collection.add(
ids=[str(uuid.uuid1())],
metadatas=chunk.metadata,
documents=chunk.page_content,
)
ids=chunk_ids,
metadatas=metadatas,
documents=documents
)
db = Chroma(
embedding_function=embedder,
collection_name=self.collection_name,
Expand All @@ -174,35 +203,61 @@ def query_docs(self, model, question, vector_store, prompt):
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

def create_retriever(name, model, description, client, chroma_embedding_function, embedder):
#def create_retriever(name, model, description, client, chroma_embedding_function, embedder):
# rag = RAG(llm=model, embeddings=embedder, collection_name="Slack", db_client=client)
# pages = rag.load_documents("spencer/software_slacks")
# chunks = rag.chunk_doc(pages)
# vector_store = rag.insert_embeddings(chunks, chroma_embedding_function, embedder)
# retriever = vector_store.as_retriever(
# search_type="similarity", search_kwargs={"k": 10}
# )
# info_retriever = create_retriever_tool(retriever, name, description)
# return info_retriever

def create_reranker_retriever(name, model, description, client, chroma_embedding_function, embedder):
rag = RAG(llm=model, embeddings=embedder, collection_name="Slack", db_client=client)
pages = rag.load_documents("spencer/software_slacks")
pages = rag.load_documents("spencer/software_slacks", num_docs=100)
chunks = rag.chunk_doc(pages)
vector_store = rag.insert_embeddings(chunks, chroma_embedding_function, embedder)
compressor = TEIRerank(url="http://{host}:{port}".format(host=os.getenv("RERANKER_HOST", "localhost"),
port=os.getenv("RERANKER_PORT", "8082")),
top_n=10,
batch_size=16)
retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 10}
search_type="similarity", search_kwargs={"k": 100}
)
info_retriever = create_retriever_tool(retriever, name, description)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
info_retriever = create_retriever_tool(compression_retriever, name, description)
return info_retriever

@st.cache_resource
def setup_tools(_model, _client, _chroma_embedding_function, _embedder):
stackexchange_wrapper = StackExchangeAPIWrapper(max_results=3)
stackexchange_tool = StackExchangeTool(api_wrapper=stackexchange_wrapper)

web_search_tool = TavilySearchResults(max_results=5,
search_depth = "advanced",
include_answer=True)

retriever = create_retriever(
name="Slack conversations retriever",
web_search_tool = TavilySearchResults(max_results=10,
handle_tool_error=True)

#retriever = create_retriever(
# name="Slack conversations retriever",
# model=_model,
# description="Retrieves conversations from Slack for context.",
# client=_client,
# chroma_embedding_function=_chroma_embedding_function,
# embedder=_embedder,
#)
reranker_retriever = create_reranker_retriever(
name="slack_conversations_retriever",
model=_model,
description="Retrieves conversations from Slack for context.",
description="Useful for when you need to answer from Slack conversations.",
client=_client,
chroma_embedding_function=_chroma_embedding_function,
embedder=_embedder,
)
return [web_search_tool, stackexchange_tool, retriever]

return [web_search_tool, stackexchange_tool, reranker_retriever]

def setup_agent(model, prompt, client, chroma_embedding_function, embedder):
tools = setup_tools(model, client, chroma_embedding_function, embedder)
Expand Down
68 changes: 68 additions & 0 deletions tei_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Dict, Optional, Sequence, List
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.pydantic_v1 import Extra
import requests

DEFAULT_TOP_N = 3
DEFAULT_BATCH_SIZE = 32

class TEIRerank(BaseDocumentCompressor):
"""Document compressor using a custom rerank service."""

url: str
"""URL of the custom rerank service."""
top_n: int = DEFAULT_TOP_N
"""Number of documents to return."""
batch_size: int = DEFAULT_BATCH_SIZE
"""Batch size to use for reranking."""

class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid

def rerank(self, query: str, texts: List[str]) -> List[Dict]:
url = f"{self.url}/rerank"
print(f"URL: {url}")
request_body = {"query": query, "texts": texts, "truncate": True, "batch_size": self.batch_size}
print(f"Request Body: {request_body}")
response = requests.post(url, json=request_body)
print(f"Response Status Code: {response.status_code}")
if response.status_code != 200:
print(f"Response Content: {response.content}")
raise RuntimeError(f"Failed to rerank documents, detail: {response}")
print(f"Response JSON: {response.json()}")
return response.json()

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
print("compress_documents called")
if not documents:
print("No documents to compress")
return []

texts = [doc.page_content for doc in documents]
batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)]
all_results = []

for batch in batches:
results = self.rerank(query=query, texts=batch)
all_results.extend(results)

# Sort results based on scores and select top_n
all_results = sorted(all_results, key=lambda x: x["score"], reverse=True)[:self.top_n]

final_results = []
for result in all_results:
index = int(result["index"])
metadata = documents[index].metadata.copy()
metadata["relevance_score"] = result["score"]
final_results.append(
Document(page_content=documents[index].page_content, metadata=metadata)
)

return final_results

0 comments on commit c79b590

Please sign in to comment.