Skip to content

Commit

Permalink
force mypy into happiness
Browse files Browse the repository at this point in the history
  • Loading branch information
jkwatson committed Nov 26, 2024
1 parent 6a2eac3 commit 8ced596
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
22 changes: 14 additions & 8 deletions llm-service/app/ai/vector_stores/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#

import os
from typing import Optional
from typing import Optional, Any

import qdrant_client
from llama_index.core.indices import VectorStoreIndex
Expand All @@ -47,8 +47,8 @@
)
from qdrant_client.http.models import CountResult, Record

from ...services import models
from .vector_store import VectorStore
from ...services import models


def new_qdrant_client() -> qdrant_client.QdrantClient:
Expand All @@ -60,20 +60,20 @@ def new_qdrant_client() -> qdrant_client.QdrantClient:
class QdrantVectorStore(VectorStore):
@staticmethod
def for_chunks(
data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None
data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None
) -> "QdrantVectorStore":
return QdrantVectorStore(table_name=f"index_{data_source_id}", client=client)

@staticmethod
def for_summaries(
data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None
data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None
) -> "QdrantVectorStore":
return QdrantVectorStore(
table_name=f"summary_index_{data_source_id}", client=client
)

def __init__(
self, table_name: str, client: Optional[qdrant_client.QdrantClient] = None
self, table_name: str, client: Optional[qdrant_client.QdrantClient] = None
):
self.client = client or new_qdrant_client()
self.table_name = table_name
Expand Down Expand Up @@ -106,7 +106,7 @@ def llama_vector_store(self) -> BasePydanticVectorStore:
vector_store = LlamaIndexQdrantVectorStore(self.table_name, self.client)
return vector_store

def visualize(self, user_query: Optional[str] = None):
def visualize(self, user_query: Optional[str] = None) -> list[tuple[tuple[float], str]]:
records: list[Record]
records, _ = self.client.scroll(self.table_name, limit=5000, with_vectors=True)

Expand All @@ -115,11 +115,17 @@ def visualize(self, user_query: Optional[str] = None):
user_query_vector = embedding_model.get_query_embedding(user_query)
records.append(Record(vector=user_query_vector, id="abc123", payload={"file_name": "USER_QUERY"}))

filenames = [record.payload.get("file_name") for record in records]
record: Record
filenames = []
for record in records:
payload: dict[str, Any] | None = record.payload
if payload:
filenames.append(payload.get("file_name"))

import umap
reducer = umap.UMAP()
embeddings = [record.vector for record in records]
reduced_embeddings = reducer.fit_transform(embeddings)

return [(tuple(x), filenames[i]) for i, x in enumerate(reduced_embeddings.tolist())]
# todo: figure out how to satisfy mypy on this line
return [(tuple(coordinate), filenames[i]) for i, coordinate in enumerate(reduced_embeddings.tolist())] # type: ignore
2 changes: 1 addition & 1 deletion llm-service/app/ai/vector_stores/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ def exists(self) -> bool:
"""Does the vector store exist?"""

@abstractmethod
def visualize(self, user_query: Optional[str] = None) -> list[[tuple[float], str]]:
def visualize(self, user_query: Optional[str] = None) -> list[tuple[tuple[float], str]]:
"""get a 2-d visualization of the vectors in the store"""
4 changes: 2 additions & 2 deletions llm-service/app/routers/index/data_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def chunk_contents(self, chunk_id: str) -> ChunkContentsResponse:

@router.get("/visualize")
@exceptions.propagates
def visualize(self) -> list:
def visualize(self) -> list[tuple[tuple[float], str]]:
return self.chunks_vector_store.visualize()


Expand All @@ -112,7 +112,7 @@ class VisualizationRequest(BaseModel):

@router.post("/visualize")
@exceptions.propagates
def visualize_with_query(self, request: VisualizationRequest) -> list:
def visualize_with_query(self, request: VisualizationRequest) -> list[tuple[tuple[float], str]]:
return self.chunks_vector_store.visualize(request.user_query)


Expand Down

0 comments on commit 8ced596

Please sign in to comment.