From 0001d142c05eada60f80d0cf6889003ac7671009 Mon Sep 17 00:00:00 2001 From: Anthony Naddeo Date: Sun, 7 Apr 2024 12:20:12 -0700 Subject: [PATCH] Update asset caching logic for toxicity model --- langkit/metrics/toxicity.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/langkit/metrics/toxicity.py b/langkit/metrics/toxicity.py index b3edc97..bae0210 100644 --- a/langkit/metrics/toxicity.py +++ b/langkit/metrics/toxicity.py @@ -25,20 +25,28 @@ def __toxicity(pipeline: TextClassificationPipeline, max_length: int, text: List __model_path = "martin-ha/toxic-comment-model" +def _cache_assets(): + AutoModelForSequenceClassification.from_pretrained(__model_path) + AutoTokenizer.from_pretrained(__model_path) + + @lru_cache def _get_tokenizer() -> PreTrainedTokenizerBase: - return AutoTokenizer.from_pretrained(__model_path) + return AutoTokenizer.from_pretrained(__model_path, local_files_only=True) @lru_cache def _get_pipeline() -> TextClassificationPipeline: use_cuda = torch.cuda.is_available() and not bool(os.environ.get("LANGKIT_NO_CUDA", False)) - model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained(__model_path) + model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained(__model_path, local_files_only=True) tokenizer = _get_tokenizer() return TextClassificationPipeline(model=model, tokenizer=tokenizer, device=0 if use_cuda else -1) def toxicity_metric(column_name: str) -> Metric: + def cache_assets(): + _cache_assets() + def init(): _get_pipeline() @@ -51,7 +59,9 @@ def udf(text: pd.DataFrame) -> SingleMetricResult: metrics = __toxicity(_pipeline, max_length, col) return SingleMetricResult(metrics=metrics) - return SingleMetric(name=f"{column_name}.toxicity.toxicity_score", input_names=[column_name], evaluate=udf, init=init) + return SingleMetric( + name=f"{column_name}.toxicity.toxicity_score", input_names=[column_name], evaluate=udf, init=init, cache_assets=cache_assets + ) prompt_toxicity_metric = partial(toxicity_metric, "prompt")