diff --git a/examples/architext.py b/examples/architext.py index 592de2a2a..ca704be03 100644 --- a/examples/architext.py +++ b/examples/architext.py @@ -6,7 +6,7 @@ from trlx.data.configs import TRLConfig -def reward_fn(samples): +def reward_fn(samples, **kwargs): "Gives a negative count of rooms for each sample" return [-sample.count(":") for sample in samples] diff --git a/examples/ilql_sentiments.py b/examples/ilql_sentiments.py index 077bfe650..11b52d9f9 100644 --- a/examples/ilql_sentiments.py +++ b/examples/ilql_sentiments.py @@ -29,7 +29,7 @@ def main(hparams={}): device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1, ) - def metric_fn(samples: List[str]) -> Dict[str, List[float]]: + def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: sentiments = list(map(get_positive_score, sentiment_fn(samples))) return {"sentiments": sentiments} diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 32ee98ff8..8b2504cfb 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -38,7 +38,7 @@ def main(hparams={}): device=device, ) - def reward_fn(samples: List[str]) -> List[float]: + def reward_fn(samples: List[str], **kwargs) -> List[float]: sentiments = list(map(get_positive_score, sentiment_fn(samples))) return sentiments diff --git a/examples/randomwalks/configs/ilql_randomwalks.yml b/examples/randomwalks/configs/ilql_randomwalks.yml index e70caa02c..0642fcbe5 100644 --- a/examples/randomwalks/configs/ilql_randomwalks.yml +++ b/examples/randomwalks/configs/ilql_randomwalks.yml @@ -44,6 +44,6 @@ method: two_qs: true gen_kwargs: max_new_tokens: 9 - top_k: 1 - beta: 100 + top_k: 10 + beta: [0, 1, 100] temperature: 1.0 diff --git a/examples/randomwalks/ilql_randomwalks.py b/examples/randomwalks/ilql_randomwalks.py index 67e880084..0aed97b67 100644 --- a/examples/randomwalks/ilql_randomwalks.py +++ b/examples/randomwalks/ilql_randomwalks.py @@ -23,8 +23,9 @@ def main(hparams={}): GPT2Config(n_layer=6, n_embd=144, vocab_size=23), dataset=(walks, rewards), eval_prompts=eval_prompts, - metric_fn=metric_fn, + metric_fn=lambda samples, **kwargs: metric_fn(samples), config=config, + stop_sequences=["|"], ) diff --git a/examples/randomwalks/ppo_randomwalks.py b/examples/randomwalks/ppo_randomwalks.py index 62a38eea2..dbae71044 100644 --- a/examples/randomwalks/ppo_randomwalks.py +++ b/examples/randomwalks/ppo_randomwalks.py @@ -17,10 +17,10 @@ def main(hparams={}): trlx.train( "CarperAI/randomwalks", - reward_fn=lambda walks: metric_fn(walks)["optimality"], + reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"], prompts=prompts, eval_prompts=prompts, - metric_fn=metric_fn, + metric_fn=lambda samples, **kwargs: metric_fn(samples), config=config, ) diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index 7ae25e60e..5b2621184 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -25,16 +25,14 @@ if __name__ == "__main__": - def reward_fn(samples: List[str]): - sep_token = tokenizer.sep_token - articles = [sample.split(sep_token)[0].strip() for sample in samples] - predicted_summaries = [sample.split(sep_token)[1].strip() for sample in samples] - labels = [prompt_label[sample] for sample in articles] + def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): + original_summaries = [prompt_label[prompt.strip()] for prompt in prompts] scores = [ - meteor.compute(predictions=[summary], references=[label]) - for (summary, label) in zip(predicted_summaries, labels) + meteor.compute(predictions=[output.strip()], references=[original])[ + "meteor" + ] + for (original, output) in zip(original_summaries, outputs) ] - scores = [score["meteor"] for score in scores] return scores dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") diff --git a/setup.cfg b/setup.cfg index 3bc55caac..4a54f7747 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,6 +19,7 @@ install_requires = torchtyping transformers>=4.21.2 tqdm + rich wandb>=0.13.5 ray>=2.0.1 tabulate>=0.9.0 diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 5729a89f6..a9f6ec5e4 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -1,8 +1,8 @@ from time import time -from typing import Callable, Optional import ray import torch +import torch.nn.functional as F from trlx.data.accelerate_base_datatypes import PromptBatch from trlx.data.ppo_types import PPORLElement @@ -24,8 +24,6 @@ def __init__( self, trainer: BaseRLTrainer, pipeline: BasePipeline, - reward_fn: Callable, - metric_fn: Optional[Callable] = None, chunk_size: int = 512, ): self.pipeline = pipeline @@ -43,8 +41,6 @@ def __init__( self.ref_model.to(self.trainer.accelerator.device) self.trainer.orch = self - self.trainer.reward_fn = reward_fn - self.trainer.metric_fn = metric_fn self.running = RunningMoments() self.ref_mean = self.trainer.config.method.ref_mean @@ -65,9 +61,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq stats = {} clock = Clock() while len(ppo_rl_elements) < num_rollouts: - if self.trainer.accelerator.is_main_process: - print(f"Making experience {len(ppo_rl_elements)} / {num_rollouts}") - # Get next batch in prompt dataset and refresh if exhausted try: batch: PromptBatch = next(self.pipeline_iterator) @@ -79,30 +72,38 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq samples = self.trainer.generate(**batch) stats["time/exp_generate"] = time() - exp_generate_time - if self.trainer.config.model.model_arch_type == "seq2seq": - response_tensors = samples - else: - query_tensors = batch.input_ids - response_tensors = samples[:, query_tensors.shape[1] :] - - texts = self.trainer.tokenizer.batch_decode( - samples, skip_special_tokens=True + query_tensors = batch.input_ids + device = samples.device + str_samples, str_prompts, str_outputs = self.trainer.decode( + query_tensors, samples ) - if self.trainer.config.model.model_arch_type == "seq2seq": - articles = self.trainer.tokenizer.batch_decode( - batch.input_ids, skip_special_tokens=True + # Convert trimmed samples back into tensors for another head pass + # This can be defered, instead letting the pass to made over the original samples + # after unbinding and truncating operations lower are fixed + outputs = self.trainer.tokenizer(str_outputs).input_ids + outputs = list(map(torch.LongTensor, outputs)) + maxsize = max(map(len, outputs)) + outputs = [ + F.pad( + output, + (0, maxsize - len(output)), + value=self.trainer.tokenizer.pad_token_id, ) - sep_token = self.trainer.tokenizer.sep_token - texts = [ - f"{article}{sep_token}{response}" - for article, response in zip(articles, texts) - ] + for output in outputs + ] + response_tensors = torch.vstack(outputs).to(device) exp_score_time = time() + scores = torch.tensor( - self.score(texts), device=samples.device, dtype=torch.float - ) + self.trainer.reward_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + ), + dtype=float, + ).to(device) stats["time/exp_score"] = time() - exp_score_time # store statistics of the initial rollout as reference @@ -125,9 +126,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Precompute logprobs, values if self.trainer.config.model.model_arch_type == "seq2seq": - response_tensors = response_tensors - attention_mask = batch.attention_mask.to(response_tensors.device) - query_tensors = batch.input_ids.to(response_tensors.device) + attention_mask = batch.attention_mask.to(device) + query_tensors = batch.input_ids.to(device) with torch.no_grad(): outputs = self.trainer.model( input_ids=query_tensors, @@ -150,12 +150,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ).logits else: all_tokens = torch.cat( - (query_tensors.to(response_tensors.device), response_tensors), dim=1 + (query_tensors.to(device), response_tensors), dim=1 ) attention_mask = ( all_tokens.not_equal(self.trainer.tokenizer.pad_token_id) .long() - .to(all_tokens.device) + .to(device) ) with torch.no_grad(): logits, *_, values = self.trainer.model( @@ -175,7 +175,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask=attention_mask, return_dict=False, ) - ref_logits = ref_logits.to(self.trainer.accelerator.device) + ref_logits = ref_logits.to(device) if self.trainer.config.model.model_arch_type == "seq2seq": logprobs = logprobs_from_logits( diff --git a/trlx/ray_tune/wandb.py b/trlx/ray_tune/wandb.py index fc4a69203..a97d940fc 100644 --- a/trlx/ray_tune/wandb.py +++ b/trlx/ray_tune/wandb.py @@ -7,6 +7,8 @@ import wandb +from trlx.utils import significant + import wandb.apis.reports as wb # isort: skip @@ -39,10 +41,6 @@ def parse_result(result): return out -def significant(x): - return round(x, 1 - int(math.floor(math.log10(x)))) - - def log_trials(trial_path: str, project_name: str): trial_path = Path(trial_path) files = os.listdir(trial_path) diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 3d454d064..2bec6bad1 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -37,10 +37,22 @@ def register_class(cls, name): @register_trainer class BaseRLTrainer: - def __init__(self, config: TRLConfig, train_mode=False): + def __init__( + self, + config: TRLConfig, + reward_fn=None, + metric_fn=None, + logit_mask=None, + stop_sequences=None, + train_mode=False, + ): self.store: BaseRolloutStore = None self.config = config + self.reward_fn = reward_fn + self.metric_fn = metric_fn self.train_mode = train_mode + self.logit_mask = logit_mask + self.stop_sequences = stop_sequences def push_to_store(self, data): self.store.push(data) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 15c804ae6..87ae4dd78 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -1,24 +1,20 @@ -import importlib import json import os import sys from abc import abstractmethod from time import time -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union +import ray import torch import torch.nn.functional as F from accelerate import Accelerator # type: ignore -from transformers import AutoTokenizer - -if importlib.util.find_spec("rich") is not None: - from tqdm.rich import tqdm -else: - from tqdm import tqdm - -import ray from ray.air import session from ray.air.checkpoint import Checkpoint +from rich.console import Console +from rich.table import Table +from tqdm import tqdm +from transformers import AutoTokenizer from trlx.data.configs import TRLConfig from trlx.trainer import BaseRLTrainer, register_trainer @@ -28,6 +24,8 @@ get_git_tag, get_optimizer_class, get_scheduler_class, + print_rank_0, + significant, ) from trlx.utils.modeling import ( freeze_bottom_causal_layers, @@ -43,8 +41,8 @@ class AccelerateRLTrainer(BaseRLTrainer): RL model trainer with an `accelerate` based backend """ - def __init__(self, config, train_mode=True): - super().__init__(config, train_mode) + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) self.max_length = config.train.seq_length self.accelerator = Accelerator(log_with=config.train.trackers) if int(os.environ.get("WORLD_SIZE", 1)) > 1: @@ -69,7 +67,12 @@ def __init__(self, config, train_mode=True): model_name = str(config.model.model_path).split()[0] else: model_name = config.model.model_path.split("/")[-1] - run_name = f"{script_name}/{model_name}" + + branch = get_git_tag()[0] + run_name = ( + "/".join([script_name, model_name, f"{self.accelerator.num_processes}gpus"]) + + f":{branch}" + ) if self.accelerator.is_main_process and not ray.is_initialized(): config_dict = self.config.to_dict() @@ -80,7 +83,7 @@ def __init__(self, config, train_mode=True): init_trackers_kwargs["wandb"] = { "name": run_name, "entity": self.config.train.entity_name, - "tags": [get_git_tag()], + "tags": ["/".join(get_git_tag())], "mode": "disabled" if os.environ.get("debug", False) else "online", } self.accelerator.init_trackers( @@ -168,6 +171,52 @@ def tokenize(self, text: Union[Sequence[str], Sequence[torch.LongTensor]]): add_special_tokens=False, ) + def decode( + self, + prompts: List[torch.LongTensor], + samples: List[torch.LongTensor], + prompt_sizes: torch.LongTensor = None, + ) -> Tuple[List[str], List[str], List[str]]: + """ + Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) + """ + if prompt_sizes is None: + # Assuming prompts were left-padded + prompt_sizes = [prompts.shape[1]] * len(prompts) + + str_samples, str_prompts, str_outputs = [], [], [] + for prompt, sample, prompt_size in zip(prompts, samples, prompt_sizes): + if self.config.model.model_arch_type == "seq2seq": + output_start_ix = 0 + else: + output_start_ix = prompt_size + + str_prompt = self.tokenizer.decode( + prompt[:prompt_size], skip_special_tokens=True + ) + str_output = self.tokenizer.decode( + sample[output_start_ix:], skip_special_tokens=True + ) + + # Trim outputs up to `self.stop_sequences` if any are present + if self.stop_sequences: + for stop in self.stop_sequences: + stop_ix = str_output.find(stop) + if stop_ix >= 0: + str_output = str_output[:stop_ix].rstrip() + + str_prompts.append(str_prompt) + str_outputs.append(str_output) + + if self.config.model.model_arch_type == "seq2seq": + sample = str_prompt + self.tokenizer.sep_token + str_output + else: + sample = str_prompt + str_output + + str_samples.append(sample) + + return str_samples, str_prompts, str_outputs + def generate(self, input_ids, attention_mask=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" input_ids = input_ids.to(self.accelerator.device) @@ -218,125 +267,157 @@ def add_eval_pipeline(self, eval_pipeline): def evaluate(self): # noqa: C901 """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided""" stats = {} - all_samples = [] - prompts_sizes = [] - prompts_list = [] - generate_time = time() - for prompts in self.eval_dataloader: - if isinstance(prompts, torch.Tensor): - samples = self.generate_eval(prompts) - else: - samples = self.generate_eval(**prompts) - if isinstance(samples, tuple): - samples, *_ = samples - if self.config.model.model_arch_type == "seq2seq": - pad_token = self.tokenizer.pad_token_id - all_samples.extend(samples[:, 1:]) + table = [] + + # Do multiple evaluations over a single list in `gen_kwargs` if present + if self.generate_sweep_kwarg is not None: + gen_sweep_arg, gen_sweep_values = self.generate_sweep_kwarg + else: + gen_sweep_values = [None] + + for gen_sweep_value in gen_sweep_values: + # A dedicated suffix for wandb logging + if gen_sweep_value is not None: + sweep_suffix = f"@{gen_sweep_arg}={gen_sweep_value}" else: - pad_token = self.tokenizer.eos_token_id if self.tokenizer else 0 + sweep_suffix = "" + + all_samples = [] + all_prompts = [] + prompt_sizes = [] + generate_time = time() + for prompts in self.eval_dataloader: + if self.generate_sweep_kwarg: + samples = self.generate_eval( + **prompts, **{gen_sweep_arg: gen_sweep_value} + ) + else: + samples = self.generate_eval(**prompts) + + if self.config.model.model_arch_type == "seq2seq": + samples = samples[:, 1:] + all_samples.append( F.pad( samples, (0, self.max_length - samples.shape[1]), - value=pad_token, + value=self.tokenizer.pad_token_id, ) ) - sizes = torch.tensor(prompts.input_ids.shape[1]).repeat( - len(prompts.input_ids) - ) - prompts_sizes.append(sizes.to(samples.device)) - prompts_list.extend(prompts.input_ids) + all_prompts.append( + F.pad( + prompts.input_ids, + (0, self.max_length - prompts.input_ids.shape[1]), + value=self.tokenizer.pad_token_id, + ).to(samples.device) + ) + prompt_sizes.append( + torch.tensor( + prompts.input_ids.shape[1], device=samples.device + ).repeat(len(prompts.input_ids)) + ) - stats["time/generate"] = time() - generate_time + stats["time/generate"] = time() - generate_time - if self.config.model.model_arch_type == "seq2seq": - samples = all_samples - else: samples = self.accelerator.gather(torch.vstack(all_samples)) - prompts_sizes = self.accelerator.gather(torch.hstack(prompts_sizes)) + prompts = self.accelerator.gather(torch.vstack(all_prompts)) + prompt_sizes = self.accelerator.gather(torch.hstack(prompt_sizes)) - if self.accelerator.is_main_process: - if self.tokenizer: - prompts, responses = [], [] - if self.config.model.model_arch_type == "seq2seq": - prompts = prompts_list - responses = all_samples - else: - for sample, prompt_size in zip(samples, prompts_sizes): - prompts.append(sample[:prompt_size]) - responses.append(sample[prompt_size:]) - str_prompts = self.tokenizer.batch_decode( - prompts, skip_special_tokens=True - ) - str_responses = self.tokenizer.batch_decode( - responses, skip_special_tokens=True - ) - if self.config.model.model_arch_type == "seq2seq": - str_samples = str_responses - else: - str_samples = self.tokenizer.batch_decode( - samples, skip_special_tokens=True + if self.accelerator.is_main_process: + str_samples, str_prompts, str_outputs = self.decode( + prompts, samples, prompt_sizes ) - if isinstance(str_samples[0], str): - columns_data = [str_prompts, str_responses] - else: - columns_data = [samples.tolist()] - columns = ["prompt", "response"] - # in online setting, compute the reward for validation - if self.reward_fn: - if self.config.model.model_arch_type == "seq2seq": - sep_token = self.tokenizer.sep_token - texts = [ - f"{article}{sep_token}{response}" - for article, response in zip(str_prompts, str_samples) - ] - rewards = torch.tensor(self.reward_fn(texts), dtype=torch.float) - else: + columns = ["prompt", "output"] + columns_data = [str_prompts, str_outputs] + + # in online setting, compute the reward for validation + if self.reward_fn: rewards = torch.tensor( - self.reward_fn(str_samples), dtype=torch.float + self.reward_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + ), + dtype=float, ) + mean_reward = rewards.mean().item() + columns.append("reward") + if not isinstance(rewards, list): + rewards = rewards.tolist() + columns_data.append(rewards) + stats[f"reward/mean{sweep_suffix}"] = mean_reward + + # additionally log any other metrics + if self.metric_fn: + metric_time = time() + metrics = self.metric_fn(str_samples) + stats["time/metric"] = time() - metric_time + + mean_metrics = { + f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1) + for k, xs in metrics.items() + } + + stats.update(mean_metrics) + + for metric, values in metrics.items(): + columns.append(metric) + if not isinstance(values, list): + values = values.tolist() + columns_data.append(values) + + # Prepend the sweep argument along with samples + if self.generate_sweep_kwarg: + columns.insert(0, gen_sweep_arg) + columns_data.insert(0, [gen_sweep_value] * len(samples)) + + table.append(list(zip(*columns_data))) + + # Log and display evaluation metrics + if self.accelerator.is_main_process: + rows = sum(list(map(list, zip(*table))), []) - mean_reward = rewards.mean() - columns.append("reward") - columns_data.append(rewards) - stats["reward/mean"] = mean_reward - print(f"{mean_reward=}") - - # additionally log any other metrics - if self.metric_fn: - metric_time = time() - metrics = self.metric_fn(str_samples) - stats["time/metric"] = time() - metric_time - - mean_metrics = { - f"metrics/{k}": torch.as_tensor(xs).mean(-1) - for k, xs in metrics.items() - } + # Add metrics/rewards to the table's title + table_title = f"Evaluation #{self.nth_evaluation}" + for k, x in stats.items(): + if k.startswith("reward") or k.startswith("metrics"): + table_title += f" {k}: {significant(x)}" - stats.update(mean_metrics) + rich_table = Table(*columns, title=table_title, show_lines=True) - for metric, values in metrics.items(): - columns.append(metric) - columns_data.append(values) + for ix in range(max(min(3, len(rows)), len(gen_sweep_values))): + rich_table.add_row(*[str(significant(x)) for x in rows[ix]]) - rows = list(zip(*columns_data)) - print(rows[0]) if not ray.is_initialized(): if "wandb" in self.config.train.trackers: import wandb - stats["samples"] = wandb.Table(columns=columns, rows=rows) + stats["samples"] = wandb.Table(columns, rows) + Console().print(rich_table) + + self.nth_evaluation += 1 return stats def learn(self): # noqa: C901 """ Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader` """ + self.generate_sweep_kwarg = None + for k, v in self.config.method.gen_kwargs.items(): + if isinstance(v, list): + if self.generate_sweep_kwarg is not None: + print_rank_0( + "Only a single sweep is allowed, {k} is going to be set to {v[0]}" + ) + self.generate_kwargs[k] = v[0] + else: + self.generate_sweep_kwarg = (k, v) self.prepare_learning() self.iter_count = 0 + self.nth_evaluation = 0 if ray.is_initialized(): checkpoint = session.get_checkpoint() @@ -386,7 +467,11 @@ def learn(self): # noqa: C901 results = self.evaluate() stats.update(results) - if self.config.train.save_best: + # FIXME: seems to not work with zero and barriers don't seem to help + if ( + self.config.train.save_best + and int(os.environ.get("DEEPSPEED_ZERO_STAGE", -1)) == -1 + ): if ( "reward/mean" in stats and stats["reward/mean"] > best_reward diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index a5191697b..891f56863 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -12,17 +12,8 @@ @register_trainer class AccelerateILQLTrainer(AccelerateRLTrainer): - def __init__( - self, - config: TRLConfig, - logit_mask=None, - metric_fn=None, - train_mode=True, - ): - super().__init__(config, train_mode) - self.logit_mask = logit_mask - self.metric_fn = metric_fn - self.reward_fn = None + def __init__(self, config: TRLConfig, **kwargs): + super().__init__(config, **kwargs) if not isinstance(config.method, ILQLConfig): raise ValueError("config.method must be ILQLConfig") diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 06593b2a2..4826905f7 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -22,8 +22,8 @@ @register_trainer class AcceleratePPOTrainer(AccelerateRLTrainer): - def __init__(self, config): - super().__init__(config) + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) if config.train.rollout_logging_dir is not None: self.log_rollouts = True diff --git a/trlx/trlx.py b/trlx/trlx.py index c3c2f747b..6b8321f7c 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple from trlx.data.configs import TRLConfig from trlx.utils import set_seed @@ -8,13 +8,18 @@ def train( model_path: Optional[str] = None, - reward_fn: Optional[Callable] = None, + reward_fn: Optional[ + Callable[[List[str], List[str], List[str]], List[float]] + ] = None, dataset: Optional[Iterable[Tuple[str, float]]] = None, prompts: Optional[List[str]] = None, eval_prompts: Optional[List[str]] = None, - metric_fn: Optional[Callable] = None, + metric_fn: Optional[ + Callable[[List[str], List[str], List[str]], Dict[str, List[float]]] + ] = None, config: Optional[TRLConfig] = None, logit_mask: Optional[List[List[bool]]] = None, + stop_sequences: Optional[List[str]] = [], ): """ Dispatches online or offline reinforcement training @@ -22,7 +27,9 @@ def train( Args: model_path (Optional[str]): Path to either huggingface checkpoint or a local directory - reward_fn (List[str] -> List[float]): Function to rate batches of generated samples + reward_fn (Optional[Callable[[List[str], List[str], List[str]], List[float]]]): + Function to rate batches of generated samples. Its arguments are + (`samples`, `prompts`, `outputs`) and the return is a list of `rewards` per each sample dataset (List[Union[str, List[str]]], List[float]): Lists of samples and rewards for offline training. Samples consist of a variable number of prompts (questions, environment states etc.) and outputs which are meant to be optimized. @@ -30,10 +37,17 @@ def train( Giving a single string `s` for the sample is a shorthand for (`tokenizer.bos_token`, `s`) prompts (List[str]): Prompts to sample off from during online training eval_prompts (List[str]): Prompts to periodically validate training on - metric_fn (Optional[Callable[List[str], List[float]]]): Function to compute statistics on validation samples + metric_fn (Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]]): + Function to compute statistics on batches of gnerated samples. Its arguments are the same + as in `reward_fn` (`samples`, `prompts`, `outputs`) but the return is dictionary with keys + as metric's name and values and lists of numeric values per each sample in batch config (Optional[TRLConfig]): TRL configuration object to override default settings logit_mask (Optional[List]): Bigram masking matrix + stop_sequences (Optional[List[str]]): + String sequences to trim generations (either for experience or evaluation) up to its + encounter in them. Generatations will not contain them and also will be right-stripped """ + if reward_fn is not None: if config is None: config = TRLConfig.load_yaml("configs/ppo_config.yml") @@ -42,7 +56,12 @@ def train( if model_path: config.model.model_path = model_path - trainer = get_trainer(config.train.trainer)(config) + trainer = get_trainer(config.train.trainer)( + config=config, + reward_fn=reward_fn, + metric_fn=metric_fn, + stop_sequences=stop_sequences, + ) batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) prompts = prompts or [trainer.tokenizer.bos_token] * batch_size @@ -57,7 +76,7 @@ def train( prompts, max_prompt_length, trainer.tokenizer ) orch = get_orchestrator(config.train.orchestrator)( - trainer, pipeline, reward_fn=reward_fn, chunk_size=config.method.chunk_size + trainer, pipeline, chunk_size=config.method.chunk_size ) orch.make_experience(config.method.num_rollouts) @@ -83,8 +102,9 @@ def train( trainer = get_trainer(config.train.trainer)( config=config, - logit_mask=logit_mask, metric_fn=metric_fn, + logit_mask=logit_mask, + stop_sequences=stop_sequences, ) batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) max_prompt_length = ( diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 8f25faf58..1da2d8175 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -1,9 +1,11 @@ +import math import os import random import subprocess import time from dataclasses import is_dataclass from enum import Enum +from numbers import Number from typing import Dict, Iterable import numpy as np @@ -21,6 +23,19 @@ def print_rank_0(*message): print(*message) +def significant(x: Number, ndigits=2) -> Number: + """ + Cut the number up to its `ndigits` after the most significant + """ + if isinstance(x, torch.Tensor): + x = x.item() + + if not isinstance(x, Number) or x == 0: + return x + + return round(x, ndigits - int(math.floor(math.log10(abs(x))))) + + def set_seed(seed: int): """ Sets seeds across package dependencies for reproducibility. @@ -246,4 +261,4 @@ def get_git_tag() -> str: """ output = subprocess.check_output("git log --format='%h/%as' -n1".split()) branch = subprocess.check_output("git rev-parse --abbrev-ref HEAD".split()) - return f"{branch.decode()[:-1]}/{output.decode()[1:-2]}" + return branch.decode()[:-1], output.decode()[1:-2]