Skip to content

Commit

Permalink
Add support for customizing the sentence transformer
Browse files Browse the repository at this point in the history
For the metrics where this makes sense (embedding similarity metrics)
you can now customize the sentence transformer used.
  • Loading branch information
naddeoa committed Apr 15, 2024
1 parent d4c0ee3 commit 2236a88
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 57 deletions.
4 changes: 2 additions & 2 deletions langkit/metrics/injections.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def _get_embeddings(version: str) -> "np.ndarray[Any, Any]":
return __process_embeddings(__download_embeddings(version))


def injections_metric(column_name: str, version: str = "v2", onnx: bool = True) -> Metric:
def injections_metric(column_name: str, version: str = "v2") -> Metric:
def cache_assets():
__download_embeddings(version)

def init():
_get_embeddings(version)

embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=column_name)
embedding_dep = EmbeddingContextDependency(embedding_choice="default", input_column=column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
if column_name not in text.columns:
Expand Down
10 changes: 6 additions & 4 deletions langkit/metrics/input_context_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from langkit.core.context import Context
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded
from langkit.transformer import EmbeddingContextDependency, RAGContextDependency
from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency, RAGContextDependency


def input_context_similarity(input_column_name: str = "prompt", context_column_name: str = "context", onnx: bool = True) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name)
context_embedding_dep = RAGContextDependency(onnx=onnx, context_column_name=context_column_name)
def input_context_similarity(
input_column_name: str = "prompt", context_column_name: str = "context", embedding: EmbeddingChoiceArg = "default"
) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=input_column_name)
context_embedding_dep = RAGContextDependency(embedding_choice=embedding, context_column_name=context_column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
prompt_embedding = prompt_embedding_dep.get_request_data(context)
Expand Down
10 changes: 6 additions & 4 deletions langkit/metrics/input_output_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from langkit.core.context import Context
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded
from langkit.transformer import EmbeddingContextDependency
from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency


def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response", onnx: bool = True) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name)
response_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=output_column_name)
def input_output_similarity_metric(
input_column_name: str = "prompt", output_column_name: str = "response", embedding: EmbeddingChoiceArg = "default"
) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=input_column_name)
response_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=output_column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
prompt_embedding = prompt_embedding_dep.get_request_data(context)
Expand Down
27 changes: 14 additions & 13 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

from langkit.core.metric import MetricCreator
from langkit.transformer import EmbeddingChoiceArg


class lib:
Expand Down Expand Up @@ -251,33 +252,33 @@ def __call__(self) -> MetricCreator:
]

@staticmethod
def injection(version: Optional[str] = None, onnx: bool = True) -> MetricCreator:
def injection(version: Optional[str] = None) -> MetricCreator:
"""
Analyze the input for injection themes. The injection score is a measure of how similar the input is
to known injection examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.injections import prompt_injections_metric

if version:
return partial(prompt_injections_metric, onnx=onnx, version=version)
return partial(prompt_injections_metric, version=version)

return partial(prompt_injections_metric, onnx=onnx)
return partial(prompt_injections_metric)

@staticmethod
def jailbreak(onnx: bool = True) -> MetricCreator:
def jailbreak(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
"""
Analyze the input for jailbreak themes. The jailbreak score is a measure of how similar the input is
to known jailbreak examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.themes.themes import prompt_jailbreak_similarity_metric

return partial(prompt_jailbreak_similarity_metric, onnx=onnx)
return partial(prompt_jailbreak_similarity_metric, embedding=embedding)

@staticmethod
def context(onnx: bool = True) -> MetricCreator:
def context(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
from langkit.metrics.input_context_similarity import input_context_similarity

return partial(input_context_similarity, onnx=onnx)
return partial(input_context_similarity, embedding=embedding)

class sentiment:
def __call__(self) -> MetricCreator:
Expand Down Expand Up @@ -494,30 +495,30 @@ def __call__(self) -> MetricCreator:
]

