-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐾 Process-supervised RM Trainer #2127
base: main
Are you sure you want to change the base?
Conversation
This is awesome @gaetanlop ! Would you like some early feedback on the PR or would you prefer I wait a bit until it's more polished? |
Hey @lewtun, thank you for the message. Currently, the only files that are more or less ready are Implementing a PRMs seems to be pretty straighforward, it seems to be a token classification task where only prediction for the last token of each step gets assigned a label and other tokens are ignored during loss calculation. If the dataset isn’t pre-tokenized, I assume it should contain the following columns:
Are you aware of an HF dataset to train PRMs for the example file? Also, how can I add a new subset to the Thanks again for your time! |
PR ready for review. I have changed the naming conventions that I used before Tests: I created a dummy_dataset but we should add a subset to trl-internal-testing/zen as done in other scripts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the very clean PR @gaetanlop - this looks great! I've left some minor suggestions regarding the structure, but aside from that and having a smallish dataset in the right format we can sanity check that the accuracy goes up, loss goes down etc I think this is quite close to being ready
Thanks for looking at this @lewtun. Seems like |
@qgallouedec yes of course, as you prefer, I was following the implementation done in the |
Should we add the separator token between the prompt and the first step? If you don't (like the current code) you get something like: prompt = "This is my prompt."
completions = ["This is my first step.", "This is my second step."]
separator = "\n"
# Processing here
result == "This is my prompt.This is my first step. This is my second step."
# ^💀 |
I still need to |
First trained model: https://huggingface.co/qgallouedec/Qwen2-0.5B-Reward |
Refactor `tokenize_row`
@qgallouedec Thank you for the refactoring work on the tokenize_row function. I have made some adjustments to ensure proper handling of special tokens. Also, I refined the label creation process and updated the tokenize_row function to support truncation based on both max_length and max_completion_ids. I have added some tests to confirm that the updated tokenize_row function behaves as intended. I also made some experiments. The model gets 99.8% accuracy after just a few steps... It might just be predicting True all the time, I will need to double check |
|
||
## Overview | ||
|
||
Stepwise or process reward models were proposed in [Solving math word problems with processand outcome-based feedback](https://arxiv.org/pdf/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving and Irina Higgins. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stepwise or process reward models were proposed in [Solving math word problems with processand outcome-based feedback](https://arxiv.org/pdf/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving and Irina Higgins. | |
Stepwise or process reward models were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins. |
What does this PR do?
Adding support for process-supervised reward training to TRL as requested in #2110 .
List of papers using PRMs: [1], [2], [3], [4]...
Fixes # (issue)
#2110
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
@lewtun @kashif