From 3c599eb05d076c932d436c3dabc6668e7a75345f Mon Sep 17 00:00:00 2001 From: Anthony Naddeo Date: Sun, 24 Mar 2024 18:55:50 -0700 Subject: [PATCH] Use tuple instead of list in validation options This enesures that this type can be serialized. --- langkit/validators/comparison.py | 8 +++---- langkit/validators/library.py | 6 +++--- tests/langkit/validators/test_comparison.py | 24 ++++++++++----------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/langkit/validators/comparison.py b/langkit/validators/comparison.py index 72e7527..dc699ad 100644 --- a/langkit/validators/comparison.py +++ b/langkit/validators/comparison.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, replace from functools import partial -from typing import Any, Callable, List, Literal, Optional, Sequence, Set, Union +from typing import Any, Callable, List, Literal, Optional, Sequence, Set, Tuple, Union import numpy as np import pandas as pd @@ -147,8 +147,8 @@ class ConstraintValidatorOptions: upper_threshold_inclusive: Optional[Union[float, int]] = None lower_threshold: Optional[Union[float, int]] = None lower_threshold_inclusive: Optional[Union[float, int]] = None - one_of: Optional[Sequence[Union[str, float, int]]] = None - none_of: Optional[Sequence[Union[str, float, int]]] = None + one_of: Optional[Tuple[Union[str, float, int], ...]] = None + none_of: Optional[Tuple[Union[str, float, int], ...]] = None must_be_non_none: Optional[bool] = None must_be_none: Optional[bool] = None @@ -208,7 +208,7 @@ def validate_result(self, df: pd.DataFrame) -> Optional[ValidationResult]: @dataclass class MultiColumnConstraintValidatorOptions: - constraints: List[ConstraintValidatorOptions] + constraints: Tuple[ConstraintValidatorOptions, ...] operator: Literal["AND", "OR"] = "AND" report_mode: Literal["ALL_FAILED_METRICS", "FIRST_FAILED_METRIC"] = "FIRST_FAILED_METRIC" diff --git a/langkit/validators/library.py b/langkit/validators/library.py index 7414644..fe08001 100644 --- a/langkit/validators/library.py +++ b/langkit/validators/library.py @@ -80,8 +80,8 @@ def constraint( upper_threshold_inclusive=upper_threshold_inclusive, lower_threshold=lower_threshold, lower_threshold_inclusive=lower_threshold_inclusive, - one_of=one_of, - none_of=none_of, + one_of=tuple(one_of) if one_of else None, + none_of=tuple(none_of) if none_of else None, must_be_non_none=must_be_non_none, must_be_none=must_be_none, ) @@ -94,5 +94,5 @@ def multi_column_constraint( report_mode: Literal["ALL_FAILED_METRICS", "FIRST_FAILED_METRIC"] = "FIRST_FAILED_METRIC", ) -> Validator: return MultiColumnConstraintValidator( - MultiColumnConstraintValidatorOptions(constraints=constraints, operator=operator, report_mode=report_mode) + MultiColumnConstraintValidatorOptions(constraints=tuple(constraints), operator=operator, report_mode=report_mode) ) diff --git a/tests/langkit/validators/test_comparison.py b/tests/langkit/validators/test_comparison.py index b7a3aa1..e7d869b 100644 --- a/tests/langkit/validators/test_comparison.py +++ b/tests/langkit/validators/test_comparison.py @@ -155,10 +155,10 @@ def test_must_be_non_none(): def test_multiple_contraint_first_failure(): validator = MultiColumnConstraintValidator( MultiColumnConstraintValidatorOptions( - [ + ( ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=5), ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=5), - ] + ) ) ) wf = Workflow( @@ -188,10 +188,10 @@ def test_multiple_contraint_first_failure(): def test_multiple_constriant_all_failure(): validator = MultiColumnConstraintValidator( MultiColumnConstraintValidatorOptions( - [ + ( ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=5), ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=5), - ], + ), report_mode="ALL_FAILED_METRICS", operator="AND", ) @@ -234,10 +234,10 @@ def test_multiple_constriant_all_failure(): def test_multiple_constriant_all_failure_or(): validator = MultiColumnConstraintValidator( MultiColumnConstraintValidatorOptions( - [ + ( ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=4), ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=4), - ], + ), report_mode="ALL_FAILED_METRICS", operator="OR", ) @@ -268,10 +268,10 @@ def test_multiple_constriant_all_failure_or(): def test_multiple_constriant_first_failure_or(): validator = MultiColumnConstraintValidator( MultiColumnConstraintValidatorOptions( - [ + ( ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=4), ConstraintValidatorOptions("prompt.stats.token_count", lower_threshold=4), - ], + ), report_mode="FIRST_FAILED_METRIC", operator="OR", ) @@ -302,11 +302,11 @@ def test_multiple_constriant_first_failure_or(): def test_multiple_constriant_first_failure_or_multiple_failures(): validator = MultiColumnConstraintValidator( MultiColumnConstraintValidatorOptions( - [ + ( ConstraintValidatorOptions("prompt.stats.char_count", upper_threshold=100), ConstraintValidatorOptions("prompt.stats.token_count", upper_threshold=1), ConstraintValidatorOptions("prompt.regex.email_address", upper_threshold=0), - ], + ), report_mode="FIRST_FAILED_METRIC", operator="OR", ) @@ -338,11 +338,11 @@ def test_multiple_constriant_first_failure_or_multiple_failures(): def test_multiple_constriant_first_failure_and_ordering(): validator = MultiColumnConstraintValidator( MultiColumnConstraintValidatorOptions( - [ + ( ConstraintValidatorOptions("prompt.regex.email_address", upper_threshold=0), ConstraintValidatorOptions("prompt.stats.char_count", upper_threshold=1), ConstraintValidatorOptions("prompt.stats.token_count", upper_threshold=1), - ], + ), report_mode="FIRST_FAILED_METRIC", operator="AND", )