Skip to content

Commit

Permalink
Fix ColbertVectorStore as_retriever()
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jul 24, 2024
1 parent a88325d commit abf4c85
Showing 1 changed file with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar

from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from ragstack_colbert import Chunk
from ragstack_colbert import ColbertVectorStore as RagstackColbertVectorStore
from ragstack_colbert.base_database import BaseDatabase as ColbertBaseDatabase
Expand All @@ -13,8 +12,6 @@
from ragstack_colbert.base_vector_store import BaseVectorStore as ColbertBaseVectorStore
from typing_extensions import override

from .colbert_retriever import ColbertRetriever

CVS = TypeVar("CVS", bound="ColbertVectorStore")


Expand Down Expand Up @@ -282,8 +279,11 @@ async def afrom_texts(
return instance

@override
def as_retriever(self, k: Optional[int] = 5, **kwargs: Any) -> BaseRetriever:
def as_retriever(self, k: Optional[int] = 5, **kwargs: Any) -> VectorStoreRetriever:
"""Return a VectorStoreRetriever initialized from this VectorStore."""
return ColbertRetriever(
retriever=self._vector_store.as_retriever(), k=k, **kwargs
)
search_kwargs = kwargs.pop("search_kwargs", {})
search_kwargs["k"] = k
search_type = kwargs.get("search_type", "similarity")
if search_type != "similarity":
raise ValueError(f"Unsupported search type: {search_type}")
return super().as_retriever(search_kwargs=search_kwargs, **kwargs)

0 comments on commit abf4c85

Please sign in to comment.