diff --git a/libs/mongodb/langchain_mongodb/index.py b/libs/mongodb/langchain_mongodb/index.py index 59cb6d3..fce71e7 100644 --- a/libs/mongodb/langchain_mongodb/index.py +++ b/libs/mongodb/langchain_mongodb/index.py @@ -5,22 +5,11 @@ from typing import Any, Callable, Dict, List, Optional from pymongo.collection import Collection -from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel logger = logging.getLogger(__file__) -def _search_index_error_message() -> str: - return ( - "Search index operations are not currently available on shared clusters, " - "such as MO. They require dedicated clusters >= M10. " - "You may still perform vector search. " - "You simply must set up indexes manually. Follow the instructions here: " - "https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/" - ) - - def _vector_search_index_definition( dimensions: int, path: str, @@ -71,22 +60,19 @@ def create_vector_search_index( """ logger.info("Creating Search Index %s on %s", index_name, collection.name) - try: - result = collection.create_search_index( - SearchIndexModel( - definition=_vector_search_index_definition( - dimensions=dimensions, - path=path, - similarity=similarity, - filters=filters, - **kwargs, - ), - name=index_name, - type="vectorSearch", - ) + result = collection.create_search_index( + SearchIndexModel( + definition=_vector_search_index_definition( + dimensions=dimensions, + path=path, + similarity=similarity, + filters=filters, + **kwargs, + ), + name=index_name, + type="vectorSearch", ) - except OperationFailure as e: - raise OperationFailure(_search_index_error_message()) from e + ) if wait_until_complete: _wait_for_predicate( @@ -114,12 +100,7 @@ def drop_vector_search_index( logger.info( "Dropping Search Index %s from Collection: %s", index_name, collection.name ) - try: - collection.drop_search_index(index_name) - except OperationFailure as e: - if "CommandNotSupported" in str(e): - raise OperationFailure(_search_index_error_message()) from e - # else this most likely means an ongoing drop request was made so skip + collection.drop_search_index(index_name) if wait_until_complete: _wait_for_predicate( predicate=lambda: len(list(collection.list_search_indexes())) == 0, @@ -155,24 +136,19 @@ def update_vector_search_index( until search index is ready. kwargs: Keyword arguments supplying any additional options to SearchIndexModel. """ - logger.info( "Updating Search Index %s from Collection: %s", index_name, collection.name ) - try: - collection.update_search_index( - name=index_name, - definition=_vector_search_index_definition( - dimensions=dimensions, - path=path, - similarity=similarity, - filters=filters, - **kwargs, - ), - ) - except OperationFailure as e: - raise OperationFailure(_search_index_error_message()) from e - + collection.update_search_index( + name=index_name, + definition=_vector_search_index_definition( + dimensions=dimensions, + path=path, + similarity=similarity, + filters=filters, + **kwargs, + ), + ) if wait_until_complete: _wait_for_predicate( predicate=lambda: _is_index_ready(collection, index_name), @@ -193,12 +169,7 @@ def _is_index_ready(collection: Collection, index_name: str) -> bool: Returns: bool : True if the index is present and READY false otherwise """ - try: - search_indexes = collection.list_search_indexes(index_name) - except OperationFailure as e: - raise OperationFailure(_search_index_error_message()) from e - - for index in search_indexes: + for index in collection.list_search_indexes(index_name): if index["type"] == "vectorSearch" and index["status"] == "READY": return True return False @@ -248,19 +219,14 @@ def create_fulltext_search_index( definition = { "mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}} } - - try: - result = collection.create_search_index( - SearchIndexModel( - definition=definition, - name=index_name, - type="search", - **kwargs, - ) + result = collection.create_search_index( + SearchIndexModel( + definition=definition, + name=index_name, + type="search", + **kwargs, ) - except OperationFailure as e: - raise OperationFailure(_search_index_error_message()) from e - + ) if wait_until_complete: _wait_for_predicate( predicate=lambda: _is_index_ready(collection, index_name), diff --git a/libs/mongodb/langchain_mongodb/retrievers/full_text_search.py b/libs/mongodb/langchain_mongodb/retrievers/full_text_search.py index 52f12c4..ef596e4 100644 --- a/libs/mongodb/langchain_mongodb/retrievers/full_text_search.py +++ b/libs/mongodb/langchain_mongodb/retrievers/full_text_search.py @@ -24,8 +24,8 @@ class MongoDBAtlasFullTextSearchRetriever(BaseRetriever): """Number of documents to return. Default is no limit""" filter: Optional[Dict[str, Any]] = None """(Optional) List of MQL match expression comparing an indexed field""" - show_embeddings: float = False - """If true, returned Document metadata will include vectors""" + include_scores: bool = True + """If True, include scores that provide measure of relative relevance""" def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun @@ -45,6 +45,7 @@ def _get_relevant_documents( index_name=self.search_index_name, limit=self.top_k, filter=self.filter, + include_scores=self.include_scores, ) # Execution diff --git a/libs/mongodb/langchain_mongodb/vectorstores.py b/libs/mongodb/langchain_mongodb/vectorstores.py index 7860281..6578744 100644 --- a/libs/mongodb/langchain_mongodb/vectorstores.py +++ b/libs/mongodb/langchain_mongodb/vectorstores.py @@ -235,6 +235,14 @@ def __init__( def embeddings(self) -> Embeddings: return self._embedding + @property + def collection(self) -> Collection: + return self._collection + + @collection.setter + def collection(self, value: Collection) -> None: + self._collection = value + def _select_relevance_score_fn(self) -> Callable[[float], float]: scoring: dict[str, Callable] = { "euclidean": self._euclidean_relevance_score_fn, @@ -761,6 +769,7 @@ def create_vector_search_index( dimensions: int, filters: Optional[List[str]] = None, update: bool = False, + wait_until_complete: Optional[float] = None, ) -> None: """Creates a MongoDB Atlas vectorSearch index for the VectorStore @@ -774,8 +783,11 @@ def create_vector_search_index( filters (Optional[List[Dict[str, str]]], optional): additional filters for index definition. Defaults to None. - update (bool, optional): Updates existing vectorSearch index. - Defaults to False. + update (Optional[bool]): Updates existing vectorSearch index. + Defaults to False. + wait_until_complete (Optional[float]): If given, a TimeoutError is raised + if search index is not ready after this number of seconds. + If not given, the default, operation will not wait. """ try: self._collection.database.create_collection(self._collection.name) @@ -793,4 +805,5 @@ def create_vector_search_index( path=self._embedding_key, similarity=self._relevance_score_fn, filters=filters or [], + wait_until_complete=wait_until_complete, ) # type: ignore [operator] diff --git a/libs/mongodb/tests/integration_tests/test_chain_example.py b/libs/mongodb/tests/integration_tests/test_chain_example.py index d54fa6c..e75664a 100644 --- a/libs/mongodb/tests/integration_tests/test_chain_example.py +++ b/libs/mongodb/tests/integration_tests/test_chain_example.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -from time import sleep import pytest # type: ignore[import-not-found] from langchain_core.documents import Document @@ -20,7 +19,7 @@ CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_chain_example" -INDEX_NAME = "vector_index" +INDEX_NAME = "langchain-test-chain-example-vector-index" DIMENSIONS = 1536 TIMEOUT = 60.0 INTERVAL = 0.5 @@ -88,9 +87,6 @@ def test_chain( ] vectorstore.add_texts(texts) - # Give the index time to build (For CI) - sleep(TIMEOUT) - query = "In the United States, what city did I visit last?" # One can do vector search on the vector store, using its various search types. k = len(texts) diff --git a/libs/mongodb/tests/integration_tests/test_index.py b/libs/mongodb/tests/integration_tests/test_index.py index 000cbed..fe4021c 100644 --- a/libs/mongodb/tests/integration_tests/test_index.py +++ b/libs/mongodb/tests/integration_tests/test_index.py @@ -1,5 +1,3 @@ -"""Search index commands are only supported on Atlas Clusters >=M10""" - import os from typing import Generator, List, Optional @@ -7,7 +5,9 @@ from pymongo import MongoClient from pymongo.collection import Collection -from langchain_mongodb import index +from langchain_mongodb import MongoDBAtlasVectorSearch, index + +from ..utils import ConsistentFakeEmbeddings DB_NAME = "langchain_test_index_db" COLLECTION_NAME = "test_index" @@ -22,13 +22,13 @@ def collection() -> Generator: """Depending on uri, this could point to any type of cluster.""" uri = os.environ.get("MONGODB_ATLAS_URI") client: MongoClient = MongoClient(uri) + client[DB_NAME].create_collection(COLLECTION_NAME) clxn = client[DB_NAME][COLLECTION_NAME] - clxn.insert_one({"foo": "bar"}) yield clxn clxn.drop() -def test_search_index_commands(collection: Collection) -> None: +def test_search_index_drop_add_delete_commands(collection: Collection) -> None: index_name = VECTOR_INDEX_NAME dimensions = DIMENSIONS path = "embedding" @@ -58,26 +58,79 @@ def test_search_index_commands(collection: Collection) -> None: assert len(indexes) == 1 assert indexes[0]["name"] == index_name - new_similarity = "euclidean" + index.drop_vector_search_index( + collection, index_name, wait_until_complete=wait_until_complete + ) + + indexes = list(collection.list_search_indexes()) + assert len(indexes) == 0 + + +@pytest.mark.skip("collection.update_vector_search_index requires [CLOUDP-275518]") +def test_search_index_update_vector_search_index(collection: Collection) -> None: + index_name = "INDEX_TO_UPDATE" + similarity_orig = "cosine" + similarity_new = "euclidean" + + # Create another index + index.create_vector_search_index( + collection=collection, + index_name=index_name, + dimensions=DIMENSIONS, + path="embedding", + similarity=similarity_orig, + wait_until_complete=TIMEOUT, + ) + + assert index._is_index_ready(collection, index_name) + indexes = list(collection.list_search_indexes()) + assert len(indexes) == 1 + assert indexes[0]["name"] == index_name + assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_orig + + # Update the index and test new similarity index.update_vector_search_index( - collection, - index_name, - DIMENSIONS, - "embedding", - new_similarity, - filters=[], - wait_until_complete=wait_until_complete, + collection=collection, + index_name=index_name, + dimensions=DIMENSIONS, + path="embedding", + similarity=similarity_new, + wait_until_complete=TIMEOUT, ) assert index._is_index_ready(collection, index_name) indexes = list(collection.list_search_indexes()) assert len(indexes) == 1 assert indexes[0]["name"] == index_name - assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == new_similarity + assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_new - index.drop_vector_search_index( - collection, index_name, wait_until_complete=wait_until_complete + +def test_vectorstore_create_vector_search_index(collection: Collection) -> None: + """Tests vectorstore wrapper around index command.""" + + # Set up using the index module's api + if len(list(collection.list_search_indexes())) != 0: + index.drop_vector_search_index( + collection, VECTOR_INDEX_NAME, wait_until_complete=TIMEOUT + ) + + # Test MongoDBAtlasVectorSearch's API + vectorstore = MongoDBAtlasVectorSearch( + collection=collection, + embedding=ConsistentFakeEmbeddings(), + index_name=VECTOR_INDEX_NAME, ) + vectorstore.create_vector_search_index( + dimensions=DIMENSIONS, wait_until_complete=TIMEOUT + ) + + assert index._is_index_ready(collection, VECTOR_INDEX_NAME) indexes = list(collection.list_search_indexes()) - assert len(indexes) == 0 + assert len(indexes) == 1 + assert indexes[0]["name"] == VECTOR_INDEX_NAME + + # Tear down using the index module's api + index.drop_vector_search_index( + collection, VECTOR_INDEX_NAME, wait_until_complete=TIMEOUT + ) diff --git a/libs/mongodb/tests/integration_tests/test_mmr.py b/libs/mongodb/tests/integration_tests/test_mmr.py new file mode 100644 index 0000000..a2f011d --- /dev/null +++ b/libs/mongodb/tests/integration_tests/test_mmr.py @@ -0,0 +1,45 @@ +"""Test max_marginal_relevance_search.""" + +from __future__ import annotations + +import os + +import pytest # type: ignore[import-not-found] +from langchain_core.embeddings import Embeddings +from pymongo import MongoClient +from pymongo.collection import Collection + +from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch + +CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") +DB_NAME = "langchain_test_db" +COLLECTION_NAME = "langchain_test_vectorstores" +INDEX_NAME = "langchain-test-index-vectorstores" +DIMENSIONS = 5 + + +@pytest.fixture() +def collection() -> Collection: + test_client: MongoClient = MongoClient(CONNECTION_STRING) + return test_client[DB_NAME][COLLECTION_NAME] + + +@pytest.fixture +def embeddings() -> Embeddings: + return ConsistentFakeEmbeddings(DIMENSIONS) + + +def test_mmr(embeddings: Embeddings, collection: Collection) -> None: + texts = ["foo", "foo", "fou", "foy"] + collection.delete_many({}) + vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( + texts, + embedding=embeddings, + collection=collection, + index_name=INDEX_NAME, + ) + query = "foo" + output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1) + assert len(output) == len(texts) + assert output[0].page_content == "foo" + assert output[1].page_content != "foo" diff --git a/libs/mongodb/tests/integration_tests/test_retrievers.py b/libs/mongodb/tests/integration_tests/test_retrievers.py index 8a0cf34..1dd7626 100644 --- a/libs/mongodb/tests/integration_tests/test_retrievers.py +++ b/libs/mongodb/tests/integration_tests/test_retrievers.py @@ -1,6 +1,5 @@ import os -from time import sleep -from typing import List +from typing import Generator, List import pytest from langchain_core.documents import Document @@ -8,13 +7,17 @@ from pymongo import MongoClient from pymongo.collection import Collection -from langchain_mongodb import index +from langchain_mongodb import MongoDBAtlasVectorSearch +from langchain_mongodb.index import ( + create_fulltext_search_index, + create_vector_search_index, +) from langchain_mongodb.retrievers import ( MongoDBAtlasFullTextSearchRetriever, MongoDBAtlasHybridSearchRetriever, ) -from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch +from ..utils import PatchedMongoDBAtlasVectorSearch CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") DB_NAME = "langchain_test_db" @@ -24,12 +27,12 @@ PAGE_CONTENT_FIELD = "text" SEARCH_INDEX_NAME = "text_index" -DIMENSIONS = 1536 +DIMENSIONS = 1536 # Meets OpenAI model TIMEOUT = 60.0 INTERVAL = 0.5 -@pytest.fixture +@pytest.fixture(scope="module") def example_documents() -> List[Document]: return [ Document(page_content="In 2023, I visited Paris"), @@ -39,20 +42,22 @@ def example_documents() -> List[Document]: ] -@pytest.fixture +@pytest.fixture(scope="module") def embedding_openai() -> Embeddings: from langchain_openai import OpenAIEmbeddings try: + from langchain_openai import OpenAIEmbeddings + return OpenAIEmbeddings( openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa model="text-embedding-3-small", ) except Exception: - return ConsistentFakeEmbeddings(DIMENSIONS) + pytest.fail("test_retrievers expects OPENAI_API_KEY in os.environ") -@pytest.fixture +@pytest.fixture(scope="module") def collection() -> Collection: """A Collection with both a Vector and a Full-text Search Index""" client: MongoClient = MongoClient(CONNECTION_STRING) @@ -64,7 +69,7 @@ def collection() -> Collection: clxn.delete_many({}) if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): - index.create_vector_search_index( + create_vector_search_index( collection=clxn, index_name=VECTOR_INDEX_NAME, dimensions=DIMENSIONS, @@ -74,7 +79,7 @@ def collection() -> Collection: ) if not any([SEARCH_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): - index.create_fulltext_search_index( + create_fulltext_search_index( collection=clxn, index_name=SEARCH_INDEX_NAME, field=PAGE_CONTENT_FIELD, @@ -84,12 +89,13 @@ def collection() -> Collection: return clxn -def test_hybrid_retriever( - embedding_openai: Embeddings, +@pytest.fixture(scope="module") +def indexed_vectorstore( collection: Collection, example_documents: List[Document], -) -> None: - """Test basic usage of MongoDBAtlasHybridSearchRetriever""" + embedding_openai: Embeddings, +) -> Generator[MongoDBAtlasVectorSearch, None, None]: + """Return a VectorStore with example document embeddings indexed.""" vectorstore = PatchedMongoDBAtlasVectorSearch( collection=collection, @@ -100,10 +106,29 @@ def test_hybrid_retriever( vectorstore.add_documents(example_documents) - sleep(TIMEOUT) # Wait for documents to be sync'd + yield vectorstore + + vectorstore.collection.delete_many({}) + + +def test_vector_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: + """Test VectorStoreRetriever""" + retriever = indexed_vectorstore.as_retriever() + + query1 = "What was the latest city that I visited?" + results = retriever.invoke(query1) + assert len(results) == 4 + assert "Paris" in results[0].page_content + + query2 = "When was the last time I visited new orleans?" + results = retriever.invoke(query2) + assert "New Orleans" in results[0].page_content + +def test_hybrid_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: + """Test basic usage of MongoDBAtlasHybridSearchRetriever""" retriever = MongoDBAtlasHybridSearchRetriever( - vectorstore=vectorstore, + vectorstore=indexed_vectorstore, search_index_name=SEARCH_INDEX_NAME, top_k=3, ) @@ -119,20 +144,15 @@ def test_hybrid_retriever( def test_fulltext_retriever( - collection: Collection, - example_documents: List[Document], + indexed_vectorstore: PatchedMongoDBAtlasVectorSearch, ) -> None: - """Test result of performing fulltext search + """Test result of performing fulltext search. - Independent of the VectorStore, one adds documents - via MongoDB's Collection API + The Retriever is independent of the VectorStore. + We use it here only to get the Collection, which we know to be indexed. """ - # - collection.insert_many( - [{PAGE_CONTENT_FIELD: doc.page_content} for doc in example_documents] - ) - sleep(TIMEOUT) # Wait for documents to be sync'd + collection: Collection = indexed_vectorstore.collection retriever = MongoDBAtlasFullTextSearchRetriever( collection=collection, @@ -144,33 +164,3 @@ def test_fulltext_retriever( results = retriever.invoke(query) assert "New Orleans" in results[0].page_content assert "score" in results[0].metadata - - -def test_vector_retriever( - embedding_openai: Embeddings, - collection: Collection, - example_documents: List[Document], -) -> None: - """Test VectorStoreRetriever""" - - vectorstore = PatchedMongoDBAtlasVectorSearch( - collection=collection, - embedding=embedding_openai, - index_name=VECTOR_INDEX_NAME, - text_key=PAGE_CONTENT_FIELD, - ) - - vectorstore.add_documents(example_documents) - - sleep(TIMEOUT) # Wait for documents to be sync'd - - retriever = vectorstore.as_retriever() - - query1 = "What was the latest city that I visited?" - results = retriever.invoke(query1) - assert len(results) == 4 - assert "Paris" in results[0].page_content - - query2 = "When was the last time I visited new orleans?" - results = retriever.invoke(query2) - assert "New Orleans" in results[0].page_content diff --git a/libs/mongodb/tests/integration_tests/test_vectorstore_add_delete.py b/libs/mongodb/tests/integration_tests/test_vectorstore_add_delete.py new file mode 100644 index 0000000..3fb4220 --- /dev/null +++ b/libs/mongodb/tests/integration_tests/test_vectorstore_add_delete.py @@ -0,0 +1,211 @@ +"""Test MongoDB Atlas Vector Search functionality.""" + +from __future__ import annotations + +import os +from typing import Any, Dict, List + +import pytest # type: ignore[import-not-found] +from bson import ObjectId +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from pymongo import MongoClient +from pymongo.collection import Collection + +from langchain_mongodb import MongoDBAtlasVectorSearch +from langchain_mongodb.utils import oid_to_str + +from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch + +CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") +DB_NAME = "langchain_test_db" +INDEX_NAME = "langchain-test-index-vectorstores" +COLLECTION_NAME = "langchain_test_vectorstores" +DIMENSIONS = 5 + + +@pytest.fixture(scope="module") +def collection() -> Collection: + test_client: MongoClient = MongoClient(CONNECTION_STRING) + return test_client[DB_NAME][COLLECTION_NAME] + + +@pytest.fixture(scope="module") +def texts() -> List[str]: + return [ + "Dogs are tough.", + "Cats have fluff.", + "What is a sandwich?", + "That fence is purple.", + ] + + +@pytest.fixture(scope="module") +def trivial_embeddings() -> Embeddings: + return ConsistentFakeEmbeddings(DIMENSIONS) + + +def test_delete( + trivial_embeddings: Embeddings, collection: Any, texts: List[str] +) -> None: + vectorstore = MongoDBAtlasVectorSearch( + collection=collection, + embedding=trivial_embeddings, + index_name="MATCHES_NOTHING", + ) + clxn: Collection = vectorstore.collection + clxn.delete_many({}) + assert clxn.count_documents({}) == 0 + ids = vectorstore.add_texts(texts) + assert clxn.count_documents({}) == len(texts) + + deleted = vectorstore.delete(ids[-2:]) + assert deleted + assert clxn.count_documents({}) == len(texts) - 2 + + new_ids = vectorstore.add_texts(["Pigs eat stuff", "Pigs eat sandwiches"]) + assert set(new_ids).intersection(set(ids)) == set() # new ids will be unique. + assert isinstance(new_ids, list) + assert all(isinstance(i, str) for i in new_ids) + assert len(new_ids) == 2 + assert clxn.count_documents({}) == 4 + + +def test_add_texts( + trivial_embeddings: Embeddings, + collection: Collection, + texts: List[str], +) -> None: + """Tests API of add_texts, focussing on id treatment + + Warning: This is slow because of the number of cases + """ + metadatas: List[Dict[str, Any]] = [ + {"a": 1}, + {"b": 1}, + {"c": 1}, + {"d": 1, "e": 2}, + ] + + vectorstore = PatchedMongoDBAtlasVectorSearch( + collection=collection, + embedding=trivial_embeddings, + index_name=INDEX_NAME, + ) + vectorstore.delete() + + # Case 1. Add texts without ids + provided_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) + all_docs = list(vectorstore._collection.find({})) + assert all("_id" in doc for doc in all_docs) + docids = set(doc["_id"] for doc in all_docs) + assert all(isinstance(_id, ObjectId) for _id in docids) # + assert set(provided_ids) == set(oid_to_str(oid) for oid in docids) + + # Case 2: Test Document.metadata looks right. i.e. contains _id + search_res = vectorstore.similarity_search_with_score("sandwich", k=1) + doc, score = search_res[0] + assert "_id" in doc.metadata + + # Case 3: Add new ids that are 24-char hex strings + hex_ids = [oid_to_str(ObjectId()) for _ in range(2)] + hex_texts = ["Text for hex_id"] * len(hex_ids) + out_ids = vectorstore.add_texts(texts=hex_texts, ids=hex_ids) + assert set(out_ids) == set(hex_ids) + assert collection.count_documents({}) == len(texts) + len(hex_texts) + assert all( + isinstance(doc["_id"], ObjectId) for doc in vectorstore._collection.find({}) + ) + + # Case 4: Add new ids that cannot be cast to ObjectId + # - We can still index and search on them + str_ids = ["Sandwiches are beautiful,", "..sandwiches are fine."] + str_texts = str_ids # No reason for them to differ + out_ids = vectorstore.add_texts(texts=str_texts, ids=str_ids) + assert set(out_ids) == set(str_ids) + assert collection.count_documents({}) == 8 + res = vectorstore.similarity_search("sandwich", k=8) + assert any(str_ids[0] in doc.metadata["_id"] for doc in res) + + # Case 5: Test adding in multiple batches + batch_size = 2 + batch_ids = [oid_to_str(ObjectId()) for _ in range(2 * batch_size)] + batch_texts = [f"Text for batch text {i}" for i in range(2 * batch_size)] + out_ids = vectorstore.add_texts( + texts=batch_texts, ids=batch_ids, batch_size=batch_size + ) + assert set(out_ids) == set(batch_ids) + assert collection.count_documents({}) == 12 + + # Case 6: _ids in metadata + collection.delete_many({}) + # 6a. Unique _id in metadata, but ids=None + # Will be added as if ids kwarg provided + i = 0 + n = len(texts) + assert len(metadatas) == n + _ids = [str(i) for i in range(n)] + for md in metadatas: + md["_id"] = _ids[i] + i += 1 + returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) + assert returned_ids == ["0", "1", "2", "3"] + assert set(d["_id"] for d in vectorstore._collection.find({})) == set(_ids) + + # 6b. Unique "id", not "_id", but ids=None + # New ids will be assigned + i = 1 + for md in metadatas: + md.pop("_id") + md["id"] = f"{1}" + i += 1 + returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) + assert len(set(returned_ids).intersection(set(_ids))) == 0 + + +def test_add_documents( + collection: Collection, + trivial_embeddings: Embeddings, +) -> None: + """Tests add_documents. + + Note: Does not need indexes so no need to use patient patched vectorstore.""" + vectorstore = MongoDBAtlasVectorSearch( + collection=collection, + embedding=trivial_embeddings, + index_name="MATCHES_NOTHING", + ) + vectorstore.collection.delete_many({}) + # Case 1: No ids + n_docs = 10 + batch_size = 3 + docs = [ + Document(page_content=f"document {i}", metadata={"i": i}) for i in range(n_docs) + ] + result_ids = vectorstore.add_documents(docs, batch_size=batch_size) + assert len(result_ids) == n_docs + assert collection.count_documents({}) == n_docs + + # Case 2: ids + collection.delete_many({}) + n_docs = 10 + batch_size = 3 + docs = [ + Document(page_content=f"document {i}", metadata={"i": i}) for i in range(n_docs) + ] + ids = [str(i) for i in range(n_docs)] + result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size) + assert len(result_ids) == n_docs + assert set(ids) == set(collection.distinct("_id")) + + # Case 3: Single batch + collection.delete_many({}) + n_docs = 3 + batch_size = 10 + docs = [ + Document(page_content=f"document {i}", metadata={"i": i}) for i in range(n_docs) + ] + ids = [str(i) for i in range(n_docs)] + result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size) + assert len(result_ids) == n_docs + assert set(ids) == set(collection.distinct("_id")) diff --git a/libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py b/libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py new file mode 100644 index 0000000..0be4de4 --- /dev/null +++ b/libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py @@ -0,0 +1,84 @@ +"""Test MongoDBAtlasVectorSearch.from_documents.""" + +from __future__ import annotations + +import os +from typing import Generator, List + +import pytest # type: ignore[import-not-found] +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from pymongo import MongoClient +from pymongo.collection import Collection + +from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch + +CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") +DB_NAME = "langchain_test_db" +COLLECTION_NAME = "langchain_test_from_documents" +INDEX_NAME = "langchain-test-index-from-documents" +DIMENSIONS = 5 + + +@pytest.fixture(scope="module") +def collection() -> Generator[Collection, None, None]: + test_client: MongoClient = MongoClient(CONNECTION_STRING) + clxn = test_client[DB_NAME][COLLECTION_NAME] + yield clxn + clxn.delete_many({}) + + +@pytest.fixture(scope="module") +def example_documents() -> List[Document]: + return [ + Document(page_content="Dogs are tough.", metadata={"a": 1}), + Document(page_content="Cats have fluff.", metadata={"b": 1}), + Document(page_content="What is a sandwich?", metadata={"c": 1}), + Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}), + ] + + +@pytest.fixture(scope="module") +def embeddings() -> Embeddings: + return ConsistentFakeEmbeddings(DIMENSIONS) + + +@pytest.fixture(scope="module") +def vectorstore( + collection: Collection, example_documents: List[Document], embeddings: Embeddings +) -> PatchedMongoDBAtlasVectorSearch: + """VectorStore created with a few documents and a trivial embedding model. + + Note: PatchedMongoDBAtlasVectorSearch is MongoDBAtlasVectorSearch in all + but one important feature. It waits until all documents are fully indexed + before returning control to the caller. + """ + vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( + example_documents, + embedding=embeddings, + collection=collection, + index_name=INDEX_NAME, + ) + return vectorstore + + +def test_default_search( + vectorstore: PatchedMongoDBAtlasVectorSearch, example_documents: List[Document] +) -> None: + """Test end to end construction and search.""" + output = vectorstore.similarity_search("Sandwich", k=1) + assert len(output) == 1 + # Check for the presence of the metadata key + assert any( + [key.page_content == output[0].page_content for key in example_documents] + ) + # Assert no presence of embeddings in results + assert all(["embedding" not in key.metadata for key in output]) + + +def test_search_with_embeddings(vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: + output = vectorstore.similarity_search("Sandwich", k=2, include_embeddings=True) + assert len(output) == 2 + + # Assert embeddings in results + assert all([key.metadata.get("embedding") for key in output]) diff --git a/libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py b/libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py new file mode 100644 index 0000000..84bb448 --- /dev/null +++ b/libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py @@ -0,0 +1,102 @@ +"""Test MongoDBAtlasVectorSearch.from_documents.""" + +from __future__ import annotations + +import os +from typing import Dict, Generator, List + +import pytest # type: ignore[import-not-found] +from langchain_core.embeddings import Embeddings +from pymongo import MongoClient +from pymongo.collection import Collection + +from langchain_mongodb import MongoDBAtlasVectorSearch + +from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch + +CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") +DB_NAME = "langchain_test_db" +COLLECTION_NAME = "langchain_test_from_texts" +INDEX_NAME = "langchain-test-index-from-texts" +DIMENSIONS = 5 + + +@pytest.fixture(scope="module") +def collection() -> Collection: + test_client: MongoClient = MongoClient(CONNECTION_STRING) + return test_client[DB_NAME][COLLECTION_NAME] + + +@pytest.fixture(scope="module") +def texts() -> List[str]: + return [ + "Dogs are tough.", + "Cats have fluff.", + "What is a sandwich?", + "That fence is purple.", + ] + + +@pytest.fixture(scope="module") +def metadatas() -> List[Dict]: + return [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] + + +@pytest.fixture(scope="module") +def embeddings() -> Embeddings: + return ConsistentFakeEmbeddings(DIMENSIONS) + + +@pytest.fixture(scope="module") +def vectorstore( + collection: Collection, + texts: List[str], + embeddings: Embeddings, + metadatas: List[dict], +) -> Generator[MongoDBAtlasVectorSearch, None, None]: + """VectorStore created with a few documents and a trivial embedding model. + + Note: PatchedMongoDBAtlasVectorSearch is MongoDBAtlasVectorSearch in all + but one important feature. It waits until all documents are fully indexed + before returning control to the caller. + """ + vectorstore_from_texts = PatchedMongoDBAtlasVectorSearch.from_texts( + texts=texts, + embedding=embeddings, + metadatas=metadatas, + collection=collection, + index_name=INDEX_NAME, + ) + yield vectorstore_from_texts + + vectorstore_from_texts.collection.delete_many({}) + + +def test_search_with_metadatas_and_pre_filter( + vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict] +) -> None: + # Confirm the presence of metadata in output + output = vectorstore.similarity_search("Sandwich", k=1) + assert len(output) == 1 + metakeys = [list(d.keys())[0] for d in metadatas] + assert any([key in output[0].metadata for key in metakeys]) + + +def test_search_filters_all( + vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict] +) -> None: + # Test filtering out + does_not_match_filter = vectorstore.similarity_search( + "Sandwich", k=1, pre_filter={"c": {"$lte": 0}} + ) + assert does_not_match_filter == [] + + +def test_search_pre_filter( + vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict] +) -> None: + # Test filtering with expected output + matches_filter = vectorstore.similarity_search( + "Sandwich", k=3, pre_filter={"c": {"$gt": 0}} + ) + assert len(matches_filter) == 1 diff --git a/libs/mongodb/tests/integration_tests/test_vectorstores.py b/libs/mongodb/tests/integration_tests/test_vectorstores.py deleted file mode 100644 index 3033b20..0000000 --- a/libs/mongodb/tests/integration_tests/test_vectorstores.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Test MongoDB Atlas Vector Search functionality.""" - -from __future__ import annotations - -import os -from time import monotonic, sleep -from typing import Any, Dict, List - -import pytest # type: ignore[import-not-found] -from bson import ObjectId -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from pymongo import MongoClient -from pymongo.collection import Collection -from pymongo.errors import OperationFailure - -from langchain_mongodb.index import drop_vector_search_index -from langchain_mongodb.utils import oid_to_str - -from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch - -INDEX_NAME = "langchain-test-index-vectorstores" -INDEX_CREATION_NAME = "langchain-test-index-vectorstores-create-test" -NAMESPACE = "langchain_test_db.langchain_test_vectorstores" -CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI") -DB_NAME, COLLECTION_NAME = NAMESPACE.split(".") -INDEX_COLLECTION_NAME = "langchain_test_vectorstores_index" -INDEX_DB_NAME = "langchain_test_index_db" -DIMENSIONS = 1536 -TIMEOUT = 120.0 -INTERVAL = 0.5 - - -@pytest.fixture -def example_documents() -> List[Document]: - return [ - Document(page_content="Dogs are tough.", metadata={"a": 1}), - Document(page_content="Cats have fluff.", metadata={"b": 1}), - Document(page_content="What is a sandwich?", metadata={"c": 1}), - Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}), - ] - - -def _await_index_deletion(coll: Collection, index_name: str) -> None: - start = monotonic() - try: - drop_vector_search_index(coll, index_name) - except OperationFailure: - # This most likely means an ongoing drop request was made so skip - pass - - while list(coll.list_search_indexes(name=index_name)): - if monotonic() - start > TIMEOUT: - raise TimeoutError(f"Index Name: {index_name} never dropped") - sleep(INTERVAL) - - -def get_collection( - database_name: str = DB_NAME, collection_name: str = COLLECTION_NAME -) -> Collection: - test_client: MongoClient = MongoClient(CONNECTION_STRING) - return test_client[database_name][collection_name] - - -@pytest.fixture() -def collection() -> Collection: - return get_collection() - - -@pytest.fixture -def texts() -> List[str]: - return [ - "Dogs are tough.", - "Cats have fluff.", - "What is a sandwich?", - "That fence is purple.", - ] - - -@pytest.fixture() -def index_collection() -> Collection: - return get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME) - - -class TestMongoDBAtlasVectorSearch: - @classmethod - def setup_class(cls) -> None: - # insure the test collection is empty - collection = get_collection() - if collection.count_documents({}): - collection.delete_many({}) # type: ignore[index] - - @classmethod - def teardown_class(cls) -> None: - collection = get_collection() - # delete all the documents in the collection - collection.delete_many({}) # type: ignore[index] - - @pytest.fixture(autouse=True) - def setup(self) -> None: - collection = get_collection() - # delete all the documents in the collection - collection.delete_many({}) # type: ignore[index] - - # delete all indexes on index collection name - _await_index_deletion( - get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME), INDEX_CREATION_NAME - ) - - @pytest.fixture - def embeddings(self) -> Embeddings: - try: - from langchain_openai import OpenAIEmbeddings - - return OpenAIEmbeddings( - openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa - model="text-embedding-3-small", - ) - except Exception: - return ConsistentFakeEmbeddings(DIMENSIONS) - - def test_from_documents( - self, - embeddings: Embeddings, - collection: Any, - example_documents: List[Document], - ) -> None: - """Test end to end construction and search.""" - vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( - example_documents, - embedding=embeddings, - collection=collection, - index_name=INDEX_NAME, - ) - output = vectorstore.similarity_search("Sandwich", k=1) - assert len(output) == 1 - # Check for the presence of the metadata key - assert any( - [key.page_content == output[0].page_content for key in example_documents] - ) - - def test_from_documents_no_embedding_return( - self, - embeddings: Embeddings, - collection: Any, - example_documents: List[Document], - ) -> None: - """Test end to end construction and search.""" - vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( - example_documents, - embedding=embeddings, - collection=collection, - index_name=INDEX_NAME, - ) - output = vectorstore.similarity_search("Sandwich", k=1) - assert len(output) == 1 - # Check for presence of embedding in each document - assert all(["embedding" not in key.metadata for key in output]) - # Check for the presence of the metadata key - assert any( - [key.page_content == output[0].page_content for key in example_documents] - ) - - def test_from_documents_embedding_return( - self, - embeddings: Embeddings, - collection: Any, - example_documents: List[Document], - ) -> None: - """Test end to end construction and search.""" - vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( - example_documents, - embedding=embeddings, - collection=collection, - index_name=INDEX_NAME, - ) - output = vectorstore.similarity_search("Sandwich", k=1, include_embeddings=True) - assert len(output) == 1 - # Check for presence of embedding in each document - assert all([key.metadata.get("embedding") for key in output]) - # Check for the presence of the metadata key - assert any( - [key.page_content == output[0].page_content for key in example_documents] - ) - - def test_from_texts( - self, embeddings: Embeddings, collection: Collection, texts: List[str] - ) -> None: - vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( - texts, - embedding=embeddings, - collection=collection, - index_name=INDEX_NAME, - ) - output = vectorstore.similarity_search("Sandwich", k=1) - assert len(output) == 1 - - def test_from_texts_with_metadatas( - self, - embeddings: Embeddings, - collection: Collection, - texts: List[str], - ) -> None: - metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] - metakeys = ["a", "b", "c", "d", "e"] - vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( - texts, - embedding=embeddings, - metadatas=metadatas, - collection=collection, - index_name=INDEX_NAME, - ) - output = vectorstore.similarity_search("Sandwich", k=1) - assert len(output) == 1 - # Check for the presence of the metadata key - assert any([key in output[0].metadata for key in metakeys]) - - def test_from_texts_with_metadatas_and_pre_filter( - self, embeddings: Embeddings, collection: Any, texts: List[str] - ) -> None: - metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] - vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( - texts, - embedding=embeddings, - metadatas=metadatas, - collection=collection, - index_name=INDEX_NAME, - ) - does_not_match_filter = vectorstore.similarity_search( - "Sandwich", k=1, pre_filter={"c": {"$lte": 0}} - ) - assert does_not_match_filter == [] - - matches_filter = vectorstore.similarity_search( - "Sandwich", k=3, pre_filter={"c": {"$gt": 0}} - ) - assert len(matches_filter) == 1 - - def test_mmr(self, embeddings: Embeddings, collection: Any) -> None: - texts = ["foo", "foo", "fou", "foy"] - vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( - texts, - embedding=embeddings, - collection=collection, - index_name=INDEX_NAME, - ) - query = "foo" - output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1) - assert len(output) == len(texts) - assert output[0].page_content == "foo" - assert output[1].page_content != "foo" - - def test_retriever( - self, - embeddings: Embeddings, - collection: Any, - example_documents: List[Document], - ) -> None: - """Demonstrate usage and parity of VectorStore similarity_search - with Retriever.invoke.""" - vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( - example_documents, - embedding=embeddings, - collection=collection, - index_name=INDEX_NAME, - ) - query = "sandwich" - - retriever_default_kwargs = vectorstore.as_retriever() - result_retriever = retriever_default_kwargs.invoke(query) - result_vectorstore = vectorstore.similarity_search(query) - assert all( - [ - result_retriever[i].page_content == result_vectorstore[i].page_content - for i in range(len(result_retriever)) - ] - ) - - def test_include_embeddings( - self, - embeddings: Embeddings, - collection: Any, - example_documents: List[Document], - ) -> None: - """Test explicitly passing vector kwarg matches default.""" - vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( - documents=example_documents, - embedding=embeddings, - collection=collection, - index_name=INDEX_NAME, - ) - - output_with = vectorstore.similarity_search( - "Sandwich", include_embeddings=True, k=1 - ) - assert vectorstore._embedding_key in output_with[0].metadata - output_without = vectorstore.similarity_search("Sandwich", k=1) - assert vectorstore._embedding_key not in output_without[0].metadata - - def test_delete( - self, embeddings: Embeddings, collection: Any, texts: List[str] - ) -> None: - vectorstore = PatchedMongoDBAtlasVectorSearch( - collection=collection, - embedding=embeddings, - index_name=INDEX_NAME, - ) - clxn: Collection = vectorstore._collection - assert clxn.count_documents({}) == 0 - ids = vectorstore.add_texts(texts) - assert clxn.count_documents({}) == len(texts) - - deleted = vectorstore.delete(ids[-2:]) - assert deleted - assert clxn.count_documents({}) == len(texts) - 2 - - new_ids = vectorstore.add_texts(["Pigs eat stuff", "Pigs eat sandwiches"]) - assert set(new_ids).intersection(set(ids)) == set() # new ids will be unique. - assert isinstance(new_ids, list) - assert all(isinstance(i, str) for i in new_ids) - assert len(new_ids) == 2 - assert clxn.count_documents({}) == 4 - - def test_add_texts( - self, - embeddings: Embeddings, - collection: Collection, - texts: List[str], - ) -> None: - """Tests API of add_texts, focussing on id treatment - - Warning: This is slow because of the number of cases - """ - metadatas: List[Dict[str, Any]] = [ - {"a": 1}, - {"b": 1}, - {"c": 1}, - {"d": 1, "e": 2}, - ] - - vectorstore = PatchedMongoDBAtlasVectorSearch( - collection=collection, embedding=embeddings, index_name=INDEX_NAME - ) - - # Case 1. Add texts without ids - provided_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) - all_docs = list(vectorstore._collection.find({})) - assert all("_id" in doc for doc in all_docs) - docids = set(doc["_id"] for doc in all_docs) - assert all(isinstance(_id, ObjectId) for _id in docids) # - assert set(provided_ids) == set(oid_to_str(oid) for oid in docids) - - # Case 2: Test Document.metadata looks right. i.e. contains _id - search_res = vectorstore.similarity_search_with_score("sandwich", k=1) - doc, score = search_res[0] - assert "_id" in doc.metadata - - # Case 3: Add new ids that are 24-char hex strings - hex_ids = [oid_to_str(ObjectId()) for _ in range(2)] - hex_texts = ["Text for hex_id"] * len(hex_ids) - out_ids = vectorstore.add_texts(texts=hex_texts, ids=hex_ids) - assert set(out_ids) == set(hex_ids) - assert collection.count_documents({}) == len(texts) + len(hex_texts) - assert all( - isinstance(doc["_id"], ObjectId) for doc in vectorstore._collection.find({}) - ) - - # Case 4: Add new ids that cannot be cast to ObjectId - # - We can still index and search on them - str_ids = ["Sandwiches are beautiful,", "..sandwiches are fine."] - str_texts = str_ids # No reason for them to differ - out_ids = vectorstore.add_texts(texts=str_texts, ids=str_ids) - assert set(out_ids) == set(str_ids) - assert collection.count_documents({}) == 8 - res = vectorstore.similarity_search("sandwich", k=8) - assert any(str_ids[0] in doc.metadata["_id"] for doc in res) - - # Case 5: Test adding in multiple batches - batch_size = 2 - batch_ids = [oid_to_str(ObjectId()) for _ in range(2 * batch_size)] - batch_texts = [f"Text for batch text {i}" for i in range(2 * batch_size)] - out_ids = vectorstore.add_texts( - texts=batch_texts, ids=batch_ids, batch_size=batch_size - ) - assert set(out_ids) == set(batch_ids) - assert collection.count_documents({}) == 12 - - # Case 6: _ids in metadata - collection.delete_many({}) - # 6a. Unique _id in metadata, but ids=None - # Will be added as if ids kwarg provided - i = 0 - n = len(texts) - assert len(metadatas) == n - _ids = [str(i) for i in range(n)] - for md in metadatas: - md["_id"] = _ids[i] - i += 1 - returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) - assert returned_ids == ["0", "1", "2", "3"] - assert set(d["_id"] for d in vectorstore._collection.find({})) == set(_ids) - - # 6b. Unique "id", not "_id", but ids=None - # New ids will be assigned - i = 1 - for md in metadatas: - md.pop("_id") - md["id"] = f"{1}" - i += 1 - returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) - assert len(set(returned_ids).intersection(set(_ids))) == 0 - - def test_add_documents( - self, - embeddings: Embeddings, - collection: Collection, - ) -> None: - """Tests add_documents.""" - vectorstore = PatchedMongoDBAtlasVectorSearch( - collection=collection, embedding=embeddings, index_name=INDEX_NAME - ) - - # Case 1: No ids - n_docs = 10 - batch_size = 3 - docs = [ - Document(page_content=f"document {i}", metadata={"i": i}) - for i in range(n_docs) - ] - result_ids = vectorstore.add_documents(docs, batch_size=batch_size) - assert len(result_ids) == n_docs - assert collection.count_documents({}) == n_docs - - # Case 2: ids - collection.delete_many({}) - n_docs = 10 - batch_size = 3 - docs = [ - Document(page_content=f"document {i}", metadata={"i": i}) - for i in range(n_docs) - ] - ids = [str(i) for i in range(n_docs)] - result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size) - assert len(result_ids) == n_docs - assert set(ids) == set(collection.distinct("_id")) - - # Case 3: Single batch - collection.delete_many({}) - n_docs = 3 - batch_size = 10 - docs = [ - Document(page_content=f"document {i}", metadata={"i": i}) - for i in range(n_docs) - ] - ids = [str(i) for i in range(n_docs)] - result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size) - assert len(result_ids) == n_docs - assert set(ids) == set(collection.distinct("_id")) - - def test_index_creation( - self, embeddings: Embeddings, index_collection: Any - ) -> None: - vectorstore = PatchedMongoDBAtlasVectorSearch( - index_collection, embedding=embeddings, index_name=INDEX_CREATION_NAME - ) - vectorstore.create_vector_search_index(dimensions=1536) - - def test_index_update(self, embeddings: Embeddings, index_collection: Any) -> None: - vectorstore = PatchedMongoDBAtlasVectorSearch( - index_collection, embedding=embeddings, index_name=INDEX_CREATION_NAME - ) - vectorstore.create_vector_search_index(dimensions=1536) - vectorstore.create_vector_search_index(dimensions=1536, update=True) diff --git a/libs/mongodb/tests/unit_tests/test_index.py b/libs/mongodb/tests/unit_tests/test_index.py index acb6cb5..2c73791 100644 --- a/libs/mongodb/tests/unit_tests/test_index.py +++ b/libs/mongodb/tests/unit_tests/test_index.py @@ -1,6 +1,3 @@ -"""Search index commands are only supported on Atlas Clusters >=M10""" - -import os from time import sleep import pytest @@ -10,18 +7,14 @@ from langchain_mongodb import index -DIMENSION = 10 -TIMEOUT = 10 +DIMENSION = 5 +TIMEOUT = 120 @pytest.fixture def collection() -> Collection: - """Depending on uri, this could point to any type of cluster. - - For unit tests, MONGODB_URI should be localhost, None, or Atlas cluster List: """Patched insert_texts that waits for data to be indexed before returning""" ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids) - start = monotonic() - while len(ids_inserted) != len(self.similarity_search("sandwich")) and ( - monotonic() - start <= TIMEOUT - ): - sleep(INTERVAL) - return ids_inserted - - def create_vector_search_index( - self, - dimensions: int, - filters: Optional[List[str]] = None, - update: bool = False, - ) -> None: - result = super().create_vector_search_index( - dimensions=dimensions, filters=filters, update=update - ) + n_docs = self.collection.count_documents({}) start = monotonic() while monotonic() - start <= TIMEOUT: - if indexes := list( - self._collection.list_search_indexes(name=self._index_name) - ): - if indexes[0].get("status") == "READY": - return result - sleep(INTERVAL) + if len(self.similarity_search("sandwich", k=n_docs)) == n_docs: + return ids_inserted + else: + sleep(INTERVAL) + raise TimeoutError(f"Failed to embed, insert, and index texts in {TIMEOUT}s.") class ConsistentFakeEmbeddings(Embeddings):