Skip to content

Commit

Permalink
removed redundant logic
Browse files Browse the repository at this point in the history
  • Loading branch information
i-vainn committed Jun 19, 2024
1 parent 5ba6830 commit b8b9a11
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions nemo_skills/finetuning/data_preparation_utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,7 @@
PATTERN_CODE = re.compile(CODE_SEPARATORS[0])


class BaseFilter(BaseParallelProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def process_dataset_entry(self, data_entry) -> List:
raise NotImplementedError

def test(self):
cached_value, self.should_apply = self.should_apply, True
super().test()
self.should_apply = cached_value


class DropMultiBoxed(BaseFilter):
class DropMultiBoxed(BaseParallelProcessor):

def __init__(self, should_apply: bool = False, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
Expand All @@ -60,7 +47,7 @@ def process_dataset_entry(self, data_entry) -> List:
return [DataEntry(data=data_entry)]


class DropUselessCode(BaseFilter):
class DropUselessCode(BaseParallelProcessor):

def __init__(self, should_apply: bool = False, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
Expand All @@ -79,7 +66,7 @@ def process_dataset_entry(self, data_entry) -> List:
return [DataEntry(data=data_entry)]


class DropBrokenCode(BaseFilter):
class DropBrokenCode(BaseParallelProcessor):
def __init__(self, should_apply: bool = False, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
self.solution_key = solution_key
Expand Down Expand Up @@ -113,7 +100,7 @@ def process_dataset_entry(self, data_entry) -> List:
return [DataEntry(data=data_entry)]


class TrimSolutions(BaseFilter):
class TrimSolutions(BaseParallelProcessor):

def __init__(self, should_apply: bool = False, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -145,7 +132,7 @@ def process_dataset_entry(self, data_entry) -> List:
return [DataEntry(data=data_entry)]


class DropIncorrectArithmetic(BaseFilter):
class DropIncorrectArithmetic(BaseParallelProcessor):

def __init__(self, should_apply: bool = True, solution_key: str = "generation", tolerance=1e-4, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -176,7 +163,7 @@ def process_dataset_entry(self, data_entry: str) -> str:
return [DataEntry(data=data_entry)]


class SplitArithmetic(BaseFilter):
class SplitArithmetic(BaseParallelProcessor):

def __init__(self, should_apply: bool = True, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
Expand Down

0 comments on commit b8b9a11

Please sign in to comment.