diff --git a/langkit/metrics/library.py b/langkit/metrics/library.py index cac7e40..2050bef 100644 --- a/langkit/metrics/library.py +++ b/langkit/metrics/library.py @@ -226,6 +226,12 @@ def credit_card_number() -> MetricCreator: return prompt_credit_card_number_regex_metric + @staticmethod + def url() -> MetricCreator: + from langkit.metrics.regexes.regexes import prompt_url_regex_metric + + return prompt_url_regex_metric + class similarity: """ These metrics are used to compare the response to various examples and use cosine similarity/embedding distances @@ -427,6 +433,12 @@ def credit_card_number() -> MetricCreator: return response_credit_card_number_regex_metric + @staticmethod + def url() -> MetricCreator: + from langkit.metrics.regexes.regexes import response_url_regex_metric + + return response_url_regex_metric + class sentiment: def __call__(self) -> MetricCreator: return self.sentiment_score() diff --git a/langkit/metrics/regexes/regexes.py b/langkit/metrics/regexes/regexes.py index 142dd9e..9ec758a 100644 --- a/langkit/metrics/regexes/regexes.py +++ b/langkit/metrics/regexes/regexes.py @@ -1,4 +1,5 @@ import os +import re from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Optional, Union @@ -6,11 +7,23 @@ import pandas as pd from langkit.core.metric import Metric, MetricCreator, SingleMetric, SingleMetricResult, UdfInput -from langkit.metrics.regexes.regex_loader import CompiledPatternGroups, load_patterns_file +from langkit.metrics.regexes.regex_loader import CompiledPattern, CompiledPatternGroups, load_patterns_file __current_module_path = os.path.dirname(__file__) __default_pattern_file = os.path.join(__current_module_path, "pattern_groups.json") __default_patterns: CompiledPatternGroups = load_patterns_file(__default_pattern_file) +__default_patterns["patterns"].append( + # This regex is too complicated to define in json + CompiledPattern( + name="url", + expressions=[ + re.compile( + r"""(?i)\b((?:https?:(?:/{1,3}|[a-z0-9%])|[a-z0-9.\-]+[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|name|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)/)(?:[^\s()<>{}\[\]]+|\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\))+(?:\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\)|[^\s`!()\[\]{};:'".,<>?«»“”‘’])|(?:(? str: @@ -123,6 +136,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult: prompt_phone_number_regex_metric = partial(__single_regex_module, "prompt", __default_patterns, "phone number") prompt_mailing_address_regex_metric = partial(__single_regex_module, "prompt", __default_patterns, "mailing address") prompt_email_address_regex_metric = partial(__single_regex_module, "prompt", __default_patterns, "email address") +prompt_url_regex_metric = partial(__single_regex_module, "prompt", __default_patterns, "url") prompt_regex_metric = [ prompt_ssn_regex_metric, @@ -130,6 +144,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult: prompt_phone_number_regex_metric, prompt_mailing_address_regex_metric, prompt_email_address_regex_metric, + prompt_url_regex_metric, ] response_ssn_regex_metric = partial(__single_regex_module, "response", __default_patterns, "SSN") @@ -137,6 +152,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult: response_phone_number_regex_metric = partial(__single_regex_module, "response", __default_patterns, "phone number") response_mailing_address_regex_metric = partial(__single_regex_module, "response", __default_patterns, "mailing address") response_email_address_regex_metric = partial(__single_regex_module, "response", __default_patterns, "email address") +response_url_regex_metric = partial(__single_regex_module, "response", __default_patterns, "url") response_regex_metric = [ response_ssn_regex_metric, @@ -144,6 +160,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult: response_phone_number_regex_metric, response_mailing_address_regex_metric, response_email_address_regex_metric, + response_url_regex_metric, ] prompt_response_ssn_regex_module = [prompt_ssn_regex_metric, response_ssn_regex_metric] @@ -151,6 +168,7 @@ def udf(text: Union[pd.DataFrame, Dict[str, List[Any]]]) -> SingleMetricResult: prompt_response_phone_number_regex_module = [prompt_phone_number_regex_metric, response_phone_number_regex_metric] prompt_response_mailing_address_regex_module = [prompt_mailing_address_regex_metric, response_mailing_address_regex_metric] prompt_response_email_address_regex_module = [prompt_email_address_regex_metric, response_email_address_regex_metric] +prompt_response_url_regex_module = [prompt_url_regex_metric, response_url_regex_metric] def custom_regex_metric(column_name: str, file_or_patterns: Optional[Union[str, CompiledPatternGroups]] = None) -> MetricCreator: diff --git a/tests/langkit/metrics/test_regexes.py b/tests/langkit/metrics/test_regexes.py index 430a409..2303d4e 100644 --- a/tests/langkit/metrics/test_regexes.py +++ b/tests/langkit/metrics/test_regexes.py @@ -10,6 +10,7 @@ import whylogs as why from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder +from langkit.core.workflow import Workflow from langkit.metrics.regexes.regex_loader import CompiledPatternGroups, PatternGroups from langkit.metrics.regexes.regexes import ( get_custom_regex_frequent_items_for_column_module, @@ -27,12 +28,14 @@ prompt_response_phone_number_regex_module, prompt_response_ssn_regex_module, prompt_ssn_regex_metric, + prompt_url_regex_metric, response_credit_card_number_regex_metric, response_default_regexes_module, response_email_address_regex_metric, response_mailing_address_regex_metric, response_phone_number_regex_metric, response_ssn_regex_metric, + response_url_regex_metric, ) from langkit.metrics.whylogs_compat import create_whylogs_udf_schema from whylogs.core.metrics.metrics import FrequentItem @@ -76,6 +79,30 @@ def _log(item: Any, conf: WorkflowMetricConfig) -> pd.DataFrame: return why.log(item, schema=schema).view().to_pandas() # type: ignore +def test_prompt_regex_df_url(): + df = pd.DataFrame( + { + "prompt": [ + "Does this code look good? foo.strip()/10. My blog is at whylabs.ai/foo.html", + ], + "response": [ + "Yeah. Nice blog, mine is at whylabs.ai/bar.html", + ], + } + ) + + wf = Workflow(metrics=[prompt_url_regex_metric, response_url_regex_metric]) + result = wf.run(df) + + actual = result.metrics + + expected_columns = ["prompt.regex.url", "response.regex.url", "id"] + + assert list(actual.columns) == expected_columns + assert actual["prompt.regex.url"][0] == 1 + assert actual["response.regex.url"][0] == 1 + + def test_prompt_regex_df_ssn(): df = pd.DataFrame( {