Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CGPO] Calibrated reward #2155

Closed
wants to merge 5 commits into from

Conversation

gaetanlop
Copy link
Contributor

@gaetanlop gaetanlop commented Oct 2, 2024

What does this PR do?

Adds a get_calibrated_reward function as introduced in the CGPO paper of Meta. Please refer to equation 5 in section 4.1.1 for more information on this (https://arxiv.org/pdf/2409.20370).

This PR should be part of a set of PRs to incorporate CGPO in trl.

Fixes # (issue)

#2156

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

@kashif @lewtun

@gaetanlop gaetanlop changed the title [CGBO] Calibrated reward [CGPO] Calibrated reward Oct 2, 2024
@kashif
Copy link
Collaborator

kashif commented Oct 2, 2024

the way i understood the calibrated reward was that the scores from a reward model might not be comparable across different completions given a prompt, and thus for some completion baseline or ground-truth completion for a prompt s and completion a the calibrated reward should be: Rcalib(s, a) = σ(reward_model(s, a) − reward_model(s, ā))

my implementation is:

    def _compute_calib_rewards(self, completions, prompts, ground_truth_completions):
        context_length = prompts["input_ids"].shape[1]
        with torch.no_grad():
            _, generated_scores, _ = get_reward(
                self.reward_model, completions["input_ids"], self.tokenizer.pad_token_id, context_length
            )

        # Compute scores for ground-truth completions
        ground_truth_input_ids = torch.cat([prompts["input_ids"], ground_truth_completions["input_ids"]], dim=1)
        _, ground_truth_scores, _ = get_reward(
            self.reward_model, ground_truth_input_ids, self.tokenizer.pad_token_id, context_length
        )

        if self.args.missing_eos_penalty is not None:
            completion_contain_eos = torch.any(completions["input_ids"] == self.tokenizer.eos_token_id, dim=-1)
            generated_scores[~completion_contain_eos] -= self.args.missing_eos_penalty
            ground_truth_contain_eos = torch.any(
                ground_truth_completions["input_ids"] == self.tokenizer.eos_token_id, dim=-1
            )
            ground_truth_scores[~ground_truth_contain_eos] -= self.args.missing_eos_penalty

        return F.sigmoid(generated_scores - ground_truth_scores)

@gaetanlop
Copy link
Contributor Author

Thanks for looking at it @kashif. Your code and mine are exactly the same except for the missing_eos_penalty part that I did not put in the function to be consistent with your function get_reward (we can handle the missing_eos_penalty part in the potential CGPOTrainer as done in the OnlineDPOTrainer). Apart from that, we have the same implementation. You are assuming that the ground_truth_completion do not contain the prompt, while I am assuming it contains it.

Also, I am computing the reward for both a and in a single forward pass by concatenating them, while you are computing it separately for a and .

Example using your naming conventions and assuming both query_responses and baselines_responses contain the prompt:

batch_size = query_responses.shape[0]
concatenated_responses = torch.cat(
        (query_responses, baseline_responses),
        dim=0,
    )

reward_logits, final_rewards, sequence_lengths = get_reward(
    model, concatenated_responses, pad_token_id, context_length
)

generated_scores, ground_truth_scores = final_rewards.split(batch_size, dim=0)

final_rewards = F.sigmoid(generated_scores-ground_truth_scores)

For the returns, I am also returning all the calibrated_logits, and the sequence_lengths alongside what you are returning (final calibrated reward) to be consistent with the get_reward function of trl.

My implementation lacked a sigmoid function for the reward_logits thought.

Am I correct?

@kashif
Copy link
Collaborator

kashif commented Oct 2, 2024

ah right right! you are right!

@kashif
Copy link
Collaborator

kashif commented Oct 2, 2024

so the reason i have the stuff split is because of padding... when i join the two different completions i have to pad them together while its slightly easier to pad each completion... and then i was scared if by contacting the memory needs might be too much for largish reward models... but yes makes sense

@gaetanlop
Copy link
Contributor Author

Yes @kashif 100% agree, there are pros and cons for both methods. It also depends on the distributed training strategy you are using to train the model.

In any case, I checked the trl code base and it seems you adopted this concatenation method in the OnlineDPOTrainer

_, scores, _ = get_reward(
self.reward_model, prompt_completion_ids, self.tokenizer.pad_token_id, context_length
)
# Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
contain_eos_token = torch.any(completion_ids == self.tokenizer.eos_token_id, dim=-1)
if self.args.missing_eos_penalty is not None:
scores[~contain_eos_token] -= self.args.missing_eos_penalty
# Split the scores in 2 (the prompts of the first half are the same as the second half)
first_half, second_half = scores.split(num_examples)

I think we should keep it this way. Wdyt?

@gaetanlop gaetanlop closed this Oct 6, 2024
@gaetanlop
Copy link
Contributor Author

Closing in favor of #2190

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants