Skip to content

Commit

Permalink
feat: add rejection finetuning trainer (#554)
Browse files Browse the repository at this point in the history
* feat: add rejection finetuning trainer

* style: satisfy flake

* fix(rft_trainer): broadcast scores to all ranks

* feat(rft_trainer): dedup & clip thresholds for quantized rewards

* config(rft_randomwalks): lower `total_steps`, keep 1 improve step

* fix(rft_trainer): handle prompt duplicates, due to `drop_last=False`

* feat(examples): add `rft_sentiments` example

* style: satisfy black
  • Loading branch information
maxreciprocate authored Oct 11, 2023
1 parent 78c7faa commit bcd237f
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 1 deletion.
64 changes: 64 additions & 0 deletions examples/randomwalks/rft_randomwalks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import trlx
from examples.randomwalks import generate_random_walks
from trlx.data.default_configs import (
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)
from trlx.trainer.accelerate_rft_trainer import RFTConfig

default_config = TRLConfig(
train=TrainConfig(
seq_length=10,
epochs=100,
total_steps=1000,
batch_size=100,
checkpoint_interval=1000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AccelerateRFTTrainer",
checkpoint_dir="checkpoints/randomwalks",
),
model=ModelConfig(model_path="CarperAI/randomwalks", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="CarperAI/randomwalks", truncation_side="right"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3.0e-4, betas=(0.9, 0.99), eps=1.0e-8, weight_decay=0)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=3.0e-4)),
method=RFTConfig(
name="RFTConfig",
n_generations_per_prompt=100,
start_percentile=0.9,
end_percentile=0.95,
n_improve_steps=1,
gen_kwargs=dict(
max_new_tokens=9,
top_k=0,
top_p=1.0,
temperature=1.0,
do_sample=True,
),
),
)


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed)

trlx.train(
reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"],
prompts=prompts,
eval_prompts=prompts,
metric_fn=lambda samples, **kwargs: metric_fn(samples),
config=config,
)


if __name__ == "__main__":
import json
import sys

hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
96 changes: 96 additions & 0 deletions examples/rft_sentiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# This script trains a model to output positive reviews
# using rejection finetuning with a sentiment classifier reward function.
import json
import os
import sys
from typing import List

import torch
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import (
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)
from trlx.trainer.accelerate_rft_trainer import RFTConfig


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


default_config = TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=1000,
batch_size=32,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AccelerateRFTTrainer",
),
model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)),
method=RFTConfig(
name="RFTConfig",
n_generations_per_prompt=4,
start_percentile=0.9,
end_percentile=0.95,
n_improve_steps=1,
gen_kwargs=dict(
max_new_tokens=40,
top_k=0,
top_p=1.0,
temperature=1.0,
do_sample=True,
),
),
)


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

if torch.cuda.is_available():
device = int(os.environ.get("LOCAL_RANK", 0))
else:
device = -1

sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=device,
)

def reward_fn(samples: List[str], **kwargs) -> List[float]:
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
return sentiments

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train[:512]")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
config=config,
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
197 changes: 197 additions & 0 deletions trlx/trainer/accelerate_rft_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import itertools
from collections import defaultdict
from dataclasses import dataclass

import numpy as np
import torch
import wandb
from tqdm import tqdm
from transformers import AutoModelForCausalLM, PretrainedConfig

from trlx.data.configs import TRLConfig
from trlx.data.method_configs import MethodConfig, register_method
from trlx.pipeline.offline_pipeline import PromptPipeline
from trlx.trainer import register_trainer
from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer


@dataclass
@register_method
class RFTConfig(MethodConfig):
"""
Config for RFT training
:param gen_kwargs: kwargs for generation
:type gen_kwargs: Dict[str, Any]
:param start_percentile: percentile for the starting score threshold for each prompt used for the first improvement step
:type start_percentile: float
:param end_percentile: percentile for the final score threshold for each prompt
:type end_percentile: float
:param n_improve_steps: the number of improvement steps for each growth step with linearly increasing score threshold
:type n_improve_steps: int
:param n_generations_per_prompt: number of generations to sample per each prompt per each growth step
:type n_generations_per_prompt: int
"""

gen_kwargs: dict
start_percentile: float = 0.7
end_percentile: float = 0.95
n_improve_steps: int = 4
n_generations_per_prompt: int = 32


@register_trainer
class AccelerateRFTTrainer(AccelerateRLTrainer):
def __init__(self, config: TRLConfig, **kwargs):
super().__init__(config, **kwargs)

self.generate_kwargs = dict(
config.method.gen_kwargs,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)

self.generate_experience_kwargs = None

def get_arch(self, config):
from_fn = AutoModelForCausalLM.from_pretrained
if issubclass(type(config.model.model_path), PretrainedConfig):
from_fn = AutoModelForCausalLM.from_config

model = from_fn(config.model.model_path)

if config.model.peft_config is not None:
# Initialize the peft adapter
import peft

peft_config = config.model.peft_config
if not isinstance(peft_config, peft.PeftConfig):
if isinstance(peft_config, dict):
peft_config = peft.get_peft_config(peft_config)
else:
raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.")
model = peft.get_peft_model(model, peft_config)
if self.accelerator.is_main_process:
model.print_trainable_parameters()

return model

def loss(self, batch):
labels = batch.input_ids.clone()
loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss
stats = {"loss": loss.item()}

return loss, stats

def create_train_dataloader(self):
return self.accelerator.prepare(self.store.create_loader(self.config.train.batch_size))

def prepare_learning(self):
self.epoch_count = 0
self.iter_count = 0
self.n_inner_epochs = 1
# because of variable number of samples per each improvement steps
# there is no way to get the estimate, so here it's just copied from the config
self.total_steps = self.config.train.total_steps

self.generations_per_prompt = defaultdict(list)

eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size)
self.model, self.opt, self.eval_dataloader = self.accelerator.prepare(self.model, self.opt, eval_dataloader)

self.make_experience()

def add_prompt_pipeline(self, pipeline: PromptPipeline):
"""Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage"""
prompt_dataloader = pipeline.create_loader(self.config.train.batch_size)
self.prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader)

def post_epoch_callback(self):
self.make_experience()
self.epoch_count += 1

def make_experience(self): # noqa:
if self.epoch_count % self.config.method.n_improve_steps == 0:
# generate n samples for each prompt in the prompt_dataloader
generations = []
for batch in tqdm(self.prompt_dataloader, desc="Generating", disable=not self.accelerator.is_main_process):
for _ in range(self.config.method.n_generations_per_prompt):
samples = self.generate(**batch)
str_samples, str_prompts, str_outputs = self.decode(batch.input_ids, samples, append_eos_token=True)
generations.extend({"prompt": p, "output": o} for p, o in zip(str_prompts, str_outputs))

if torch.distributed.is_initialized():
all_generations = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(all_generations, generations)
generations = list(itertools.chain(*all_generations))

# score the generations
if self.accelerator.is_main_process:
all_scores = self.reward_fn(
samples=[x["prompt"] + x["output"] for x in generations],
prompts=[x["prompt"] for x in generations],
outputs=[x["output"] for x in generations],
)

all_scores = torch.tensor(all_scores, device=self.accelerator.device)
else:
all_scores = torch.zeros(len(generations), device=self.accelerator.device)
if torch.distributed.is_initialized():
torch.distributed.broadcast(all_scores, src=0)
scores = all_scores
else:
scores = all_scores

for g, s in zip(generations, scores):
self.generations_per_prompt[g["prompt"]].append({"output": g["output"], "score": s.item()})

scores = [[x["score"] for x in self.generations_per_prompt[p]] for p in self.generations_per_prompt]

percentile_delta = (
self.config.method.end_percentile - self.config.method.start_percentile
) / self.config.method.n_improve_steps
percentile = self.config.method.start_percentile + percentile_delta * (
self.epoch_count % self.config.method.n_improve_steps
)
thresholds = np.array([np.quantile(np.array(scores), percentile) for scores in scores])
# corner case for quantized rewards: don't include the min values, but don't exclude the max values
thresholds = np.clip(thresholds, thresholds.min() + 1e-3, thresholds.max() - 1e-3)

# filter out the generations with a score below the percentile per prompt
samples_selected = []
for prompt, threshold in zip(self.generations_per_prompt, thresholds):
for x in self.generations_per_prompt[prompt]:
if x["score"] >= threshold:
samples_selected.append([prompt, x["output"]])

# deduplicate the samples
samples_selected = list({tuple(x) for x in samples_selected})

self.accelerator.log(
{
"scores_per_single_prompt": wandb.Histogram(scores[0]),
"thresholds": wandb.Histogram(thresholds),
"scores_mean": np.mean(np.hstack(scores)),
"scores_dist": wandb.Histogram(np.hstack(scores)),
"len_samples_selected": len(samples_selected),
"samples_per_single_prompt": wandb.Table(
data=list(
zip(
[x[0] for x in samples_selected[:128]],
[x[1] for x in samples_selected[:128]],
)
),
columns=["prompt", "output"],
),
},
step=self.iter_count,
)

if len(samples_selected):
self.store = PromptPipeline(
samples_selected, max_prompt_length=2048, tokenizer=self.tokenizer, add_special_tokens=True
)
2 changes: 1 addition & 1 deletion trlx/trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def train( # noqa: C901
batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1))
max_prompt_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]

# Online training against a reward function (e.g. PPO)
# Online training against a reward function (e.g. PPO, RFT)
if reward_fn:
prompts = prompts or [trainer.tokenizer.bos_token] * batch_size

Expand Down
1 change: 1 addition & 0 deletions trlx/utils/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from trlx.trainer import _TRAINERS, register_trainer
from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer
from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer
from trlx.trainer.accelerate_rft_trainer import AccelerateRFTTrainer
from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer

try:
Expand Down

0 comments on commit bcd237f

Please sign in to comment.