@staticmethod
def prompt(onnx: bool = True) -> MetricCreator:
def prompt(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
"""
Analyze the similarity between the input and the response. The output of this metric ranges from 0 to 1,
where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.input_output_similarity import prompt_response_input_output_similarity_metric

return partial(prompt_response_input_output_similarity_metric, onnx=onnx)
return partial(prompt_response_input_output_similarity_metric, embedding=embedding)

@staticmethod
def refusal(onnx: bool = True) -> MetricCreator:
def refusal(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
"""
Analyze the response for refusal themes. The refusal score is a measure of how similar the response is
to known refusal examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.themes.themes import response_refusal_similarity_metric

return partial(response_refusal_similarity_metric, onnx=onnx)
return partial(response_refusal_similarity_metric, embedding=embedding)

@staticmethod
def context(onnx: bool = True) -> MetricCreator:
def context(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
from langkit.metrics.input_context_similarity import input_context_similarity

return partial(input_context_similarity, onnx=onnx, input_column_name="response")
return partial(input_context_similarity, embedding=embedding, input_column_name="response")

class topics:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None, onnx: bool = True):
Expand Down
8 changes: 3 additions & 5 deletions langkit/metrics/themes/themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from langkit.core.context import Context
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.transformer import EmbeddingContextDependency, embedding_adapter
from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency, embedding_adapter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +60,7 @@ def _get_themes() -> Dict[str, torch.Tensor]:
return {group: torch.as_tensor(encoder.encode(tuple(themes))) for group, themes in theme_groups.items()}


def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], onnx: bool = True) -> Metric:
def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], embedding: EmbeddingChoiceArg = "default") -> Metric:
if themes_group == "refusal" and column_name == "prompt":
raise ValueError("Refusal themes are not applicable to prompt")

Expand All @@ -70,12 +70,10 @@ def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusa
def init():
_get_themes()

embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=column_name)
embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
theme = _get_themes()[themes_group] # (n_theme_examples, embedding_dim)
# text_list: List[str] = text[column_name].tolist()
# encoded_text = encoder.encode(tuple(text_list)) # (n_input_rows, embedding_dim)
encoded_text = embedding_dep.get_request_data(context)
similarities = F.cosine_similarity(encoded_text.unsqueeze(1), theme.unsqueeze(0), dim=2) # (n_input_rows, n_theme_examples)
max_similarities = similarities.max(dim=1)[0] # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] (n_input_rows,)
Expand Down
99 changes: 70 additions & 29 deletions langkit/transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Literal, Tuple
from typing import List, Literal, Union

import pandas as pd
import torch
Expand All @@ -12,43 +13,74 @@
from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel


def _sentence_transformer(
name_revision: Tuple[str, str] = ("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6"),
) -> SentenceTransformer:
"""
Returns a SentenceTransformer model instance.
class EmbeddingChoice(ABC):
@abstractmethod
def get_encoder(self) -> EmbeddingEncoder:
raise NotImplementedError()

The intent of this function is to cache the SentenceTransformer instance to avoid
multple instances being created all over langkit, and have a single place that
can be used to change the transformer name for the metrics that default to the same one.
"""
transformer_name, revision = name_revision
device = "cuda" if torch.cuda.is_available() else "cpu"
return SentenceTransformer(transformer_name, revision=revision, device=device)

class SentenceTransformerChoice(EmbeddingChoice):
def __init__(self, name: str, revision: str):
self.name = name
self.revision = revision

@lru_cache
def embedding_adapter(onnx: bool = True) -> EmbeddingEncoder:
if onnx:
def get_encoder(self) -> EmbeddingEncoder:
device = "cuda" if torch.cuda.is_available() else "cpu"
return TransformerEmbeddingAdapter(SentenceTransformer(self.name, revision=self.revision, device=device))


class DefaultChoice(SentenceTransformerChoice):
def __init__(self):
super().__init__("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6")


