Skip to content

Commit

Permalink
Update generation utilities (#172)
Browse files Browse the repository at this point in the history
* feat(base_trainer): enable sweeping over a single `gen_kwargs` value

* refactor(base_trainer): rename relevant variables

* fix(base_trainer): initialize `gen_sweep_arg` regardless

* feat(base_trainer): change `reward_fn`'s signature to accept kwargs

* merge(base_trainer): refactor to reflect main

* feat(*_trainer): add `stop_word`

* refactor(base_trainer): remove `seq2seq` if-case

* refactor(base_trainer): clean up logging of samples

* fix(base_trainer): remove inconsistencies

* fix(ppo_orchestrator): consistent padding and gpu device

* feat(base_trainer): add `rich` as dependency

* chore(examples): update signatures

* fix(ppo_orchestrator): logprob gather indexing

* docs(trlx): update `train`'s signature

* fix(base_trainer): disable `save_best` when training with deepspeed

* merge(base): complete merge

* feat(base_trainer): rework `stop_word` -> `stop_sequences`

* docs(base_trainer): update `decode`'s signature

* chore(base_trainer): `print` -> `print_rank_0`

* feat(base_trainer): clean up table's output

* feat(base_trainer): add number of gpus to the run's name

* style(trlx): satisfy black

* style(wandb): satisfy isort
  • Loading branch information
maxreciprocate authored Jan 13, 2023
1 parent 400dcfd commit 84dd156
Show file tree
Hide file tree
Showing 16 changed files with 297 additions and 176 deletions.
2 changes: 1 addition & 1 deletion examples/architext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from trlx.data.configs import TRLConfig


def reward_fn(samples):
def reward_fn(samples, **kwargs):
"Gives a negative count of rooms for each sample"
return [-sample.count(":") for sample in samples]

Expand Down
2 changes: 1 addition & 1 deletion examples/ilql_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(hparams={}):
device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1,
)

def metric_fn(samples: List[str]) -> Dict[str, List[float]]:
def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]:
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
return {"sentiments": sentiments}

Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main(hparams={}):
device=device,
)

def reward_fn(samples: List[str]) -> List[float]:
def reward_fn(samples: List[str], **kwargs) -> List[float]:
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
return sentiments

Expand Down
4 changes: 2 additions & 2 deletions examples/randomwalks/configs/ilql_randomwalks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ method:
two_qs: true
gen_kwargs:
max_new_tokens: 9
top_k: 1
beta: 100
top_k: 10
beta: [0, 1, 100]
temperature: 1.0
3 changes: 2 additions & 1 deletion examples/randomwalks/ilql_randomwalks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def main(hparams={}):
GPT2Config(n_layer=6, n_embd=144, vocab_size=23),
dataset=(walks, rewards),
eval_prompts=eval_prompts,
metric_fn=metric_fn,
metric_fn=lambda samples, **kwargs: metric_fn(samples),
config=config,
stop_sequences=["|"],
)


Expand Down
4 changes: 2 additions & 2 deletions examples/randomwalks/ppo_randomwalks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def main(hparams={}):

trlx.train(
"CarperAI/randomwalks",
reward_fn=lambda walks: metric_fn(walks)["optimality"],
reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"],
prompts=prompts,
eval_prompts=prompts,
metric_fn=metric_fn,
metric_fn=lambda samples, **kwargs: metric_fn(samples),
config=config,
)

Expand Down
14 changes: 6 additions & 8 deletions examples/summarize_daily_cnn/t5_summarize_daily_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@

if __name__ == "__main__":

def reward_fn(samples: List[str]):
sep_token = tokenizer.sep_token
articles = [sample.split(sep_token)[0].strip() for sample in samples]
predicted_summaries = [sample.split(sep_token)[1].strip() for sample in samples]
labels = [prompt_label[sample] for sample in articles]
def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
original_summaries = [prompt_label[prompt.strip()] for prompt in prompts]
scores = [
meteor.compute(predictions=[summary], references=[label])
for (summary, label) in zip(predicted_summaries, labels)
meteor.compute(predictions=[output.strip()], references=[original])[
"meteor"
]
for (original, output) in zip(original_summaries, outputs)
]
scores = [score["meteor"] for score in scores]
return scores

dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data")
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ install_requires =
torchtyping
transformers>=4.21.2
tqdm
rich
wandb>=0.13.5
ray>=2.0.1
tabulate>=0.9.0
Expand Down
64 changes: 32 additions & 32 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from time import time
from typing import Callable, Optional

import ray
import torch
import torch.nn.functional as F

