-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactoring + fill majority answer script (#66)
Signed-off-by: Igor Gitman <[email protected]>
- Loading branch information
Showing
17 changed files
with
797 additions
and
432 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import logging | ||
import sys | ||
from collections import Counter | ||
from itertools import zip_longest | ||
from typing import Any | ||
|
||
import hydra | ||
from omegaconf import MISSING | ||
from tqdm import tqdm | ||
|
||
from nemo_skills.evaluation.metrics import MathEval, read_predictions | ||
from nemo_skills.utils import get_help_message, nested_dataclass, setup_logging, unroll_files | ||
|
||
LOG = logging.getLogger(__file__) | ||
|
||
|
||
@nested_dataclass | ||
class FillMajorityAnswerConfig: | ||
"""Top-level parameters for the script""" | ||
|
||
# list of files to use for majority voting. | ||
# Can specify multiple patterns separated by space | ||
# e.g. "path/to/file1.jsonl path/to/file2.jsonl" or with regex | ||
# "test_folder/output-rs*.jsonl" | ||
prediction_jsonl_files: Any = MISSING | ||
|
||
# if set to True will error if any responses/data is missing | ||
allow_incomplete: bool = False | ||
|
||
# minimum number of majority votes to use the answer. | ||
# -1 means use half of the votes, which is a good default value | ||
min_votes: int = -1 | ||
|
||
# will be used to fill up when not enough votes are available for the majority | ||
default_answer: str = "no_answer" | ||
|
||
# will not use any negative answers as this likely indicates bad problems | ||
# (at least for GSM8K domain). If running with other data, where negative answers | ||
# are common, should be set to False | ||
drop_negative_answers: bool = False | ||
|
||
# will not use any non-integer answers as this might indicates bad problems | ||
drop_noninteger_answers: bool = False | ||
|
||
def __post_init__(self): | ||
"""Building data_file from dataset/split_name if not provided directly.""" | ||
if isinstance(self.prediction_jsonl_files, str): | ||
self.prediction_jsonl_files = self.prediction_jsonl_files.split(" ") | ||
|
||
|
||
cs = hydra.core.config_store.ConfigStore.instance() | ||
cs.store(name="base_fill_majority_answer_conifg", node=FillMajorityAnswerConfig) | ||
|
||
|
||
@hydra.main(version_base=None, config_name="base_fill_majority_answer_conifg") | ||
def fill_majority_answer(cfg: FillMajorityAnswerConfig): | ||
cfg = FillMajorityAnswerConfig(_init_nested=True, **cfg) | ||
LOG.info("Config used: %s", cfg) | ||
|
||
file_handles = [open(file, "rt", encoding="utf-8") for file in unroll_files(cfg.prediction_jsonl_files)] | ||
if cfg.min_votes < 0: | ||
cfg.min_votes = len(file_handles) // 2 | ||
|
||
# currently majority is only defined for math evals | ||
evaluator = MathEval() | ||
|
||
majority_answers = [] | ||
all_predictions = [] | ||
retained_questions = 0 | ||
for idx, predictions in enumerate(tqdm(zip_longest(*file_handles))): | ||
data = read_predictions(predictions, evaluator, cfg.allow_incomplete) | ||
all_predictions.append(data) | ||
# TODO: currently majority does not take into account equivalent answers written in a different way | ||
valid_answers_and_results = [ | ||
(elem['predicted_answer'], elem['is_correct']) for elem in data if elem['predicted_answer'] is not None | ||
] | ||
majority_answers.append(cfg.default_answer) | ||
if len(valid_answers_and_results) == 0: | ||
continue | ||
(majority_answer, _), num_votes = Counter(valid_answers_and_results).most_common(1)[0] | ||
|
||
if num_votes <= cfg.min_votes: | ||
continue | ||
|
||
if cfg.drop_negative_answers or cfg.drop_noninteger_answers: | ||
try: | ||
majority_answer = float(majority_answer) | ||
except ValueError: | ||
continue | ||
|
||
if cfg.drop_negative_answers and majority_answer < 0: | ||
continue | ||
|
||
if cfg.drop_noninteger_answers and not majority_answer.is_integer(): | ||
continue | ||
|
||
majority_answers[-1] = majority_answer | ||
retained_questions += 1 | ||
|
||
LOG.info("Total questions: %d, retained questions: %d", len(all_predictions), retained_questions) | ||
|
||
for file_handle in file_handles: | ||
file_handle.close() | ||
|
||
# writing the majority answers back to the files | ||
file_handles = [open(file, "wt", encoding="utf-8") for file in unroll_files(cfg.prediction_jsonl_files)] | ||
for idx, predictions in enumerate(all_predictions): | ||
for lidx, handle in enumerate(file_handles): | ||
predictions[lidx]["expected_answer"] = majority_answers[idx] | ||
handle.write(json.dumps(predictions[lidx]) + "\n") | ||
|
||
for file_handle in file_handles: | ||
file_handle.close() | ||
|
||
|
||
HELP_MESSAGE = get_help_message(FillMajorityAnswerConfig) | ||
|
||
|
||
if __name__ == "__main__": | ||
if '--help' in sys.argv or '-h' in sys.argv: | ||
print(HELP_MESSAGE) | ||
else: | ||
setup_logging() | ||
fill_majority_answer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import logging | ||
import shutil | ||
import subprocess | ||
from argparse import Namespace | ||
from pathlib import Path | ||
|
||
LOG = logging.getLogger(__file__) | ||
|
||
|
||
def math_eval(cfg): | ||
from nemo_skills.code_execution.sandbox import get_sandbox | ||
|
||
sandbox = get_sandbox(**cfg.sandbox) | ||
sandbox.batch_evaluate_results( | ||
prediction_jsonl_files=cfg.prediction_jsonl_files, | ||
**cfg.eval_config, | ||
) | ||
|
||
|
||
def code_eval(cfg): | ||
# TODO: need to move it to a separate docker (either our sandbox or separate srun) | ||
from evalplus.evaluate import evaluate | ||
from omegaconf import OmegaConf | ||
|
||
from nemo_skills.evaluation.code_utils import preprocess_code | ||
|
||
# processing each generation separately (TODO: evalplus can do it together, but need to figure out the format) | ||
for jsonl_file in cfg.prediction_jsonl_files: | ||
with open(jsonl_file) as f: | ||
samples = [preprocess_code(json.loads(line)) for line in f] | ||
# all changes will be done with a new key "completion", so it's ok to write to the same file | ||
with open(jsonl_file, "wt", encoding="utf-8") as f: | ||
for sample in samples: | ||
f.write(json.dumps(sample) + "\n") | ||
eval_config = { | ||
"samples": jsonl_file, | ||
"base_only": False, | ||
"parallel": None, | ||
"i_just_wanna_run": False, | ||
"test_details": False, | ||
"min_time_limit": 1, | ||
"gt_time_limit_factor": 4.0, | ||
"mini": False, | ||
"noextreme": False, | ||
"version": "default", | ||
} | ||
eval_config.update(OmegaConf.to_container(cfg.eval_config)) | ||
evaluate(Namespace(**eval_config)) | ||
with open(jsonl_file[:-6] + '_eval_results.json', 'rt', encoding="utf-8") as fin: | ||
evalplus_grades = json.load(fin) | ||
# adding is_correct key to allow compute_metrics to work | ||
with open(jsonl_file, "wt", encoding="utf-8") as f: | ||
for sample in samples: | ||
sample['is_correct'] = evalplus_grades['eval'][sample['task_id']][0]['base_status'] == "pass" | ||
sample['is_correct-plus'] = ( | ||
sample['is_correct'] and evalplus_grades['eval'][sample['task_id']][0]['plus_status'] == "pass" | ||
) | ||
f.write(json.dumps(sample) + "\n") | ||
|
||
# moving eval file as otherwise evalplus does not want to recompute metrics if it's present.. | ||
shutil.move(jsonl_file[:-6] + '_eval_results.json', jsonl_file[:-6] + '_eval_results-saved.json') | ||
|
||
|
||
def ifeval(cfg): | ||
for jsonl_file in cfg.prediction_jsonl_files: | ||
parent_dir = Path(jsonl_file).absolute().parent | ||
cmd = ( | ||
'cd /opt/benchmarks/google-research && python -m instruction_following_eval.evaluation_main ' | ||
f'--input_data={jsonl_file} ' | ||
f'--input_response_data={jsonl_file} ' | ||
f'--output_dir={parent_dir} ' | ||
) | ||
subprocess.run(cmd, shell=True, check=True) | ||
# fusing eval metrics back into the generation file | ||
with open(jsonl_file, "rt", encoding="utf-8") as f: | ||
samples = [json.loads(line) for line in f] | ||
|
||
with open(parent_dir / 'eval_results_loose.jsonl', 'rt', encoding="utf-8") as f: | ||
eval_results = [json.loads(line) for line in f] | ||
for sample, eval_result in zip(samples, eval_results): | ||
sample['loose_eval'] = eval_result | ||
|
||
with open(parent_dir / 'eval_results_strict.jsonl', 'rt', encoding="utf-8") as f: | ||
eval_results = [json.loads(line) for line in f] | ||
for sample, eval_result in zip(samples, eval_results): | ||
sample['strict_eval'] = eval_result | ||
|
||
with open(jsonl_file, "wt", encoding="utf-8") as f: | ||
for sample in samples: | ||
f.write(json.dumps(sample) + "\n") | ||
|
||
# removing metric files to avoid reusing them | ||
(parent_dir / 'eval_results_loose.jsonl').unlink() | ||
(parent_dir / 'eval_results_strict.jsonl').unlink() |
Oops, something went wrong.