From c824681a28b4105014ef547eb5298e0fee8df595 Mon Sep 17 00:00:00 2001 From: Anthony Naddeo Date: Sun, 7 Apr 2024 12:20:12 -0700 Subject: [PATCH] Update asset caching logic for toxicity model --- langkit/asset_downloader.py | 2 +- langkit/metrics/toxicity.py | 21 +++++++++++++++++---- langkit/metrics/whylogs_compat.py | 2 ++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/langkit/asset_downloader.py b/langkit/asset_downloader.py index 150dc30..de0206a 100644 --- a/langkit/asset_downloader.py +++ b/langkit/asset_downloader.py @@ -65,7 +65,7 @@ def _is_zip_file(file_path: str) -> bool: return False -@retry(stop=stop_after_attempt(3), wait=wait_exponential_jitter(max=5)) +# @retry(stop=stop_after_attempt(3), wait=wait_exponential_jitter(max=5)) def _download_asset(asset_id: str, tag: str = "0"): asset_path = _get_asset_path(asset_id, tag) response: GetAssetResponse = cast(GetAssetResponse, assets_api.get_asset(asset_id)) diff --git a/langkit/metrics/toxicity.py b/langkit/metrics/toxicity.py index b3edc97..36b9d9f 100644 --- a/langkit/metrics/toxicity.py +++ b/langkit/metrics/toxicity.py @@ -22,23 +22,34 @@ def __toxicity(pipeline: TextClassificationPipeline, max_length: int, text: List return [result["score"] if result["label"] == "toxic" else 1.0 - result["score"] for result in results] # type: ignore -__model_path = "martin-ha/toxic-comment-model" +_model_path = "martin-ha/toxic-comment-model" +_revision = "9842c08b35a4687e7b211187d676986c8c96256d" + + +def _cache_assets(): + AutoModelForSequenceClassification.from_pretrained(_model_path, revision=_revision) + AutoTokenizer.from_pretrained(_model_path, revision=_revision) @lru_cache def _get_tokenizer() -> PreTrainedTokenizerBase: - return AutoTokenizer.from_pretrained(__model_path) + return AutoTokenizer.from_pretrained(_model_path, local_files_only=True, revision=_revision) @lru_cache def _get_pipeline() -> TextClassificationPipeline: use_cuda = torch.cuda.is_available() and not bool(os.environ.get("LANGKIT_NO_CUDA", False)) - model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained(__model_path) + model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained( + _model_path, local_files_only=True, revision=_revision + ) tokenizer = _get_tokenizer() return TextClassificationPipeline(model=model, tokenizer=tokenizer, device=0 if use_cuda else -1) def toxicity_metric(column_name: str) -> Metric: + def cache_assets(): + _cache_assets() + def init(): _get_pipeline() @@ -51,7 +62,9 @@ def udf(text: pd.DataFrame) -> SingleMetricResult: metrics = __toxicity(_pipeline, max_length, col) return SingleMetricResult(metrics=metrics) - return SingleMetric(name=f"{column_name}.toxicity.toxicity_score", input_names=[column_name], evaluate=udf, init=init) + return SingleMetric( + name=f"{column_name}.toxicity.toxicity_score", input_names=[column_name], evaluate=udf, init=init, cache_assets=cache_assets + ) prompt_toxicity_metric = partial(toxicity_metric, "prompt") diff --git a/langkit/metrics/whylogs_compat.py b/langkit/metrics/whylogs_compat.py index ba1a526..932b6a2 100644 --- a/langkit/metrics/whylogs_compat.py +++ b/langkit/metrics/whylogs_compat.py @@ -105,6 +105,8 @@ def to_udf_schema_args(metric: Metric) -> List[UdfSchemaArgs]: def create_whylogs_udf_schema(eval_conf: WorkflowMetricConfig) -> UdfSchema: for metric in eval_conf.metrics: + if metric.cache_assets: + metric.cache_assets() if metric.init: metric.init()