Skip to content

Commit

Permalink
Use session scoped fixtures instead of static variables
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jul 22, 2024
1 parent 1012c3e commit 0d1b32a
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 210 deletions.
10 changes: 7 additions & 3 deletions libs/colbert/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ torch = "2.2.1"
cassio = "~0.1.7"
pydantic = "^2.7.1"

# Workaround for https://github.com/pytorch/pytorch/pull/127921
# Remove when we upgrade to pytorch 2.4
setuptools = { version = "^71.1.0", python = ">=3.12" }


[tool.poetry.group.test.dependencies]
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
pytest-asyncio = "^0.23.6"

[tool.poetry.group.dev.dependencies]
setuptools = "70.0.0"

[tool.pytest.ini_options]
asyncio_mode = "auto"
36 changes: 16 additions & 20 deletions libs/colbert/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import pytest
from cassandra.cluster import Session
from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore

status = {
"local_cassandra_test_store": None,
"astradb_test_store": None,
}

@pytest.fixture(scope="session")
def cassandra() -> LocalCassandraTestStore:
store = LocalCassandraTestStore()
yield store
if store.docker_container:
store.docker_container.stop()

def get_local_cassandra_test_store():
if not status["local_cassandra_test_store"]:
status["local_cassandra_test_store"] = LocalCassandraTestStore()
return status["local_cassandra_test_store"]

@pytest.fixture(scope="session")
def astra_db() -> AstraDBTestStore:
return AstraDBTestStore()

def get_astradb_test_store():
if not status["astradb_test_store"]:
status["astradb_test_store"] = AstraDBTestStore()
return status["astradb_test_store"]


@pytest.hookimpl()
def pytest_sessionfinish():
if (
status["local_cassandra_test_store"]
and status["local_cassandra_test_store"].docker_container
):
status["local_cassandra_test_store"].docker_container.stop()
@pytest.fixture()
def session(request) -> Session:
test_store = request.getfixturevalue(request.param)
session = test_store.create_cassandra_session()
session.default_timeout = 180
return session
35 changes: 5 additions & 30 deletions libs/colbert/tests/integration_tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
import pytest
from cassandra.cluster import Session
from ragstack_colbert import CassandraDatabase, Chunk
from ragstack_tests_utils import TestData

from tests.integration_tests.conftest import (
get_astradb_test_store,
get_local_cassandra_test_store,
)


@pytest.fixture()
def cassandra():
return get_local_cassandra_test_store()


@pytest.fixture()
def astra_db():
return get_astradb_test_store()


@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
def test_database_sync(request, vector_store: str):
vector_store = request.getfixturevalue(vector_store)

@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
def test_database_sync(session: Session):
doc_id = "earth_doc_id"

chunk_0 = Chunk(
Expand All @@ -40,9 +24,6 @@ def test_database_sync(request, vector_store: str):
embedding=TestData.renewable_energy_embedding(),
)

session = vector_store.create_cassandra_session()
session.default_timeout = 180

