Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: mmr traversal starting in neighborhood #634

Merged
merged 6 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 72 additions & 30 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,6 @@ def __init__(
""" # noqa: S608
)

self._query_source_tags_by_id = session.prepare(
f"""
SELECT link_to_tags
FROM {keyspace}.{node_table}
WHERE content_id = ?
""" # noqa: S608
)

def table_name(self) -> str:
"""Returns the fully qualified table name."""
return f"{self._keyspace}.{self._node_table}"
Expand Down Expand Up @@ -364,6 +356,7 @@ def mmr_traversal_search(
self,
query: str,
*,
neighborhood: Sequence[str] | None = None,
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
Expand All @@ -384,6 +377,10 @@ def mmr_traversal_search(

Args:
query: The query string to search for.
neighborhood: Optional list of documents to use as the initial neighborhood
bjchambers marked this conversation as resolved.
Show resolved Hide resolved
to search. If provided, the adjacent nodes to the neighborhood will be
used as the initial candidates, rather than performing a generic vector
search.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of initial Documents to fetch via similarity.
Defaults to 100.
Expand All @@ -406,19 +403,12 @@ def mmr_traversal_search(
score_threshold=score_threshold,
)

# For each unvisited node, stores the outgoing tags.
# For each unselected node, stores the outgoing tags.
outgoing_tags: dict[str, set[tuple[str, str]]] = {}

# Fetch the initial candidates and add them to the helper and
# outgoing_tags.
columns = "content_id, text_embedding, link_to_tags"
initial_candidates_query = self._get_search_cql(
has_limit=True,
columns=columns,
metadata_keys=list(metadata_filter.keys()),
has_embedding=True,
)

adjacent_query = self._get_search_cql(
has_limit=True,
columns=columns,
Expand All @@ -427,7 +417,16 @@ def mmr_traversal_search(
has_link_from_tags=True,
)

visited_tags: set[tuple[str, str]] = set()

def fetch_initial_candidates() -> None:
initial_candidates_query = self._get_search_cql(
has_limit=True,
columns=columns,
metadata_keys=list(metadata_filter.keys()),
has_embedding=True,
)

params = self._get_search_params(
limit=fetch_k,
metadata=metadata_filter,
Expand All @@ -443,12 +442,45 @@ def fetch_initial_candidates() -> None:
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
helper.add_candidates(candidates)

fetch_initial_candidates()
def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
# Put the neighborhood into the ouhgoing tags, to avoid adding it
bjchambers marked this conversation as resolved.
Show resolved Hide resolved
# to the candidate set in the future.
outgoing_tags.update({content_id: set() for content_id in neighborhood})

# Initialize the visited_tags with the set of outgoing from the
# neighborhood. This prevents re-visiting them.
visited_tags = self._get_outgoing_tags(neighborhood)

# Call `self._get_adjacent` to fetch the candidates.
adjacents = self._get_adjacent(
visited_tags,
adjacent_query=adjacent_query,
query_embedding=query_embedding,
k_per_tag=adjacent_k,
metadata_filter=metadata_filter,
)

# Select the best item, K times.
new_candidates = {}
for adjacent in adjacents:
if adjacent.target_content_id not in outgoing_tags:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)

new_candidates[adjacent.target_content_id] = (
adjacent.target_text_embedding
)
helper.add_candidates(new_candidates)

if neighborhood is None:
fetch_initial_candidates()
else:
fetch_neighborhood(neighborhood)

# Tracks the depth of each candidate.
depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()}
visited_tags: set[tuple[str, str]] = set()

# Select the best item, K times.
for _ in range(k):
selected_id = helper.pop_best()

Expand Down Expand Up @@ -683,12 +715,15 @@ def _get_outgoing_tags(

def add_sources(rows: Iterable[Any]) -> None:
for row in rows:
tags.update(row.link_to_tags)
if row.link_to_tags:
tags.update(row.link_to_tags)

with self._concurrent_queries() as cq:
for source_id in source_ids:
cq.execute(
self._query_source_tags_by_id, (source_id,), callback=add_sources
self._query_ids_and_link_to_tags_by_id,
(source_id,),
callback=add_sources,
)

return tags
Expand All @@ -699,7 +734,7 @@ def _get_adjacent(
adjacent_query: PreparedStatement,
query_embedding: list[float],
k_per_tag: int | None = None,
metadata_filter: dict[str, Any] = {}, # noqa: B006
metadata_filter: dict[str, Any] | None = None,
) -> Iterable[_Edge]:
"""Return the target nodes with incoming links from any of the given tags.

Expand Down Expand Up @@ -809,15 +844,19 @@ def _coerce_string(value: Any) -> str:

def _extract_where_clause_cql(
self,
metadata_keys: list[str] = [], # noqa: B006
has_id: bool = False,
metadata_keys: list[str] | None = None,
bjchambers marked this conversation as resolved.
Show resolved Hide resolved
has_link_from_tags: bool = False,
) -> str:
wc_blocks: list[str] = []

if has_id:
wc_blocks.append("content_id == ?")

if has_link_from_tags:
wc_blocks.append("link_from_tags CONTAINS (?, ?)")

for key in sorted(metadata_keys):
for key in sorted(metadata_keys or []):
if _is_metadata_field_indexed(key, self._metadata_indexing_policy):
wc_blocks.append(f"metadata_s['{key}'] = ?")
else:
Expand Down Expand Up @@ -855,12 +894,15 @@ def _get_search_cql(
self,
has_limit: bool = False,
columns: str | None = CONTENT_COLUMNS,
metadata_keys: list[str] = [], # noqa: B006
metadata_keys: list[str] | None = None,
has_id: bool = False,
has_embedding: bool = False,
has_link_from_tags: bool = False,
) -> PreparedStatement:
where_clause = self._extract_where_clause_cql(
metadata_keys=metadata_keys, has_link_from_tags=has_link_from_tags
has_id=has_id,
metadata_keys=metadata_keys,
has_link_from_tags=has_link_from_tags,
)
limit_clause = " LIMIT ?" if has_limit else ""
order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else ""
Expand All @@ -885,12 +927,12 @@ def _get_search_cql(
def _get_search_params(
self,
limit: int | None = None,
metadata: dict[str, Any] = {}, # noqa: B006
metadata: dict[str, Any] | None = None,
embedding: list[float] | None = None,
link_from_tags: tuple[str, str] | None = None,
) -> tuple[PreparedStatement, tuple[Any, ...]]:
where_params = self._extract_where_clause_params(
metadata=metadata, link_from_tags=link_from_tags
metadata=metadata or {}, link_from_tags=link_from_tags
)

limit_params = [limit] if limit is not None else []
Expand All @@ -902,14 +944,14 @@ def _get_search_cql_and_params(
self,
limit: int | None = None,
columns: str | None = CONTENT_COLUMNS,
metadata: dict[str, Any] = {}, # noqa: B006
metadata: dict[str, Any] | None = None,
embedding: list[float] | None = None,
link_from_tags: tuple[str, str] | None = None,
) -> tuple[PreparedStatement, tuple[Any, ...]]:
query = self._get_search_cql(
has_limit=limit is not None,
columns=columns,
metadata_keys=list(metadata.keys()),
metadata_keys=list(metadata.keys()) if metadata else None,
has_embedding=embedding is not None,
has_link_from_tags=link_from_tags is not None,
)
Expand Down
11 changes: 11 additions & 0 deletions libs/knowledge-store/tests/integration_tests/test_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,17 @@ def test_mmr_traversal(
results = gs.mmr_traversal_search("0.0", k=4, metadata_filter={"even": True})
assert _result_ids(results) == ["v0", "v2"]

# with neighborhood=[v0], we should start traversal there. this means that
# the initial candidates are `v2`,`v3`. `v1` is unreachable and not
# included.
results = gs.mmr_traversal_search("0.0", k=4, neighborhood=["v0"])
assert _result_ids(results) == ["v2", "v3"]
bjchambers marked this conversation as resolved.
Show resolved Hide resolved

# with neighborhood=[v1], we should start traversal there.
# there are no adjacent nodes, so there are no results.
results = gs.mmr_traversal_search("0.0", k=4, neighborhood=["v1"])
assert _result_ids(results) == []


def test_write_retrieve_keywords(
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
Expand Down