Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use session scoped fixtures instead of static variables #599

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions libs/colbert/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ torch = "2.2.1"
cassio = "~0.1.7"
pydantic = "^2.7.1"

# Workaround for https://github.com/pytorch/pytorch/pull/127921
# Remove when we upgrade to pytorch 2.4
setuptools = { version = ">=70", python = ">=3.12" }


[tool.poetry.group.test.dependencies]
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
pytest-asyncio = "^0.23.6"

[tool.poetry.group.dev.dependencies]
setuptools = "70.0.0"

[tool.pytest.ini_options]
asyncio_mode = "auto"
36 changes: 16 additions & 20 deletions libs/colbert/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import pytest
from cassandra.cluster import Session
from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore

status = {
"local_cassandra_test_store": None,
"astradb_test_store": None,
}

@pytest.fixture(scope="session")
def cassandra() -> LocalCassandraTestStore:
store = LocalCassandraTestStore()
yield store
if store.docker_container:
store.docker_container.stop()

def get_local_cassandra_test_store():
if not status["local_cassandra_test_store"]:
status["local_cassandra_test_store"] = LocalCassandraTestStore()
return status["local_cassandra_test_store"]

@pytest.fixture(scope="session")
def astra_db() -> AstraDBTestStore:
return AstraDBTestStore()

def get_astradb_test_store():
if not status["astradb_test_store"]:
status["astradb_test_store"] = AstraDBTestStore()
return status["astradb_test_store"]


@pytest.hookimpl()
def pytest_sessionfinish():
if (
status["local_cassandra_test_store"]
and status["local_cassandra_test_store"].docker_container
):
status["local_cassandra_test_store"].docker_container.stop()
@pytest.fixture()
def session(request) -> Session:
test_store = request.getfixturevalue(request.param)
session = test_store.create_cassandra_session()
session.default_timeout = 180
return session
35 changes: 5 additions & 30 deletions libs/colbert/tests/integration_tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
import pytest
from cassandra.cluster import Session
from ragstack_colbert import CassandraDatabase, Chunk
from ragstack_tests_utils import TestData

from tests.integration_tests.conftest import (
get_astradb_test_store,
get_local_cassandra_test_store,
)


@pytest.fixture()
def cassandra():
return get_local_cassandra_test_store()


@pytest.fixture()
def astra_db():
return get_astradb_test_store()


@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
def test_database_sync(request, vector_store: str):
vector_store = request.getfixturevalue(vector_store)

@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
def test_database_sync(session: Session):
doc_id = "earth_doc_id"

chunk_0 = Chunk(
Expand All @@ -40,9 +24,6 @@ def test_database_sync(request, vector_store: str):
embedding=TestData.renewable_energy_embedding(),
)

session = vector_store.create_cassandra_session()
session.default_timeout = 180

