From 228f5e6f6c9f105ba46b5102dc79ea79b118668c Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Thu, 26 Sep 2024 16:19:35 +0200 Subject: [PATCH] more updates --- .../ragstack_knowledge_store/graph_store.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 55dceee2b..24cf8c326 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -142,7 +142,16 @@ def _row_to_node(row: Any) -> Node: links=links, ) +def _get_metadata_filter( + metadata: dict[str, Any] | None = None, + outgoing_link: Link | None = None, +) -> dict[str, Any]: + if outgoing_link is None: + return metadata + metadata_filter = {} if metadata is None else metadata.copy() + metadata_filter[_metadata_s_link_key(link=outgoing_link)] = _metadata_s_link_value() + return metadata_filter _CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") @@ -841,6 +850,9 @@ def _extract_where_clause_params( return params + + + def _get_search_cql_and_params( self, columns: str, @@ -849,15 +861,9 @@ def _get_search_cql_and_params( embedding: list[float] | None = None, outgoing_link: Link | None = None, ) -> tuple[PreparedStatement|SimpleStatement, tuple[Any, ...]]: - if outgoing_link is not None: - if metadata is None: - metadata = {} - else: - # don't add link search to original metadata dict - metadata = metadata.copy() - metadata[_metadata_s_link_key(link=outgoing_link)] = _metadata_s_link_value() + metadata_filter = _get_metadata_filter(metadata=metadata, outgoing_link=outgoing_link) - metadata_keys = list(metadata.keys()) if metadata else [] + metadata_keys = list(metadata_filter.keys()) if metadata else [] where_clause = self._extract_where_clause_cql(metadata_keys=metadata_keys) limit_clause = " LIMIT ?" if limit is not None else "" @@ -871,7 +877,7 @@ def _get_search_cql_and_params( limit_clause=limit_clause, ) - where_params = self._extract_where_clause_params(metadata=metadata or {}) + where_params = self._extract_where_clause_params(metadata=metadata_filter or {}) limit_params = [limit] if limit is not None else [] order_params = [embedding] if embedding is not None else [] @@ -886,3 +892,4 @@ def _get_search_cql_and_params( prepared_query.consistency_level = ConsistencyLevel.ONE self._prepared_query_cache[select_cql] = prepared_query return prepared_query, params +