Skip to content

Commit

Permalink
INTPYTHON-272 Parent Document Retriever (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyclements authored Oct 31, 2024
1 parent 327b75f commit 3eeb7fd
Show file tree
Hide file tree
Showing 10 changed files with 1,103 additions and 412 deletions.
22 changes: 13 additions & 9 deletions libs/mongodb/langchain_mongodb/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
DEFAULT_HISTORY_KEY = "History"

try:
from motor.motor_asyncio import AsyncIOMotorClient # type: ignore
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorCollection,
AsyncIOMotorDatabase,
)

_motor_available = True
except ImportError:
AsyncIOMotorClient = None
AsyncIOMotorClient = None # type: ignore
_motor_available = False
logger.warning(
"Motor library is not installed. Asynchronous methods will fall back to using "
Expand Down Expand Up @@ -135,13 +139,13 @@ def __init__(
self.collection = self.db[collection_name]

if _motor_available:
self.async_client = AsyncIOMotorClient(connection_string)
self.async_db = self.async_client[database_name]
self.async_collection = self.async_db[collection_name]
else:
self.async_client = None
self.async_db = None
self.async_collection = None
self.async_client: AsyncIOMotorClient = AsyncIOMotorClient(
connection_string
)
self.async_db: AsyncIOMotorDatabase = self.async_client[database_name]
self.async_collection: AsyncIOMotorCollection = self.async_db[
collection_name
]

if create_index:
index_kwargs = index_kwargs or {}
Expand Down
146 changes: 146 additions & 0 deletions libs/mongodb/langchain_mongodb/docstores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from importlib.metadata import version
from typing import Any, Generator, Iterable, Iterator, List, Optional, Sequence, Union

from langchain_core.documents import Document
from langchain_core.stores import BaseStore
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo

from langchain_mongodb.utils import (
make_serializable,
)

DEFAULT_INSERT_BATCH_SIZE = 100_000


class MongoDBDocStore(BaseStore):
"""MongoDB Collection providing BaseStore interface.
This is meant to be treated as a key-value store: [str, Document]
In a MongoDB Collection, the field name _id is reserved for use as a primary key.
Its value must be unique in the collection, is immutable,
and may be of any type other than an array or regex.
As this field is always indexed, it is the natural choice to hold keys.
The value will be held simply in a field called "value".
It can contain any valid BSON type.
Example key value pair: {"_id": "foo", "value": "bar"}.
"""

def __init__(self, collection: Collection, text_key: str = "page_content") -> None:
self.collection = collection
self._text_key = text_key

@classmethod
def from_connection_string(
cls,
connection_string: str,
namespace: str,
**kwargs: Any,
) -> MongoDBDocStore:
"""Construct a Key-Value Store from a MongoDB connection URI.
Args:
connection_string: A valid MongoDB connection URI.
namespace: A valid MongoDB namespace (in form f"{database}.{collection}")
Returns:
A new MongoDBDocStore instance.
"""
client: MongoClient = MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
db_name, collection_name = namespace.split(".")
collection = client[db_name][collection_name]
return cls(collection=collection)

def mget(self, keys: Sequence[str]) -> list[Optional[Document]]:
"""Get the values associated with the given keys.
If a key is not found in the store, the corresponding value will be None.
As returning None is not the default find behavior, we form a dictionary
and loop over the keys.
Args:
keys (Sequence[str]): A sequence of keys.
Returns: List of values associated with the given keys.
"""
found_docs = {}
for res in self.collection.find({"_id": {"$in": keys}}):
text = res.pop(self._text_key)
key = res.pop("_id")
make_serializable(res)
found_docs[key] = Document(page_content=text, metadata=res)
return [found_docs.get(key, None) for key in keys]

def mset(
self,
key_value_pairs: Sequence[tuple[str, Document]],
batch_size: int = DEFAULT_INSERT_BATCH_SIZE,
) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs: A sequence of key-value pairs.
"""
keys, docs = zip(*key_value_pairs)
n_docs = len(docs)
start = 0
for end in range(batch_size, n_docs + batch_size, batch_size):
texts, metadatas = zip(
*[(doc.page_content, doc.metadata) for doc in docs[start:end]]
)
self.insert_many(texts=texts, metadatas=metadatas, ids=keys[start:end]) # type: ignore
start = end

def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
"""
self.collection.delete_many({"_id": {"$in": keys}})

def yield_keys(
self, *, prefix: Optional[str] = None
) -> Union[Iterator[str], Iterator[str]]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str): The prefix to match.
Yields:
Iterator[str | str]: An iterator over keys that match the given prefix.
This method is allowed to return an iterator over either str
depending on what makes more sense for the given store.
"""
query = {"_id": {"$regex": f"^{prefix}"}} if prefix else {}
for document in self.collection.find(query, {"_id": 1}):
yield document["_id"]

def insert_many(
self,
texts: Union[List[str], Iterable[str]],
metadatas: Union[List[dict], Generator[dict, Any, Any]],
ids: List[str],
) -> None:
"""Bulk insert single batch of texts, embeddings, and optionally ids.
insert_many in PyMongo does not overwrite existing documents.
Instead, it attempts to insert each document as a new document.
If a document with the same _id already exists in the collection,
an error will be raised for that specific document. However, other documents
in the batch that do not have conflicting _ids will still be inserted.
"""
to_insert = [
{"_id": i, self._text_key: t, **m} for i, t, m in zip(ids, texts, metadatas)
]
self.collection.insert_many(to_insert) # type: ignore
4 changes: 4 additions & 0 deletions libs/mongodb/langchain_mongodb/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
MongoDBAtlasFullTextSearchRetriever,
)
from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever
from langchain_mongodb.retrievers.parent_document import (
MongoDBAtlasParentDocumentRetriever,
)

__all__ = [
"MongoDBAtlasHybridSearchRetriever",
"MongoDBAtlasFullTextSearchRetriever",
"MongoDBAtlasParentDocumentRetriever",
]
154 changes: 154 additions & 0 deletions libs/mongodb/langchain_mongodb/retrievers/parent_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations

from importlib.metadata import version
from typing import Any, List

import pymongo
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever
from langchain_core.callbacks import (
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_text_splitters import TextSplitter
from pymongo import MongoClient
from pymongo.driver_info import DriverInfo

from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.docstores import MongoDBDocStore
from langchain_mongodb.pipelines import vector_search_stage
from langchain_mongodb.utils import make_serializable


class MongoDBAtlasParentDocumentRetriever(ParentDocumentRetriever):
"""MongoDB Atlas's ParentDocumentRetriever
Uses ONE Collection for both Vector and Doc store.
For details, see parent classes
:class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever`
and :class:`~langchain.retrievers.MultiVectorRetriever` for further details.
Examples:
>>> from langchain_mongodb.retrievers.parent_document import (
>>> ParentDocumentRetriever
>>> )
>>> from langchain_text_splitters import RecursiveCharacterTextSplitter
>>> from langchain_openai import OpenAIEmbeddings
>>>
>>> retriever = ParentDocumentRetriever.from_connection_string(
>>> "mongodb+srv://<user>:<clustername>.mongodb.net",
>>> OpenAIEmbeddings(model="text-embedding-3-large"),
>>> RecursiveCharacterTextSplitter(chunk_size=400),
>>> "example_database"
>>> )
retriever.add_documents([Document(..., technical_report_pages)
>>> resp = retriever.invoke("Langchain MongDB Partnership Ecosystem")
>>> print(resp)
[Document(...), ...]
"""

vectorstore: MongoDBAtlasVectorSearch
"""Vectorstore API to add, embed, and search through child documents"""

docstore: MongoDBDocStore
"""Provides an API around the Collection to add the parent documents"""

id_key: str = "doc_id"
"""Key stored in metadata pointing to parent document"""

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
query_vector = self.vectorstore._embedding.embed_query(query)

pipeline = [
vector_search_stage(
query_vector,
self.vectorstore._embedding_key,
self.vectorstore._index_name,
**self.search_kwargs, # See MongoDBAtlasVectorSearch
),
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
{"$project": {"embedding": 0}},
{ # Find corresponding parent doc
"$lookup": {
"from": self.vectorstore.collection.name,
"localField": self.id_key,
"foreignField": "_id",
"as": "parent_context",
"pipeline": [
# Discard sub-documents
{"$match": {f"metadata.{self.id_key}": {"$exists": False}}},
],
}
}, # Remove duplicate parent docs and reformat
{"$unwind": {"path": "$parent_context"}},
{
"$group": {
"_id": "$parent_context._id",
"uniqueDocument": {"$first": "$parent_context"},
}
},
{"$replaceRoot": {"newRoot": "$uniqueDocument"}},
]
# Execute
cursor = self.vectorstore._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = []
# Format into Documents
for res in cursor:
text = res.pop(self.vectorstore._text_key)
make_serializable(res)
docs.append(Document(page_content=text, metadata=res))
return docs

@classmethod
def from_connection_string(
cls,
connection_string: str,
embedding_model: Embeddings,
child_splitter: TextSplitter,
database_name: str,
collection_name: str = "document_with_chunks",
id_key: str = "doc_id",
**kwargs: Any,
) -> MongoDBAtlasParentDocumentRetriever:
"""Construct Retriever using one Collection for VectorStore and one for DocStore
See parent classes
:class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever`
and :class:`~langchain.retrievers.MultiVectorRetriever` for further details.
Args:
connection_string: A valid MongoDB Atlas connection URI.
embedding_model: The text embedding model to use for the vector store.
child_splitter: Splits documents into chunks.
If parent_splitter is given, the documents will have already been split.
database_name: Name of database to connect to. Created if it does not exist.
collection_name: Name of collection to use.
It includes parent documents, sub-documents and their embeddings.
id_key: Key used to identify parent documents.
**kwargs: Additional keyword arguments. See parent classes for more.
Returns: A new MongoDBAtlasParentDocumentRetriever
"""
client: MongoClient = MongoClient(
connection_string,
driver=DriverInfo(name="langchain", version=version("langchain-mongodb")),
)
collection = client[database_name][collection_name]
vectorstore = MongoDBAtlasVectorSearch(
collection=collection, embedding=embedding_model, **kwargs
)

docstore = MongoDBDocStore(collection=collection)
docstore.collection.create_index([(id_key, pymongo.ASCENDING)])

return cls(
vectorstore=vectorstore,
docstore=docstore,
child_splitter=child_splitter,
id_key=id_key,
**kwargs,
)
Loading

0 comments on commit 3eeb7fd

Please sign in to comment.