From eeb973f09325f8871af9a1ac255f005d9768b2fc Mon Sep 17 00:00:00 2001 From: Gaetan LOPEZ Date: Tue, 1 Oct 2024 21:25:31 -0400 Subject: [PATCH 1/4] skeleton --- tests/test_utils.py | 3 +++ trl/trainer/utils.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index d23e18c841..b094d942f5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -169,3 +169,6 @@ def test_val_none(self): assert "my_model" in card_text assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text assert "My Trainer" in card_text + +class TestGetReward(unittest.TestCase): + pass \ No newline at end of file diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index cdec1f4d45..4da2655bc1 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1053,7 +1053,7 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long): def get_reward( - model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int + model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int, baseline_responses: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the reward logits and the rewards for a given model and query responses. @@ -1067,6 +1067,8 @@ def get_reward( The token ID representing the pad token. context_length (`int`): The length of the context in the query responses. + baseline_responses (`torch.Tensor`): + The tensor containing the baseline responses for reward calibration. See section 4.1.1 of https://arxiv.org/pdf/2409.20370 for more information. Returns: tuple: From 20c2892f96c02397ba04aece99a38cca71f1cba2 Mon Sep 17 00:00:00 2001 From: Gaetan LOPEZ Date: Tue, 1 Oct 2024 22:07:20 -0400 Subject: [PATCH 2/4] calibrated reward fn --- tests/test_utils.py | 35 +++++++++++++++++++++--- trl/__init__.py | 4 +-- trl/trainer/utils.py | 63 +++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index b094d942f5..471a07e53c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,12 +15,18 @@ import unittest import torch -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl.trainer.model_config import ModelConfig -from trl.trainer.utils import decode_and_strip_padding, generate_model_card, get_peft_config, pad +from trl.trainer.utils import ( + decode_and_strip_padding, + generate_model_card, + get_peft_config, + pad, + get_calibrated_reward, +) if is_peft_available(): @@ -170,5 +176,26 @@ def test_val_none(self): assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text assert "My Trainer" in card_text -class TestGetReward(unittest.TestCase): - pass \ No newline at end of file + +class TestGetCalibratedReward(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + def test_basic_functionality(self): + textual_query_responses = ["The color of the sky is blue.", "The color of the sun is yellow."] + textual_baseline_responses = [ + "The color of the sky is dependent of the color of the sun.", + "The color of the sun is dependent of the color of the sky.", + ] + + query_responses = self.tokenizer(textual_query_responses, padding=True, return_tensors="pt")["input_ids"] + baseline_responses = self.tokenizer(textual_baseline_responses, padding=True, return_tensors="pt")["input_ids"] + + _, scores, _ = get_calibrated_reward( + self.model, query_responses, baseline_responses, self.tokenizer.pad_token_id, 5 + ) + + self.assertTrue(torch.all((scores >= 0) & (scores <= 1)).item(), "At least one element is not between 0 and 1") diff --git a/trl/__init__.py b/trl/__init__.py index 87ce9bfa63..38acf3373a 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -95,7 +95,7 @@ "XPOTrainer", ], "trainer.callbacks": ["RichProgressCallback", "SyncRefModelCallback"], - "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], + "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "get_calibrated_reward"], } try: @@ -190,7 +190,7 @@ XPOTrainer, ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback - from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config + from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_calibrated_reward try: if not is_diffusers_available(): diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 4da2655bc1..748815d435 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1053,7 +1053,10 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long): def get_reward( - model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int, baseline_responses: torch.Tensor = None + model: torch.nn.Module, + query_responses: torch.Tensor, + pad_token_id: int, + context_length: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the reward logits and the rewards for a given model and query responses. @@ -1067,8 +1070,6 @@ def get_reward( The token ID representing the pad token. context_length (`int`): The length of the context in the query responses. - baseline_responses (`torch.Tensor`): - The tensor containing the baseline responses for reward calibration. See section 4.1.1 of https://arxiv.org/pdf/2409.20370 for more information. Returns: tuple: @@ -1104,6 +1105,62 @@ def get_reward( ) +def get_calibrated_reward( + model: torch.nn.Module, + query_responses: torch.Tensor, + baseline_responses: torch.Tensor, + pad_token_id: int, + context_length: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the reward logits and the calibrated rewards for a given model, query responses and baseline responses. + Please refer to section 4.1.1 eqn (5) of the CGPO paper https://arxiv.org/pdf/2409.20370 + + Args: + model (`torch.nn.Module`): + The model used to compute the reward logits. + query_responses (`torch.Tensor`): + The tensor containing the query responses. + baseline_responses (`torch.Tensor`): + The tensor containing the baseline responses. + pad_token_id (`int`): + The token ID representing the pad token. + context_length (`int`): + The length of the context in the query responses. + baseline_responses (`torch.Tensor`): + The tensor containing the baseline responses for reward calibration. See section 4.1.1 of https://arxiv.org/pdf/2409.20370 for more information. + + Returns: + tuple: + - `reward_logits` (`torch.Tensor`): + The calibrated logits for the reward model. + - `final_rewards` (`torch.Tensor`): + The final calibrated rewards for each query response. + - `sequence_lengths` (`torch.Tensor`): + The lengths of the sequences in the query responses. + """ + len_responses = query_responses.shape[0] + max_length = max(query_responses.shape[1], baseline_responses.shape[1]) + query_responses = pad_to_length(query_responses, max_length, pad_value=pad_token_id) + baseline_responses = pad_to_length(baseline_responses, max_length, pad_value=pad_token_id) + + concatenated_responses = torch.cat( + (query_responses, baseline_responses), + dim=0, + ) + + reward_logits, final_rewards, sequence_lengths = get_reward( + model, concatenated_responses, pad_token_id, context_length + ) + + reward_logits = reward_logits[:len_responses] - reward_logits[len_responses:] + # computes the calibrated rewards as done in eqn (5) of the CGPO paper: https://arxiv.org/pdf/2409.20370 + final_rewards = torch.nn.functional.sigmoid(final_rewards[:len_responses] - final_rewards[len_responses:]) + sequence_lengths = sequence_lengths[:len_responses] + + return reward_logits, final_rewards, sequence_lengths + + def forward( model: torch.nn.Module, query_responses: torch.Tensor, From f1a70e75d050df8276895bcc24b854c31ebc0abb Mon Sep 17 00:00:00 2001 From: Gaetan LOPEZ Date: Tue, 1 Oct 2024 22:14:01 -0400 Subject: [PATCH 3/4] fix small typo in doc --- trl/trainer/utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 748815d435..b4afa63b99 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1053,10 +1053,7 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long): def get_reward( - model: torch.nn.Module, - query_responses: torch.Tensor, - pad_token_id: int, - context_length: int, + model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the reward logits and the rewards for a given model and query responses. @@ -1127,8 +1124,6 @@ def get_calibrated_reward( The token ID representing the pad token. context_length (`int`): The length of the context in the query responses. - baseline_responses (`torch.Tensor`): - The tensor containing the baseline responses for reward calibration. See section 4.1.1 of https://arxiv.org/pdf/2409.20370 for more information. Returns: tuple: From 37e5caafe84e48ea2a7601aa9cb138cc691d7ed5 Mon Sep 17 00:00:00 2001 From: Gaetan LOPEZ Date: Wed, 2 Oct 2024 08:19:55 -0400 Subject: [PATCH 4/4] fix typo in reward logits lacking sigmoid --- trl/trainer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index b4afa63b99..7268aa65d9 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1148,7 +1148,7 @@ def get_calibrated_reward( model, concatenated_responses, pad_token_id, context_length ) - reward_logits = reward_logits[:len_responses] - reward_logits[len_responses:] + reward_logits = torch.nn.functional.sigmoid(reward_logits[:len_responses] - reward_logits[len_responses:]) # computes the calibrated rewards as done in eqn (5) of the CGPO paper: https://arxiv.org/pdf/2409.20370 final_rewards = torch.nn.functional.sigmoid(final_rewards[:len_responses] - final_rewards[len_responses:]) sequence_lengths = sequence_lengths[:len_responses]