database = CassandraDatabase.from_session(
keyspace="default_keyspace",
table_name="test_database_sync",
Expand All @@ -61,11 +42,8 @@ def test_database_sync(request, vector_store: str):
assert result


@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
@pytest.mark.asyncio()
async def test_database_async(request, vector_store: str):
vector_store = request.getfixturevalue(vector_store)

@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
async def test_database_async(session: Session):
doc_id = "earth_doc_id"

chunk_0 = Chunk(
Expand All @@ -84,9 +62,6 @@ async def test_database_async(request, vector_store: str):
embedding=TestData.renewable_energy_embedding(),
)

session = vector_store.create_cassandra_session()
session.default_timeout = 180

database = CassandraDatabase.from_session(
keyspace="default_keyspace",
table_name="test_database_async",
Expand Down
24 changes: 3 additions & 21 deletions libs/colbert/tests/integration_tests/test_embedding_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,17 @@
import logging

import pytest
from cassandra.cluster import Session
from ragstack_colbert import (
CassandraDatabase,
ColbertEmbeddingModel,
ColbertVectorStore,
)
from ragstack_tests_utils import TestData

from tests.integration_tests.conftest import (
get_astradb_test_store,
get_local_cassandra_test_store,
)


@pytest.fixture()
def cassandra():
return get_local_cassandra_test_store()


@pytest.fixture()
def astra_db():
return get_astradb_test_store()


@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
def test_embedding_cassandra_retriever(request, vector_store: str):
vector_store = request.getfixturevalue(vector_store)
@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
def test_embedding_cassandra_retriever(session: Session):
narrative = TestData.marine_animals_text()

# Define the desired chunk size and overlap size
Expand All @@ -53,9 +38,6 @@ def chunk_texts(text, chunk_size, overlap_size):

doc_id = "marine_animals"

session = vector_store.create_cassandra_session()
session.default_timeout = 180

database = CassandraDatabase.from_session(
keyspace="default_keyspace",
table_name="test_embedding_cassandra_retriever",
Expand Down
75 changes: 30 additions & 45 deletions libs/knowledge-store/tests/integration_tests/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import secrets
from typing import Callable, Iterable, Iterator, List
from typing import Iterable, Iterator, List

import numpy as np
import pytest
Expand Down Expand Up @@ -89,26 +89,25 @@ def cassandra() -> Iterator[LocalCassandraTestStore]:


@pytest.fixture()
def graph_store_factory(
def graph_store(
cassandra: LocalCassandraTestStore,
) -> Iterator[Callable[[], GraphStore]]:
) -> Iterator[GraphStore]:
session = cassandra.create_cassandra_session()
session.set_keyspace(KEYSPACE)

embedding = SimpleEmbeddingModel()

def _make_graph_store() -> GraphStore:
name = secrets.token_hex(8)
name = secrets.token_hex(8)

node_table = f"nodes_{name}"
return GraphStore(
embedding,
session=session,
keyspace=KEYSPACE,
node_table=node_table,
)
node_table = f"nodes_{name}"
store = GraphStore(
embedding,
session=session,
keyspace=KEYSPACE,
node_table=node_table,
)

yield _make_graph_store
yield store

session.shutdown()

Expand All @@ -117,15 +116,7 @@ def _result_ids(nodes: Iterable[Node]) -> List[str]:
return [n.id for n in nodes if n.id is not None]


def test_graph_store_creation(graph_store_factory: Callable[[], GraphStore]) -> None:
"""Test that a graph store can be created.
This verifies the schema can be applied and the queries prepared.
"""
graph_store_factory()


def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
def test_mmr_traversal(graph_store: GraphStore) -> None:
"""
Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
Expand All @@ -145,8 +136,6 @@ def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
Both v2 and v3 are reachable via edges from v0, so once it is
selected, those are both considered.
"""
gs = graph_store_factory()

v0 = Node(
id="v0",
text="-0.124",
Expand All @@ -166,32 +155,30 @@ def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
text="+1.0",
links={Link(direction="in", kind="explicit", tag="link")},
)
gs.add_nodes([v0, v1, v2, v3])
graph_store.add_nodes([v0, v1, v2, v3])

results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2)
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
assert _result_ids(results) == ["v0", "v2"]

# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
# So it ends up picking "v1" even though it's similar to "v0".
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
assert _result_ids(results) == ["v0", "v1"]

# With max depth 0 but higher `fetch_k`, we encounter v2
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
assert _result_ids(results) == ["v0", "v2"]

# v0 score is .46, v2 score is 0.16 so it won't be chosen.
results = gs.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
results = graph_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
assert _result_ids(results) == ["v0"]

# with k=4 we should get all of the documents.
results = gs.mmr_traversal_search("0.0", k=4)
results = graph_store.mmr_traversal_search("0.0", k=4)
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]


def test_write_retrieve_keywords(graph_store_factory: Callable[[], GraphStore]) -> None:
gs = graph_store_factory()

def test_write_retrieve_keywords(graph_store: GraphStore) -> None:
greetings = Node(
id="greetings",
text="Typical Greetings",
Expand All @@ -218,36 +205,34 @@ def test_write_retrieve_keywords(graph_store_factory: Callable[[], GraphStore])
},
)

gs.add_nodes([greetings, doc1, doc2])
graph_store.add_nodes([greetings, doc1, doc2])

# Doc2 is more similar, but World and Earth are similar enough that doc1 also shows
# up.
results = gs.similarity_search(text_to_embedding("Earth"), k=2)
results = graph_store.similarity_search(text_to_embedding("Earth"), k=2)
assert _result_ids(results) == ["doc2", "doc1"]

results = gs.similarity_search(text_to_embedding("Earth"), k=1)
results = graph_store.similarity_search(text_to_embedding("Earth"), k=1)
assert _result_ids(results) == ["doc2"]

results = gs.traversal_search("Earth", k=2, depth=0)
results = graph_store.traversal_search("Earth", k=2, depth=0)
assert _result_ids(results) == ["doc2", "doc1"]

results = gs.traversal_search("Earth", k=2, depth=1)
results = graph_store.traversal_search("Earth", k=2, depth=1)
assert _result_ids(results) == ["doc2", "doc1", "greetings"]

# K=1 only pulls in doc2 (Hello Earth)
results = gs.traversal_search("Earth", k=1, depth=0)
results = graph_store.traversal_search("Earth", k=1, depth=0)
assert _result_ids(results) == ["doc2"]

# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via keyword
# edge.
results = gs.traversal_search("Earth", k=1, depth=1)
results = graph_store.traversal_search("Earth", k=1, depth=1)
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}


def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
gs = graph_store_factory()

gs.add_nodes(
def test_metadata(graph_store: GraphStore) -> None:
graph_store.add_nodes(
[
Node(
id="a",
Expand All @@ -260,7 +245,7 @@ def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
)
]
)
results = list(gs.similarity_search(text_to_embedding("A")))
results = list(graph_store.similarity_search(text_to_embedding("A")))
assert len(results) == 1
assert results[0].id == "a"
assert results[0].metadata["other"] == "some other field"
Expand Down
5 changes: 2 additions & 3 deletions libs/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,5 @@ pytest-asyncio = "^0.23.6"
keybert = "^0.8.5"
gliner = "^0.2.5"

[tool.poetry.group.dev.dependencies]
setuptools = "^70.0.0"

[tool.pytest.ini_options]
asyncio_mode = "auto"
Loading

0 comments on commit 0d1b32a

Please sign in to comment.