Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CGPO] Calibrated reward #2155

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
import unittest

import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl.trainer.model_config import ModelConfig
from trl.trainer.utils import decode_and_strip_padding, generate_model_card, get_peft_config, pad
from trl.trainer.utils import (
decode_and_strip_padding,
generate_model_card,
get_peft_config,
pad,
get_calibrated_reward,
)


if is_peft_available():
Expand Down Expand Up @@ -169,3 +175,27 @@ def test_val_none(self):
assert "my_model" in card_text
assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
assert "My Trainer" in card_text


class TestGetCalibratedReward(unittest.TestCase):
def setUp(self):
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token

def test_basic_functionality(self):
textual_query_responses = ["The color of the sky is blue.", "The color of the sun is yellow."]
textual_baseline_responses = [
"The color of the sky is dependent of the color of the sun.",
"The color of the sun is dependent of the color of the sky.",
]

query_responses = self.tokenizer(textual_query_responses, padding=True, return_tensors="pt")["input_ids"]
baseline_responses = self.tokenizer(textual_baseline_responses, padding=True, return_tensors="pt")["input_ids"]

_, scores, _ = get_calibrated_reward(
self.model, query_responses, baseline_responses, self.tokenizer.pad_token_id, 5
)

self.assertTrue(torch.all((scores >= 0) & (scores <= 1)).item(), "At least one element is not between 0 and 1")
4 changes: 2 additions & 2 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"XPOTrainer",
],
"trainer.callbacks": ["RichProgressCallback", "SyncRefModelCallback"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "get_calibrated_reward"],
}

try:
Expand Down Expand Up @@ -190,7 +190,7 @@
XPOTrainer,
)
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_calibrated_reward

try:
if not is_diffusers_available():
Expand Down
54 changes: 54 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,60 @@ def get_reward(
)


def get_calibrated_reward(
model: torch.nn.Module,
query_responses: torch.Tensor,
baseline_responses: torch.Tensor,
pad_token_id: int,
context_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the reward logits and the calibrated rewards for a given model, query responses and baseline responses.
Please refer to section 4.1.1 eqn (5) of the CGPO paper https://arxiv.org/pdf/2409.20370

Args:
model (`torch.nn.Module`):
The model used to compute the reward logits.
query_responses (`torch.Tensor`):
The tensor containing the query responses.
baseline_responses (`torch.Tensor`):
The tensor containing the baseline responses.
pad_token_id (`int`):
The token ID representing the pad token.
context_length (`int`):
The length of the context in the query responses.

Returns:
tuple:
- `reward_logits` (`torch.Tensor`):
The calibrated logits for the reward model.
- `final_rewards` (`torch.Tensor`):
The final calibrated rewards for each query response.
- `sequence_lengths` (`torch.Tensor`):
The lengths of the sequences in the query responses.
"""
len_responses = query_responses.shape[0]
max_length = max(query_responses.shape[1], baseline_responses.shape[1])
query_responses = pad_to_length(query_responses, max_length, pad_value=pad_token_id)
baseline_responses = pad_to_length(baseline_responses, max_length, pad_value=pad_token_id)

concatenated_responses = torch.cat(
(query_responses, baseline_responses),
dim=0,
)

reward_logits, final_rewards, sequence_lengths = get_reward(
model, concatenated_responses, pad_token_id, context_length
)

reward_logits = torch.nn.functional.sigmoid(reward_logits[:len_responses] - reward_logits[len_responses:])
# computes the calibrated rewards as done in eqn (5) of the CGPO paper: https://arxiv.org/pdf/2409.20370
final_rewards = torch.nn.functional.sigmoid(final_rewards[:len_responses] - final_rewards[len_responses:])
sequence_lengths = sequence_lengths[:len_responses]

return reward_logits, final_rewards, sequence_lengths


def forward(
model: torch.nn.Module,
query_responses: torch.Tensor,
Expand Down