-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rag context as input option + similarity to context metric
This adds a structure that can be passed into workflow runs that represents the context in a rag system. It's a list of objects with optional metadata atm (not currently used).
- Loading branch information
Showing
7 changed files
with
223 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import pandas as pd | ||
|
||
from langkit.core.context import Context | ||
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult | ||
from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded | ||
from langkit.transformer import EmbeddingContextDependency, RAGContextDependency | ||
|
||
|
||
def input_context_similarity(input_column_name: str = "prompt", context_column_name: str = "context", onnx: bool = True) -> Metric: | ||
prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name) | ||
context_embedding_dep = RAGContextDependency(onnx=onnx, context_column_name=context_column_name) | ||
|
||
def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult: | ||
prompt_embedding = prompt_embedding_dep.get_request_data(context) | ||
context_embedding = context_embedding_dep.get_request_data(context) | ||
similarity = compute_embedding_similarity_encoded(prompt_embedding, context_embedding) | ||
|
||
if len(similarity.shape) == 1: | ||
return SingleMetricResult(similarity.tolist()) # type: ignore[reportUnknownVariableType] | ||
else: | ||
return SingleMetricResult(similarity.squeeze(dim=0).tolist()) # type: ignore[reportUnknownVariableType] | ||
|
||
return SingleMetric( | ||
name=f"{input_column_name}.similarity.{context_column_name}", | ||
input_names=[input_column_name, context_column_name], | ||
evaluate=udf, | ||
context_dependencies=[prompt_embedding_dep, context_embedding_dep], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from typing import List | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from langkit.core.workflow import InputContext, Workflow, is_input_context | ||
from langkit.metrics.library import lib | ||
|
||
|
||
def test_similarity(): | ||
wf = Workflow(metrics=[lib.prompt.similarity.context()]) | ||
|
||
context: InputContext = { | ||
"entries": [ | ||
{"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}}, | ||
{"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}}, | ||
] | ||
} | ||
|
||
df = pd.DataFrame({"prompt": ["Some source"], "context": [context]}) | ||
|
||
result = wf.run(df) | ||
|
||
metrics = result.metrics | ||
|
||
metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] | ||
|
||
assert metric_names == ["prompt.similarity.context", "id"] | ||
assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType] | ||
|
||
|
||
def test_similarity_missing_context(): | ||
# The metric should not be run in this case since the context is missing | ||
wf = Workflow(metrics=[lib.prompt.similarity.context()]) | ||
|
||
df = pd.DataFrame({"prompt": ["Some source"]}) | ||
|
||
result = wf.run(df) | ||
|
||
metrics = result.metrics | ||
|
||
metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] | ||
|
||
assert metric_names == ["id"] | ||
|
||
|
||
def test_similarity_multiple(): | ||
wf = Workflow(metrics=[lib.prompt.similarity.context()]) | ||
|
||
context: InputContext = { | ||
"entries": [ | ||
{"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}}, | ||
{"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}}, | ||
] | ||
} | ||
|
||
df = pd.DataFrame({"prompt": ["Some source", "Nothing in common"], "context": [context, context]}) | ||
|
||
result = wf.run(df) | ||
|
||
metrics = result.metrics | ||
|
||
metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] | ||
|
||
assert metric_names == ["prompt.similarity.context", "id"] | ||
assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType] | ||
assert metrics["prompt.similarity.context"][1] < 0.2 | ||
|
||
|
||
def test_similarity_row(): | ||
wf = Workflow(metrics=[lib.prompt.similarity.context()]) | ||
|
||
context: InputContext = { | ||
"entries": [ | ||
{"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}}, | ||
{"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}}, | ||
] | ||
} | ||
|
||
assert is_input_context(context) | ||
|
||
result = wf.run({"prompt": "Some source", "context": context}) | ||
|
||
metrics = result.metrics | ||
|
||
metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] | ||
|
||
assert metric_names == ["prompt.similarity.context", "id"] | ||
assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType] |