Skip to content

Commit

Permalink
Update asset caching logic for toxicity model
Browse files Browse the repository at this point in the history
  • Loading branch information
naddeoa committed Apr 7, 2024
1 parent 4e85e39 commit c824681
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion langkit/asset_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _is_zip_file(file_path: str) -> bool:
return False


@retry(stop=stop_after_attempt(3), wait=wait_exponential_jitter(max=5))
# @retry(stop=stop_after_attempt(3), wait=wait_exponential_jitter(max=5))
def _download_asset(asset_id: str, tag: str = "0"):
asset_path = _get_asset_path(asset_id, tag)
response: GetAssetResponse = cast(GetAssetResponse, assets_api.get_asset(asset_id))
Expand Down
21 changes: 17 additions & 4 deletions langkit/metrics/toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,34 @@ def __toxicity(pipeline: TextClassificationPipeline, max_length: int, text: List
return [result["score"] if result["label"] == "toxic" else 1.0 - result["score"] for result in results] # type: ignore


__model_path = "martin-ha/toxic-comment-model"
_model_path = "martin-ha/toxic-comment-model"
_revision = "9842c08b35a4687e7b211187d676986c8c96256d"


def _cache_assets():
AutoModelForSequenceClassification.from_pretrained(_model_path, revision=_revision)
AutoTokenizer.from_pretrained(_model_path, revision=_revision)


@lru_cache
def _get_tokenizer() -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(__model_path)
return AutoTokenizer.from_pretrained(_model_path, local_files_only=True, revision=_revision)


@lru_cache
def _get_pipeline() -> TextClassificationPipeline:
use_cuda = torch.cuda.is_available() and not bool(os.environ.get("LANGKIT_NO_CUDA", False))
model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained(__model_path)
model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained(
_model_path, local_files_only=True, revision=_revision
)
tokenizer = _get_tokenizer()
return TextClassificationPipeline(model=model, tokenizer=tokenizer, device=0 if use_cuda else -1)


def toxicity_metric(column_name: str) -> Metric:
def cache_assets():
_cache_assets()

def init():
_get_pipeline()

Expand All @@ -51,7 +62,9 @@ def udf(text: pd.DataFrame) -> SingleMetricResult:
metrics = __toxicity(_pipeline, max_length, col)
return SingleMetricResult(metrics=metrics)

return SingleMetric(name=f"{column_name}.toxicity.toxicity_score", input_names=[column_name], evaluate=udf, init=init)
return SingleMetric(
name=f"{column_name}.toxicity.toxicity_score", input_names=[column_name], evaluate=udf, init=init, cache_assets=cache_assets
)


prompt_toxicity_metric = partial(toxicity_metric, "prompt")
Expand Down
2 changes: 2 additions & 0 deletions langkit/metrics/whylogs_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def to_udf_schema_args(metric: Metric) -> List[UdfSchemaArgs]:

def create_whylogs_udf_schema(eval_conf: WorkflowMetricConfig) -> UdfSchema:
for metric in eval_conf.metrics:
if metric.cache_assets:
metric.cache_assets()
if metric.init:
metric.init()

Expand Down

0 comments on commit c824681

Please sign in to comment.