from trlx.data.accelerate_base_datatypes import PromptBatch
from trlx.data.ppo_types import PPORLElement
Expand All @@ -24,8 +24,6 @@ def __init__(
self,
trainer: BaseRLTrainer,
pipeline: BasePipeline,
reward_fn: Callable,
metric_fn: Optional[Callable] = None,
chunk_size: int = 512,
):
self.pipeline = pipeline
Expand All @@ -43,8 +41,6 @@ def __init__(
self.ref_model.to(self.trainer.accelerator.device)

self.trainer.orch = self
self.trainer.reward_fn = reward_fn
self.trainer.metric_fn = metric_fn

self.running = RunningMoments()
self.ref_mean = self.trainer.config.method.ref_mean
Expand All @@ -65,9 +61,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
stats = {}
clock = Clock()
while len(ppo_rl_elements) < num_rollouts:
if self.trainer.accelerator.is_main_process:
print(f"Making experience {len(ppo_rl_elements)} / {num_rollouts}")

# Get next batch in prompt dataset and refresh if exhausted
try:
batch: PromptBatch = next(self.pipeline_iterator)
Expand All @@ -79,30 +72,38 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
samples = self.trainer.generate(**batch)
stats["time/exp_generate"] = time() - exp_generate_time

if self.trainer.config.model.model_arch_type == "seq2seq":
response_tensors = samples
else:
query_tensors = batch.input_ids
response_tensors = samples[:, query_tensors.shape[1] :]

texts = self.trainer.tokenizer.batch_decode(
samples, skip_special_tokens=True
query_tensors = batch.input_ids
device = samples.device
str_samples, str_prompts, str_outputs = self.trainer.decode(
query_tensors, samples
)

if self.trainer.config.model.model_arch_type == "seq2seq":
articles = self.trainer.tokenizer.batch_decode(
batch.input_ids, skip_special_tokens=True
# Convert trimmed samples back into tensors for another head pass
# This can be defered, instead letting the pass to made over the original samples
# after unbinding and truncating operations lower are fixed
outputs = self.trainer.tokenizer(str_outputs).input_ids
outputs = list(map(torch.LongTensor, outputs))
maxsize = max(map(len, outputs))
outputs = [
F.pad(
output,
(0, maxsize - len(output)),
value=self.trainer.tokenizer.pad_token_id,
)
sep_token = self.trainer.tokenizer.sep_token
texts = [
f"{article}{sep_token}{response}"
for article, response in zip(articles, texts)
]
for output in outputs
]
response_tensors = torch.vstack(outputs).to(device)

exp_score_time = time()

scores = torch.tensor(
self.score(texts), device=samples.device, dtype=torch.float
)
self.trainer.reward_fn(
samples=str_samples,
prompts=str_prompts,
outputs=str_outputs,
),
dtype=float,
).to(device)
stats["time/exp_score"] = time() - exp_score_time

# store statistics of the initial rollout as reference
Expand All @@ -125,9 +126,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq

# Precompute logprobs, values
if self.trainer.config.model.model_arch_type == "seq2seq":
response_tensors = response_tensors
attention_mask = batch.attention_mask.to(response_tensors.device)
query_tensors = batch.input_ids.to(response_tensors.device)
attention_mask = batch.attention_mask.to(device)
query_tensors = batch.input_ids.to(device)
with torch.no_grad():
outputs = self.trainer.model(
input_ids=query_tensors,
Expand All @@ -150,12 +150,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
).logits
else:
all_tokens = torch.cat(
(query_tensors.to(response_tensors.device), response_tensors), dim=1
(query_tensors.to(device), response_tensors), dim=1
)
attention_mask = (
all_tokens.not_equal(self.trainer.tokenizer.pad_token_id)
.long()
.to(all_tokens.device)
.to(device)
)
with torch.no_grad():
logits, *_, values = self.trainer.model(
Expand All @@ -175,7 +175,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
attention_mask=attention_mask,
return_dict=False,
)
ref_logits = ref_logits.to(self.trainer.accelerator.device)
ref_logits = ref_logits.to(device)

if self.trainer.config.model.model_arch_type == "seq2seq":
logprobs = logprobs_from_logits(
Expand Down
6 changes: 2 additions & 4 deletions trlx/ray_tune/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import wandb

from trlx.utils import significant

import wandb.apis.reports as wb # isort: skip


Expand Down Expand Up @@ -39,10 +41,6 @@ def parse_result(result):
return out


def significant(x):
return round(x, 1 - int(math.floor(math.log10(x))))


def log_trials(trial_path: str, project_name: str):
trial_path = Path(trial_path)
files = os.listdir(trial_path)
Expand Down
14 changes: 13 additions & 1 deletion trlx/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,22 @@ def register_class(cls, name):

@register_trainer
class BaseRLTrainer:
def __init__(self, config: TRLConfig, train_mode=False):
def __init__(
self,
config: TRLConfig,
reward_fn=None,
metric_fn=None,
logit_mask=None,
stop_sequences=None,
train_mode=False,
):
self.store: BaseRolloutStore = None
self.config = config
self.reward_fn = reward_fn
self.metric_fn = metric_fn
self.train_mode = train_mode
self.logit_mask = logit_mask
self.stop_sequences = stop_sequences

def push_to_store(self, data):
self.store.push(data)
Expand Down
Loading

0 comments on commit 84dd156

Please sign in to comment.