database = CassandraDatabase.from_session(
keyspace="default_keyspace",
table_name="test_database_sync",
Expand All @@ -61,11 +42,8 @@ def test_database_sync(request, vector_store: str):
assert result


@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
@pytest.mark.asyncio()
async def test_database_async(request, vector_store: str):
vector_store = request.getfixturevalue(vector_store)

@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
async def test_database_async(session: Session):
doc_id = "earth_doc_id"

chunk_0 = Chunk(
Expand All @@ -84,9 +62,6 @@ async def test_database_async(request, vector_store: str):
embedding=TestData.renewable_energy_embedding(),
)

session = vector_store.create_cassandra_session()
session.default_timeout = 180

database = CassandraDatabase.from_session(
keyspace="default_keyspace",
table_name="test_database_async",
Expand Down
24 changes: 3 additions & 21 deletions libs/colbert/tests/integration_tests/test_embedding_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,17 @@
import logging

import pytest
from cassandra.cluster import Session
from ragstack_colbert import (
CassandraDatabase,
ColbertEmbeddingModel,
ColbertVectorStore,
)
from ragstack_tests_utils import TestData

from tests.integration_tests.conftest import (
get_astradb_test_store,
get_local_cassandra_test_store,
)


@pytest.fixture()
def cassandra():
return get_local_cassandra_test_store()


@pytest.fixture()
def astra_db():
return get_astradb_test_store()


@pytest.mark.parametrize("vector_store", ["cassandra", "astra_db"])
def test_embedding_cassandra_retriever(request, vector_store: str):
vector_store = request.getfixturevalue(vector_store)
@pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"])
def test_embedding_cassandra_retriever(session: Session):
narrative = TestData.marine_animals_text()

# Define the desired chunk size and overlap size
Expand All @@ -53,9 +38,6 @@ def chunk_texts(text, chunk_size, overlap_size):

doc_id = "marine_animals"

session = vector_store.create_cassandra_session()
session.default_timeout = 180

database = CassandraDatabase.from_session(
keyspace="default_keyspace",
table_name="test_embedding_cassandra_retriever",
Expand Down
1 change: 0 additions & 1 deletion libs/knowledge-store/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ mypy = "^1.10.0"
pytest-asyncio = "^0.23.6"
ipykernel = "^6.29.4"
testcontainers = "~3.7.1"
setuptools = "^70.0.0"
python-dotenv = "^1.0.1"

# Resolve numpy version for 3.8 to 3.12+
Expand Down
75 changes: 30 additions & 45 deletions libs/knowledge-store/tests/integration_tests/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import secrets
from typing import Callable, Iterable, Iterator, List
from typing import Iterable, Iterator, List

import numpy as np
import pytest
Expand Down Expand Up @@ -89,26 +89,25 @@ def cassandra() -> Iterator[LocalCassandraTestStore]:


@pytest.fixture()
def graph_store_factory(
def graph_store(
cassandra: LocalCassandraTestStore,
) -> Iterator[Callable[[], GraphStore]]:
) -> Iterator[GraphStore]:
session = cassandra.create_cassandra_session()
session.set_keyspace(KEYSPACE)

embedding = SimpleEmbeddingModel()

def _make_graph_store() -> GraphStore:
name = secrets.token_hex(8)
name = secrets.token_hex(8)

node_table = f"nodes_{name}"
return GraphStore(
embedding,
session=session,
keyspace=KEYSPACE,
node_table=node_table,
)
node_table = f"nodes_{name}"
store = GraphStore(
embedding,
session=session,
keyspace=KEYSPACE,
node_table=node_table,
)

yield _make_graph_store
yield store

session.shutdown()

Expand All @@ -117,15 +116,7 @@ def _result_ids(nodes: Iterable[Node]) -> List[str]:
return [n.id for n in nodes if n.id is not None]


def test_graph_store_creation(graph_store_factory: Callable[[], GraphStore]) -> None:
"""Test that a graph store can be created.

This verifies the schema can be applied and the queries prepared.
"""
graph_store_factory()


def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
def test_mmr_traversal(graph_store: GraphStore) -> None:
"""
Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
Expand All @@ -145,8 +136,6 @@ def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
Both v2 and v3 are reachable via edges from v0, so once it is
selected, those are both considered.
"""
gs = graph_store_factory()

v0 = Node(
id="v0",
text="-0.124",
Expand All @@ -166,32 +155,30 @@ def test_mmr_traversal(graph_store_factory: Callable[[], GraphStore]) -> None:
text="+1.0",
links={Link(direction="in", kind="explicit", tag="link")},
)
gs.add_nodes([v0, v1, v2, v3])
graph_store.add_nodes([v0, v1, v2, v3])

results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2)
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
assert _result_ids(results) == ["v0", "v2"]

# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
# So it ends up picking "v1" even though it's similar to "v0".
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
assert _result_ids(results) == ["v0", "v1"]

# With max depth 0 but higher `fetch_k`, we encounter v2
results = gs.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
assert _result_ids(results) == ["v0", "v2"]

# v0 score is .46, v2 score is 0.16 so it won't be chosen.
results = gs.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
results = graph_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
assert _result_ids(results) == ["v0"]

# with k=4 we should get all of the documents.
results = gs.mmr_traversal_search("0.0", k=4)
results = graph_store.mmr_traversal_search("0.0", k=4)
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]


def test_write_retrieve_keywords(graph_store_factory: Callable[[], GraphStore]) -> None:
gs = graph_store_factory()

def test_write_retrieve_keywords(graph_store: GraphStore) -> None:
greetings = Node(
id="greetings",
text="Typical Greetings",
Expand All @@ -218,36 +205,34 @@ def test_write_retrieve_keywords(graph_store_factory: Callable[[], GraphStore])
},
)

gs.add_nodes([greetings, doc1, doc2])
graph_store.add_nodes([greetings, doc1, doc2])

# Doc2 is more similar, but World and Earth are similar enough that doc1 also shows
# up.
results = gs.similarity_search(text_to_embedding("Earth"), k=2)
results = graph_store.similarity_search(text_to_embedding("Earth"), k=2)
assert _result_ids(results) == ["doc2", "doc1"]

results = gs.similarity_search(text_to_embedding("Earth"), k=1)
results = graph_store.similarity_search(text_to_embedding("Earth"), k=1)
assert _result_ids(results) == ["doc2"]

results = gs.traversal_search("Earth", k=2, depth=0)
results = graph_store.traversal_search("Earth", k=2, depth=0)
assert _result_ids(results) == ["doc2", "doc1"]

results = gs.traversal_search("Earth", k=2, depth=1)
results = graph_store.traversal_search("Earth", k=2, depth=1)
assert _result_ids(results) == ["doc2", "doc1", "greetings"]

# K=1 only pulls in doc2 (Hello Earth)
results = gs.traversal_search("Earth", k=1, depth=0)
results = graph_store.traversal_search("Earth", k=1, depth=0)
assert _result_ids(results) == ["doc2"]

# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via keyword
# edge.
results = gs.traversal_search("Earth", k=1, depth=1)
results = graph_store.traversal_search("Earth", k=1, depth=1)
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}


def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
gs = graph_store_factory()

gs.add_nodes(
def test_metadata(graph_store: GraphStore) -> None:
graph_store.add_nodes(
[
Node(
id="a",
Expand All @@ -260,7 +245,7 @@ def test_metadata(graph_store_factory: Callable[[], GraphStore]) -> None:
)
]
)
results = list(gs.similarity_search(text_to_embedding("A")))
results = list(graph_store.similarity_search(text_to_embedding("A")))
assert len(results) == 1
assert results[0].id == "a"
assert results[0].metadata["other"] == "some other field"
Expand Down
5 changes: 2 additions & 3 deletions libs/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,5 @@ pytest-asyncio = "^0.23.6"
keybert = "^0.8.5"
gliner = "^0.2.5"

[tool.poetry.group.dev.dependencies]
setuptools = "^70.0.0"

[tool.pytest.ini_options]
asyncio_mode = "auto"
Loading
Loading