Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDP data preparation pipeline #37

Merged
merged 27 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5c01b51
added sdp pipeline
i-vainn Jun 14, 2024
0371406
removed redundant config
i-vainn Jun 14, 2024
f31e489
added arithmetic processing utils
i-vainn Jun 14, 2024
7e962d8
removed old scripts
i-vainn Jun 14, 2024
7a2ccf4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
7cff70f
changed config structure
i-vainn Jun 14, 2024
9644c9a
formatting
i-vainn Jun 14, 2024
b9724d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
dbf0146
added dependency
i-vainn Jun 14, 2024
d8b86a4
Merge branch 'imoshkov/sdp_pipeline' of https://github.com/i-vainn/Ne…
i-vainn Jun 14, 2024
6d450fd
bug and formatting fixes
i-vainn Jun 15, 2024
7ab03a3
fixed tests
i-vainn Jun 17, 2024
6484d3a
minor changes
i-vainn Jun 17, 2024
53c95d8
simplified tests
i-vainn Jun 17, 2024
469331c
fixed tests
i-vainn Jun 17, 2024
055123c
added test data
i-vainn Jun 17, 2024
5ba6830
minor improvements
i-vainn Jun 19, 2024
060e0ef
minor fixes
i-vainn Jun 19, 2024
69eae1a
added metrics, added parallel read data, fixed deduplication bug
i-vainn Jun 24, 2024
dbdc72d
Merge branch 'main' into imoshkov/sdp_pipeline
i-vainn Jun 24, 2024
515b32a
fixed deduplication issue and improved processing speed
i-vainn Jun 26, 2024
75c745c
added sdp to setup and requirements
Jun 26, 2024
59b96db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
756d750
read requirements list from requirements.txt
Jun 26, 2024
ab0c0bb
read requirements list from requirements.txt
Jun 26, 2024
5ed0948
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
1262484
Merge branch 'main' into imoshkov/sdp_pipeline
Kipok Jun 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
313 changes: 313 additions & 0 deletions nemo_skills/finetuning/data_preparation_utils/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
i-vainn marked this conversation as resolved.
Show resolved Hide resolved
#
i-vainn marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import re
import warnings
from itertools import chain
from math import isclose
from typing import List

import tqdm
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry
from tqdm.contrib.concurrent import process_map

from nemo_skills.code_execution import CODE_OUTPUT_SEPARATORS, CODE_SEPARATORS
from nemo_skills.synthetic_arithmetic.solve_expression import merge_solution_steps, solve_expression
from nemo_skills.synthetic_arithmetic.utils import extract_expressions

LOG = logging.getLogger(__file__)

PATTERN_ANS = re.compile(r"\\boxed\{([^}]*)\}")
PATTERN_CODE = re.compile(CODE_SEPARATORS[0])


class BaseFilter(BaseParallelProcessor):
i-vainn marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, **kwargs):
if 'in_memory_chunksize' not in kwargs:
kwargs['in_memory_chunksize'] = 100000000
if 'chunksize' not in kwargs:
kwargs['chunksize'] = 100000
super().__init__(**kwargs)

def finalize(self, metrics: List):
LOG.info("Number of entries after processing: %d", self.number_of_entries)

if not metrics:
return

if 'num_removed' in metrics[0]:
num_removed_entries = sum(metric.get('num_removed', 0) for metric in metrics)
LOG.info("Number of removed entries: %d", num_removed_entries)

if 'num_modified' in metrics[0]:
num_modified_entries = sum(metric.get('num_modified', 0) for metric in metrics)
LOG.info("Number of modified entries: %d", num_modified_entries)


class DropMultiBoxed(BaseFilter):

