Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTPYTHON-395 Test performance #4

Merged
merged 10 commits into from
Oct 30, 2024
96 changes: 31 additions & 65 deletions libs/mongodb/langchain_mongodb/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,11 @@
from typing import Any, Callable, Dict, List, Optional

from pymongo.collection import Collection
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel

logger = logging.getLogger(__file__)


def _search_index_error_message() -> str:
return (
"Search index operations are not currently available on shared clusters, "
"such as MO. They require dedicated clusters >= M10. "
"You may still perform vector search. "
"You simply must set up indexes manually. Follow the instructions here: "
"https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/"
)


def _vector_search_index_definition(
dimensions: int,
path: str,
Expand Down Expand Up @@ -71,22 +60,19 @@ def create_vector_search_index(
"""
logger.info("Creating Search Index %s on %s", index_name, collection.name)

try:
result = collection.create_search_index(
SearchIndexModel(
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
name=index_name,
type="vectorSearch",
)
result = collection.create_search_index(
SearchIndexModel(
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
name=index_name,
type="vectorSearch",
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
)

if wait_until_complete:
_wait_for_predicate(
Expand Down Expand Up @@ -114,12 +100,7 @@ def drop_vector_search_index(
logger.info(
"Dropping Search Index %s from Collection: %s", index_name, collection.name
)
try:
collection.drop_search_index(index_name)
except OperationFailure as e:
if "CommandNotSupported" in str(e):
raise OperationFailure(_search_index_error_message()) from e
# else this most likely means an ongoing drop request was made so skip
collection.drop_search_index(index_name)
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
Expand Down Expand Up @@ -155,24 +136,19 @@ def update_vector_search_index(
until search index is ready.
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""

logger.info(
"Updating Search Index %s from Collection: %s", index_name, collection.name
)
try:
collection.update_search_index(
name=index_name,
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e

collection.update_search_index(
name=index_name,
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
)
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
Expand All @@ -193,12 +169,7 @@ def _is_index_ready(collection: Collection, index_name: str) -> bool:
Returns:
bool : True if the index is present and READY false otherwise
"""
try:
search_indexes = collection.list_search_indexes(index_name)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e

for index in search_indexes:
for index in collection.list_search_indexes(index_name):
if index["type"] == "vectorSearch" and index["status"] == "READY":
return True
return False
Expand Down Expand Up @@ -248,19 +219,14 @@ def create_fulltext_search_index(
definition = {
"mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}}
}

try:
result = collection.create_search_index(
SearchIndexModel(
definition=definition,
name=index_name,
type="search",
**kwargs,
)
result = collection.create_search_index(
SearchIndexModel(
definition=definition,
name=index_name,
type="search",
**kwargs,
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e

)
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
Expand Down
5 changes: 3 additions & 2 deletions libs/mongodb/langchain_mongodb/retrievers/full_text_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
"""Number of documents to return. Default is no limit"""
filter: Optional[Dict[str, Any]] = None
"""(Optional) List of MQL match expression comparing an indexed field"""
show_embeddings: float = False
"""If true, returned Document metadata will include vectors"""
include_scores: bool = True
"""If True, include scores that provide measure of relative relevance"""

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
Expand All @@ -45,6 +45,7 @@ def _get_relevant_documents(
index_name=self.search_index_name,
limit=self.top_k,
filter=self.filter,
include_scores=self.include_scores,
)

# Execution
Expand Down
17 changes: 15 additions & 2 deletions libs/mongodb/langchain_mongodb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def __init__(
def embeddings(self) -> Embeddings:
return self._embedding

@property
def collection(self) -> Collection:
return self._collection

@collection.setter
def collection(self, value: Collection) -> None:
self._collection = value

def _select_relevance_score_fn(self) -> Callable[[float], float]:
scoring: dict[str, Callable] = {
"euclidean": self._euclidean_relevance_score_fn,
Expand Down Expand Up @@ -761,6 +769,7 @@ def create_vector_search_index(
dimensions: int,
filters: Optional[List[str]] = None,
update: bool = False,
wait_until_complete: Optional[float] = None,
) -> None:
"""Creates a MongoDB Atlas vectorSearch index for the VectorStore

Expand All @@ -774,8 +783,11 @@ def create_vector_search_index(
filters (Optional[List[Dict[str, str]]], optional): additional filters
for index definition.
Defaults to None.
update (bool, optional): Updates existing vectorSearch index.
Defaults to False.
update (Optional[bool]): Updates existing vectorSearch index.
Defaults to False.
wait_until_complete (Optional[float]): If given, a TimeoutError is raised
if search index is not ready after this number of seconds.
If not given, the default, operation will not wait.
"""
try:
self._collection.database.create_collection(self._collection.name)
Expand All @@ -793,4 +805,5 @@ def create_vector_search_index(
path=self._embedding_key,
similarity=self._relevance_score_fn,
filters=filters or [],
wait_until_complete=wait_until_complete,
) # type: ignore [operator]
6 changes: 1 addition & 5 deletions libs/mongodb/tests/integration_tests/test_chain_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import os
from time import sleep

import pytest # type: ignore[import-not-found]
from langchain_core.documents import Document
Expand All @@ -20,7 +19,7 @@
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_chain_example"
INDEX_NAME = "vector_index"
INDEX_NAME = "langchain-test-chain-example-vector-index"
DIMENSIONS = 1536
TIMEOUT = 60.0
INTERVAL = 0.5
Expand Down Expand Up @@ -88,9 +87,6 @@ def test_chain(
]
vectorstore.add_texts(texts)

# Give the index time to build (For CI)
sleep(TIMEOUT)

query = "In the United States, what city did I visit last?"
# One can do vector search on the vector store, using its various search types.
k = len(texts)
Expand Down
87 changes: 70 additions & 17 deletions libs/mongodb/tests/integration_tests/test_index.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Search index commands are only supported on Atlas Clusters >=M10"""

import os
from typing import Generator, List, Optional

import pytest
from pymongo import MongoClient
from pymongo.collection import Collection

from langchain_mongodb import index
from langchain_mongodb import MongoDBAtlasVectorSearch, index

from ..utils import ConsistentFakeEmbeddings

DB_NAME = "langchain_test_index_db"
COLLECTION_NAME = "test_index"
Expand All @@ -22,13 +22,13 @@ def collection() -> Generator:
"""Depending on uri, this could point to any type of cluster."""
uri = os.environ.get("MONGODB_ATLAS_URI")
client: MongoClient = MongoClient(uri)
client[DB_NAME].create_collection(COLLECTION_NAME)
clxn = client[DB_NAME][COLLECTION_NAME]
clxn.insert_one({"foo": "bar"})
yield clxn
clxn.drop()


def test_search_index_commands(collection: Collection) -> None:
def test_search_index_drop_add_delete_commands(collection: Collection) -> None:
index_name = VECTOR_INDEX_NAME
dimensions = DIMENSIONS
path = "embedding"
Expand Down Expand Up @@ -58,26 +58,79 @@ def test_search_index_commands(collection: Collection) -> None:
assert len(indexes) == 1
assert indexes[0]["name"] == index_name

new_similarity = "euclidean"
index.drop_vector_search_index(
collection, index_name, wait_until_complete=wait_until_complete
)

indexes = list(collection.list_search_indexes())
assert len(indexes) == 0


@pytest.mark.skip("collection.update_vector_search_index requires [CLOUDP-275518]")
def test_search_index_update_vector_search_index(collection: Collection) -> None:
index_name = "INDEX_TO_UPDATE"
similarity_orig = "cosine"
similarity_new = "euclidean"

# Create another index
index.create_vector_search_index(
collection=collection,
index_name=index_name,
dimensions=DIMENSIONS,
path="embedding",
similarity=similarity_orig,
wait_until_complete=TIMEOUT,
)

assert index._is_index_ready(collection, index_name)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 1
assert indexes[0]["name"] == index_name
assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_orig

# Update the index and test new similarity
index.update_vector_search_index(
collection,
index_name,
DIMENSIONS,
"embedding",
new_similarity,
filters=[],
wait_until_complete=wait_until_complete,
collection=collection,
index_name=index_name,
dimensions=DIMENSIONS,
path="embedding",
similarity=similarity_new,
wait_until_complete=TIMEOUT,
)

assert index._is_index_ready(collection, index_name)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 1
assert indexes[0]["name"] == index_name
assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == new_similarity
assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_new

index.drop_vector_search_index(
collection, index_name, wait_until_complete=wait_until_complete

def test_vectorstore_create_vector_search_index(collection: Collection) -> None:
"""Tests vectorstore wrapper around index command."""

# Set up using the index module's api
if len(list(collection.list_search_indexes())) != 0:
index.drop_vector_search_index(
collection, VECTOR_INDEX_NAME, wait_until_complete=TIMEOUT
)

# Test MongoDBAtlasVectorSearch's API
vectorstore = MongoDBAtlasVectorSearch(
collection=collection,
embedding=ConsistentFakeEmbeddings(),
index_name=VECTOR_INDEX_NAME,
)

vectorstore.create_vector_search_index(
dimensions=DIMENSIONS, wait_until_complete=TIMEOUT
)

assert index._is_index_ready(collection, VECTOR_INDEX_NAME)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 0
assert len(indexes) == 1
assert indexes[0]["name"] == VECTOR_INDEX_NAME

# Tear down using the index module's api
index.drop_vector_search_index(
collection, VECTOR_INDEX_NAME, wait_until_complete=TIMEOUT
)
Loading
Loading