diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 5632bfd7a..4cd3ca6ab 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -13,10 +13,11 @@ Sequence, Set, Tuple, + Union, cast, ) -from cassandra.cluster import ConsistencyLevel, Session +from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session from cassio.config import check_resolve_keyspace, check_resolve_session from ._mmr_helper import MmrHelper @@ -29,6 +30,12 @@ CONTENT_ID = "content_id" +CONTENT_COLUMNS = "content_id, kind, text_content, links_blob, metadata_blob" + +SELECT_CQL_TEMPLATE = ( + "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause};" +) + @dataclass class Node: @@ -52,6 +59,26 @@ class SetupMode(Enum): OFF = 3 +class MetadataIndexingMode(Enum): + """Mode used to index metadata.""" + + DEFAULT_TO_UNSEARCHABLE = 1 + DEFAULT_TO_SEARCHABLE = 2 + + +MetadataIndexingType = Union[Tuple[str, Iterable[str]], str] +MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]] + + +def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool: + p_mode, p_fields = policy + if p_mode == MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE: + return field_name in p_fields + if p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE: + return field_name not in p_fields + raise ValueError(f"Unexpected metadata indexing mode {p_mode}") + + def _serialize_metadata(md: Dict[str, Any]) -> str: if isinstance(md.get("links"), Set): md = md.copy() @@ -132,6 +159,7 @@ def __init__( session: Optional[Session] = None, keyspace: Optional[str] = None, setup_mode: SetupMode = SetupMode.SYNC, + metadata_indexing: MetadataIndexingType = "all", ): if targets_table: logger.warning( @@ -152,6 +180,11 @@ def __init__( self._node_table = node_table self._session = session self._keyspace = keyspace + self._prepared_query_cache: Dict[str, PreparedStatement] = {} + + self._metadata_indexing_policy = self._normalize_metadata_indexing_policy( + metadata_indexing=metadata_indexing, + ) if setup_mode == SetupMode.SYNC: self._apply_schema() @@ -166,41 +199,19 @@ def __init__( f""" INSERT INTO {keyspace}.{node_table} ( content_id, kind, text_content, text_embedding, link_to_tags, - link_from_tags, metadata_blob, links_blob - ) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?) + link_from_tags, links_blob, metadata_blob, metadata_s + ) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?, ?) """ # noqa: S608 ) self._query_by_id = session.prepare( f""" - SELECT content_id, kind, text_content, metadata_blob, links_blob + SELECT {CONTENT_COLUMNS} FROM {keyspace}.{node_table} WHERE content_id = ? """ # noqa: S608 ) - self._query_by_embedding = session.prepare( - f""" - SELECT content_id, kind, text_content, metadata_blob, links_blob - FROM {keyspace}.{node_table} - ORDER BY text_embedding ANN OF ? - LIMIT ? - """ # noqa: S608 - ) - self._query_by_embedding.consistency_level = ConsistencyLevel.ONE - - self._query_ids_and_link_to_tags_by_embedding = session.prepare( - f""" - SELECT content_id, link_to_tags - FROM {keyspace}.{node_table} - ORDER BY text_embedding ANN OF ? - LIMIT ? - """ # noqa: S608 - ) - self._query_ids_and_link_to_tags_by_embedding.consistency_level = ( - ConsistencyLevel.ONE - ) - self._query_ids_and_link_to_tags_by_id = session.prepare( f""" SELECT content_id, link_to_tags @@ -209,18 +220,6 @@ def __init__( """ # noqa: S608 ) - self._query_ids_and_embedding_by_embedding = session.prepare( - f""" - SELECT content_id, text_embedding, link_to_tags - FROM {keyspace}.{node_table} - ORDER BY text_embedding ANN OF ? - LIMIT ? - """ # noqa: S608 - ) - self._query_ids_and_embedding_by_embedding.consistency_level = ( - ConsistencyLevel.ONE - ) - self._query_source_tags_by_id = session.prepare( f""" SELECT link_to_tags @@ -229,33 +228,15 @@ def __init__( """ # noqa: S608 ) - self._query_targets_embeddings_by_kind_and_tag_and_embedding = session.prepare( - f""" - SELECT - content_id AS target_content_id, - text_embedding AS target_text_embedding, - link_to_tags AS target_link_to_tags - FROM {keyspace}.{node_table} - WHERE link_from_tags CONTAINS (?, ?) - ORDER BY text_embedding ANN of ? - LIMIT ? - """ - ) - - self._query_targets_by_kind_and_value = session.prepare( - f""" - SELECT - content_id AS target_content_id - FROM {keyspace}.{node_table} - WHERE link_from_tags CONTAINS (?, ?) - """ - ) + def table_name(self) -> str: + """Returns the fully qualified table name.""" + return f"{self._keyspace}.{self._node_table}" def _apply_schema(self) -> None: """Apply the schema to the database.""" embedding_dim = len(self._embedding.embed_query("Test Query")) self._session.execute(f""" - CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._node_table} ( + CREATE TABLE IF NOT EXISTS {self.table_name()} ( content_id TEXT, kind TEXT, text_content TEXT, @@ -263,8 +244,9 @@ def _apply_schema(self) -> None: link_to_tags SET>, link_from_tags SET>, - metadata_blob TEXT, links_blob TEXT, + metadata_blob TEXT, + metadata_s MAP, PRIMARY KEY (content_id) ) @@ -273,13 +255,19 @@ def _apply_schema(self) -> None: # Index on text_embedding (for similarity search) self._session.execute(f""" CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_text_embedding_index - ON {self._keyspace}.{self._node_table}(text_embedding) + ON {self.table_name()}(text_embedding) USING 'StorageAttachedIndex'; """) self._session.execute(f""" CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_link_from_tags - ON {self._keyspace}.{self._node_table}(link_from_tags) + ON {self.table_name()}(link_from_tags) + USING 'StorageAttachedIndex'; + """) + + self._session.execute(f""" + CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_s_index + ON {self.table_name()}(ENTRIES(metadata_s)) USING 'StorageAttachedIndex'; """) @@ -321,6 +309,12 @@ def add_nodes( if tag.direction in {"out", "bidir"}: link_to_tags.add((tag.kind, tag.tag)) + metadata_s = { + k: self._coerce_string(v) + for k, v in metadata.items() + if _is_metadata_field_indexed(k, self._metadata_indexing_policy) + } + metadata_blob = _serialize_metadata(metadata) links_blob = _serialize_links(links) cq.execute( @@ -331,8 +325,9 @@ def add_nodes( text_embedding, link_to_tags, link_from_tags, - metadata_blob, links_blob, + metadata_blob, + metadata_s, ), ) @@ -345,7 +340,7 @@ def _nodes_with_ids( results: Dict[str, Optional[Node]] = {} with self._concurrent_queries() as cq: - def add_nodes(rows: Iterable[Any]) -> None: + def node_callback(rows: Iterable[Any]) -> None: # Should always be exactly one row here. We don't need to check # 1. The query is for a `ID == ?` query on the primary key. # 2. If it doesn't exist, the `get_result` method below will @@ -358,7 +353,7 @@ def add_nodes(rows: Iterable[Any]) -> None: # Mark this node ID as being fetched. results[node_id] = None cq.execute( - self._query_by_id, parameters=(node_id,), callback=add_nodes + self._query_by_id, parameters=(node_id,), callback=node_callback ) def get_result(node_id: str) -> Node: @@ -378,6 +373,7 @@ def mmr_traversal_search( adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), + metadata_filter: Dict[str, Any] = {}, # noqa: B006 ) -> Iterable[Node]: """Retrieve documents from this graph store using MMR-traversal. @@ -403,6 +399,7 @@ def mmr_traversal_search( diversity and 1 to minimum diversity. Defaults to 0.5. score_threshold: Only documents with a score greater than or equal this threshold will be chosen. Defaults to -infinity. + metadata_filter: Optional metadata to filter the results. """ query_embedding = self._embedding.embed_query(query) helper = MmrHelper( @@ -417,10 +414,22 @@ def mmr_traversal_search( # Fetch the initial candidates and add them to the helper and # outgoing_tags. + initial_candidates_query = self._get_search_cql( + has_limit=True, + columns="content_id, text_embedding, link_to_tags", + metadata_keys=list(metadata_filter.keys()), + has_embedding=True, + ) + def fetch_initial_candidates() -> None: + params = self._get_search_params( + limit=fetch_k, + metadata=metadata_filter, + embedding=query_embedding, + ) + fetched = self._session.execute( - self._query_ids_and_embedding_by_embedding, - (query_embedding, fetch_k), + query=initial_candidates_query, parameters=params ) candidates = {} for row in fetched: @@ -460,6 +469,7 @@ def fetch_initial_candidates() -> None: link_to_tags, query_embedding=query_embedding, k_per_tag=adjacent_k, + metadata_filter=metadata_filter, ) # Record the link_to_tags as visited. @@ -493,7 +503,12 @@ def fetch_initial_candidates() -> None: return self._nodes_with_ids(helper.selected_ids) def traversal_search( - self, query: str, *, k: int = 4, depth: int = 1 + self, + query: str, + *, + k: int = 4, + depth: int = 1, + metadata_filter: Dict[str, Any] = {}, # noqa: B006 ) -> Iterable[Node]: """Retrieve documents from this knowledge store. @@ -506,6 +521,7 @@ def traversal_search( k: The number of Documents to return from the initial vector search. Defaults to 4. depth: The maximum depth of edges to traverse. Defaults to 1. + metadata_filter: Optional metadata to filter the results. Returns: Collection of retrieved documents. @@ -521,6 +537,19 @@ def traversal_search( # # ... + traversal_query = self._get_search_cql( + columns="content_id, link_to_tags", + has_limit=True, + metadata_keys=list(metadata_filter.keys()), + has_embedding=True, + ) + + visit_nodes_query = self._get_search_cql( + columns="content_id AS target_content_id", + has_link_from_tags=True, + metadata_keys=list(metadata_filter.keys()), + ) + with self._concurrent_queries() as cq: # Map from visited ID to depth visited_ids: Dict[str, int] = {} @@ -563,12 +592,12 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None: # If there are new tags to visit at the next depth, query for the # node IDs. for kind, value in outgoing_tags: + params = self._get_search_params( + link_from_tags=(kind, value), metadata=metadata_filter + ) cq.execute( - self._query_targets_by_kind_and_value, - parameters=( - kind, - value, - ), + query=visit_nodes_query, + parameters=params, callback=lambda rows, d=d: visit_targets(d, rows), ) @@ -591,9 +620,15 @@ def visit_targets(d: int, targets: Sequence[Any]) -> None: ) query_embedding = self._embedding.embed_query(query) + params = self._get_search_params( + limit=k, + metadata=metadata_filter, + embedding=query_embedding, + ) + cq.execute( - self._query_ids_and_link_to_tags_by_embedding, - parameters=(query_embedding, k), + traversal_query, + parameters=params, callback=lambda nodes: visit_nodes(0, nodes), ) @@ -603,11 +638,31 @@ def similarity_search( self, embedding: List[float], k: int = 4, + metadata_filter: Dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501 + query, params = self._get_search_cql_and_params( + embedding=embedding, limit=k, metadata=metadata_filter + ) + + for row in self._session.execute(query, params): + yield _row_to_node(row) + + def metadata_search( + self, + metadata: Dict[str, Any] = {}, # noqa: B006 + n: int = 5, ) -> Iterable[Node]: - """Retrieve nodes similar to the given embedding.""" - for row in self._session.execute(self._query_by_embedding, (embedding, k)): + """Retrieve nodes based on their metadata.""" + query, params = self._get_search_cql_and_params(metadata=metadata, limit=n) + + for row in self._session.execute(query, params): yield _row_to_node(row) + def get_node(self, content_id: str) -> Node: + """Get a node by its id.""" + return self._nodes_with_ids(ids=[content_id])[0] + def _get_outgoing_tags( self, source_ids: Iterable[str], @@ -636,6 +691,7 @@ def _get_adjacent( tags: Set[Tuple[str, str]], query_embedding: List[float], k_per_tag: Optional[int] = None, + metadata_filter: Dict[str, Any] = {}, # noqa: B006 ) -> Iterable[_Edge]: """Return the target nodes with incoming links from any of the given tags. @@ -643,12 +699,27 @@ def _get_adjacent( tags: The tags to look for links *from*. query_embedding: The query embedding. Used to rank target nodes. k_per_tag: The number of target nodes to fetch for each outgoing tag. + metadata_filter: Optional metadata to filter the results. Returns: List of adjacent edges. """ targets: Dict[str, _Edge] = {} + columns = """ + content_id AS target_content_id, + text_embedding AS target_text_embedding, + link_to_tags AS target_link_to_tags + """ + + adjacent_query = self._get_search_cql( + has_limit=True, + columns=columns, + metadata_keys=list(metadata_filter.keys()), + has_embedding=True, + has_link_from_tags=True, + ) + def add_targets(rows: Iterable[Any]) -> None: # TODO: Figure out how to use the "kind" on the edge. # This is tricky, since we currently issue one query for anything @@ -664,17 +735,193 @@ def add_targets(rows: Iterable[Any]) -> None: with self._concurrent_queries() as cq: for kind, value in tags: + params = self._get_search_params( + limit=k_per_tag or 10, + metadata=metadata_filter, + embedding=query_embedding, + link_from_tags=(kind, value), + ) + cq.execute( - self._query_targets_embeddings_by_kind_and_tag_and_embedding, - parameters=( - kind, - value, - query_embedding, - k_per_tag or 10, - ), + query=adjacent_query, + parameters=params, callback=add_targets, ) # TODO: Consider a combined limit based on the similarity and/or # predicated MMR score? return targets.values() + + @staticmethod + def _normalize_metadata_indexing_policy( + metadata_indexing: Union[Tuple[str, Iterable[str]], str], + ) -> MetadataIndexingPolicy: + mode: MetadataIndexingMode + fields: Set[str] + # metadata indexing policy normalization: + if isinstance(metadata_indexing, str): + if metadata_indexing.lower() == "all": + mode, fields = (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set()) + elif metadata_indexing.lower() == "none": + mode, fields = (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set()) + else: + raise ValueError( + f"Unsupported metadata_indexing value '{metadata_indexing}'" + ) + else: + if len(metadata_indexing) != 2: # noqa: PLR2004 + raise ValueError( + f"Unsupported metadata_indexing value '{metadata_indexing}'." + ) + # it's a 2-tuple (mode, fields) still to normalize + _mode, _field_spec = metadata_indexing + fields = {_field_spec} if isinstance(_field_spec, str) else set(_field_spec) + if _mode.lower() in { + "default_to_unsearchable", + "allowlist", + "allow", + "allow_list", + }: + mode = MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE + elif _mode.lower() in { + "default_to_searchable", + "denylist", + "deny", + "deny_list", + }: + mode = MetadataIndexingMode.DEFAULT_TO_SEARCHABLE + else: + raise ValueError( + f"Unsupported metadata indexing mode specification '{_mode}'" + ) + return (mode, fields) + + @staticmethod + def _coerce_string(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, bool): + # bool MUST come before int in this chain of ifs! + return json.dumps(value) + if isinstance(value, int): + # we don't want to store '1' and '1.0' differently + # for the sake of metadata-filtered retrieval: + return json.dumps(float(value)) + if isinstance(value, float) or value is None: + return json.dumps(value) + # when all else fails ... + return str(value) + + def _extract_where_clause_cql( + self, + metadata_keys: List[str] = [], # noqa: B006 + has_link_from_tags: bool = False, + ) -> str: + wc_blocks: List[str] = [] + + if has_link_from_tags: + wc_blocks.append("link_from_tags CONTAINS (?, ?)") + + for key in sorted(metadata_keys): + if _is_metadata_field_indexed(key, self._metadata_indexing_policy): + wc_blocks.append(f"metadata_s['{key}'] = ?") + else: + raise ValueError( + "Non-indexed metadata fields cannot be used in queries." + ) + + if len(wc_blocks) == 0: + return "" + + return " WHERE " + " AND ".join(wc_blocks) + + def _extract_where_clause_params( + self, + metadata: Dict[str, Any], + link_from_tags: Optional[Tuple[str, str]] = None, + ) -> List[Any]: + params: List[Any] = [] + + if link_from_tags is not None: + params.append(link_from_tags[0]) + params.append(link_from_tags[1]) + + for key, value in sorted(metadata.items()): + if _is_metadata_field_indexed(key, self._metadata_indexing_policy): + params.append(self._coerce_string(value=value)) + else: + raise ValueError( + "Non-indexed metadata fields cannot be used in queries." + ) + + return params + + def _get_search_cql( + self, + has_limit: bool = False, + columns: Optional[str] = CONTENT_COLUMNS, + metadata_keys: List[str] = [], # noqa: B006 + has_embedding: bool = False, + has_link_from_tags: bool = False, + ) -> PreparedStatement: + where_clause = self._extract_where_clause_cql( + metadata_keys=metadata_keys, has_link_from_tags=has_link_from_tags + ) + limit_clause = " LIMIT ?" if has_limit else "" + order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else "" + + select_cql = SELECT_CQL_TEMPLATE.format( + columns=columns, + table_name=self.table_name(), + where_clause=where_clause, + order_clause=order_clause, + limit_clause=limit_clause, + ) + + if select_cql in self._prepared_query_cache: + return self._prepared_query_cache[select_cql] + + prepared_query = self._session.prepare(select_cql) + prepared_query.consistency_level = ConsistencyLevel.ONE + self._prepared_query_cache[select_cql] = prepared_query + + return prepared_query + + def _get_search_params( + self, + limit: Optional[int] = None, + metadata: Dict[str, Any] = {}, # noqa: B006 + embedding: Optional[List[float]] = None, + link_from_tags: Optional[Tuple[str, str]] = None, + ) -> Tuple[PreparedStatement, Tuple[Any, ...]]: + where_params = self._extract_where_clause_params( + metadata=metadata, link_from_tags=link_from_tags + ) + + limit_params = [limit] if limit is not None else [] + order_params = [embedding] if embedding is not None else [] + + return tuple(list(where_params) + order_params + limit_params) + + def _get_search_cql_and_params( + self, + limit: Optional[int] = None, + columns: Optional[str] = CONTENT_COLUMNS, + metadata: Dict[str, Any] = {}, # noqa: B006 + embedding: Optional[List[float]] = None, + link_from_tags: Optional[Tuple[str, str]] = None, + ) -> Tuple[PreparedStatement, Tuple[Any, ...]]: + query = self._get_search_cql( + has_limit=limit is not None, + columns=columns, + metadata_keys=list(metadata.keys()), + has_embedding=embedding is not None, + has_link_from_tags=link_from_tags is not None, + ) + params = self._get_search_params( + limit=limit, + metadata=metadata, + embedding=embedding, + link_from_tags=link_from_tags, + ) + return query, params diff --git a/libs/knowledge-store/tests/integration_tests/test_graph_store.py b/libs/knowledge-store/tests/integration_tests/test_graph_store.py index 69b15c3b9..b69a51c39 100644 --- a/libs/knowledge-store/tests/integration_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/integration_tests/test_graph_store.py @@ -1,12 +1,12 @@ import math import secrets -from typing import Iterable, Iterator, List +from typing import Callable, Iterable, Iterator, List import numpy as np import pytest from dotenv import load_dotenv from ragstack_knowledge_store import EmbeddingModel -from ragstack_knowledge_store.graph_store import GraphStore, Node +from ragstack_knowledge_store.graph_store import GraphStore, MetadataIndexingType, Node from ragstack_knowledge_store.links import Link from ragstack_tests_utils import LocalCassandraTestStore @@ -89,26 +89,29 @@ def cassandra() -> Iterator[LocalCassandraTestStore]: @pytest.fixture() -def graph_store( +def graph_store_factory( cassandra: LocalCassandraTestStore, -) -> Iterator[GraphStore]: +) -> Iterator[Callable[[], GraphStore]]: session = cassandra.create_cassandra_session() session.set_keyspace(KEYSPACE) embedding = SimpleEmbeddingModel() - name = secrets.token_hex(8) - - node_table = f"nodes_{name}" - store = GraphStore( - embedding, - session=session, - keyspace=KEYSPACE, - node_table=node_table, - ) - - yield store - + def _make_graph_store( + metadata_indexing: MetadataIndexingType = "all", + ) -> GraphStore: + name = secrets.token_hex(8) + + node_table = f"nodes_{name}" + return GraphStore( + embedding, + session=session, + keyspace=KEYSPACE, + node_table=node_table, + metadata_indexing=metadata_indexing, + ) + + yield _make_graph_store session.shutdown() @@ -116,7 +119,9 @@ def _result_ids(nodes: Iterable[Node]) -> List[str]: return [n.id for n in nodes if n.id is not None] -def test_mmr_traversal(graph_store: GraphStore) -> None: +def test_mmr_traversal( + graph_store_factory: Callable[[MetadataIndexingType], GraphStore], +) -> None: """ Test end to end construction and MMR search. The embedding function used here ensures `texts` become @@ -136,55 +141,69 @@ def test_mmr_traversal(graph_store: GraphStore) -> None: Both v2 and v3 are reachable via edges from v0, so once it is selected, those are both considered. """ + v0 = Node( id="v0", text="-0.124", links={Link(direction="out", kind="explicit", tag="link")}, + metadata={"even": True}, ) v1 = Node( id="v1", text="+0.127", + metadata={"even": False}, ) v2 = Node( id="v2", text="+0.25", links={Link(direction="in", kind="explicit", tag="link")}, + metadata={"even": True}, ) v3 = Node( id="v3", text="+1.0", links={Link(direction="in", kind="explicit", tag="link")}, + metadata={"even": False}, ) - graph_store.add_nodes([v0, v1, v2, v3]) - results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=2) + gs = graph_store_factory("all") + gs.add_nodes([v0, v1, v2, v3]) + + results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2) assert _result_ids(results) == ["v0", "v2"] # With max depth 0, no edges are traversed, so this doesn't reach v2 or v3. # So it ends up picking "v1" even though it's similar to "v0". - results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0) + results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0) assert _result_ids(results) == ["v0", "v1"] # With max depth 0 but higher `fetch_k`, we encounter v2 - results = graph_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0) + results = gs.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0) assert _result_ids(results) == ["v0", "v2"] # v0 score is .46, v2 score is 0.16 so it won't be chosen. - results = graph_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2) + results = gs.mmr_traversal_search("0.0", k=2, score_threshold=0.2) assert _result_ids(results) == ["v0"] # with k=4 we should get all of the documents. - results = graph_store.mmr_traversal_search("0.0", k=4) + results = gs.mmr_traversal_search("0.0", k=4) assert _result_ids(results) == ["v0", "v2", "v1", "v3"] + # with metadata_filter even=True we should only get the `even` documents. + results = gs.mmr_traversal_search("0.0", k=4, metadata_filter={"even": True}) + assert _result_ids(results) == ["v0", "v2"] + -def test_write_retrieve_keywords(graph_store: GraphStore) -> None: +def test_write_retrieve_keywords( + graph_store_factory: Callable[[MetadataIndexingType], GraphStore], +) -> None: greetings = Node( id="greetings", text="Typical Greetings", links={ Link(direction="in", kind="parent", tag="parent"), }, + metadata={"Hello": False, "Greeting": "typical"}, ) doc1 = Node( id="doc1", @@ -194,6 +213,7 @@ def test_write_retrieve_keywords(graph_store: GraphStore) -> None: Link(direction="bidir", kind="kw", tag="greeting"), Link(direction="bidir", kind="kw", tag="world"), }, + metadata={"Hello": True, "Greeting": "world"}, ) doc2 = Node( id="doc2", @@ -203,36 +223,53 @@ def test_write_retrieve_keywords(graph_store: GraphStore) -> None: Link(direction="bidir", kind="kw", tag="greeting"), Link(direction="bidir", kind="kw", tag="earth"), }, + metadata={"Hello": True, "Greeting": "earth"}, ) - graph_store.add_nodes([greetings, doc1, doc2]) + gs = graph_store_factory("all") + gs.add_nodes([greetings, doc1, doc2]) # Doc2 is more similar, but World and Earth are similar enough that doc1 also shows # up. - results = graph_store.similarity_search(text_to_embedding("Earth"), k=2) + results = gs.similarity_search(text_to_embedding("Earth"), k=2) assert _result_ids(results) == ["doc2", "doc1"] - results = graph_store.similarity_search(text_to_embedding("Earth"), k=1) + results = gs.similarity_search(text_to_embedding("Earth"), k=1) assert _result_ids(results) == ["doc2"] - results = graph_store.traversal_search("Earth", k=2, depth=0) + # with metadata filter + results = gs.similarity_search( + text_to_embedding("Earth"), k=1, metadata_filter={"Greeting": "world"} + ) + assert _result_ids(results) == ["doc1"] + + results = gs.traversal_search("Earth", k=2, depth=0) assert _result_ids(results) == ["doc2", "doc1"] - results = graph_store.traversal_search("Earth", k=2, depth=1) + results = gs.traversal_search("Earth", k=2, depth=1) assert _result_ids(results) == ["doc2", "doc1", "greetings"] + # with metadata filter + results = gs.traversal_search( + "Earth", k=2, depth=1, metadata_filter={"Hello": True} + ) + assert _result_ids(results) == ["doc2", "doc1"] + # K=1 only pulls in doc2 (Hello Earth) - results = graph_store.traversal_search("Earth", k=1, depth=0) + results = gs.traversal_search("Earth", k=1, depth=0) assert _result_ids(results) == ["doc2"] # K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via keyword # edge. - results = graph_store.traversal_search("Earth", k=1, depth=1) + results = gs.traversal_search("Earth", k=1, depth=1) assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"} -def test_metadata(graph_store: GraphStore) -> None: - graph_store.add_nodes( +def test_metadata( + graph_store_factory: Callable[[MetadataIndexingType], GraphStore], +) -> None: + gs = graph_store_factory("all") + gs.add_nodes( [ Node( id="a", @@ -245,7 +282,7 @@ def test_metadata(graph_store: GraphStore) -> None: ) ] ) - results = list(graph_store.similarity_search(text_to_embedding("A"))) + results = list(gs.similarity_search(text_to_embedding("A"))) assert len(results) == 1 assert results[0].id == "a" assert results[0].metadata["other"] == "some other field" @@ -253,3 +290,92 @@ def test_metadata(graph_store: GraphStore) -> None: Link(direction="in", kind="hyperlink", tag="http://a"), Link(direction="bidir", kind="other", tag="foo"), } + + +def test_graph_store_metadata( + graph_store_factory: Callable[[MetadataIndexingType], GraphStore], +) -> None: + gs = graph_store_factory("all") + + gs.add_nodes([Node(text="bb1", id="row1")]) + gotten1 = gs.get_node(content_id="row1") + assert gotten1 == Node(text="bb1", id="row1", metadata={}) + + gs.add_nodes([Node(text="bb2", id="row2", metadata={})]) + gotten2 = gs.get_node(content_id="row2") + assert gotten2 == Node(text="bb2", id="row2", metadata={}) + + md3 = {"a": 1, "b": "Bee", "c": True} + gs.add_nodes([Node(text="bb3", id="row3", metadata=md3)]) + gotten3 = gs.get_node(content_id="row3") + assert gotten3 == Node(text="bb3", id="row3", metadata=md3) + + md4 = {"c1": True, "c2": True, "c3": True} + gs.add_nodes([Node(text="bb4", id="row4", metadata=md4)]) + gotten4 = gs.get_node(content_id="row4") + assert gotten4 == Node(text="bb4", id="row4", metadata=md4) + + # metadata searches: + md_gotten3a = list(gs.metadata_search(metadata={"a": 1}))[0] # noqa: RUF015 + assert md_gotten3a == gotten3 + md_gotten3b = list(gs.metadata_search(metadata={"b": "Bee", "c": True}))[0] # noqa: RUF015 + assert md_gotten3b == gotten3 + md_gotten4 = list(gs.metadata_search(metadata={"c1": True, "c3": True}))[0] # noqa: RUF015 + assert md_gotten4 == gotten4 + + # 'search' proper + gs.add_nodes( + [ + Node(text="ta", id="twin_a", metadata={"twin": True, "index": 0}), + Node(text="tb", id="twin_b", metadata={"twin": True, "index": 1}), + ] + ) + md_twins_gotten = sorted( + gs.metadata_search(metadata={"twin": True}), + key=lambda res: int(float(res.metadata["index"])), + ) + expected = [ + Node(text="ta", id="twin_a", metadata={"twin": True, "index": 0}), + Node(text="tb", id="twin_b", metadata={"twin": True, "index": 1}), + ] + assert md_twins_gotten == expected + assert list(gs.metadata_search(metadata={"fake": True})) == [] + + +def test_graph_store_metadata_routing( + graph_store_factory: Callable[[MetadataIndexingType], GraphStore], +) -> None: + test_md = {"mds": "string", "mdn": 255, "mdb": True} + + gs_all = graph_store_factory("all") + gs_all.add_nodes([Node(id="row1", text="bb1", metadata=test_md)]) + gotten_all = list(gs_all.metadata_search(metadata={"mds": "string"}))[0] # noqa: RUF015 + assert gotten_all.metadata == test_md + gs_none = graph_store_factory("none") + gs_none.add_nodes([Node(id="row1", text="bb1", metadata=test_md)]) + with pytest.raises(ValueError): # noqa: PT011 + # querying on non-indexed metadata fields: + list(gs_none.metadata_search(metadata={"mds": "string"})) + gotten_none = gs_none.get_node(content_id="row1") + assert gotten_none is not None + assert gotten_none.metadata == test_md + test_md_allowdeny = { + "mdas": "MDAS", + "mdds": "MDDS", + "mdan": 255, + "mddn": 127, + "mdab": True, + "mddb": True, + } + gs_allow = graph_store_factory(("allow", {"mdas", "mdan", "mdab"})) + gs_allow.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)]) + with pytest.raises(ValueError): # noqa: PT011 + list(gs_allow.metadata_search(metadata={"mdds": "MDDS"})) + gotten_allow = list(gs_allow.metadata_search(metadata={"mdas": "MDAS"}))[0] # noqa: RUF015 + assert gotten_allow.metadata == test_md_allowdeny + gs_deny = graph_store_factory(("deny", {"mdds", "mddn", "mddb"})) + gs_deny.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)]) + with pytest.raises(ValueError): # noqa: PT011 + list(gs_deny.metadata_search(metadata={"mdds": "MDDS"})) + gotten_deny = list(gs_deny.metadata_search(metadata={"mdas": "MDAS"}))[0] # noqa: RUF015 + assert gotten_deny.metadata == test_md_allowdeny diff --git a/libs/knowledge-store/tests/unit_tests/test_cql_generation.py b/libs/knowledge-store/tests/unit_tests/test_cql_generation.py new file mode 100644 index 000000000..1178df52f --- /dev/null +++ b/libs/knowledge-store/tests/unit_tests/test_cql_generation.py @@ -0,0 +1,176 @@ +from ragstack_knowledge_store.graph_store import GraphStore, MetadataIndexingMode + + +class FakePreparedStatement: + query_string: str + + def __init__(self, query: str) -> None: + self.query_string = query + + +class FakeSession: + def prepare(self, query: str) -> FakePreparedStatement: + return FakePreparedStatement(query=query) + + +def _normalize_whitespace(s: str) -> str: + return " ".join(s.split()) + + +def test_cql_generation() -> None: + gs = object.__new__(GraphStore) + + gs._keyspace = "test_keyspace" # noqa: SLF001 + gs._node_table = "test_table" # noqa: SLF001 + gs._session = FakeSession() # noqa: SLF001 + gs._prepared_query_cache = {} # noqa: SLF001 + + query, values = gs._get_search_cql_and_params(limit=2, embedding=[0, 1]) # noqa: SLF001 + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT content_id, kind, text_content, links_blob, metadata_blob + FROM test_keyspace.test_table + ORDER BY text_embedding ANN OF ? + LIMIT ?; + """) + assert values == ([0, 1], 2) + + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + limit=2, embedding=[0, 1], columns="content_id, link_to_tags" + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT content_id, link_to_tags + FROM test_keyspace.test_table + ORDER BY text_embedding ANN OF ? + LIMIT ?; + """) + assert values == ([0, 1], 2) + + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + limit=2, embedding=[0, 1], columns="content_id, text_embedding, link_to_tags" + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT content_id, text_embedding, link_to_tags + FROM test_keyspace.test_table + ORDER BY text_embedding ANN OF ? + LIMIT ?; + """) + assert values == ([0, 1], 2) + + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + columns="content_id AS target_content_id", link_from_tags=("link", "tag") + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT content_id AS target_content_id + FROM test_keyspace.test_table + WHERE link_from_tags CONTAINS (?, ?); + """) + assert values == ("link", "tag") + + columns = """ + content_id AS target_content_id, + text_embedding AS target_text_embedding, + link_to_tags AS target_link_to_tags + """ + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + limit=2, embedding=[0, 1], columns=columns, link_from_tags=("link", "tag") + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT + content_id AS target_content_id, + text_embedding AS target_text_embedding, + link_to_tags AS target_link_to_tags + FROM test_keyspace.test_table + WHERE link_from_tags CONTAINS (?, ?) + ORDER BY text_embedding ANN OF ? + LIMIT ?; + """) + assert values == ("link", "tag", [0, 1], 2) + + +def test_cql_generation_with_metadata() -> None: + gs = object.__new__(GraphStore) + + gs._keyspace = "test_keyspace" # noqa: SLF001 + gs._node_table = "test_table" # noqa: SLF001 + gs._session = FakeSession() # noqa: SLF001 + gs._prepared_query_cache = {} # noqa: SLF001 + gs._metadata_indexing_policy = (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set()) # noqa: SLF001 + + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + limit=2, embedding=[0, 1], metadata={"one": True, "two": 2} + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT content_id, kind, text_content, links_blob, metadata_blob + FROM test_keyspace.test_table + WHERE metadata_s['one'] = ? AND metadata_s['two'] = ? + ORDER BY text_embedding ANN OF ? + LIMIT ?; + """) + assert values == ("true", "2.0", [0, 1], 2) + + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + limit=2, + embedding=[0, 1], + columns="content_id, link_to_tags", + metadata={"three": "four"}, + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT content_id, link_to_tags + FROM test_keyspace.test_table + WHERE metadata_s['three'] = ? + ORDER BY text_embedding ANN OF ? + LIMIT ?; + """) + assert values == ("four", [0, 1], 2) + + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + limit=2, + embedding=[0, 1], + columns="content_id, text_embedding, link_to_tags", + metadata={"test": False}, + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT content_id, text_embedding, link_to_tags + FROM test_keyspace.test_table + WHERE metadata_s['test'] = ? + ORDER BY text_embedding ANN OF ? + LIMIT ?; + """) + assert values == ("false", [0, 1], 2) + + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + columns="content_id AS target_content_id", + link_from_tags=("link", "tag"), + metadata={"one": True, "two": 2}, + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT content_id AS target_content_id + FROM test_keyspace.test_table + WHERE link_from_tags CONTAINS (?, ?) AND metadata_s['one'] = ? + AND metadata_s['two'] = ?; + """) + assert values == ("link", "tag", "true", "2.0") + + columns = """ + content_id AS target_content_id, + text_embedding AS target_text_embedding, + link_to_tags AS target_link_to_tags + """ + query, values = gs._get_search_cql_and_params( # noqa: SLF001 + limit=2, + embedding=[0, 1], + columns=columns, + link_from_tags=("link", "tag"), + metadata={"five": "3.0"}, + ) + assert _normalize_whitespace(query.query_string) == _normalize_whitespace(""" + SELECT + content_id AS target_content_id, + text_embedding AS target_text_embedding, + link_to_tags AS target_link_to_tags + FROM test_keyspace.test_table + WHERE link_from_tags CONTAINS (?, ?) AND metadata_s['five'] = ? + ORDER BY text_embedding ANN OF ? + LIMIT ?; + """) + assert values == ("link", "tag", "3.0", [0, 1], 2) diff --git a/libs/knowledge-store/tests/unit_tests/test_metadata_policy_normalization.py b/libs/knowledge-store/tests/unit_tests/test_metadata_policy_normalization.py new file mode 100644 index 000000000..5377fdbc9 --- /dev/null +++ b/libs/knowledge-store/tests/unit_tests/test_metadata_policy_normalization.py @@ -0,0 +1,26 @@ +""" +Normalization of metadata policy specification options +""" + +from ragstack_knowledge_store.graph_store import GraphStore, MetadataIndexingMode + + +class TestNormalizeMetadataPolicy: + def test_normalize_metadata_policy(self) -> None: + mdp1 = GraphStore._normalize_metadata_indexing_policy("all") # noqa: SLF001 + assert mdp1 == (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set()) + mdp2 = GraphStore._normalize_metadata_indexing_policy("none") # noqa: SLF001 + assert mdp2 == (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set()) + mdp3 = GraphStore._normalize_metadata_indexing_policy( # noqa: SLF001 + ("default_to_Unsearchable", ["x", "y"]), + ) + assert mdp3 == (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, {"x", "y"}) + mdp4 = GraphStore._normalize_metadata_indexing_policy( # noqa: SLF001 + ("DenyList", ["z"]), + ) + assert mdp4 == (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, {"z"}) + # s + mdp5 = GraphStore._normalize_metadata_indexing_policy( # noqa: SLF001 + ("deny_LIST", "singlefield") + ) + assert mdp5 == (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, {"singlefield"}) diff --git a/libs/knowledge-store/tests/unit_tests/test_metadata_string_coercion.py b/libs/knowledge-store/tests/unit_tests/test_metadata_string_coercion.py new file mode 100644 index 000000000..968794813 --- /dev/null +++ b/libs/knowledge-store/tests/unit_tests/test_metadata_string_coercion.py @@ -0,0 +1,30 @@ +""" +Stringification of everything in the simple metadata handling +""" + +from ragstack_knowledge_store.graph_store import GraphStore + + +class TestMetadataStringCoercion: + def test_metadata_string_coercion(self) -> None: + md_dict = { + "integer": 1, + "float": 2.0, + "boolean": True, + "null": None, + "string": "letter E", + "something": RuntimeError("You cannot do this!"), + } + + stringified = {k: GraphStore._coerce_string(v) for k, v in md_dict.items()} # noqa: SLF001 + + expected = { + "integer": "1.0", + "float": "2.0", + "boolean": "true", + "null": "null", + "string": "letter E", + "something": str(RuntimeError("You cannot do this!")), + } + + assert stringified == expected