def __init__(self, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
self.solution_key = solution_key

def process_dataset_entry(self, data_entry) -> List:
if len(PATTERN_ANS.findall(data_entry[self.solution_key])) > 1:
return [DataEntry(data=None, metrics=dict(num_removed=1))]
return [DataEntry(data=data_entry, metrics=dict(num_removed=0))]


class DropUselessCode(BaseFilter):

def __init__(self, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
self.solution_key = solution_key

def process_dataset_entry(self, data_entry) -> List:
ans_match = PATTERN_ANS.search(data_entry[self.solution_key])
code_match = PATTERN_CODE.search(data_entry[self.solution_key])
if not ans_match or not code_match or ans_match.start() > code_match.start():
return [DataEntry(data=None, metrics=dict(num_removed=1))]

return [DataEntry(data=data_entry, metrics=dict(num_removed=0))]


class DropBrokenCode(BaseFilter):
def __init__(self, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
self.solution_key = solution_key

def process_dataset_entry(self, data_entry) -> List:
generation = data_entry[self.solution_key]
code_start_indices = [match.start() for match in re.finditer(CODE_SEPARATORS[0], generation)]
code_end_indices = [match.start() for match in re.finditer(CODE_SEPARATORS[1], generation)]
code_out_start_indices = [match.start() for match in re.finditer(CODE_OUTPUT_SEPARATORS[0], generation)]
code_out_end_indices = [match.start() for match in re.finditer(CODE_OUTPUT_SEPARATORS[1], generation)]

num_code_occs = set(
[len(code_start_indices), len(code_end_indices), len(code_out_start_indices), len(code_out_end_indices)]
)
if len(num_code_occs) != 1:
return [DataEntry(data=None, metrics=dict(num_removed=1))]

if not len(code_end_indices):
return [DataEntry(data=data_entry, metrics=dict(num_removed=0))]

for code_start_idx, code_end_idx, code_out_start_idx, code_out_end_idx in zip(
code_start_indices, code_end_indices, code_out_start_indices, code_out_end_indices
):
if not (code_start_idx < code_end_idx < code_out_start_idx < code_out_end_idx):
return [DataEntry(data=None, metrics=dict(num_removed=1))]

return [DataEntry(data=data_entry, metrics=dict(num_removed=0))]


class TrimSolutions(BaseFilter):

def __init__(self, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
self.solution_key = solution_key

def process_dataset_entry(self, data_entry) -> List:
output_lines = data_entry[self.solution_key].split("\n")

stop_idx = 0
for idx, soln_line in enumerate(output_lines):
if PATTERN_ANS.findall(soln_line):
stop_idx = idx
break

if stop_idx < len(output_lines) - 1 and (
"\\end{align" in output_lines[stop_idx + 1]
or "\]" in output_lines[stop_idx + 1]
or "$$" in output_lines[stop_idx + 1]
):
stop_idx = stop_idx + 1

trimmed_output = "\n".join(output_lines[: stop_idx + 1])
is_modified = trimmed_output != data_entry[self.solution_key]
data_entry[self.solution_key] = trimmed_output

return [DataEntry(data=data_entry, metrics=dict(num_modified=int(is_modified)))]


class DropIncorrectArithmetic(BaseFilter):

def __init__(self, solution_key: str = "generation", tolerance=1e-4, **kwargs):
super().__init__(**kwargs)
self.solution_key = solution_key
self.tolerance = tolerance

def process_dataset_entry(self, data_entry: str) -> str:
for expression, _ in extract_expressions(data_entry[self.solution_key]):
parts = expression.split("=")
if len(parts) < 2:
continue

expr, ans = parts[0], parts[-1]

try:
solution_steps = solve_expression(expr)
# ignore eval warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=SyntaxWarning)
if not isclose(eval(solution_steps[-1]), eval(ans), rel_tol=self.tolerance):
return [DataEntry(data=None, metrics=dict(num_removed=1))]
except KeyboardInterrupt:
raise
except:
pass

return [DataEntry(data=data_entry, metrics=dict(num_removed=0))]


class SplitArithmetic(BaseFilter):

def __init__(self, solution_key: str = "generation", **kwargs):
super().__init__(**kwargs)
self.solution_key = solution_key

def process_dataset_entry(self, data_entry: str) -> str:
"""
Extends short arithmetic expressions solutions to step-by-step ones
For example `1 + 2 + 3 + 4 = 10` -> `1 + 2 + 3 + 4 = 3 + 3 + 4 = 6 + 4 = 10`.
"""
text = data_entry[self.solution_key]
new_text = []
last_end = 0

for expression, start in extract_expressions(text):
end = start + len(expression)
parts = expression.split("=")

if len(parts) != 2:
new_text.append(text[last_end:end])
last_end = end
continue
expr, ans = parts

try:
solution_steps = solve_expression(expr)
except:
new_text.append(text[last_end:end])
last_end = end
continue

solution = merge_solution_steps(solution_steps)

try:
# ignore eval warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=SyntaxWarning)
if eval(solution_steps[-1]) == eval(ans):
new_text.append(text[last_end:start] + solution)
else:
new_text.append(text[last_end:end])

last_end = end
except KeyboardInterrupt:
raise
except:
new_text.append(text[last_end:end])
last_end = end

new_text.append(text[last_end:])
data_entry[self.solution_key] = "".join(new_text)
is_modified = text != data_entry[self.solution_key]

return [DataEntry(data=data_entry, metrics=dict(num_modified=int(is_modified)))]


class CodeTextFilter(BaseParallelProcessor):
def __init__(self, filter_type, solution_key='generation', **kwargs):
if 'in_memory_chunksize' not in kwargs:
kwargs['in_memory_chunksize'] = 100000000
if 'chunksize' not in kwargs:
kwargs['chunksize'] = 100000
super().__init__(**kwargs)
self.text_filter_type = filter_type
self.solution_key = solution_key

def process_dataset_entry(self, grouped_samples: List):
code_solns = []
text_solns = []
for sample in grouped_samples:
if CODE_SEPARATORS[0] in sample[self.solution_key]:
code_solns.append(sample)
else:
text_solns.append(sample)

filtered_predictions = []
if self.text_filter_type is None:
filtered_predictions.extend(code_solns)
filtered_predictions.extend(text_solns)
elif self.text_filter_type == 'all':
filtered_predictions.extend(code_solns)
elif self.text_filter_type == 'majority_code':
filtered_predictions.extend(code_solns)
if len(code_solns) <= len(grouped_samples) // 2:
filtered_predictions.extend(text_solns)
elif self.text_filter_type == 'majority_text':
if len(code_solns) > len(grouped_samples) // 2:
filtered_predictions.extend(code_solns)
else:
filtered_predictions.extend(text_solns)
elif self.text_filter_type == 'any_code':
if code_solns:
filtered_predictions.extend(code_solns)
else:
filtered_predictions.extend(text_solns)
else:
raise NotImplementedError(f"Filtering method {self.text_filter_type} not implemented")
num_removed = len(grouped_samples) - len(filtered_predictions)

return [DataEntry(data=filtered_predictions, metrics=dict(num_removed=num_removed))]

def process(self):
self.prepare()
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True)
metrics = []

with open(self.output_manifest_file, "wt", encoding="utf-8") as fout:
for manifest_chunk in self._chunk_manifest():
# this will unroll all inner lists
data = chain(
*process_map(
self.process_dataset_entry,
manifest_chunk,
max_workers=self.max_workers,
chunksize=self.chunksize,
)
)
for data_entry in tqdm.tqdm(data):
metrics.append(data_entry.metrics)
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
self.number_of_entries += 1
fout.write("\n")

self.finalize(metrics)

def finalize(self, metrics: List):
LOG.info("Number of entries after processing: %d", self.number_of_entries)

if not metrics:
return

if 'num_removed' in metrics[0]:
num_removed_entries = sum(metric.get('num_removed', 0) for metric in metrics)
LOG.info("Number of removed entries: %d", num_removed_entries)
Loading
Loading