Skip to content

Commit

Permalink
Finalize qdrant retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
homanp committed Jan 16, 2024
1 parent a3999ed commit 51a911f
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 54 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
API_BASE_URL=https://rag.superagent.sh
COHERE_API_KEY=
COHERE_API_KEY=
HUGGINGFACE_API_KEY=
12 changes: 9 additions & 3 deletions api/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

@router.post("/ingest")
async def ingest(payload: RequestPayload) -> Dict:
embeddings = EmbeddingService(files=payload.files, index_name=payload.index_name)
documents = await embeddings.generate_documents()
return {"success": True, "data": documents}
embedding_service = EmbeddingService(
files=payload.files,
index_name=payload.index_name,
vector_credentials=payload.vector_database,
)
documents = await embedding_service.generate_documents()
chunks = await embedding_service.generate_chunks(documents=documents)
await embedding_service.generate_embeddings(nodes=chunks)
return {"success": True}
17 changes: 17 additions & 0 deletions api/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Dict
from fastapi import APIRouter
from models.query import RequestPayload, ResponsePayload
from service.vector_database import get_vector_service, VectorService

router = APIRouter()


@router.post("/query", response_model=ResponsePayload)
async def query(payload: RequestPayload):
vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
)
chunks = await vector_service.query(input=payload.input, top_k=4)
documents = await vector_service.convert_to_dict(points=chunks)
results = await vector_service.rerank(query=payload.input, documents=documents)
return {"success": True, "data": results}
14 changes: 14 additions & 0 deletions models/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pydantic import BaseModel
from typing import List
from models.vector_database import VectorDatabase


class RequestPayload(BaseModel):
input: str
vector_database: VectorDatabase
index_name: str


class ResponsePayload(BaseModel):
success: bool
data: List
3 changes: 2 additions & 1 deletion router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from fastapi import APIRouter

from api import ingest
from api import ingest, query

router = APIRouter()
api_prefix = "/api/v1"

router.include_router(ingest.router, tags=["Ingest"], prefix=api_prefix)
router.include_router(query.router, tags=["Query"], prefix=api_prefix)
58 changes: 52 additions & 6 deletions service/embedding.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import requests
import asyncio

from typing import List
from fastapi import UploadFile
from typing import Any, List, Union
from tempfile import NamedTemporaryFile
from llama_index import Document, SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from litellm import aembedding
from models.file import File
from decouple import config
from service.vector_database import get_vector_service


class EmbeddingService:
def __init__(self, files: List[File], index_name: str):
def __init__(self, files: List[File], index_name: str, vector_credentials: dict):
self.files = files
self.index_name = index_name
self.vector_credentials = vector_credentials

def _get_datasource_suffix(self, type: str) -> str:
suffixes = {"TXT": ".txt", "PDF": ".pdf", "MARKDOWN": ".md"}
Expand All @@ -20,16 +24,58 @@ def _get_datasource_suffix(self, type: str) -> str:
except KeyError:
raise ValueError("Unsupported datasource type")

async def generate_documents(self):
async def generate_documents(self) -> List[Document]:
documents = []
for file in self.files:
print(file.type.value)
suffix = self._get_datasource_suffix(file.type.value)
with NamedTemporaryFile(suffix=suffix, delete=True) as temp_file:
response = requests.get(url=file.url)
temp_file.write(response.content)
temp_file.flush()
reader = SimpleDirectoryReader(input_files=[temp_file.name])
docs = reader.load_data()
documents.append(docs)
for doc in docs:
doc.metadata["file_url"] = file.url
documents.extend(docs)
return documents

async def generate_chunks(
self, documents: List[Document]
) -> List[Union[Document, None]]:
parser = SimpleNodeParser.from_defaults(chunk_size=350, chunk_overlap=20)
nodes = parser.get_nodes_from_documents(documents, show_progress=False)
return nodes

async def generate_embeddings(
self,
nodes: List[Union[Document, None]],
) -> List[tuple[str, list, dict[str, Any]]]:
async def generate_embedding(node):
if node is not None:
vectors = []
embedding_object = await aembedding(
model="huggingface/intfloat/multilingual-e5-large",
input=node.text,
api_key=config("HUGGINGFACE_API_KEY"),
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
embedding = (
node.id_,
vectors,
{
**node.metadata,
"content": node.text,
},
)
return embedding

tasks = [generate_embedding(node) for node in nodes]
embeddings = await asyncio.gather(*tasks)
vector_service = get_vector_service(
index_name=self.index_name, credentials=self.vector_credentials
)
await vector_service.upsert(embeddings=[e for e in embeddings if e is not None])

return [e for e in embeddings if e is not None]
172 changes: 129 additions & 43 deletions service/vector_database.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,62 @@
from abc import ABC, abstractmethod
from typing import Any, List

import pinecone

from abc import ABC, abstractmethod
from typing import Any, List, Type
from decouple import config
from numpy import ndarray
from litellm import embedding
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
from models.vector_database import VectorDatabase


class VectorService(ABC):
def __init__(self, index_name: str, dimension: int, filter_id: str = None):
def __init__(self, index_name: str, dimension: int, credentials: dict):
self.index_name = index_name
self.filter_id = filter_id
self.dimension = dimension
self.credentials = credentials

@abstractmethod
def upsert():
async def upsert():
pass

@abstractmethod
def query():
async def query():
pass

@abstractmethod
def rerank(self, query: str, documents: list, top_n: int = 3):
async def convert_to_dict():
pass

async def rerank(self, query: str, documents: list, top_n: int = 4):
from cohere import Client

api_key = config("COHERE_API_KEY")
if not api_key:
raise ValueError("API key for Cohere is not present.")
cohere_client = Client(api_key=api_key)
docs = [doc["content"] for doc in documents]
re_ranked = cohere_client.rerank(
model="rerank-multilingual-v2.0",
query=query,
documents=docs,
top_n=top_n,
).results
results = []
for r in re_ranked:
doc = documents[r.index]
results.append(doc)
return results


class PineconeVectorService(VectorService):
def __init__(self, index_name: str, dimension: int, filter_id: str = None):
def __init__(self, index_name: str, dimension: int, credentials: dict):
super().__init__(
index_name=index_name, dimension=dimension, filter_id=filter_id
index_name=index_name, dimension=dimension, credentials=credentials
)
pinecone.init(
api_key=config("PINECONE_API_KEY"),
environment=config("PINECONE_ENVIRONMENT"),
api_key=credentials["PINECONE_API_KEY"],
environment=credentials["PINECONE_ENVIRONMENT"],
)
# Create a new vector index if it doesn't
# exist dimensions should be passed in the arguments
Expand All @@ -42,53 +66,115 @@ def __init__(self, index_name: str, dimension: int, filter_id: str = None):
)
self.index = pinecone.Index(index_name=self.index_name)

