Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 14, 2024
1 parent 8f7984e commit 7b85475
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 24 deletions.
2 changes: 1 addition & 1 deletion nemo_skills/finetuning/data_preparation_utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def process_dataset_entry(self, data_entry) -> List:


class DropBrokenCode(BaseParallelProcessor):
def __init__( self, text_key: str = "generation", **kwargs):
def __init__(self, text_key: str = "generation", **kwargs):
super().__init__(**kwargs)
self.text_key = text_key

Expand Down
11 changes: 9 additions & 2 deletions nemo_skills/finetuning/data_preparation_utils/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,14 @@ def process(self):


class ShuffleAndDownsampleData(BaseProcessor):
def __init__(self, random_seed: int, do_shuffle: bool, num_samples: Optional[int] = None, sampling_method: str = 'random', **kwargs):
def __init__(
self,
random_seed: int,
do_shuffle: bool,
num_samples: Optional[int] = None,
sampling_method: str = 'random',
**kwargs,
):
super().__init__(**kwargs)
self.sampling_method = sampling_method
self.num_samples = num_samples
Expand All @@ -170,7 +177,7 @@ def process(self):
output_instances = list(chain(*groupped_samples))
if self.do_shuffle:
random.shuffle(output_instances)
output_instances = output_instances[:self.num_samples]
output_instances = output_instances[: self.num_samples]
elif self.sampling_method == "fair":
soln_counter = 0
output_instances = []
Expand Down
53 changes: 32 additions & 21 deletions tests/test_data_preparation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
import subprocess
import os
import hashlib
import os
import shutil
import subprocess

import pytest
import requests


Expand Down Expand Up @@ -33,37 +34,47 @@ def compute_md5(file_path):


def test_code_files():
subprocess.run([
"python",
"nemo_skills/finetuning/prepare_sdp.py",
"--config-path",
"data_preparation_utils/sdp_configs",
"--config-name",
"test_config.yaml"
], check=True)
subprocess.run(
[
"python",
"nemo_skills/finetuning/prepare_sdp.py",
"--config-path",
"data_preparation_utils/sdp_configs",
"--config-name",
"test_config.yaml",
],
check=True,
)

expected_md5 = "779c70a2d84d96997336bcd47b3e99f9"
output_md5 = compute_md5("tests/data/processed_output.jsonl")

assert expected_md5 == output_md5, "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sdp.py"
assert (
expected_md5 == output_md5
), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sdp.py"


def test_openmathinstruct():
download_data("train")
download_data("validation")

subprocess.run([
"python",
"nemo_skills/finetuning/prepare_sdp.py",
"--config-path",
"data_preparation_utils/sdp_configs",
"--config-name",
"openmathinstruct_config.yaml"
], check=True)
subprocess.run(
[
"python",
"nemo_skills/finetuning/prepare_sdp.py",
"--config-path",
"data_preparation_utils/sdp_configs",
"--config-name",
"openmathinstruct_config.yaml",
],
check=True,
)

output_file = 'open-math-instruct-1/train_full_sft.jsonl'

expected_md5 = "c105c1c3369e5cee569dcba74e7d4d61"
output_md5 = compute_md5(output_file)

assert expected_md5 == output_md5, "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sdp.py"
assert (
expected_md5 == output_md5
), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sdp.py"

0 comments on commit 7b85475

Please sign in to comment.