generated from langchain-ai/integration-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
INTPYTHON-272 Parent Document Retriever (#5)
- Loading branch information
1 parent
327b75f
commit 3eeb7fd
Showing
10 changed files
with
1,103 additions
and
412 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from __future__ import annotations | ||
|
||
from importlib.metadata import version | ||
from typing import Any, Generator, Iterable, Iterator, List, Optional, Sequence, Union | ||
|
||
from langchain_core.documents import Document | ||
from langchain_core.stores import BaseStore | ||
from pymongo import MongoClient | ||
from pymongo.collection import Collection | ||
from pymongo.driver_info import DriverInfo | ||
|
||
from langchain_mongodb.utils import ( | ||
make_serializable, | ||
) | ||
|
||
DEFAULT_INSERT_BATCH_SIZE = 100_000 | ||
|
||
|
||
class MongoDBDocStore(BaseStore): | ||
"""MongoDB Collection providing BaseStore interface. | ||
This is meant to be treated as a key-value store: [str, Document] | ||
In a MongoDB Collection, the field name _id is reserved for use as a primary key. | ||
Its value must be unique in the collection, is immutable, | ||
and may be of any type other than an array or regex. | ||
As this field is always indexed, it is the natural choice to hold keys. | ||
The value will be held simply in a field called "value". | ||
It can contain any valid BSON type. | ||
Example key value pair: {"_id": "foo", "value": "bar"}. | ||
""" | ||
|
||
def __init__(self, collection: Collection, text_key: str = "page_content") -> None: | ||
self.collection = collection | ||
self._text_key = text_key | ||
|
||
@classmethod | ||
def from_connection_string( | ||
cls, | ||
connection_string: str, | ||
namespace: str, | ||
**kwargs: Any, | ||
) -> MongoDBDocStore: | ||
"""Construct a Key-Value Store from a MongoDB connection URI. | ||
Args: | ||
connection_string: A valid MongoDB connection URI. | ||
namespace: A valid MongoDB namespace (in form f"{database}.{collection}") | ||
Returns: | ||
A new MongoDBDocStore instance. | ||
""" | ||
client: MongoClient = MongoClient( | ||
connection_string, | ||
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), | ||
) | ||
db_name, collection_name = namespace.split(".") | ||
collection = client[db_name][collection_name] | ||
return cls(collection=collection) | ||
|
||
def mget(self, keys: Sequence[str]) -> list[Optional[Document]]: | ||
"""Get the values associated with the given keys. | ||
If a key is not found in the store, the corresponding value will be None. | ||
As returning None is not the default find behavior, we form a dictionary | ||
and loop over the keys. | ||
Args: | ||
keys (Sequence[str]): A sequence of keys. | ||
Returns: List of values associated with the given keys. | ||
""" | ||
found_docs = {} | ||
for res in self.collection.find({"_id": {"$in": keys}}): | ||
text = res.pop(self._text_key) | ||
key = res.pop("_id") | ||
make_serializable(res) | ||
found_docs[key] = Document(page_content=text, metadata=res) | ||
return [found_docs.get(key, None) for key in keys] | ||
|
||
def mset( | ||
self, | ||
key_value_pairs: Sequence[tuple[str, Document]], | ||
batch_size: int = DEFAULT_INSERT_BATCH_SIZE, | ||
) -> None: | ||
"""Set the values for the given keys. | ||
Args: | ||
key_value_pairs: A sequence of key-value pairs. | ||
""" | ||
keys, docs = zip(*key_value_pairs) | ||
n_docs = len(docs) | ||
start = 0 | ||
for end in range(batch_size, n_docs + batch_size, batch_size): | ||
texts, metadatas = zip( | ||
*[(doc.page_content, doc.metadata) for doc in docs[start:end]] | ||
) | ||
self.insert_many(texts=texts, metadatas=metadatas, ids=keys[start:end]) # type: ignore | ||
start = end | ||
|
||
def mdelete(self, keys: Sequence[str]) -> None: | ||
"""Delete the given keys and their associated values. | ||
Args: | ||
keys (Sequence[str]): A sequence of keys to delete. | ||
""" | ||
self.collection.delete_many({"_id": {"$in": keys}}) | ||
|
||
def yield_keys( | ||
self, *, prefix: Optional[str] = None | ||
) -> Union[Iterator[str], Iterator[str]]: | ||
"""Get an iterator over keys that match the given prefix. | ||
Args: | ||
prefix (str): The prefix to match. | ||
Yields: | ||
Iterator[str | str]: An iterator over keys that match the given prefix. | ||
This method is allowed to return an iterator over either str | ||
depending on what makes more sense for the given store. | ||
""" | ||
query = {"_id": {"$regex": f"^{prefix}"}} if prefix else {} | ||
for document in self.collection.find(query, {"_id": 1}): | ||
yield document["_id"] | ||
|
||
def insert_many( | ||
self, | ||
texts: Union[List[str], Iterable[str]], | ||
metadatas: Union[List[dict], Generator[dict, Any, Any]], | ||
ids: List[str], | ||
) -> None: | ||
"""Bulk insert single batch of texts, embeddings, and optionally ids. | ||
insert_many in PyMongo does not overwrite existing documents. | ||
Instead, it attempts to insert each document as a new document. | ||
If a document with the same _id already exists in the collection, | ||
an error will be raised for that specific document. However, other documents | ||
in the batch that do not have conflicting _ids will still be inserted. | ||
""" | ||
to_insert = [ | ||
{"_id": i, self._text_key: t, **m} for i, t, m in zip(ids, texts, metadatas) | ||
] | ||
self.collection.insert_many(to_insert) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 154 additions & 0 deletions
154
libs/mongodb/langchain_mongodb/retrievers/parent_document.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
from __future__ import annotations | ||
|
||
from importlib.metadata import version | ||
from typing import Any, List | ||
|
||
import pymongo | ||
from langchain.retrievers.parent_document_retriever import ParentDocumentRetriever | ||
from langchain_core.callbacks import ( | ||
CallbackManagerForRetrieverRun, | ||
) | ||
from langchain_core.documents import Document | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_text_splitters import TextSplitter | ||
from pymongo import MongoClient | ||
from pymongo.driver_info import DriverInfo | ||
|
||
from langchain_mongodb import MongoDBAtlasVectorSearch | ||
from langchain_mongodb.docstores import MongoDBDocStore | ||
from langchain_mongodb.pipelines import vector_search_stage | ||
from langchain_mongodb.utils import make_serializable | ||
|
||
|
||
class MongoDBAtlasParentDocumentRetriever(ParentDocumentRetriever): | ||
"""MongoDB Atlas's ParentDocumentRetriever | ||
Uses ONE Collection for both Vector and Doc store. | ||
For details, see parent classes | ||
:class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever` | ||
and :class:`~langchain.retrievers.MultiVectorRetriever` for further details. | ||
Examples: | ||
>>> from langchain_mongodb.retrievers.parent_document import ( | ||
>>> ParentDocumentRetriever | ||
>>> ) | ||
>>> from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
>>> from langchain_openai import OpenAIEmbeddings | ||
>>> | ||
>>> retriever = ParentDocumentRetriever.from_connection_string( | ||
>>> "mongodb+srv://<user>:<clustername>.mongodb.net", | ||
>>> OpenAIEmbeddings(model="text-embedding-3-large"), | ||
>>> RecursiveCharacterTextSplitter(chunk_size=400), | ||
>>> "example_database" | ||
>>> ) | ||
retriever.add_documents([Document(..., technical_report_pages) | ||
>>> resp = retriever.invoke("Langchain MongDB Partnership Ecosystem") | ||
>>> print(resp) | ||
[Document(...), ...] | ||
""" | ||
|
||
vectorstore: MongoDBAtlasVectorSearch | ||
"""Vectorstore API to add, embed, and search through child documents""" | ||
|
||
docstore: MongoDBDocStore | ||
"""Provides an API around the Collection to add the parent documents""" | ||
|
||
id_key: str = "doc_id" | ||
"""Key stored in metadata pointing to parent document""" | ||
|
||
def _get_relevant_documents( | ||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | ||
) -> List[Document]: | ||
query_vector = self.vectorstore._embedding.embed_query(query) | ||
|
||
pipeline = [ | ||
vector_search_stage( | ||
query_vector, | ||
self.vectorstore._embedding_key, | ||
self.vectorstore._index_name, | ||
**self.search_kwargs, # See MongoDBAtlasVectorSearch | ||
), | ||
{"$set": {"score": {"$meta": "vectorSearchScore"}}}, | ||
{"$project": {"embedding": 0}}, | ||
{ # Find corresponding parent doc | ||
"$lookup": { | ||
"from": self.vectorstore.collection.name, | ||
"localField": self.id_key, | ||
"foreignField": "_id", | ||
"as": "parent_context", | ||
"pipeline": [ | ||
# Discard sub-documents | ||
{"$match": {f"metadata.{self.id_key}": {"$exists": False}}}, | ||
], | ||
} | ||
}, # Remove duplicate parent docs and reformat | ||
{"$unwind": {"path": "$parent_context"}}, | ||
{ | ||
"$group": { | ||
"_id": "$parent_context._id", | ||
"uniqueDocument": {"$first": "$parent_context"}, | ||
} | ||
}, | ||
{"$replaceRoot": {"newRoot": "$uniqueDocument"}}, | ||
] | ||
# Execute | ||
cursor = self.vectorstore._collection.aggregate(pipeline) # type: ignore[arg-type] | ||
docs = [] | ||
# Format into Documents | ||
for res in cursor: | ||
text = res.pop(self.vectorstore._text_key) | ||
make_serializable(res) | ||
docs.append(Document(page_content=text, metadata=res)) | ||
return docs | ||
|
||
@classmethod | ||
def from_connection_string( | ||
cls, | ||
connection_string: str, | ||
embedding_model: Embeddings, | ||
child_splitter: TextSplitter, | ||
database_name: str, | ||
collection_name: str = "document_with_chunks", | ||
id_key: str = "doc_id", | ||
**kwargs: Any, | ||
) -> MongoDBAtlasParentDocumentRetriever: | ||
"""Construct Retriever using one Collection for VectorStore and one for DocStore | ||
See parent classes | ||
:class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever` | ||
and :class:`~langchain.retrievers.MultiVectorRetriever` for further details. | ||
Args: | ||
connection_string: A valid MongoDB Atlas connection URI. | ||
embedding_model: The text embedding model to use for the vector store. | ||
child_splitter: Splits documents into chunks. | ||
If parent_splitter is given, the documents will have already been split. | ||
database_name: Name of database to connect to. Created if it does not exist. | ||
collection_name: Name of collection to use. | ||
It includes parent documents, sub-documents and their embeddings. | ||
id_key: Key used to identify parent documents. | ||
**kwargs: Additional keyword arguments. See parent classes for more. | ||
Returns: A new MongoDBAtlasParentDocumentRetriever | ||
""" | ||
client: MongoClient = MongoClient( | ||
connection_string, | ||
driver=DriverInfo(name="langchain", version=version("langchain-mongodb")), | ||
) | ||
collection = client[database_name][collection_name] | ||
vectorstore = MongoDBAtlasVectorSearch( | ||
collection=collection, embedding=embedding_model, **kwargs | ||
) | ||
|
||
docstore = MongoDBDocStore(collection=collection) | ||
docstore.collection.create_index([(id_key, pymongo.ASCENDING)]) | ||
|
||
return cls( | ||
vectorstore=vectorstore, | ||
docstore=docstore, | ||
child_splitter=child_splitter, | ||
id_key=id_key, | ||
**kwargs, | ||
) |
Oops, something went wrong.