Skip to content

Commit

Permalink
Merge pull request #303 from whylabs/toxicity-revision
Browse files Browse the repository at this point in the history
Add toxicity model and version options
  • Loading branch information
naddeoa authored May 1, 2024
2 parents 10d010d + 14235d4 commit 45ece8e
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 35 deletions.
13 changes: 10 additions & 3 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,26 @@ def __call__(self) -> MetricCreator:
return self.toxicity_score()

@staticmethod
def toxicity_score(onnx: bool = True) -> MetricCreator:
def toxicity_score(
onnx: bool = True, onnx_tag: Optional[str] = None, hf_model: Optional[str] = None, hf_model_revision: Optional[str] = None
) -> MetricCreator:
"""
Analyze the input for toxicity. The output of this metric ranges from 0 to 1, where 0 indicates
non-toxic and 1 indicates toxic.
:param onnx: Whether to use the ONNX model for toxicity analysis. This is mutually exclusive with model options.
:param hf_model: The Hugging Face model to use for toxicity analysis. Defaults to martin-ha/toxic-comment-model
:param hf_model_revision: The revision of the Hugging Face model to use. This default can change between releases so you
can specify the revision to lock it to a specific version.
"""
if onnx:
from langkit.metrics.toxicity_onnx import prompt_toxicity_metric

return prompt_toxicity_metric
return partial(prompt_toxicity_metric, tag=onnx_tag)
else:
from langkit.metrics.toxicity import prompt_toxicity_metric

return prompt_toxicity_metric
return partial(prompt_toxicity_metric, hf_model=hf_model, hf_model_revision=hf_model_revision)

class stats:
def __call__(self) -> MetricCreator:
Expand Down
35 changes: 17 additions & 18 deletions langkit/metrics/toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyright: reportUnknownLambdaType=none
import os
from functools import lru_cache, partial
from typing import List, cast
from typing import List, Optional, cast

import pandas as pd
import torch
Expand All @@ -22,40 +22,39 @@ def __toxicity(pipeline: TextClassificationPipeline, max_length: int, text: List
return [result["score"] if result["label"] == "toxic" else 1.0 - result["score"] for result in results] # type: ignore


_model_path = "martin-ha/toxic-comment-model"
_revision = "9842c08b35a4687e7b211187d676986c8c96256d"


def _cache_assets():
AutoModelForSequenceClassification.from_pretrained(_model_path, revision=_revision)
AutoTokenizer.from_pretrained(_model_path, revision=_revision)
def _cache_assets(model_path: str, revision: str):
AutoModelForSequenceClassification.from_pretrained(model_path, revision=revision)
AutoTokenizer.from_pretrained(model_path, revision=revision)


@lru_cache
def _get_tokenizer() -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(_model_path, local_files_only=True, revision=_revision)
def _get_tokenizer(model_path: str, revision: str) -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(model_path, local_files_only=True, revision=revision)


@lru_cache
def _get_pipeline() -> TextClassificationPipeline:
def _get_pipeline(model_path: str, revision: str) -> TextClassificationPipeline:
use_cuda = torch.cuda.is_available() and not bool(os.environ.get("LANGKIT_NO_CUDA", False))
model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained(
_model_path, local_files_only=True, revision=_revision
model_path, local_files_only=True, revision=revision
)
tokenizer = _get_tokenizer()
tokenizer = _get_tokenizer(model_path, revision)
return TextClassificationPipeline(model=model, tokenizer=tokenizer, device=0 if use_cuda else -1)


def toxicity_metric(column_name: str) -> Metric:
def toxicity_metric(column_name: str, hf_model: Optional[str] = None, hf_model_revision: Optional[str] = None) -> Metric:
model_path = "martin-ha/toxic-comment-model" if hf_model is None else hf_model
revision = "9842c08b35a4687e7b211187d676986c8c96256d" if hf_model_revision is None else hf_model_revision

def cache_assets():
_cache_assets()
_cache_assets(model_path, revision)

def init():
_get_pipeline()
_get_pipeline(model_path, revision)

def udf(text: pd.DataFrame) -> SingleMetricResult:
_tokenizer = _get_tokenizer()
_pipeline = _get_pipeline()
_tokenizer = _get_tokenizer(model_path, revision)
_pipeline = _get_pipeline(model_path, revision)

col = list(UdfInput(text).iter_column_rows(column_name))
max_length = cast(int, _tokenizer.model_max_length)
Expand Down
28 changes: 14 additions & 14 deletions langkit/metrics/toxicity_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyright: reportUnknownLambdaType=none
import os
from functools import lru_cache, partial
from typing import List, cast
from typing import List, Optional, cast

import numpy as np
import onnxruntime
Expand Down Expand Up @@ -36,35 +36,35 @@ def __toxicity(tokenizer: PreTrainedTokenizerBase, session: onnxruntime.Inferenc
return [result["score"] if result["label"] == "toxic" else 1.0 - result["score"] for result in results] # type: ignore


def _download_assets():
name, tag = TransformerModel.ToxicCommentModel.value
return get_asset(name, tag)
def _download_assets(tag: Optional[str]):
name, default_tag = TransformerModel.ToxicCommentModel.value
return get_asset(name, tag or default_tag)


@lru_cache
def _get_tokenizer() -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(_download_assets())
def _get_tokenizer(tag: Optional[str]) -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(_download_assets(tag))


@lru_cache
def _get_session() -> onnxruntime.InferenceSession:
downloaded_path = _download_assets()
def _get_session(tag: Optional[str]) -> onnxruntime.InferenceSession:
downloaded_path = _download_assets(tag)
onnx_model_path = os.path.join(downloaded_path, "model.onnx")
print(f"Loading ONNX model from {onnx_model_path}")
return onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])


def toxicity_metric(column_name: str) -> Metric:
def toxicity_metric(column_name: str, tag: Optional[str] = None) -> Metric:
def cache_assets():
_download_assets()
_download_assets(tag)

def init():
_get_session()
_get_tokenizer()
_get_session(tag)
_get_tokenizer(tag)

def udf(text: pd.DataFrame) -> SingleMetricResult:
_tokenizer = _get_tokenizer()
_session = _get_session()
_tokenizer = _get_tokenizer(tag)
_session = _get_session(tag)

col = list(UdfInput(text).iter_column_rows(column_name))
max_length = cast(int, _tokenizer.model_max_length)
Expand Down
11 changes: 11 additions & 0 deletions tests/langkit/metrics/test_toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import whylogs as why
from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib as metrics_lib
from langkit.metrics.toxicity import prompt_response_toxicity_module, prompt_toxicity_metric, response_toxicity_metric
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema

Expand Down Expand Up @@ -81,6 +83,15 @@ def test_prompt_toxicity_row_non_toxic():
assert actual["distribution/max"]["prompt.toxicity.toxicity_score"] < 0.1


def test_prompt_toxicity_version():
wf = Workflow(metrics=[metrics_lib.prompt.toxicity.toxicity_score(hf_model_revision="f1c3aa41130e8baeee31c3ea5d14598a0d3385e5")])
result = wf.run(row)

expected_columns = ["prompt.toxicity.toxicity_score", "id"]

assert list(result.metrics.columns) == expected_columns


def test_prompt_toxicity_df_non_toxic():
schema = WorkflowMetricConfigBuilder().add(prompt_toxicity_metric).build()

Expand Down

0 comments on commit 45ece8e

Please sign in to comment.