def upsert(self, vectors: ndarray):
self.index.upsert(vectors=vectors, namespace=self.filter_id)
async def convert_to_dict(self, documents: list):
pass

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]):
self.index.upsert(vectors=embeddings)

def query(self, queries: List[ndarray], top_k: int, include_metadata: bool = True):
async def query(
self, queries: List[ndarray], top_k: int, include_metadata: bool = True
):
results = self.index.query(
queries=queries,
top_k=top_k,
include_metadata=include_metadata,
namespace=self.filter_id,
)
return results["results"][0]["matches"]

def rerank(self, query: str, documents: Any, top_n: int = 3):
from cohere import Client

api_key = config("COHERE_API_KEY")
if not api_key:
raise ValueError("API key for Cohere is not present.")
cohere_client = Client(api_key=api_key)
docs = [
(
f"{doc['metadata']['content']}\n\n"
f"page number: {doc['metadata']['page_label']}"
class QdrantService(VectorService):
def __init__(self, index_name: str, dimension: int, credentials: dict):
super().__init__(
index_name=index_name, dimension=dimension, credentials=credentials
)
self.client = QdrantClient(
url=credentials["host"], api_key=credentials["api_key"], https=True
)
collections = self.client.get_collections()
if index_name not in [c.name for c in collections.collections]:
self.client.create_collection(
collection_name=self.index_name,
vectors_config={
"content": rest.VectorParams(
size=1024, distance=rest.Distance.COSINE
)
},
optimizers_config=rest.OptimizersConfigDiff(
indexing_threshold=0,
),
)
for doc in documents

async def convert_to_dict(self, points: List[rest.PointStruct]):
docs = [
{
"content": point.payload.get("content"),
"page_label": point.payload.get("page_label"),
"file_url": point.payload.get("file_url"),
}
for point in points
]
re_ranked = cohere_client.rerank(
model="rerank-multilingual-v2.0",
query=query,
documents=docs,
top_n=top_n,
).results
results = []
for obj in re_ranked:
results.append(obj.document["text"])
return results
return docs

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]):
points = []

for embedding in embeddings:
points.append(
rest.PointStruct(
id=embedding[0],
vector={"content": embedding[1]},
payload={**embedding[2]},
)
)

self.client.upsert(collection_name=self.index_name, wait=True, points=points)
collection_vector_count = self.client.get_collection(
collection_name=self.index_name
).vectors_count
print(f"Vector count in collection: {collection_vector_count}")

async def query(self, input: str, top_k: int):
vectors = []
embedding_object = embedding(
model="huggingface/intfloat/multilingual-e5-large",
input=input,
api_key=config("HUGGINGFACE_API_KEY"),
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
search_result = self.client.search(
collection_name=self.index_name,
query_vector=("content", vectors),
limit=top_k,
# query_filter=rest.Filter(
# must=[
# rest.FieldCondition(
# key="datasource_id",
# match=rest.MatchValue(value=datasource_id),
# ),
# ]
# ),
with_payload=True,
)
return search_result


def get_vector_service(
provider: str, index_name: str, filter_id: str = None, dimension: int = 384
):
index_name: str, credentials: VectorDatabase, dimension: int = 1024
) -> Type[VectorService]:
services = {
"PINECONE": PineconeVectorService,
"pinecone": PineconeVectorService,
"qdrant": QdrantService,
# Add other providers here
# e.g "weaviate": WeaviateVectorService,
}
service = services.get(provider)
service = services.get(credentials.type.value)
if service is None:
raise ValueError(f"Unsupported provider: {provider}")
return service(index_name=index_name, filter_id=filter_id, dimension=dimension)
raise ValueError(f"Unsupported provider: {credentials.type.value}")
return service(
index_name=index_name,
dimension=dimension,
credentials=dict(credentials.config),
)

0 comments on commit 51a911f

Please sign in to comment.