From 6859e048da601fec181997a324e7b351fc997a33 Mon Sep 17 00:00:00 2001 From: lewtun Date: Mon, 23 Sep 2024 11:49:36 +0200 Subject: [PATCH] Fix PPO/RLOO examples (#2100) --- examples/scripts/ppo/ppo.py | 3 +-- examples/scripts/rloo/rloo.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 541af12b6c..66e8272adf 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -50,7 +50,6 @@ --sft_model_path EleutherAI/pythia-1b-deduped \ --reward_model_path EleutherAI/pythia-1b-deduped \ --local_rollout_forward_batch_size 1 \ - --deepspeed3 \ --missing_eos_penalty 1.0 """ @@ -88,7 +87,7 @@ # Dataset ################ dataset = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness") - eval_samples = 20 + eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) dataset_text_field = "prompt" diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 74a52fe69d..cf5f028502 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -54,7 +54,6 @@ --sft_model_path EleutherAI/pythia-1b-deduped \ --reward_model_path EleutherAI/pythia-1b-deduped \ --local_rollout_forward_batch_size 1 \ - --deepspeed3 \ --missing_eos_penalty 1.0 """ @@ -89,7 +88,7 @@ # Dataset ################ dataset = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness") - eval_samples = 20 + eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) dataset_text_field = "prompt"