Skip to content

Commit

Permalink
enable different model versions for topic model
Browse files Browse the repository at this point in the history
  • Loading branch information
felipe207 committed Mar 25, 2024
1 parent bf44778 commit 0d92dd4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 18 deletions.
51 changes: 34 additions & 17 deletions langkit/metrics/topic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from dataclasses import dataclass
from functools import partial
from functools import lru_cache, partial
from typing import Any, Dict, List, Optional

import pandas as pd
import torch
from transformers import Pipeline, pipeline # type: ignore

from langkit.core.metric import MetricCreator, MultiMetric, MultiMetricResult
from langkit.metrics.util import LazyInit

__default_topics = [
"medicine",
Expand All @@ -19,21 +18,35 @@
_hypothesis_template = "This example is about {}"


__classifier: LazyInit[Pipeline] = LazyInit(
lambda: pipeline(
@lru_cache(maxsize=None)
def _get_classifier(model_version: str) -> Pipeline:
return pipeline(
"zero-shot-classification",
model="MoritzLaurer/xtremedistil-l6-h256-zeroshot-v1.1-all-33",
model=model_version,
device="cuda" if torch.cuda.is_available() else "cpu",
)
)


MODEL_SMALL = "MoritzLaurer/xtremedistil-l6-h256-zeroshot-v1.1-all-33"
MODEL_BASE = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33"
MODEL_LARGE = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"


def __get_scores_per_label(
text: str, topics: List[str], hypothesis_template: str = _hypothesis_template, multi_label: bool = True
text: str,
topics: List[str],
hypothesis_template: str = _hypothesis_template,
multi_label: bool = True,
model_version: str = MODEL_SMALL,
) -> Optional[Dict[str, float]]:
if not text:
return None
result: Dict[str, [str, float]] = __classifier.value(text, topics, hypothesis_template=hypothesis_template, multi_label=multi_label) # type: ignore
result: Dict[str, [str, float]] = _get_classifier(model_version)( # type: ignore
text,
topics,
hypothesis_template=hypothesis_template,
multi_label=multi_label, # type: ignore
)
scores_per_label: Dict[str, float] = {label: score for label, score in zip(result["labels"], result["scores"])} # type: ignore[reportUnknownVariableType]
return scores_per_label

Expand All @@ -45,15 +58,17 @@ def _sanitize_metric_name(topic: str) -> str:
return topic.replace(" ", "_").lower()


def topic_metric(input_name: str, topics: List[str], hypothesis_template: Optional[str] = None) -> MultiMetric:
def topic_metric(
input_name: str, topics: List[str], hypothesis_template: Optional[str] = None, model_version: str = MODEL_SMALL
) -> MultiMetric:
hypothesis_template = hypothesis_template or _hypothesis_template

def udf(text: pd.DataFrame) -> MultiMetricResult:
metrics: Dict[str, List[Optional[float]]] = {topic: [] for topic in topics}

def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]:
value: Any = row[input_name] # type: ignore
scores = __get_scores_per_label(value, topics=topics, hypothesis_template=hypothesis_template) # pyright: ignore[reportUnknownArgumentType]
scores = __get_scores_per_label(value, topics=topics, hypothesis_template=hypothesis_template, model_version=model_version) # pyright: ignore[reportUnknownArgumentType]
for topic in topics:
metrics[topic].append(scores[topic] if scores else None)
return metrics
Expand All @@ -67,15 +82,15 @@ def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]:
return MultiMetricResult(metrics=all_metrics)

def cache_assets():
__classifier.value
_get_classifier(model_version)

metric_names = [f"{input_name}.topics.{_sanitize_metric_name(topic)}" for topic in topics]
return MultiMetric(names=metric_names, input_name=input_name, evaluate=udf, cache_assets=cache_assets)


prompt_topic_module = partial(topic_metric, "prompt", __default_topics, _hypothesis_template)
response_topic_module = partial(topic_metric, "response", __default_topics, _hypothesis_template)
prompt_response_topic_module = [prompt_topic_module, response_topic_module, _hypothesis_template]
prompt_topic_module = partial(topic_metric, "prompt", __default_topics, _hypothesis_template, MODEL_SMALL)
response_topic_module = partial(topic_metric, "response", __default_topics, _hypothesis_template, MODEL_SMALL)
prompt_response_topic_module = [prompt_topic_module, response_topic_module, _hypothesis_template, MODEL_SMALL]


@dataclass
Expand All @@ -85,9 +100,11 @@ class CustomTopicModules:
prompt_response_topic_module: MetricCreator


def get_custom_topic_modules(topics: List[str], template: str = _hypothesis_template) -> CustomTopicModules:
prompt_topic_module = partial(topic_metric, "prompt", topics, template)
response_topic_module = partial(topic_metric, "response", topics, template)
def get_custom_topic_modules(
topics: List[str], template: str = _hypothesis_template, model_version: str = MODEL_SMALL
) -> CustomTopicModules:
prompt_topic_module = partial(topic_metric, "prompt", topics, template, model_version)
response_topic_module = partial(topic_metric, "response", topics, template, model_version)
return CustomTopicModules(
prompt_topic_module=prompt_topic_module,
response_topic_module=response_topic_module,
Expand Down
37 changes: 36 additions & 1 deletion tests/langkit/metrics/test_topic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pyright: reportUnknownMemberType=none
from functools import partial
from typing import Any

import pandas as pd
Expand All @@ -7,7 +8,7 @@
from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib
from langkit.metrics.topic import get_custom_topic_modules, prompt_topic_module
from langkit.metrics.topic import MODEL_BASE, get_custom_topic_modules, prompt_topic_module, topic_metric
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema

expected_metrics = [
Expand Down Expand Up @@ -81,6 +82,21 @@ def test_topic():
assert actual.index.tolist() == expected_columns


def test_topic_base_model():
df = pd.DataFrame(
{
"prompt": [
"http://get-free-money-now.xyz/bank/details",
],
}
)

custom_topic_module = partial(topic_metric, "prompt", ["phishing"], model_version=MODEL_BASE)
schema = WorkflowMetricConfigBuilder().add(custom_topic_module).build()
actual = _log(df, schema)
assert actual.loc["prompt.topics.phishing"]["distribution/mean"] > 0.80


def test_topic_empty_input():
df = pd.DataFrame(
{
Expand Down Expand Up @@ -243,6 +259,25 @@ def test_custom_topic():
assert actual.loc[column]["distribution/max"] >= 0.50


def test_custom_topics_base_model():
df = pd.DataFrame(
{
"prompt": [
"http://get-free-money-now.xyz/bank/details",
],
"response": [
"http://win-a-free-iphone-today.net",
],
}
)

custom_topic_modules = get_custom_topic_modules(["phishing"], model_version=MODEL_BASE)
schema = WorkflowMetricConfigBuilder().add(custom_topic_modules.prompt_response_topic_module).build()
actual = _log(df, schema)
assert actual.loc["prompt.topics.phishing"]["distribution/mean"] > 0.80
assert actual.loc["response.topics.phishing"]["distribution/mean"] > 0.80


def test_topic_name_sanitize():
df = pd.DataFrame(
{
Expand Down

0 comments on commit 0d92dd4

Please sign in to comment.