Skip to content

Commit

Permalink
Merge pull request #272 from whylabs/validator-options
Browse files Browse the repository at this point in the history
Split options out of validator into dedicated options class
  • Loading branch information
naddeoa authored Mar 24, 2024
2 parents ec0f067 + 93e9a92 commit 8b07401
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 45 deletions.
17 changes: 11 additions & 6 deletions langkit/validators/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,17 @@ def validate_result(self, df: pd.DataFrame) -> Optional[ValidationResult]:
return ValidationResult(failures)


@dataclass
class MultiColumnConstraintValidatorOptions:
constraints: List[ConstraintValidatorOptions]
operator: Literal["AND", "OR"] = "AND"
report_mode: Literal["ALL_FAILED_METRICS", "FIRST_FAILED_METRIC"] = "FIRST_FAILED_METRIC"


class MultiColumnConstraintValidator(Validator):
def __init__(
self,
constraints: List[ConstraintValidatorOptions],
operator: Literal["AND", "OR"] = "AND",
report_mode: Literal["ALL_FAILED_METRICS", "FIRST_FAILED_METRIC"] = "FIRST_FAILED_METRIC",
options: MultiColumnConstraintValidatorOptions,
):
"""
Expand All @@ -222,9 +227,9 @@ def __init__(
return a single validation result when there are failures, and that validation result will contain the
first failed metric. If "ALL_FAILED_METRICS", then this validator will return each validation failure.
"""
self._operator = operator
self._constraints = [ConstraintValidator(constraint) for constraint in constraints]
self._report_mode = report_mode
self._operator = options.operator
self._constraints = [ConstraintValidator(constraint) for constraint in options.constraints]
self._report_mode = options.report_mode

def get_target_metric_names(self) -> List[str]:
target_metrics: List[str] = []
Expand Down
11 changes: 9 additions & 2 deletions langkit/validators/library.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List, Literal, Optional, Sequence, Union

from langkit.core.validation import Validator
from langkit.validators.comparison import ConstraintValidator, ConstraintValidatorOptions, MultiColumnConstraintValidator
from langkit.validators.comparison import (
ConstraintValidator,
ConstraintValidatorOptions,
MultiColumnConstraintValidator,
MultiColumnConstraintValidatorOptions,
)


class lib:
Expand Down Expand Up @@ -88,4 +93,6 @@ def multi_column_constraint(
operator: Literal["AND", "OR"] = "AND",
report_mode: Literal["ALL_FAILED_METRICS", "FIRST_FAILED_METRIC"] = "FIRST_FAILED_METRIC",
) -> Validator:
return MultiColumnConstraintValidator(constraints=constraints, operator=operator, report_mode=report_mode)
return MultiColumnConstraintValidator(
MultiColumnConstraintValidatorOptions(constraints=constraints, operator=operator, report_mode=report_mode)
)
91 changes: 54 additions & 37 deletions tests/langkit/validators/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from langkit.core.validation import ValidationFailure
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib as metric_lib
from langkit.validators.comparison import ConstraintValidator, ConstraintValidatorOptions, MultiColumnConstraintValidator
from langkit.validators.comparison import (
ConstraintValidator,
ConstraintValidatorOptions,
MultiColumnConstraintValidator,
MultiColumnConstraintValidatorOptions,
)


def test_one_required():
Expand Down Expand Up @@ -149,10 +154,12 @@ def test_must_be_non_none():

