diff --git a/.bumpversion.cfg b/.bumpversion.cfg index b5d79d7..e9def7f 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.0.28.dev8 +current_version = 0.0.28.dev9 tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? serialize = diff --git a/langkit/metrics/library.py b/langkit/metrics/library.py index c7f1a5e..77a7bc4 100644 --- a/langkit/metrics/library.py +++ b/langkit/metrics/library.py @@ -39,6 +39,7 @@ def all(prompt: bool = True, response: bool = True) -> MetricCreator: response_refusal_similarity_metric, response_presidio_pii_metric, lib.response.toxicity(), + lib.response.similarity.context(), lib.response.topics.medicine(), ] @@ -512,6 +513,12 @@ def refusal(onnx: bool = True) -> MetricCreator: return partial(response_refusal_similarity_metric, onnx=onnx) + @staticmethod + def context(onnx: bool = True) -> MetricCreator: + from langkit.metrics.input_context_similarity import input_context_similarity + + return partial(input_context_similarity, onnx=onnx, input_column_name="response") + class topics: def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None, onnx: bool = True): self.topics = topics diff --git a/pyproject.toml b/pyproject.toml index a2673e2..a9c934d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langkit" -version = "0.0.28.dev8" +version = "0.0.28.dev9" description = "A language toolkit for monitoring LLM interactions" authors = ["WhyLabs.ai "] homepage = "https://docs.whylabs.ai/docs/large-language-model-monitoring" diff --git a/tests/langkit/metrics/test_input_context_similarity.py b/tests/langkit/metrics/test_input_context_similarity.py index 1c26d00..ce97d13 100644 --- a/tests/langkit/metrics/test_input_context_similarity.py +++ b/tests/langkit/metrics/test_input_context_similarity.py @@ -29,6 +29,28 @@ def test_similarity(): assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType] +def test_similarity_repoonse(): + wf = Workflow(metrics=[lib.response.similarity.context()]) + + context: InputContext = { + "entries": [ + {"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}}, + {"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}}, + ] + } + + df = pd.DataFrame({"response": ["Some source"], "context": [context]}) + + result = wf.run(df) + + metrics = result.metrics + + metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] + + assert metric_names == ["response.similarity.context", "id"] + assert metrics["response.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType] + + def test_similarity_missing_context(): # The metric should not be run in this case since the context is missing wf = Workflow(metrics=[lib.prompt.similarity.context()])