From 22f979f657b246b790edb92bd52fa772b1636954 Mon Sep 17 00:00:00 2001 From: Anthony Naddeo Date: Thu, 18 Apr 2024 13:28:43 -0700 Subject: [PATCH] Add a refusal pattern matching metric --- langkit/asset_downloader.py | 5 ++++- langkit/metrics/library.py | 6 ++++++ langkit/metrics/regexes/regexes.py | 24 ++++++++++++++++++++++++ tests/langkit/metrics/test_regexes.py | 21 +++++++++++++++++++++ 4 files changed, 55 insertions(+), 1 deletion(-) diff --git a/langkit/asset_downloader.py b/langkit/asset_downloader.py index 150dc30..34a4dbc 100644 --- a/langkit/asset_downloader.py +++ b/langkit/asset_downloader.py @@ -68,7 +68,10 @@ def _is_zip_file(file_path: str) -> bool: @retry(stop=stop_after_attempt(3), wait=wait_exponential_jitter(max=5)) def _download_asset(asset_id: str, tag: str = "0"): asset_path = _get_asset_path(asset_id, tag) - response: GetAssetResponse = cast(GetAssetResponse, assets_api.get_asset(asset_id)) + try: + response: GetAssetResponse = cast(GetAssetResponse, assets_api.get_asset(asset_id)) + except whylabs_client.ApiException as e: + raise ValueError(f"Failed to download asset {asset_id} with tag {tag}: {e}") url = cast(str, response.download_url) os.makedirs(os.path.dirname(asset_path.zip_path), exist_ok=True) r = requests.get(url, stream=True) diff --git a/langkit/metrics/library.py b/langkit/metrics/library.py index 24369f7..f9799da 100644 --- a/langkit/metrics/library.py +++ b/langkit/metrics/library.py @@ -432,6 +432,12 @@ def __call__(self) -> MetricCreator: return response_regex_metric + @staticmethod + def refusal() -> MetricCreator: + from langkit.metrics.regexes.regexes import response_refusal_regex_metric + + return response_refusal_regex_metric + @staticmethod def ssn() -> MetricCreator: from langkit.metrics.regexes.regexes import response_ssn_regex_metric diff --git a/langkit/metrics/regexes/regexes.py b/langkit/metrics/regexes/regexes.py index ae915ed..a51ba90 100644 --- a/langkit/metrics/regexes/regexes.py +++ b/langkit/metrics/regexes/regexes.py @@ -26,6 +26,28 @@ ) +def generate_refusal_regex(phrases: List[str]): + # Replace specific escaped tokens that should be variable (like spaces and apostrophes) + adjusted_phrases = [phrase.replace("\\ ", r"\\s+").replace("['’]", "['’]") for phrase in phrases] + + # Join all phrases into a single regex pattern with word boundaries and a case-insensitive flag + pattern = r"\b(" + "|".join(adjusted_phrases) + r")\b" + print(pattern) + return re.compile(pattern, re.IGNORECASE) + + +# expressions=[re.compile(r"\b(I['’]m\s+sorry|I\s+can['’]t)\b", re.IGNORECASE)], + +__default_patterns["patterns"].append( + # This regex is too complicated to define in json + CompiledPattern( + name="refusal", + expressions=[generate_refusal_regex(["I'm sorry", "I can't", "I cannot", "I can not", "I'm unable", "I am unable"])], + substitutions=None, + ) +) + + def __sanitize_name_for_metric(pattern_name: str) -> str: return pattern_name.replace(" ", "_").lower() @@ -148,6 +170,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult: ] response_ssn_regex_metric = partial(__single_regex_module, "response", __default_patterns, "SSN") +response_refusal_regex_metric = partial(__single_regex_module, "response", __default_patterns, "refusal") response_credit_card_number_regex_metric = partial(__single_regex_module, "response", __default_patterns, "credit card number") response_phone_number_regex_metric = partial(__single_regex_module, "response", __default_patterns, "phone number") response_mailing_address_regex_metric = partial(__single_regex_module, "response", __default_patterns, "mailing address") @@ -156,6 +179,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult: response_regex_metric = [ response_ssn_regex_metric, + response_refusal_regex_metric, response_credit_card_number_regex_metric, response_phone_number_regex_metric, response_mailing_address_regex_metric, diff --git a/tests/langkit/metrics/test_regexes.py b/tests/langkit/metrics/test_regexes.py index 2303d4e..bae968a 100644 --- a/tests/langkit/metrics/test_regexes.py +++ b/tests/langkit/metrics/test_regexes.py @@ -11,6 +11,7 @@ import whylogs as why from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder from langkit.core.workflow import Workflow +from langkit.metrics.library import lib from langkit.metrics.regexes.regex_loader import CompiledPatternGroups, PatternGroups from langkit.metrics.regexes.regexes import ( get_custom_regex_frequent_items_for_column_module, @@ -103,6 +104,26 @@ def test_prompt_regex_df_url(): assert actual["response.regex.url"][0] == 1 +def test_response_regex_refusal(): + df = pd.DataFrame( + { + "response": [ + "I'm sorry, I can't answer that", + ], + } + ) + + wf = Workflow(metrics=[lib.response.regex.refusal()]) + result = wf.run(df) + + actual = result.metrics + + expected_columns = ["prompt.regex.refusal", "id"] + + assert list(actual.columns) == expected_columns + assert actual["prompt.regex.refusal"][0] == 1 + + def test_prompt_regex_df_ssn(): df = pd.DataFrame( {