Skip to content

Commit

Permalink
Fixing test_train
Browse files Browse the repository at this point in the history
  • Loading branch information
shtoshni committed Nov 27, 2024
1 parent 298fb9f commit 4314007
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions tests/gpu-tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import yaml

sys.path.append(str(Path(__file__).absolute().parents[1]))
from nemo_skills.evaluation.metrics import compute_metrics
from nemo_skills.evaluation.metrics import ComputeMetrics
from nemo_skills.pipeline import wrap_arguments
from nemo_skills.pipeline.cli import eval, generate, train

Expand Down Expand Up @@ -112,10 +112,9 @@ def test_sft():
partition="interactive",
)

metrics = compute_metrics(
metrics = ComputeMetrics(benchmark='gsm8k').compute_metrics(
[f"/tmp/nemo-skills-tests/{model_type}/test-sft/evaluation/eval-results/gsm8k/output.jsonl"],
importlib.import_module('nemo_skills.dataset.gsm8k').METRICS_CLASS(),
)
)["greedy"]
# only checking the total, since model is tiny
assert metrics['num_entries'] == 10

Expand Down Expand Up @@ -173,10 +172,9 @@ def test_dpo():
partition="interactive",
)

metrics = compute_metrics(
metrics = ComputeMetrics(benchmark='gsm8k').compute_metrics(
[f"/tmp/nemo-skills-tests/{model_type}/test-dpo/evaluation/eval-results/gsm8k/output.jsonl"],
importlib.import_module('nemo_skills.dataset.gsm8k').METRICS_CLASS(),
)
)["greedy"]
# only checking the total, since model is tiny
assert metrics['num_entries'] == 10

Expand Down

0 comments on commit 4314007

Please sign in to comment.