def test_multiple_contraint_first_failure():
validator = MultiColumnConstraintValidator(
[
ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=5),
ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=5),
]
MultiColumnConstraintValidatorOptions(
[
ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=5),
ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=5),
]
)
)
wf = Workflow(
metrics=[metric_lib.prompt.stats.char_count, metric_lib.prompt.stats.token_count],
Expand Down Expand Up @@ -180,12 +187,14 @@ def test_multiple_contraint_first_failure():

def test_multiple_constriant_all_failure():
validator = MultiColumnConstraintValidator(
[
ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=5),
ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=5),
],
report_mode="ALL_FAILED_METRICS",
operator="AND",
MultiColumnConstraintValidatorOptions(
[
ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=5),
ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=5),
],
report_mode="ALL_FAILED_METRICS",
operator="AND",
)
)
wf = Workflow(
metrics=[metric_lib.prompt.stats.char_count, metric_lib.prompt.stats.token_count],
Expand Down Expand Up @@ -224,12 +233,14 @@ def test_multiple_constriant_all_failure():

def test_multiple_constriant_all_failure_or():
validator = MultiColumnConstraintValidator(
[
ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=4),
ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=4),
],
report_mode="ALL_FAILED_METRICS",
operator="OR",
MultiColumnConstraintValidatorOptions(
[
ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=4),
ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=4),
],
report_mode="ALL_FAILED_METRICS",
operator="OR",
)
)
wf = Workflow(
metrics=[metric_lib.prompt.stats.char_count, metric_lib.prompt.stats.token_count],
Expand All @@ -256,12 +267,14 @@ def test_multiple_constriant_all_failure_or():

def test_multiple_constriant_first_failure_or():
validator = MultiColumnConstraintValidator(
[
ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=4),
ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=4),
],
report_mode="FIRST_FAILED_METRIC",
operator="OR",
MultiColumnConstraintValidatorOptions(
[
ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=4),
ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=4),
],
report_mode="FIRST_FAILED_METRIC",
operator="OR",
)
)
wf = Workflow(
metrics=[metric_lib.prompt.stats.char_count, metric_lib.prompt.stats.token_count],
Expand All @@ -288,13 +301,15 @@ def test_multiple_constriant_first_failure_or():

def test_multiple_constriant_first_failure_or_multiple_failures():
validator = MultiColumnConstraintValidator(
[
ConstraintValidatorOptions("prompt.stats.char_count", upper_threshold=100),
ConstraintValidatorOptions("prompt.stats.token_count", upper_threshold=1),
ConstraintValidatorOptions("prompt.regex.email_address", upper_threshold=0),
],
report_mode="FIRST_FAILED_METRIC",
operator="OR",
MultiColumnConstraintValidatorOptions(
[
ConstraintValidatorOptions("prompt.stats.char_count", upper_threshold=100),
ConstraintValidatorOptions("prompt.stats.token_count", upper_threshold=1),
ConstraintValidatorOptions("prompt.regex.email_address", upper_threshold=0),
],
report_mode="FIRST_FAILED_METRIC",
operator="OR",
)
)
wf = Workflow(
metrics=[metric_lib.prompt.stats.char_count, metric_lib.prompt.stats.token_count, metric_lib.prompt.regex.email_address],
Expand Down Expand Up @@ -322,13 +337,15 @@ def test_multiple_constriant_first_failure_or_multiple_failures():

def test_multiple_constriant_first_failure_and_ordering():
validator = MultiColumnConstraintValidator(
[
ConstraintValidatorOptions("prompt.regex.email_address", upper_threshold=0),
ConstraintValidatorOptions("prompt.stats.char_count", upper_threshold=1),
ConstraintValidatorOptions("prompt.stats.token_count", upper_threshold=1),
],
report_mode="FIRST_FAILED_METRIC",
operator="AND",
MultiColumnConstraintValidatorOptions(
[
ConstraintValidatorOptions("prompt.regex.email_address", upper_threshold=0),
ConstraintValidatorOptions("prompt.stats.char_count", upper_threshold=1),
ConstraintValidatorOptions("prompt.stats.token_count", upper_threshold=1),
],
report_mode="FIRST_FAILED_METRIC",
operator="AND",
)
)
wf = Workflow(
metrics=[metric_lib.prompt.stats.char_count, metric_lib.prompt.stats.token_count, metric_lib.prompt.regex.email_address],
Expand Down

0 comments on commit 8b07401

Please sign in to comment.