Skip to content

Commit

Permalink
feat: add scores to MMR results (#652)
Browse files Browse the repository at this point in the history
* feat: add scores to MMR results

* add tests

* adjust tolerance
  • Loading branch information
bjchambers authored Sep 30, 2024
1 parent bca3e4f commit a8a092a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
26 changes: 21 additions & 5 deletions libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
@dataclasses.dataclass
class _Candidate:
id: str
similarity: float
weighted_similarity: float
weighted_redundancy: float
score: float = dataclasses.field(init=False)
Expand Down Expand Up @@ -69,6 +70,13 @@ class MmrHelper:

selected_ids: list[str]
"""List of selected IDs (in selection order)."""

selected_mmr_scores: list[float]
"""List of MMR score at the time each document is selected."""

selected_similarity_scores: list[float]
"""List of similarity score for each selected document."""

selected_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each selected node."""

Expand Down Expand Up @@ -100,6 +108,8 @@ def __init__(
self.score_threshold = score_threshold

self.selected_ids = []
self.selected_similarity_scores = []
self.selected_mmr_scores = []

# List of selected embeddings (in selection order).
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)
Expand All @@ -123,11 +133,11 @@ def _already_selected_embeddings(self) -> NDArray[np.float32]:
selected = len(self.selected_ids)
return np.vsplit(self.selected_embeddings, [selected])[0]

def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]:
"""Pop the candidate with the given ID.
Returns:
The embedding of the candidate.
The similarity score and embedding of the candidate.
"""
# Get the embedding for the id.
index = self.candidate_id_to_index.pop(candidate_id)
Expand All @@ -143,12 +153,15 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
# candidate_embeddings.
last_index = self.candidate_embeddings.shape[0] - 1

similarity = 0.0
if index == last_index:
# Already the last item. We don't need to swap.
self.candidates.pop()
similarity = self.candidates.pop().similarity
else:
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]

similarity = self.candidates[index].similarity

old_last = self.candidates.pop()
self.candidates[index] = old_last
self.candidate_id_to_index[old_last.id] = index
Expand All @@ -157,7 +170,7 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
0
]

return embedding
return similarity, embedding

def pop_best(self) -> str | None:
"""Select and pop the best item being considered.
Expand All @@ -172,11 +185,13 @@ def pop_best(self) -> str | None:

# Get the selection and remove from candidates.
selected_id = self.best_id
selected_embedding = self._pop_candidate(selected_id)
selected_similarity, selected_embedding = self._pop_candidate(selected_id)

# Add the ID and embedding to the selected information.
selection_index = len(self.selected_ids)
self.selected_ids.append(selected_id)
self.selected_mmr_scores.append(self.best_score)
self.selected_similarity_scores.append(selected_similarity)
self.selected_embeddings[selection_index] = selected_embedding

# Reset the best score / best ID.
Expand Down Expand Up @@ -232,6 +247,7 @@ def add_candidates(self, candidates: dict[str, list[float]]) -> None:
max_redundancy = redundancy[index].max()
candidate = _Candidate(
id=candidate_id,
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
)
Expand Down
8 changes: 7 additions & 1 deletion libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,13 @@ def fetch_initial_candidates() -> None:
depths[adjacent.target_content_id] = next_depth
helper.add_candidates(new_candidates)

return self._nodes_with_ids(helper.selected_ids)
nodes = self._nodes_with_ids(helper.selected_ids)
for node, similarity_score, mmr_score in zip(
nodes, helper.selected_similarity_scores, helper.selected_mmr_scores
):
node.metadata["similarity_score"] = similarity_score
node.metadata["mmr_score"] = mmr_score
return nodes

def traversal_search(
self,
Expand Down
5 changes: 5 additions & 0 deletions libs/knowledge-store/tests/unit_tests/test_mmr_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ def test_mmr_helper_added_documetns() -> None:
}
)
assert helper.pop_best() == "v2"

assert math.isclose(helper.selected_similarity_scores[0], 0.9251, abs_tol=0.0001)
assert math.isclose(helper.selected_similarity_scores[1], 0.7071, abs_tol=0.0001)
assert math.isclose(helper.selected_mmr_scores[0], 0.4625, abs_tol=0.0001)
assert math.isclose(helper.selected_mmr_scores[1], 0.1608, abs_tol=0.0001)

0 comments on commit a8a092a

Please sign in to comment.