From bc8f3d7ddba3c9a0352c8f92e6308f49282f21ac Mon Sep 17 00:00:00 2001 From: Anthony Naddeo Date: Tue, 9 Apr 2024 18:17:21 -0700 Subject: [PATCH] Add rag context as input option + similarity to context metric This adds a structure that can be passed into workflow runs that represents the context in a rag system. It's a list of objects with optional metadata atm (not currently used). --- langkit/core/workflow.py | 51 +++++++++-- langkit/metrics/input_context_similarity.py | 28 ++++++ langkit/metrics/input_output_similarity.py | 3 - langkit/metrics/library.py | 7 ++ langkit/metrics/util.py | 7 -- langkit/transformer.py | 54 ++++++++++- .../metrics/test_input_context_similarity.py | 89 +++++++++++++++++++ 7 files changed, 223 insertions(+), 16 deletions(-) create mode 100644 langkit/metrics/input_context_similarity.py create mode 100644 tests/langkit/metrics/test_input_context_similarity.py diff --git a/langkit/core/workflow.py b/langkit/core/workflow.py index ae79976..cba58f1 100644 --- a/langkit/core/workflow.py +++ b/langkit/core/workflow.py @@ -6,6 +6,7 @@ from typing import Dict, List, Mapping, Optional, Set, Tuple, TypedDict, Union, cast, overload import pandas as pd +from typing_extensions import NotRequired from langkit.core.context import Context from langkit.core.metric import ( @@ -22,14 +23,23 @@ WorkflowMetricConfigBuilder, ) from langkit.core.validation import ValidationResult, Validator -from langkit.metrics.util import is_dict_with_strings logger = logging.getLogger(__name__) +class InputContextItem(TypedDict): + content: str + metadata: NotRequired[Dict[str, str]] + + +class InputContext(TypedDict): + entries: List[InputContextItem] + + class Row(TypedDict): - prompt: str - response: str + prompt: NotRequired[str] + response: NotRequired[str] + context: NotRequired[InputContext] @dataclass(frozen=True) @@ -214,9 +224,9 @@ def run(self, data: Union[pd.DataFrame, Row, Dict[str, str]], options: Optional[ init_end = time.perf_counter() - init_start if not isinstance(data, pd.DataFrame): - if not is_dict_with_strings(data): + if not is_dict_input(data): raise ValueError("Input must be a pandas DataFrame or a dictionary with string keys and string values") - df = pd.DataFrame(data, index=[0]) + df = pd.DataFrame([data]) else: df = data @@ -351,3 +361,34 @@ def _validate_evaluate(self, input_df: pd.DataFrame, metric: Metric, metric_resu if isinstance(metric_result, MultiMetricResult): for result in metric_result.metrics: assert len(input_df) == len(result) + + +def is_input_context_item(variable: object) -> bool: + if not isinstance(variable, dict): + return False + + variable = cast(InputContextItem, variable) + return "content" in variable and ("metadata" in variable or len(variable) == 1) + + +def is_input_context(variable: object) -> bool: + if not isinstance(variable, dict): + return False + if "entries" not in variable: + return False + + if not isinstance(variable["entries"], list): + return False + + variable = cast(InputContext, variable) + if len(variable) != 1: + return False + + return all(is_input_context_item(value) for value in variable["entries"]) + + +def is_dict_input(variable: object) -> bool: + if not isinstance(variable, dict): + return False + # Check if all values in the dictionary are strings + return all(isinstance(value, str) or is_input_context(value) for value in variable.values()) # type: ignore[reportUnknownMemberType] diff --git a/langkit/metrics/input_context_similarity.py b/langkit/metrics/input_context_similarity.py new file mode 100644 index 0000000..fa76354 --- /dev/null +++ b/langkit/metrics/input_context_similarity.py @@ -0,0 +1,28 @@ +import pandas as pd + +from langkit.core.context import Context +from langkit.core.metric import Metric, SingleMetric, SingleMetricResult +from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded +from langkit.transformer import EmbeddingContextDependency, RAGContextDependency + + +def input_context_similarity(input_column_name: str = "prompt", context_column_name: str = "context", onnx: bool = True) -> Metric: + prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name) + context_embedding_dep = RAGContextDependency(onnx=onnx, context_column_name=context_column_name) + + def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult: + prompt_embedding = prompt_embedding_dep.get_request_data(context) + context_embedding = context_embedding_dep.get_request_data(context) + similarity = compute_embedding_similarity_encoded(prompt_embedding, context_embedding) + + if len(similarity.shape) == 1: + return SingleMetricResult(similarity.tolist()) # type: ignore[reportUnknownVariableType] + else: + return SingleMetricResult(similarity.squeeze(dim=0).tolist()) # type: ignore[reportUnknownVariableType] + + return SingleMetric( + name=f"{input_column_name}.similarity.{context_column_name}", + input_names=[input_column_name, context_column_name], + evaluate=udf, + context_dependencies=[prompt_embedding_dep, context_embedding_dep], + ) diff --git a/langkit/metrics/input_output_similarity.py b/langkit/metrics/input_output_similarity.py index 3953ed0..11639c9 100644 --- a/langkit/metrics/input_output_similarity.py +++ b/langkit/metrics/input_output_similarity.py @@ -13,9 +13,6 @@ def input_output_similarity_metric(input_column_name: str = "prompt", output_col response_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=output_column_name) def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult: - # in_np = UdfInput(text).to_list(input_column_name) - # out_np = UdfInput(text).to_list(output_column_name) - # encoder = embedding_adapter(onnx) prompt_embedding = prompt_embedding_dep.get_request_data(context) response_embedding = response_embedding_dep.get_request_data(context) similarity = compute_embedding_similarity_encoded(prompt_embedding, response_embedding) diff --git a/langkit/metrics/library.py b/langkit/metrics/library.py index 1e6d2f8..c7f1a5e 100644 --- a/langkit/metrics/library.py +++ b/langkit/metrics/library.py @@ -24,6 +24,7 @@ def all(prompt: bool = True, response: bool = True) -> MetricCreator: prompt_sentiment_polarity, lib.prompt.toxicity(), prompt_response_input_output_similarity_metric, + lib.prompt.similarity.context(), prompt_injections_metric, prompt_jailbreak_similarity_metric, prompt_presidio_pii_metric, @@ -271,6 +272,12 @@ def jailbreak(onnx: bool = True) -> MetricCreator: return partial(prompt_jailbreak_similarity_metric, onnx=onnx) + @staticmethod + def context(onnx: bool = True) -> MetricCreator: + from langkit.metrics.input_context_similarity import input_context_similarity + + return partial(input_context_similarity, onnx=onnx) + class sentiment: def __call__(self) -> MetricCreator: return self.sentiment_score() diff --git a/langkit/metrics/util.py b/langkit/metrics/util.py index 9c61d82..30172b4 100644 --- a/langkit/metrics/util.py +++ b/langkit/metrics/util.py @@ -35,13 +35,6 @@ def value(self, arg: In) -> Out: return self.__cache[arg] -def is_dict_with_strings(variable: object) -> bool: - if not isinstance(variable, dict): - return False - # Check if all values in the dictionary are strings - return all(isinstance(value, str) for value in variable.values()) # type: ignore[reportUnknownMemberType] - - ReturnType = TypeVar("ReturnType") diff --git a/langkit/transformer.py b/langkit/transformer.py index 4130ae6..ceb0b97 100644 --- a/langkit/transformer.py +++ b/langkit/transformer.py @@ -1,12 +1,13 @@ from dataclasses import dataclass from functools import lru_cache -from typing import Tuple +from typing import List, Literal, Tuple import pandas as pd import torch from sentence_transformers import SentenceTransformer from langkit.core.context import Context, ContextDependency +from langkit.core.workflow import InputContext from langkit.metrics.embeddings_types import EmbeddingEncoder, TransformerEmbeddingAdapter from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel @@ -62,3 +63,54 @@ def populate_request(self, context: Context, data: pd.DataFrame): def get_request_data(self, context: Context) -> torch.Tensor: return context.request_data[self.name()] + + +@dataclass(frozen=True) +class RAGContextDependency(ContextDependency[torch.Tensor]): + onnx: bool + strategy: Literal["combine"] = "combine" + """ + The strategy for converting the context into embeddings. + + - combine: Combine all the entries in the context into a single string and encode it. + """ + context_column_name: str = "context" + + def name(self) -> str: + return f"{self.context_column_name}.context?onnx={self.onnx}" + + def cache_assets(self) -> None: + # TODO do only the downloading + embedding_adapter(onnx=self.onnx) + + def init(self) -> None: + embedding_adapter(onnx=self.onnx) + + def populate_request(self, context: Context, data: pd.DataFrame): + if self.context_column_name not in data.columns: + return + + if self.name() in context.request_data: + return + + rag_context = self._get_rag_context(data) + + if self.strategy == "combine": + combined: List[str] = [] + for row in rag_context: + print(row) + row_string = "\n".join([it["content"] for it in row["entries"]]) + combined.append(row_string) + else: + raise ValueError(f"Unknown context embedding strategy {self.strategy}") + + encoder = embedding_adapter(onnx=self.onnx) + embedding = encoder.encode(tuple(combined)) + context.request_data[self.name()] = embedding + + def _get_rag_context(self, df: pd.DataFrame) -> List[InputContext]: + context_column: List[InputContext] = df[self.context_column_name].tolist() + return context_column + + def get_request_data(self, context: Context) -> torch.Tensor: + return context.request_data[self.name()] diff --git a/tests/langkit/metrics/test_input_context_similarity.py b/tests/langkit/metrics/test_input_context_similarity.py new file mode 100644 index 0000000..1c26d00 --- /dev/null +++ b/tests/langkit/metrics/test_input_context_similarity.py @@ -0,0 +1,89 @@ +from typing import List + +import pandas as pd +import pytest + +from langkit.core.workflow import InputContext, Workflow, is_input_context +from langkit.metrics.library import lib + + +def test_similarity(): + wf = Workflow(metrics=[lib.prompt.similarity.context()]) + + context: InputContext = { + "entries": [ + {"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}}, + {"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}}, + ] + } + + df = pd.DataFrame({"prompt": ["Some source"], "context": [context]}) + + result = wf.run(df) + + metrics = result.metrics + + metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] + + assert metric_names == ["prompt.similarity.context", "id"] + assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType] + + +def test_similarity_missing_context(): + # The metric should not be run in this case since the context is missing + wf = Workflow(metrics=[lib.prompt.similarity.context()]) + + df = pd.DataFrame({"prompt": ["Some source"]}) + + result = wf.run(df) + + metrics = result.metrics + + metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] + + assert metric_names == ["id"] + + +def test_similarity_multiple(): + wf = Workflow(metrics=[lib.prompt.similarity.context()]) + + context: InputContext = { + "entries": [ + {"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}}, + {"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}}, + ] + } + + df = pd.DataFrame({"prompt": ["Some source", "Nothing in common"], "context": [context, context]}) + + result = wf.run(df) + + metrics = result.metrics + + metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] + + assert metric_names == ["prompt.similarity.context", "id"] + assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType] + assert metrics["prompt.similarity.context"][1] < 0.2 + + +def test_similarity_row(): + wf = Workflow(metrics=[lib.prompt.similarity.context()]) + + context: InputContext = { + "entries": [ + {"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}}, + {"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}}, + ] + } + + assert is_input_context(context) + + result = wf.run({"prompt": "Some source", "context": context}) + + metrics = result.metrics + + metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType] + + assert metric_names == ["prompt.similarity.context", "id"] + assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType]