From 7b85475e35b2f4b61aeefff28895c0ebada7f9b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Jun 2024 16:48:16 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../data_preparation_utils/filters.py | 2 +- .../data_preparation_utils/preprocessing.py | 11 +++- tests/test_data_preparation.py | 53 +++++++++++-------- 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/nemo_skills/finetuning/data_preparation_utils/filters.py b/nemo_skills/finetuning/data_preparation_utils/filters.py index fbaa3bc47..9bf0f8390 100644 --- a/nemo_skills/finetuning/data_preparation_utils/filters.py +++ b/nemo_skills/finetuning/data_preparation_utils/filters.py @@ -60,7 +60,7 @@ def process_dataset_entry(self, data_entry) -> List: class DropBrokenCode(BaseParallelProcessor): - def __init__( self, text_key: str = "generation", **kwargs): + def __init__(self, text_key: str = "generation", **kwargs): super().__init__(**kwargs) self.text_key = text_key diff --git a/nemo_skills/finetuning/data_preparation_utils/preprocessing.py b/nemo_skills/finetuning/data_preparation_utils/preprocessing.py index a87947350..6151fce0b 100644 --- a/nemo_skills/finetuning/data_preparation_utils/preprocessing.py +++ b/nemo_skills/finetuning/data_preparation_utils/preprocessing.py @@ -151,7 +151,14 @@ def process(self): class ShuffleAndDownsampleData(BaseProcessor): - def __init__(self, random_seed: int, do_shuffle: bool, num_samples: Optional[int] = None, sampling_method: str = 'random', **kwargs): + def __init__( + self, + random_seed: int, + do_shuffle: bool, + num_samples: Optional[int] = None, + sampling_method: str = 'random', + **kwargs, + ): super().__init__(**kwargs) self.sampling_method = sampling_method self.num_samples = num_samples @@ -170,7 +177,7 @@ def process(self): output_instances = list(chain(*groupped_samples)) if self.do_shuffle: random.shuffle(output_instances) - output_instances = output_instances[:self.num_samples] + output_instances = output_instances[: self.num_samples] elif self.sampling_method == "fair": soln_counter = 0 output_instances = [] diff --git a/tests/test_data_preparation.py b/tests/test_data_preparation.py index e6dc432ff..cfb4977cf 100644 --- a/tests/test_data_preparation.py +++ b/tests/test_data_preparation.py @@ -1,8 +1,9 @@ -import pytest -import subprocess -import os import hashlib +import os import shutil +import subprocess + +import pytest import requests @@ -33,37 +34,47 @@ def compute_md5(file_path): def test_code_files(): - subprocess.run([ - "python", - "nemo_skills/finetuning/prepare_sdp.py", - "--config-path", - "data_preparation_utils/sdp_configs", - "--config-name", - "test_config.yaml" - ], check=True) + subprocess.run( + [ + "python", + "nemo_skills/finetuning/prepare_sdp.py", + "--config-path", + "data_preparation_utils/sdp_configs", + "--config-name", + "test_config.yaml", + ], + check=True, + ) expected_md5 = "779c70a2d84d96997336bcd47b3e99f9" output_md5 = compute_md5("tests/data/processed_output.jsonl") - assert expected_md5 == output_md5, "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sdp.py" + assert ( + expected_md5 == output_md5 + ), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sdp.py" def test_openmathinstruct(): download_data("train") download_data("validation") - subprocess.run([ - "python", - "nemo_skills/finetuning/prepare_sdp.py", - "--config-path", - "data_preparation_utils/sdp_configs", - "--config-name", - "openmathinstruct_config.yaml" - ], check=True) + subprocess.run( + [ + "python", + "nemo_skills/finetuning/prepare_sdp.py", + "--config-path", + "data_preparation_utils/sdp_configs", + "--config-name", + "openmathinstruct_config.yaml", + ], + check=True, + ) output_file = 'open-math-instruct-1/train_full_sft.jsonl' expected_md5 = "c105c1c3369e5cee569dcba74e7d4d61" output_md5 = compute_md5(output_file) - assert expected_md5 == output_md5, "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sdp.py" + assert ( + expected_md5 == output_md5 + ), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sdp.py"