Skip to content

Commit

Permalink
INTPYTHON-459 Refactor handling of client into a pytest fixture (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 authored Dec 19, 2024
1 parent 2e3ab1b commit 7ccf149
Show file tree
Hide file tree
Showing 13 changed files with 29 additions and 51 deletions.
12 changes: 12 additions & 0 deletions libs/langchain-mongodb/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from typing import List

import pytest
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from pymongo import MongoClient


@pytest.fixture(scope="session")
Expand All @@ -11,3 +13,13 @@ def technical_report_pages() -> List[Document]:
loader = PyPDFLoader("https://arxiv.org/pdf/2303.08774.pdf")
pages = loader.load()
return pages


@pytest.fixture(scope="session")
def connection_string() -> str:
return os.environ["MONGODB_URI"]


@pytest.fixture(scope="session")
def client(connection_string: str) -> MongoClient:
return MongoClient(connection_string)
3 changes: 1 addition & 2 deletions libs/langchain-mongodb/tests/integration_tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ def random_string() -> str:


@pytest.fixture(scope="module")
def collection() -> Collection:
def collection(client: MongoClient) -> Collection:
"""A Collection with both a Vector and a Full-text Search Index"""
client: MongoClient = MongoClient(CONN_STRING)
if COLLECTION not in client[DATABASE].list_collection_names():
clxn = client[DATABASE].create_collection(COLLECTION)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from ..utils import PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_chain_example"
INDEX_NAME = "langchain-test-chain-example-vector-index"
Expand All @@ -26,9 +25,8 @@


@pytest.fixture
def collection() -> Collection:
def collection(client: MongoClient) -> Collection:
"""A Collection with both a Vector and a Full-text Search Index"""
client: MongoClient = MongoClient(CONNECTION_STRING)
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import os

from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
from langchain_core.messages import message_to_dict
Expand All @@ -9,11 +8,8 @@
DATABASE = "langchain_test_db"
COLLECTION = "langchain_test_chat"

# Replace these with your mongodb connection string
connection_string = os.environ.get("MONGODB_URI", "")


def test_memory_with_message_store() -> None:
def test_memory_with_message_store(connection_string: str) -> None:
"""Test the memory with a message store."""
# setup MongoDB as a message store
message_history = MongoDBChatMessageHistory(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import os
from typing import List

from langchain_core.documents import Document
from pymongo import MongoClient

from langchain_mongodb.docstores import MongoDBDocStore

CONNECTION_STRING = os.environ.get("MONGODB_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_docstore"


def test_docstore(technical_report_pages: List[Document]) -> None:
client: MongoClient = MongoClient(CONNECTION_STRING)
def test_docstore(client: MongoClient, technical_report_pages: List[Document]) -> None:
db = client[DB_NAME]
db.drop_collection(COLLECTION_NAME)
clxn = db[COLLECTION_NAME]
Expand Down
5 changes: 1 addition & 4 deletions libs/langchain-mongodb/tests/integration_tests/test_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Generator, List, Optional

import pytest
Expand All @@ -18,10 +17,8 @@


@pytest.fixture
def collection() -> Generator:
def collection(client: MongoClient) -> Generator:
"""Depending on uri, this could point to any type of cluster."""
uri = os.environ.get("MONGODB_URI")
client: MongoClient = MongoClient(uri)
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from datetime import datetime
from unittest.mock import patch

Expand All @@ -8,25 +7,22 @@

from langchain_mongodb.indexes import MongoDBRecordManager

CONNECTION_STRING = os.environ["MONGODB_URI"]
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_docstore"
NAMESPACE = f"{DB_NAME}.{COLLECTION_NAME}"


@pytest.fixture
def manager() -> MongoDBRecordManager:
def manager(client: MongoClient) -> MongoDBRecordManager:
"""Initialize the test MongoDB and yield the DocumentManager instance."""
client: MongoClient = MongoClient(CONNECTION_STRING)
collection = client[DB_NAME][COLLECTION_NAME]
document_manager = MongoDBRecordManager(collection=collection)
return document_manager


@pytest_asyncio.fixture
async def amanager() -> MongoDBRecordManager:
async def amanager(client: MongoClient) -> MongoDBRecordManager:
"""Initialize the test MongoDB and yield the DocumentManager instance."""
client: MongoClient = MongoClient(CONNECTION_STRING)
collection = client[DB_NAME][COLLECTION_NAME]
document_manager = MongoDBRecordManager(collection=collection)
return document_manager
Expand Down
7 changes: 1 addition & 6 deletions libs/langchain-mongodb/tests/integration_tests/test_mmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

import os

import pytest # type: ignore[import-not-found]
from langchain_core.embeddings import Embeddings
from pymongo import MongoClient
Expand All @@ -15,17 +13,14 @@

from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_vectorstores"
INDEX_NAME = "langchain-test-index-vectorstores"
DIMENSIONS = 5


@pytest.fixture()
def collection() -> Collection:
client: MongoClient = MongoClient(CONNECTION_STRING)

def collection(client: MongoClient) -> Collection:
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_parent_document_combined"
VECTOR_INDEX_NAME = "langchain-test-parent-document-vector-index"
Expand All @@ -41,11 +40,13 @@ def embedding_model() -> Embeddings:


def test_1clxn_retriever(
technical_report_pages: List[Document], embedding_model: Embeddings
connection_string: str,
technical_report_pages: List[Document],
embedding_model: Embeddings,
) -> None:
# Setup
client: MongoClient = MongoClient(
CONNECTION_STRING,
connection_string,
driver=DriverInfo(name="langchain", version=version("langchain-mongodb")),
)
db = client[DB_NAME]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from ..utils import PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_retrievers"
VECTOR_INDEX_NAME = "vector_index"
Expand Down Expand Up @@ -57,9 +56,8 @@ def embedding_openai() -> Embeddings:


@pytest.fixture(scope="module")
def collection() -> Collection:
def collection(client: MongoClient) -> Collection:
"""A Collection with both a Vector and a Full-text Search Index"""
client: MongoClient = MongoClient(CONNECTION_STRING)
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import os
from typing import Any, Dict, List

import pytest # type: ignore[import-not-found]
Expand All @@ -17,17 +16,15 @@

from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")
DB_NAME = "langchain_test_db"
INDEX_NAME = "langchain-test-index-vectorstores"
COLLECTION_NAME = "langchain_test_vectorstores"
DIMENSIONS = 5


@pytest.fixture(scope="module")
def collection() -> Collection:
test_client: MongoClient = MongoClient(CONNECTION_STRING)
return test_client[DB_NAME][COLLECTION_NAME]
def collection(client: MongoClient) -> Collection:
return client[DB_NAME][COLLECTION_NAME]


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import os
from typing import List

import pytest # type: ignore[import-not-found]
Expand All @@ -17,17 +16,14 @@

from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_from_documents"
INDEX_NAME = "langchain-test-index-from-documents"
DIMENSIONS = 5


@pytest.fixture(scope="module")
def collection() -> Collection:
client: MongoClient = MongoClient(CONNECTION_STRING)

def collection(client: MongoClient) -> Collection:
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import os
from typing import Dict, Generator, List

import pytest # type: ignore[import-not-found]
Expand All @@ -17,17 +16,14 @@

from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_from_texts"
INDEX_NAME = "langchain-test-index-from-texts"
DIMENSIONS = 5


@pytest.fixture(scope="module")
def collection() -> Collection:
client: MongoClient = MongoClient(CONNECTION_STRING)

def collection(client: MongoClient) -> Collection:
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
Expand Down

0 comments on commit 7ccf149

Please sign in to comment.