-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run on RTX 4090 #38
base: master
Are you sure you want to change the base?
Run on RTX 4090 #38
Conversation
Thanks for your work. I will have to think about this and understand the options before merging this. One thing I would merge immediately would be a reproducible log generated by this. It can be put in the |
Perhaps a reasonable alternative is
|
@lapp0 Hey what does do? Thanks! flex_kernel_options={
"BLOCK_M": 64, "BLOCK_N": 64, # forward
"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32 # backwards
} Presumably you are passing values to the Triton kernels, but how did you come up with those values? Thanks! |
@vgoklani I stole the values from pytorch/pytorch#133254 (comment) |
definitely prefer this over how it is implemented now, especially if you want to use multiple 4090 and lets say we want to keep the context length same but change the batch size instead . |
Updated code and docs. You can now specify
Unrelated: added |
I linked to this from the README. I don't think I'll merge it because I don't want to add any arguments to |
Perhaps we could have
This is equivalent to the current code in Then we can create |
@lapp0 many thanks for your fork! I've started your fork successfully on a single RTX 4090 with minor version upgrades and a tiny fix in log-output writing:
However I am running out of GPU memory. Probably because I use single GPU setup? If you have an idea what parameters to tweak to fit in a single RTX 4090 24GB, please let me know, |
@vak you can always decrease sequence length further and increase batch size in the same ratio. |
Changes to make the code run on RTX 4090 / 3090.
Fixes #29
Runs in 2 hours 3 minutes, Runs range from 3.275 to 3.285. This finished at 3.2817,
These settings are intended to replicate the training dynamics on 8xH100, not to be optimal. This is accomplished by halving the sequence length and doubling the batch size. You can train in ~90 minutes by setting batch_size of 1.
#38 (comment)