From 2236a886868eca53b1ef61e8653a20b1f372b398 Mon Sep 17 00:00:00 2001 From: Anthony Naddeo Date: Mon, 15 Apr 2024 15:46:39 -0700 Subject: [PATCH] Add support for customizing the sentence transformer For the metrics where this makes sense (embedding similarity metrics) you can now customize the sentence transformer used. --- langkit/metrics/injections.py | 4 +- langkit/metrics/input_context_similarity.py | 10 ++- langkit/metrics/input_output_similarity.py | 10 ++- langkit/metrics/library.py | 27 +++--- langkit/metrics/themes/themes.py | 8 +- langkit/transformer.py | 99 +++++++++++++++------ 6 files changed, 101 insertions(+), 57 deletions(-) diff --git a/langkit/metrics/injections.py b/langkit/metrics/injections.py index c0e81f9..e8a4495 100644 --- a/langkit/metrics/injections.py +++ b/langkit/metrics/injections.py @@ -65,14 +65,14 @@ def _get_embeddings(version: str) -> "np.ndarray[Any, Any]": return __process_embeddings(__download_embeddings(version)) -def injections_metric(column_name: str, version: str = "v2", onnx: bool = True) -> Metric: +def injections_metric(column_name: str, version: str = "v2") -> Metric: def cache_assets(): __download_embeddings(version) def init(): _get_embeddings(version) - embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=column_name) + embedding_dep = EmbeddingContextDependency(embedding_choice="default", input_column=column_name) def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult: if column_name not in text.columns: diff --git a/langkit/metrics/input_context_similarity.py b/langkit/metrics/input_context_similarity.py index fa76354..a5a6a83 100644 --- a/langkit/metrics/input_context_similarity.py +++ b/langkit/metrics/input_context_similarity.py @@ -3,12 +3,14 @@ from langkit.core.context import Context from langkit.core.metric import Metric, SingleMetric, SingleMetricResult from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded -from langkit.transformer import EmbeddingContextDependency, RAGContextDependency +from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency, RAGContextDependency -def input_context_similarity(input_column_name: str = "prompt", context_column_name: str = "context", onnx: bool = True) -> Metric: - prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name) - context_embedding_dep = RAGContextDependency(onnx=onnx, context_column_name=context_column_name) +def input_context_similarity( + input_column_name: str = "prompt", context_column_name: str = "context", embedding: EmbeddingChoiceArg = "default" +) -> Metric: + prompt_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=input_column_name) + context_embedding_dep = RAGContextDependency(embedding_choice=embedding, context_column_name=context_column_name) def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult: prompt_embedding = prompt_embedding_dep.get_request_data(context) diff --git a/langkit/metrics/input_output_similarity.py b/langkit/metrics/input_output_similarity.py index 11639c9..2af37a6 100644 --- a/langkit/metrics/input_output_similarity.py +++ b/langkit/metrics/input_output_similarity.py @@ -5,12 +5,14 @@ from langkit.core.context import Context from langkit.core.metric import Metric, SingleMetric, SingleMetricResult from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded -from langkit.transformer import EmbeddingContextDependency +from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency -def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response", onnx: bool = True) -> Metric: - prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name) - response_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=output_column_name) +def input_output_similarity_metric( + input_column_name: str = "prompt", output_column_name: str = "response", embedding: EmbeddingChoiceArg = "default" +) -> Metric: + prompt_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=input_column_name) + response_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=output_column_name) def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult: prompt_embedding = prompt_embedding_dep.get_request_data(context) diff --git a/langkit/metrics/library.py b/langkit/metrics/library.py index 77a7bc4..24369f7 100644 --- a/langkit/metrics/library.py +++ b/langkit/metrics/library.py @@ -2,6 +2,7 @@ from typing import List, Optional from langkit.core.metric import MetricCreator +from langkit.transformer import EmbeddingChoiceArg class lib: @@ -251,7 +252,7 @@ def __call__(self) -> MetricCreator: ] @staticmethod - def injection(version: Optional[str] = None, onnx: bool = True) -> MetricCreator: + def injection(version: Optional[str] = None) -> MetricCreator: """ Analyze the input for injection themes. The injection score is a measure of how similar the input is to known injection examples, where 0 indicates no similarity and 1 indicates a high similarity. @@ -259,25 +260,25 @@ def injection(version: Optional[str] = None, onnx: bool = True) -> MetricCreator from langkit.metrics.injections import prompt_injections_metric if version: - return partial(prompt_injections_metric, onnx=onnx, version=version) + return partial(prompt_injections_metric, version=version) - return partial(prompt_injections_metric, onnx=onnx) + return partial(prompt_injections_metric) @staticmethod - def jailbreak(onnx: bool = True) -> MetricCreator: + def jailbreak(embedding: EmbeddingChoiceArg = "default") -> MetricCreator: """ Analyze the input for jailbreak themes. The jailbreak score is a measure of how similar the input is to known jailbreak examples, where 0 indicates no similarity and 1 indicates a high similarity. """ from langkit.metrics.themes.themes import prompt_jailbreak_similarity_metric - return partial(prompt_jailbreak_similarity_metric, onnx=onnx) + return partial(prompt_jailbreak_similarity_metric, embedding=embedding) @staticmethod - def context(onnx: bool = True) -> MetricCreator: + def context(embedding: EmbeddingChoiceArg = "default") -> MetricCreator: from langkit.metrics.input_context_similarity import input_context_similarity - return partial(input_context_similarity, onnx=onnx) + return partial(input_context_similarity, embedding=embedding) class sentiment: def __call__(self) -> MetricCreator: @@ -494,30 +495,30 @@ def __call__(self) -> MetricCreator: ] @staticmethod - def prompt(onnx: bool = True) -> MetricCreator: + def prompt(embedding: EmbeddingChoiceArg = "default") -> MetricCreator: """ Analyze the similarity between the input and the response. The output of this metric ranges from 0 to 1, where 0 indicates no similarity and 1 indicates a high similarity. """ from langkit.metrics.input_output_similarity import prompt_response_input_output_similarity_metric - return partial(prompt_response_input_output_similarity_metric, onnx=onnx) + return partial(prompt_response_input_output_similarity_metric, embedding=embedding) @staticmethod - def refusal(onnx: bool = True) -> MetricCreator: + def refusal(embedding: EmbeddingChoiceArg = "default") -> MetricCreator: """ Analyze the response for refusal themes. The refusal score is a measure of how similar the response is to known refusal examples, where 0 indicates no similarity and 1 indicates a high similarity. """ from langkit.metrics.themes.themes import response_refusal_similarity_metric - return partial(response_refusal_similarity_metric, onnx=onnx) + return partial(response_refusal_similarity_metric, embedding=embedding) @staticmethod - def context(onnx: bool = True) -> MetricCreator: + def context(embedding: EmbeddingChoiceArg = "default") -> MetricCreator: from langkit.metrics.input_context_similarity import input_context_similarity - return partial(input_context_similarity, onnx=onnx, input_column_name="response") + return partial(input_context_similarity, embedding=embedding, input_column_name="response") class topics: def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None, onnx: bool = True): diff --git a/langkit/metrics/themes/themes.py b/langkit/metrics/themes/themes.py index 6610a89..76626c3 100644 --- a/langkit/metrics/themes/themes.py +++ b/langkit/metrics/themes/themes.py @@ -10,7 +10,7 @@ from langkit.core.context import Context from langkit.core.metric import Metric, SingleMetric, SingleMetricResult -from langkit.transformer import EmbeddingContextDependency, embedding_adapter +from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency, embedding_adapter logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def _get_themes() -> Dict[str, torch.Tensor]: return {group: torch.as_tensor(encoder.encode(tuple(themes))) for group, themes in theme_groups.items()} -def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], onnx: bool = True) -> Metric: +def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], embedding: EmbeddingChoiceArg = "default") -> Metric: if themes_group == "refusal" and column_name == "prompt": raise ValueError("Refusal themes are not applicable to prompt") @@ -70,12 +70,10 @@ def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusa def init(): _get_themes() - embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=column_name) + embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=column_name) def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult: theme = _get_themes()[themes_group] # (n_theme_examples, embedding_dim) - # text_list: List[str] = text[column_name].tolist() - # encoded_text = encoder.encode(tuple(text_list)) # (n_input_rows, embedding_dim) encoded_text = embedding_dep.get_request_data(context) similarities = F.cosine_similarity(encoded_text.unsqueeze(1), theme.unsqueeze(0), dim=2) # (n_input_rows, n_theme_examples) max_similarities = similarities.max(dim=1)[0] # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] (n_input_rows,) diff --git a/langkit/transformer.py b/langkit/transformer.py index ceb0b97..c80b8f6 100644 --- a/langkit/transformer.py +++ b/langkit/transformer.py @@ -1,6 +1,7 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from functools import lru_cache -from typing import List, Literal, Tuple +from typing import List, Literal, Union import pandas as pd import torch @@ -12,43 +13,74 @@ from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel -def _sentence_transformer( - name_revision: Tuple[str, str] = ("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6"), -) -> SentenceTransformer: - """ - Returns a SentenceTransformer model instance. +class EmbeddingChoice(ABC): + @abstractmethod + def get_encoder(self) -> EmbeddingEncoder: + raise NotImplementedError() - The intent of this function is to cache the SentenceTransformer instance to avoid - multple instances being created all over langkit, and have a single place that - can be used to change the transformer name for the metrics that default to the same one. - """ - transformer_name, revision = name_revision - device = "cuda" if torch.cuda.is_available() else "cpu" - return SentenceTransformer(transformer_name, revision=revision, device=device) +class SentenceTransformerChoice(EmbeddingChoice): + def __init__(self, name: str, revision: str): + self.name = name + self.revision = revision -@lru_cache -def embedding_adapter(onnx: bool = True) -> EmbeddingEncoder: - if onnx: + def get_encoder(self) -> EmbeddingEncoder: + device = "cuda" if torch.cuda.is_available() else "cpu" + return TransformerEmbeddingAdapter(SentenceTransformer(self.name, revision=self.revision, device=device)) + + +class DefaultChoice(SentenceTransformerChoice): + def __init__(self): + super().__init__("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6") + + +class OnnxChoice(EmbeddingChoice): + def get_encoder(self) -> EmbeddingEncoder: return OnnxSentenceTransformer(TransformerModel.AllMiniLM) + + +@dataclass(frozen=True) +class SentenceTransformerTarget: + name: str + revision: str + + +EmbeddingChoiceArg = Union[Literal["default"], Literal["onnx"], SentenceTransformerTarget] + + +@lru_cache +def embedding_adapter(choice: EmbeddingChoiceArg = "default") -> EmbeddingEncoder: + if choice == "default": + return DefaultChoice().get_encoder() + elif choice == "onnx": + return OnnxChoice().get_encoder() else: - return TransformerEmbeddingAdapter(_sentence_transformer()) + return SentenceTransformerChoice(choice.name, choice.revision).get_encoder() @dataclass(frozen=True) class EmbeddingContextDependency(ContextDependency[torch.Tensor]): - onnx: bool + embedding_choice: EmbeddingChoiceArg input_column: str def name(self) -> str: - return f"{self.input_column}.embedding?onnx={self.onnx}" + if self.embedding_choice == "default": + choice_str = "default" + elif self.embedding_choice == "onnx": + choice_str = "onnx" + else: + choice_str = f"{self.embedding_choice.name}-{self.embedding_choice.revision}" + + return f"{self.input_column}.embedding?type={choice_str}" + + def _get_encoder(self) -> EmbeddingEncoder: + return embedding_adapter(choice=self.embedding_choice) def cache_assets(self) -> None: - # TODO do only the downloading - embedding_adapter(onnx=self.onnx) + self._get_encoder() def init(self) -> None: - embedding_adapter(onnx=self.onnx) + self._get_encoder() def populate_request(self, context: Context, data: pd.DataFrame): if self.input_column not in data.columns: @@ -57,7 +89,7 @@ def populate_request(self, context: Context, data: pd.DataFrame): if self.name() in context.request_data: return - encoder = embedding_adapter(onnx=self.onnx) + encoder = self._get_encoder() embedding = encoder.encode(tuple(data[self.input_column])) # pyright: ignore[reportUnknownArgumentType] context.request_data[self.name()] = embedding @@ -67,7 +99,7 @@ def get_request_data(self, context: Context) -> torch.Tensor: @dataclass(frozen=True) class RAGContextDependency(ContextDependency[torch.Tensor]): - onnx: bool + embedding_choice: EmbeddingChoiceArg strategy: Literal["combine"] = "combine" """ The strategy for converting the context into embeddings. @@ -77,14 +109,23 @@ class RAGContextDependency(ContextDependency[torch.Tensor]): context_column_name: str = "context" def name(self) -> str: - return f"{self.context_column_name}.context?onnx={self.onnx}" + if self.embedding_choice == "default": + choice_str = "default" + elif self.embedding_choice == "onnx": + choice_str = "onnx" + else: + choice_str = f"{self.embedding_choice.name}-{self.embedding_choice.revision}" + + return f"{self.context_column_name}.context?type={choice_str}&strategy={self.strategy}" + + def _get_encoder(self) -> EmbeddingEncoder: + return embedding_adapter(choice=self.embedding_choice) def cache_assets(self) -> None: - # TODO do only the downloading - embedding_adapter(onnx=self.onnx) + self._get_encoder() def init(self) -> None: - embedding_adapter(onnx=self.onnx) + self._get_encoder() def populate_request(self, context: Context, data: pd.DataFrame): if self.context_column_name not in data.columns: @@ -104,7 +145,7 @@ def populate_request(self, context: Context, data: pd.DataFrame): else: raise ValueError(f"Unknown context embedding strategy {self.strategy}") - encoder = embedding_adapter(onnx=self.onnx) + encoder = self._get_encoder() embedding = encoder.encode(tuple(combined)) context.request_data[self.name()] = embedding