class OnnxChoice(EmbeddingChoice):
def get_encoder(self) -> EmbeddingEncoder:
return OnnxSentenceTransformer(TransformerModel.AllMiniLM)


@dataclass(frozen=True)
class SentenceTransformerTarget:
name: str
revision: str


EmbeddingChoiceArg = Union[Literal["default"], Literal["onnx"], SentenceTransformerTarget]


@lru_cache
def embedding_adapter(choice: EmbeddingChoiceArg = "default") -> EmbeddingEncoder:
if choice == "default":
return DefaultChoice().get_encoder()
elif choice == "onnx":
return OnnxChoice().get_encoder()
else:
return TransformerEmbeddingAdapter(_sentence_transformer())
return SentenceTransformerChoice(choice.name, choice.revision).get_encoder()


@dataclass(frozen=True)
class EmbeddingContextDependency(ContextDependency[torch.Tensor]):
onnx: bool
embedding_choice: EmbeddingChoiceArg
input_column: str

def name(self) -> str:
return f"{self.input_column}.embedding?onnx={self.onnx}"
if self.embedding_choice == "default":
choice_str = "default"
elif self.embedding_choice == "onnx":
choice_str = "onnx"
else:
choice_str = f"{self.embedding_choice.name}-{self.embedding_choice.revision}"

return f"{self.input_column}.embedding?type={choice_str}"

def _get_encoder(self) -> EmbeddingEncoder:
return embedding_adapter(choice=self.embedding_choice)

def cache_assets(self) -> None:
# TODO do only the downloading
embedding_adapter(onnx=self.onnx)
self._get_encoder()

def init(self) -> None:
embedding_adapter(onnx=self.onnx)
self._get_encoder()

def populate_request(self, context: Context, data: pd.DataFrame):
if self.input_column not in data.columns:
Expand All @@ -57,7 +89,7 @@ def populate_request(self, context: Context, data: pd.DataFrame):
if self.name() in context.request_data:
return

encoder = embedding_adapter(onnx=self.onnx)
encoder = self._get_encoder()
embedding = encoder.encode(tuple(data[self.input_column])) # pyright: ignore[reportUnknownArgumentType]
context.request_data[self.name()] = embedding

Expand All @@ -67,7 +99,7 @@ def get_request_data(self, context: Context) -> torch.Tensor:

@dataclass(frozen=True)
class RAGContextDependency(ContextDependency[torch.Tensor]):
onnx: bool
embedding_choice: EmbeddingChoiceArg
strategy: Literal["combine"] = "combine"
"""
The strategy for converting the context into embeddings.
Expand All @@ -77,14 +109,23 @@ class RAGContextDependency(ContextDependency[torch.Tensor]):
context_column_name: str = "context"

def name(self) -> str:
return f"{self.context_column_name}.context?onnx={self.onnx}"
if self.embedding_choice == "default":
choice_str = "default"
elif self.embedding_choice == "onnx":
choice_str = "onnx"
else:
choice_str = f"{self.embedding_choice.name}-{self.embedding_choice.revision}"

return f"{self.context_column_name}.context?type={choice_str}&strategy={self.strategy}"

def _get_encoder(self) -> EmbeddingEncoder:
return embedding_adapter(choice=self.embedding_choice)

def cache_assets(self) -> None:
# TODO do only the downloading
embedding_adapter(onnx=self.onnx)
self._get_encoder()

def init(self) -> None:
embedding_adapter(onnx=self.onnx)
self._get_encoder()

def populate_request(self, context: Context, data: pd.DataFrame):
if self.context_column_name not in data.columns:
Expand All @@ -104,7 +145,7 @@ def populate_request(self, context: Context, data: pd.DataFrame):
else:
raise ValueError(f"Unknown context embedding strategy {self.strategy}")

encoder = embedding_adapter(onnx=self.onnx)
encoder = self._get_encoder()
embedding = encoder.encode(tuple(combined))
context.request_data[self.name()] = embedding

Expand Down

0 comments on commit 2236a88

Please sign in to comment.