Skip to content

Commit

Permalink
Add a refusal pattern matching metric
Browse files Browse the repository at this point in the history
  • Loading branch information
naddeoa committed Apr 18, 2024
1 parent c41654e commit 22f979f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 1 deletion.
5 changes: 4 additions & 1 deletion langkit/asset_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def _is_zip_file(file_path: str) -> bool:
@retry(stop=stop_after_attempt(3), wait=wait_exponential_jitter(max=5))
def _download_asset(asset_id: str, tag: str = "0"):
asset_path = _get_asset_path(asset_id, tag)
response: GetAssetResponse = cast(GetAssetResponse, assets_api.get_asset(asset_id))
try:
response: GetAssetResponse = cast(GetAssetResponse, assets_api.get_asset(asset_id))
except whylabs_client.ApiException as e:
raise ValueError(f"Failed to download asset {asset_id} with tag {tag}: {e}")
url = cast(str, response.download_url)
os.makedirs(os.path.dirname(asset_path.zip_path), exist_ok=True)
r = requests.get(url, stream=True)
Expand Down
6 changes: 6 additions & 0 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,12 @@ def __call__(self) -> MetricCreator:

return response_regex_metric

@staticmethod
def refusal() -> MetricCreator:
from langkit.metrics.regexes.regexes import response_refusal_regex_metric

return response_refusal_regex_metric

@staticmethod
def ssn() -> MetricCreator:
from langkit.metrics.regexes.regexes import response_ssn_regex_metric
Expand Down
24 changes: 24 additions & 0 deletions langkit/metrics/regexes/regexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@
)


def generate_refusal_regex(phrases: List[str]):
# Replace specific escaped tokens that should be variable (like spaces and apostrophes)
adjusted_phrases = [phrase.replace("\\ ", r"\\s+").replace("['’]", "['’]") for phrase in phrases]

# Join all phrases into a single regex pattern with word boundaries and a case-insensitive flag
pattern = r"\b(" + "|".join(adjusted_phrases) + r")\b"
print(pattern)
return re.compile(pattern, re.IGNORECASE)


# expressions=[re.compile(r"\b(I['’]m\s+sorry|I\s+can['’]t)\b", re.IGNORECASE)],

__default_patterns["patterns"].append(
# This regex is too complicated to define in json
CompiledPattern(
name="refusal",
expressions=[generate_refusal_regex(["I'm sorry", "I can't", "I cannot", "I can not", "I'm unable", "I am unable"])],
substitutions=None,
)
)


def __sanitize_name_for_metric(pattern_name: str) -> str:
return pattern_name.replace(" ", "_").lower()

Expand Down Expand Up @@ -148,6 +170,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult:
]

response_ssn_regex_metric = partial(__single_regex_module, "response", __default_patterns, "SSN")
response_refusal_regex_metric = partial(__single_regex_module, "response", __default_patterns, "refusal")
response_credit_card_number_regex_metric = partial(__single_regex_module, "response", __default_patterns, "credit card number")
response_phone_number_regex_metric = partial(__single_regex_module, "response", __default_patterns, "phone number")
response_mailing_address_regex_metric = partial(__single_regex_module, "response", __default_patterns, "mailing address")
Expand All @@ -156,6 +179,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult:

response_regex_metric = [
response_ssn_regex_metric,
response_refusal_regex_metric,
response_credit_card_number_regex_metric,
response_phone_number_regex_metric,
response_mailing_address_regex_metric,
Expand Down
21 changes: 21 additions & 0 deletions tests/langkit/metrics/test_regexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import whylogs as why
from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib
from langkit.metrics.regexes.regex_loader import CompiledPatternGroups, PatternGroups
from langkit.metrics.regexes.regexes import (
get_custom_regex_frequent_items_for_column_module,
Expand Down Expand Up @@ -103,6 +104,26 @@ def test_prompt_regex_df_url():
assert actual["response.regex.url"][0] == 1


def test_response_regex_refusal():
df = pd.DataFrame(
{
"response": [
"I'm sorry, I can't answer that",
],
}
)

wf = Workflow(metrics=[lib.response.regex.refusal()])
result = wf.run(df)

actual = result.metrics

expected_columns = ["prompt.regex.refusal", "id"]

assert list(actual.columns) == expected_columns
assert actual["prompt.regex.refusal"][0] == 1


def test_prompt_regex_df_ssn():
df = pd.DataFrame(
{
Expand Down

0 comments on commit 22f979f

Please sign in to comment.