diff --git a/tests/test_utils.py b/tests/test_utils.py index d23e18c841..471a07e53c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,12 +15,18 @@ import unittest import torch -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl.trainer.model_config import ModelConfig -from trl.trainer.utils import decode_and_strip_padding, generate_model_card, get_peft_config, pad +from trl.trainer.utils import ( + decode_and_strip_padding, + generate_model_card, + get_peft_config, + pad, + get_calibrated_reward, +) if is_peft_available(): @@ -169,3 +175,27 @@ def test_val_none(self): assert "my_model" in card_text assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text assert "My Trainer" in card_text + + +class TestGetCalibratedReward(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + def test_basic_functionality(self): + textual_query_responses = ["The color of the sky is blue.", "The color of the sun is yellow."] + textual_baseline_responses = [ + "The color of the sky is dependent of the color of the sun.", + "The color of the sun is dependent of the color of the sky.", + ] + + query_responses = self.tokenizer(textual_query_responses, padding=True, return_tensors="pt")["input_ids"] + baseline_responses = self.tokenizer(textual_baseline_responses, padding=True, return_tensors="pt")["input_ids"] + + _, scores, _ = get_calibrated_reward( + self.model, query_responses, baseline_responses, self.tokenizer.pad_token_id, 5 + ) + + self.assertTrue(torch.all((scores >= 0) & (scores <= 1)).item(), "At least one element is not between 0 and 1") diff --git a/trl/__init__.py b/trl/__init__.py index 87ce9bfa63..38acf3373a 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -95,7 +95,7 @@ "XPOTrainer", ], "trainer.callbacks": ["RichProgressCallback", "SyncRefModelCallback"], - "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], + "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "get_calibrated_reward"], } try: @@ -190,7 +190,7 @@ XPOTrainer, ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback - from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config + from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_calibrated_reward try: if not is_diffusers_available(): diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index f1ed37f93f..913ef9b55b 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1113,6 +1113,60 @@ def get_reward( ) +def get_calibrated_reward( + model: torch.nn.Module, + query_responses: torch.Tensor, + baseline_responses: torch.Tensor, + pad_token_id: int, + context_length: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the reward logits and the calibrated rewards for a given model, query responses and baseline responses. + Please refer to section 4.1.1 eqn (5) of the CGPO paper https://arxiv.org/pdf/2409.20370 + + Args: + model (`torch.nn.Module`): + The model used to compute the reward logits. + query_responses (`torch.Tensor`): + The tensor containing the query responses. + baseline_responses (`torch.Tensor`): + The tensor containing the baseline responses. + pad_token_id (`int`): + The token ID representing the pad token. + context_length (`int`): + The length of the context in the query responses. + + Returns: + tuple: + - `reward_logits` (`torch.Tensor`): + The calibrated logits for the reward model. + - `final_rewards` (`torch.Tensor`): + The final calibrated rewards for each query response. + - `sequence_lengths` (`torch.Tensor`): + The lengths of the sequences in the query responses. + """ + len_responses = query_responses.shape[0] + max_length = max(query_responses.shape[1], baseline_responses.shape[1]) + query_responses = pad_to_length(query_responses, max_length, pad_value=pad_token_id) + baseline_responses = pad_to_length(baseline_responses, max_length, pad_value=pad_token_id) + + concatenated_responses = torch.cat( + (query_responses, baseline_responses), + dim=0, + ) + + reward_logits, final_rewards, sequence_lengths = get_reward( + model, concatenated_responses, pad_token_id, context_length + ) + + reward_logits = torch.nn.functional.sigmoid(reward_logits[:len_responses] - reward_logits[len_responses:]) + # computes the calibrated rewards as done in eqn (5) of the CGPO paper: https://arxiv.org/pdf/2409.20370 + final_rewards = torch.nn.functional.sigmoid(final_rewards[:len_responses] - final_rewards[len_responses:]) + sequence_lengths = sequence_lengths[:len_responses] + + return reward_logits, final_rewards, sequence_lengths + + def forward( model: torch.nn.Module, query_responses: torch.Tensor,