Skip to content

Commit

Permalink
Add rag context as input option + similarity to context metric
Browse files Browse the repository at this point in the history
This adds a structure that can be passed into workflow runs that
represents the context in a rag system. It's a list of objects with
optional metadata atm (not currently used).
  • Loading branch information
naddeoa committed Apr 10, 2024
1 parent 5f5485f commit bc8f3d7
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 16 deletions.
51 changes: 46 additions & 5 deletions langkit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, List, Mapping, Optional, Set, Tuple, TypedDict, Union, cast, overload

import pandas as pd
from typing_extensions import NotRequired

from langkit.core.context import Context
from langkit.core.metric import (
Expand All @@ -22,14 +23,23 @@
WorkflowMetricConfigBuilder,
)
from langkit.core.validation import ValidationResult, Validator
from langkit.metrics.util import is_dict_with_strings

logger = logging.getLogger(__name__)


class InputContextItem(TypedDict):
content: str
metadata: NotRequired[Dict[str, str]]


class InputContext(TypedDict):
entries: List[InputContextItem]


class Row(TypedDict):
prompt: str
response: str
prompt: NotRequired[str]
response: NotRequired[str]
context: NotRequired[InputContext]


@dataclass(frozen=True)
Expand Down Expand Up @@ -214,9 +224,9 @@ def run(self, data: Union[pd.DataFrame, Row, Dict[str, str]], options: Optional[
init_end = time.perf_counter() - init_start

if not isinstance(data, pd.DataFrame):
if not is_dict_with_strings(data):
if not is_dict_input(data):
raise ValueError("Input must be a pandas DataFrame or a dictionary with string keys and string values")
df = pd.DataFrame(data, index=[0])
df = pd.DataFrame([data])
else:
df = data

Expand Down Expand Up @@ -351,3 +361,34 @@ def _validate_evaluate(self, input_df: pd.DataFrame, metric: Metric, metric_resu
if isinstance(metric_result, MultiMetricResult):
for result in metric_result.metrics:
assert len(input_df) == len(result)


def is_input_context_item(variable: object) -> bool:
if not isinstance(variable, dict):
return False

variable = cast(InputContextItem, variable)
return "content" in variable and ("metadata" in variable or len(variable) == 1)


def is_input_context(variable: object) -> bool:
if not isinstance(variable, dict):
return False
if "entries" not in variable:
return False

if not isinstance(variable["entries"], list):
return False

variable = cast(InputContext, variable)
if len(variable) != 1:
return False

return all(is_input_context_item(value) for value in variable["entries"])


def is_dict_input(variable: object) -> bool:
if not isinstance(variable, dict):
return False
# Check if all values in the dictionary are strings
return all(isinstance(value, str) or is_input_context(value) for value in variable.values()) # type: ignore[reportUnknownMemberType]
28 changes: 28 additions & 0 deletions langkit/metrics/input_context_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd

from langkit.core.context import Context
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded
from langkit.transformer import EmbeddingContextDependency, RAGContextDependency


def input_context_similarity(input_column_name: str = "prompt", context_column_name: str = "context", onnx: bool = True) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name)
context_embedding_dep = RAGContextDependency(onnx=onnx, context_column_name=context_column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
prompt_embedding = prompt_embedding_dep.get_request_data(context)
context_embedding = context_embedding_dep.get_request_data(context)
similarity = compute_embedding_similarity_encoded(prompt_embedding, context_embedding)

if len(similarity.shape) == 1:
return SingleMetricResult(similarity.tolist()) # type: ignore[reportUnknownVariableType]
else:
return SingleMetricResult(similarity.squeeze(dim=0).tolist()) # type: ignore[reportUnknownVariableType]

return SingleMetric(
name=f"{input_column_name}.similarity.{context_column_name}",
input_names=[input_column_name, context_column_name],
evaluate=udf,
context_dependencies=[prompt_embedding_dep, context_embedding_dep],
)
3 changes: 0 additions & 3 deletions langkit/metrics/input_output_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ def input_output_similarity_metric(input_column_name: str = "prompt", output_col
response_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=output_column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
# in_np = UdfInput(text).to_list(input_column_name)
# out_np = UdfInput(text).to_list(output_column_name)
# encoder = embedding_adapter(onnx)
prompt_embedding = prompt_embedding_dep.get_request_data(context)
response_embedding = response_embedding_dep.get_request_data(context)
similarity = compute_embedding_similarity_encoded(prompt_embedding, response_embedding)
Expand Down
7 changes: 7 additions & 0 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def all(prompt: bool = True, response: bool = True) -> MetricCreator:
prompt_sentiment_polarity,
lib.prompt.toxicity(),
prompt_response_input_output_similarity_metric,
lib.prompt.similarity.context(),
prompt_injections_metric,
prompt_jailbreak_similarity_metric,
prompt_presidio_pii_metric,
Expand Down Expand Up @@ -271,6 +272,12 @@ def jailbreak(onnx: bool = True) -> MetricCreator:

return partial(prompt_jailbreak_similarity_metric, onnx=onnx)

@staticmethod
def context(onnx: bool = True) -> MetricCreator:
from langkit.metrics.input_context_similarity import input_context_similarity

return partial(input_context_similarity, onnx=onnx)

class sentiment:
def __call__(self) -> MetricCreator:
return self.sentiment_score()
Expand Down
7 changes: 0 additions & 7 deletions langkit/metrics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@ def value(self, arg: In) -> Out:
return self.__cache[arg]


def is_dict_with_strings(variable: object) -> bool:
if not isinstance(variable, dict):
return False
# Check if all values in the dictionary are strings
return all(isinstance(value, str) for value in variable.values()) # type: ignore[reportUnknownMemberType]


ReturnType = TypeVar("ReturnType")


Expand Down
54 changes: 53 additions & 1 deletion langkit/transformer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Tuple
from typing import List, Literal, Tuple

import pandas as pd
import torch
from sentence_transformers import SentenceTransformer

from langkit.core.context import Context, ContextDependency
from langkit.core.workflow import InputContext
from langkit.metrics.embeddings_types import EmbeddingEncoder, TransformerEmbeddingAdapter
from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel

Expand Down Expand Up @@ -62,3 +63,54 @@ def populate_request(self, context: Context, data: pd.DataFrame):

def get_request_data(self, context: Context) -> torch.Tensor:
return context.request_data[self.name()]


@dataclass(frozen=True)
class RAGContextDependency(ContextDependency[torch.Tensor]):
onnx: bool
strategy: Literal["combine"] = "combine"
"""
The strategy for converting the context into embeddings.
- combine: Combine all the entries in the context into a single string and encode it.
"""
context_column_name: str = "context"

def name(self) -> str:
return f"{self.context_column_name}.context?onnx={self.onnx}"

def cache_assets(self) -> None:
# TODO do only the downloading
embedding_adapter(onnx=self.onnx)

def init(self) -> None:
embedding_adapter(onnx=self.onnx)

def populate_request(self, context: Context, data: pd.DataFrame):
if self.context_column_name not in data.columns:
return

if self.name() in context.request_data:
return

rag_context = self._get_rag_context(data)

if self.strategy == "combine":
combined: List[str] = []
for row in rag_context:
print(row)
row_string = "\n".join([it["content"] for it in row["entries"]])
combined.append(row_string)
else:
raise ValueError(f"Unknown context embedding strategy {self.strategy}")

encoder = embedding_adapter(onnx=self.onnx)
embedding = encoder.encode(tuple(combined))
context.request_data[self.name()] = embedding

def _get_rag_context(self, df: pd.DataFrame) -> List[InputContext]:
context_column: List[InputContext] = df[self.context_column_name].tolist()
return context_column

def get_request_data(self, context: Context) -> torch.Tensor:
return context.request_data[self.name()]
89 changes: 89 additions & 0 deletions tests/langkit/metrics/test_input_context_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import List

import pandas as pd
import pytest

from langkit.core.workflow import InputContext, Workflow, is_input_context
from langkit.metrics.library import lib


def test_similarity():
wf = Workflow(metrics=[lib.prompt.similarity.context()])

context: InputContext = {
"entries": [
{"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}},
{"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}},
]
}

df = pd.DataFrame({"prompt": ["Some source"], "context": [context]})

result = wf.run(df)

metrics = result.metrics

metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType]

assert metric_names == ["prompt.similarity.context", "id"]
assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType]


def test_similarity_missing_context():
# The metric should not be run in this case since the context is missing
wf = Workflow(metrics=[lib.prompt.similarity.context()])

df = pd.DataFrame({"prompt": ["Some source"]})

result = wf.run(df)

metrics = result.metrics

metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType]

assert metric_names == ["id"]


def test_similarity_multiple():
wf = Workflow(metrics=[lib.prompt.similarity.context()])

context: InputContext = {
"entries": [
{"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}},
{"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}},
]
}

df = pd.DataFrame({"prompt": ["Some source", "Nothing in common"], "context": [context, context]})

result = wf.run(df)

metrics = result.metrics

metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType]

assert metric_names == ["prompt.similarity.context", "id"]
assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType]
assert metrics["prompt.similarity.context"][1] < 0.2


def test_similarity_row():
wf = Workflow(metrics=[lib.prompt.similarity.context()])

context: InputContext = {
"entries": [
{"content": "Some source 1", "metadata": {"source": "https://internal.com/foo"}},
{"content": "Some source 2", "metadata": {"source": "https://internal.com/bar"}},
]
}

assert is_input_context(context)

result = wf.run({"prompt": "Some source", "context": context})

metrics = result.metrics

metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType]

assert metric_names == ["prompt.similarity.context", "id"]
assert metrics["prompt.similarity.context"][0] == pytest.approx(0.7447172999382019) # pyright: ignore[reportUnknownMemberType]

0 comments on commit bc8f3d7

Please sign in to comment.