diff --git a/nemo_skills/finetuning/data_preparation_utils/filters.py b/nemo_skills/finetuning/data_preparation_utils/filters.py index 779a4e0b3..d22ddee88 100644 --- a/nemo_skills/finetuning/data_preparation_utils/filters.py +++ b/nemo_skills/finetuning/data_preparation_utils/filters.py @@ -31,20 +31,7 @@ PATTERN_CODE = re.compile(CODE_SEPARATORS[0]) -class BaseFilter(BaseParallelProcessor): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def process_dataset_entry(self, data_entry) -> List: - raise NotImplementedError - - def test(self): - cached_value, self.should_apply = self.should_apply, True - super().test() - self.should_apply = cached_value - - -class DropMultiBoxed(BaseFilter): +class DropMultiBoxed(BaseParallelProcessor): def __init__(self, should_apply: bool = False, solution_key: str = "generation", **kwargs): super().__init__(**kwargs) @@ -60,7 +47,7 @@ def process_dataset_entry(self, data_entry) -> List: return [DataEntry(data=data_entry)] -class DropUselessCode(BaseFilter): +class DropUselessCode(BaseParallelProcessor): def __init__(self, should_apply: bool = False, solution_key: str = "generation", **kwargs): super().__init__(**kwargs) @@ -79,7 +66,7 @@ def process_dataset_entry(self, data_entry) -> List: return [DataEntry(data=data_entry)] -class DropBrokenCode(BaseFilter): +class DropBrokenCode(BaseParallelProcessor): def __init__(self, should_apply: bool = False, solution_key: str = "generation", **kwargs): super().__init__(**kwargs) self.solution_key = solution_key @@ -113,7 +100,7 @@ def process_dataset_entry(self, data_entry) -> List: return [DataEntry(data=data_entry)] -class TrimSolutions(BaseFilter): +class TrimSolutions(BaseParallelProcessor): def __init__(self, should_apply: bool = False, solution_key: str = "generation", **kwargs): super().__init__(**kwargs) @@ -145,7 +132,7 @@ def process_dataset_entry(self, data_entry) -> List: return [DataEntry(data=data_entry)] -class DropIncorrectArithmetic(BaseFilter): +class DropIncorrectArithmetic(BaseParallelProcessor): def __init__(self, should_apply: bool = True, solution_key: str = "generation", tolerance=1e-4, **kwargs): super().__init__(**kwargs) @@ -176,7 +163,7 @@ def process_dataset_entry(self, data_entry: str) -> str: return [DataEntry(data=data_entry)] -class SplitArithmetic(BaseFilter): +class SplitArithmetic(BaseParallelProcessor): def __init__(self, should_apply: bool = True, solution_key: str = "generation", **kwargs): super().__init__(**kwargs)