Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Feature Request] Add KTO #590

Open
1485840691-eng opened this issue Jan 27, 2024 · 0 comments
Open

[New Feature Request] Add KTO #590

1485840691-eng opened this issue Jan 27, 2024 · 0 comments
Labels
feature request New feature or request

Comments

@1485840691-eng
Copy link

🚀 The feature, motivation, and pitch

https://github.com/ContextualAI/HALOs/tree/main
This is a new human alignment method, different from DPO and PPO, it does not depend on pairwise comparison data but work on pointwise evaluation data. According to the paper, KTO outperforms DPO , PPO in a few public benchmarks.

TRL also incorporates KTO loss into its dpo trainer and have opened a PR to create a KTOTrainer to enable e2e KTO training. huggingface/trl#1181

Given its promise, would suggest to support it in this repo.
If nobody is working on this now, I would like to work on it.

Any concerns, please let me know.

The KTO loss could be simply sketched as:

`class SimpleKTOTrainer(UnpairedPreferenceTrainer):
"""A simple version of KTO meant to introduce you to the HALOs repo."""
def loss(self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the Kahneman-Tversky loss for a batch of policy and reference model log probabilities.
For each batch of n/2 chosen examples and n/2 rejected examples (belonging to n different inputs), calculate the loss as follows.

If generation y ~ p_chosen, where x' ~ are the examples with rejected generations, we have the 'chosen' loss:
L(x, y) := 1 - sigmoid(beta * (log p_policy(y|x) - log p_reference(y|x) - KL(p_policy(y_rejected|x') || p_reference(y_rejected|x')))
If generation y ~ p_rejected, , where x' ~ are the examples with chosen generations, we have the 'rejected' loss:
L(x, y) := 1 - sigmoid(beta * KL(p_policy(y_chosen|x') || p_reference(y_chosen|x')) - [log p_policy(y|x) - log p_reference(y|x)])
"""
return losses, chosen_rewards, rejected_rewards`

Alternatives

No response

Additional context

No response

@1485840691-eng 1485840691-eng added the feature request New feature or request label Jan 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant