-
Notifications
You must be signed in to change notification settings - Fork 470
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add rejection finetuning trainer (#554)
* feat: add rejection finetuning trainer * style: satisfy flake * fix(rft_trainer): broadcast scores to all ranks * feat(rft_trainer): dedup & clip thresholds for quantized rewards * config(rft_randomwalks): lower `total_steps`, keep 1 improve step * fix(rft_trainer): handle prompt duplicates, due to `drop_last=False` * feat(examples): add `rft_sentiments` example * style: satisfy black
- Loading branch information
1 parent
78c7faa
commit bcd237f
Showing
5 changed files
with
359 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import trlx | ||
from examples.randomwalks import generate_random_walks | ||
from trlx.data.default_configs import ( | ||
ModelConfig, | ||
OptimizerConfig, | ||
SchedulerConfig, | ||
TokenizerConfig, | ||
TrainConfig, | ||
TRLConfig, | ||
) | ||
from trlx.trainer.accelerate_rft_trainer import RFTConfig | ||
|
||
default_config = TRLConfig( | ||
train=TrainConfig( | ||
seq_length=10, | ||
epochs=100, | ||
total_steps=1000, | ||
batch_size=100, | ||
checkpoint_interval=1000, | ||
eval_interval=100, | ||
pipeline="PromptPipeline", | ||
trainer="AccelerateRFTTrainer", | ||
checkpoint_dir="checkpoints/randomwalks", | ||
), | ||
model=ModelConfig(model_path="CarperAI/randomwalks", num_layers_unfrozen=-1), | ||
tokenizer=TokenizerConfig(tokenizer_path="CarperAI/randomwalks", truncation_side="right"), | ||
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3.0e-4, betas=(0.9, 0.99), eps=1.0e-8, weight_decay=0)), | ||
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=3.0e-4)), | ||
method=RFTConfig( | ||
name="RFTConfig", | ||
n_generations_per_prompt=100, | ||
start_percentile=0.9, | ||
end_percentile=0.95, | ||
n_improve_steps=1, | ||
gen_kwargs=dict( | ||
max_new_tokens=9, | ||
top_k=0, | ||
top_p=1.0, | ||
temperature=1.0, | ||
do_sample=True, | ||
), | ||
), | ||
) | ||
|
||
|
||
def main(hparams={}): | ||
config = TRLConfig.update(default_config, hparams) | ||
metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) | ||
|
||
trlx.train( | ||
reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"], | ||
prompts=prompts, | ||
eval_prompts=prompts, | ||
metric_fn=lambda samples, **kwargs: metric_fn(samples), | ||
config=config, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import json | ||
import sys | ||
|
||
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) | ||
main(hparams) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# This script trains a model to output positive reviews | ||
# using rejection finetuning with a sentiment classifier reward function. | ||
import json | ||
import os | ||
import sys | ||
from typing import List | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from transformers import pipeline | ||
|
||
import trlx | ||
from trlx.data.default_configs import ( | ||
ModelConfig, | ||
OptimizerConfig, | ||
SchedulerConfig, | ||
TokenizerConfig, | ||
TrainConfig, | ||
TRLConfig, | ||
) | ||
from trlx.trainer.accelerate_rft_trainer import RFTConfig | ||
|
||
|
||
def get_positive_score(scores): | ||
"Extract value associated with a positive sentiment from pipeline's output" | ||
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] | ||
|
||
|
||
default_config = TRLConfig( | ||
train=TrainConfig( | ||
seq_length=1024, | ||
epochs=100, | ||
total_steps=1000, | ||
batch_size=32, | ||
checkpoint_interval=10000, | ||
eval_interval=100, | ||
pipeline="PromptPipeline", | ||
trainer="AccelerateRFTTrainer", | ||
), | ||
model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=-1), | ||
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), | ||
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), | ||
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)), | ||
method=RFTConfig( | ||
name="RFTConfig", | ||
n_generations_per_prompt=4, | ||
start_percentile=0.9, | ||
end_percentile=0.95, | ||
n_improve_steps=1, | ||
gen_kwargs=dict( | ||
max_new_tokens=40, | ||
top_k=0, | ||
top_p=1.0, | ||
temperature=1.0, | ||
do_sample=True, | ||
), | ||
), | ||
) | ||
|
||
|
||
def main(hparams={}): | ||
config = TRLConfig.update(default_config, hparams) | ||
|
||
if torch.cuda.is_available(): | ||
device = int(os.environ.get("LOCAL_RANK", 0)) | ||
else: | ||
device = -1 | ||
|
||
sentiment_fn = pipeline( | ||
"sentiment-analysis", | ||
"lvwerra/distilbert-imdb", | ||
top_k=2, | ||
truncation=True, | ||
batch_size=256, | ||
device=device, | ||
) | ||
|
||
def reward_fn(samples: List[str], **kwargs) -> List[float]: | ||
sentiments = list(map(get_positive_score, sentiment_fn(samples))) | ||
return sentiments | ||
|
||
# Take few words off of movies reviews as prompts | ||
imdb = load_dataset("imdb", split="train[:512]") | ||
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] | ||
|
||
trlx.train( | ||
reward_fn=reward_fn, | ||
prompts=prompts, | ||
eval_prompts=["I don't know much about Hungarian underground"] * 256, | ||
config=config, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) | ||
main(hparams) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import itertools | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
import torch | ||
import wandb | ||
from tqdm import tqdm | ||
from transformers import AutoModelForCausalLM, PretrainedConfig | ||
|
||
from trlx.data.configs import TRLConfig | ||
from trlx.data.method_configs import MethodConfig, register_method | ||
from trlx.pipeline.offline_pipeline import PromptPipeline | ||
from trlx.trainer import register_trainer | ||
from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer | ||
|
||
|
||
@dataclass | ||
@register_method | ||
class RFTConfig(MethodConfig): | ||
""" | ||
Config for RFT training | ||
:param gen_kwargs: kwargs for generation | ||
:type gen_kwargs: Dict[str, Any] | ||
:param start_percentile: percentile for the starting score threshold for each prompt used for the first improvement step | ||
:type start_percentile: float | ||
:param end_percentile: percentile for the final score threshold for each prompt | ||
:type end_percentile: float | ||
:param n_improve_steps: the number of improvement steps for each growth step with linearly increasing score threshold | ||
:type n_improve_steps: int | ||
:param n_generations_per_prompt: number of generations to sample per each prompt per each growth step | ||
:type n_generations_per_prompt: int | ||
""" | ||
|
||
gen_kwargs: dict | ||
start_percentile: float = 0.7 | ||
end_percentile: float = 0.95 | ||
n_improve_steps: int = 4 | ||
n_generations_per_prompt: int = 32 | ||
|
||
|
||
@register_trainer | ||
class AccelerateRFTTrainer(AccelerateRLTrainer): | ||
def __init__(self, config: TRLConfig, **kwargs): | ||
super().__init__(config, **kwargs) | ||
|
||
self.generate_kwargs = dict( | ||
config.method.gen_kwargs, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
pad_token_id=self.tokenizer.pad_token_id, | ||
) | ||
|
||
self.generate_experience_kwargs = None | ||
|
||
def get_arch(self, config): | ||
from_fn = AutoModelForCausalLM.from_pretrained | ||
if issubclass(type(config.model.model_path), PretrainedConfig): | ||
from_fn = AutoModelForCausalLM.from_config | ||
|
||
model = from_fn(config.model.model_path) | ||
|
||
if config.model.peft_config is not None: | ||
# Initialize the peft adapter | ||
import peft | ||
|
||
peft_config = config.model.peft_config | ||
if not isinstance(peft_config, peft.PeftConfig): | ||
if isinstance(peft_config, dict): | ||
peft_config = peft.get_peft_config(peft_config) | ||
else: | ||
raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.") | ||
model = peft.get_peft_model(model, peft_config) | ||
if self.accelerator.is_main_process: | ||
model.print_trainable_parameters() | ||
|
||
return model | ||
|
||
def loss(self, batch): | ||
labels = batch.input_ids.clone() | ||
loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss | ||
stats = {"loss": loss.item()} | ||
|
||
return loss, stats | ||
|
||
def create_train_dataloader(self): | ||
return self.accelerator.prepare(self.store.create_loader(self.config.train.batch_size)) | ||
|
||
def prepare_learning(self): | ||
self.epoch_count = 0 | ||
self.iter_count = 0 | ||
self.n_inner_epochs = 1 | ||
# because of variable number of samples per each improvement steps | ||
# there is no way to get the estimate, so here it's just copied from the config | ||
self.total_steps = self.config.train.total_steps | ||
|
||
self.generations_per_prompt = defaultdict(list) | ||
|
||
eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) | ||
self.model, self.opt, self.eval_dataloader = self.accelerator.prepare(self.model, self.opt, eval_dataloader) | ||
|
||
self.make_experience() | ||
|
||
def add_prompt_pipeline(self, pipeline: PromptPipeline): | ||
"""Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" | ||
prompt_dataloader = pipeline.create_loader(self.config.train.batch_size) | ||
self.prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) | ||
|
||
def post_epoch_callback(self): | ||
self.make_experience() | ||
self.epoch_count += 1 | ||
|
||
def make_experience(self): # noqa: | ||
if self.epoch_count % self.config.method.n_improve_steps == 0: | ||
# generate n samples for each prompt in the prompt_dataloader | ||
generations = [] | ||
for batch in tqdm(self.prompt_dataloader, desc="Generating", disable=not self.accelerator.is_main_process): | ||
for _ in range(self.config.method.n_generations_per_prompt): | ||
samples = self.generate(**batch) | ||
str_samples, str_prompts, str_outputs = self.decode(batch.input_ids, samples, append_eos_token=True) | ||
generations.extend({"prompt": p, "output": o} for p, o in zip(str_prompts, str_outputs)) | ||
|
||
if torch.distributed.is_initialized(): | ||
all_generations = [None for _ in range(torch.distributed.get_world_size())] | ||
torch.distributed.all_gather_object(all_generations, generations) | ||
generations = list(itertools.chain(*all_generations)) | ||
|
||
# score the generations | ||
if self.accelerator.is_main_process: | ||
all_scores = self.reward_fn( | ||
samples=[x["prompt"] + x["output"] for x in generations], | ||
prompts=[x["prompt"] for x in generations], | ||
outputs=[x["output"] for x in generations], | ||
) | ||
|
||
all_scores = torch.tensor(all_scores, device=self.accelerator.device) | ||
else: | ||
all_scores = torch.zeros(len(generations), device=self.accelerator.device) | ||
if torch.distributed.is_initialized(): | ||
torch.distributed.broadcast(all_scores, src=0) | ||
scores = all_scores | ||
else: | ||
scores = all_scores | ||
|
||
for g, s in zip(generations, scores): | ||
self.generations_per_prompt[g["prompt"]].append({"output": g["output"], "score": s.item()}) | ||
|
||
scores = [[x["score"] for x in self.generations_per_prompt[p]] for p in self.generations_per_prompt] | ||
|
||
percentile_delta = ( | ||
self.config.method.end_percentile - self.config.method.start_percentile | ||
) / self.config.method.n_improve_steps | ||
percentile = self.config.method.start_percentile + percentile_delta * ( | ||
self.epoch_count % self.config.method.n_improve_steps | ||
) | ||
thresholds = np.array([np.quantile(np.array(scores), percentile) for scores in scores]) | ||
# corner case for quantized rewards: don't include the min values, but don't exclude the max values | ||
thresholds = np.clip(thresholds, thresholds.min() + 1e-3, thresholds.max() - 1e-3) | ||
|
||
# filter out the generations with a score below the percentile per prompt | ||
samples_selected = [] | ||
for prompt, threshold in zip(self.generations_per_prompt, thresholds): | ||
for x in self.generations_per_prompt[prompt]: | ||
if x["score"] >= threshold: | ||
samples_selected.append([prompt, x["output"]]) | ||
|
||
# deduplicate the samples | ||
samples_selected = list({tuple(x) for x in samples_selected}) | ||
|
||
self.accelerator.log( | ||
{ | ||
"scores_per_single_prompt": wandb.Histogram(scores[0]), | ||
"thresholds": wandb.Histogram(thresholds), | ||
"scores_mean": np.mean(np.hstack(scores)), | ||
"scores_dist": wandb.Histogram(np.hstack(scores)), | ||
"len_samples_selected": len(samples_selected), | ||
"samples_per_single_prompt": wandb.Table( | ||
data=list( | ||
zip( | ||
[x[0] for x in samples_selected[:128]], | ||
[x[1] for x in samples_selected[:128]], | ||
) | ||
), | ||
columns=["prompt", "output"], | ||
), | ||
}, | ||
step=self.iter_count, | ||
) | ||
|
||
if len(samples_selected): | ||
self.store = PromptPipeline( | ||
samples_selected, max_prompt_length=2048, tokenizer=self.tokenizer, add_special_tokens=True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters