Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reinvestigate Fancy Gradient Checkpointing #427

Open
3 tasks
dlwh opened this issue Jan 23, 2024 · 0 comments
Open
3 tasks

Reinvestigate Fancy Gradient Checkpointing #427

dlwh opened this issue Jan 23, 2024 · 0 comments

Comments

@dlwh
Copy link
Member

dlwh commented Jan 23, 2024

Tasks:

  • Finish/debug checkpointed_scan PR Add checkpointed_scan haliax#60 (it works, but I think it throws out too much)
  • Benchmark with and without. Make sure memory use goes down while not slowing things down too much.
  • add config and expose in various models, maybe with reusable config options.

Inspired by #424, we should reinvestigate fancier gradient checkpointing.


  1. Size: 2.50G
     Operator: op_name="jit(train_step)/jit(main)/jvp(LlamaTransformer)/broadcast_in_dim[shape=(80, 64, 2048, 8192) broadcast_dimensions=()]" source_file="/home/dlwh/venv310/lib/python3.10/site-packages/haliax/hof.py" source_line=88
     Shape: bf16[80,1,2048,8192]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 2.50G
     XLA label: broadcast.659 = broadcast(constant.983), dimensions={}
     Allocation type: HLO temp
     ==========================

This is a [NumLayers, Batch, Pos, Embed] matrix that is, I believe, the cached activations from a layers.fold(x) call in llama. This is probably generally a good tensor to have in the non-limited memory case, since it's not too huge and it makes recomputing the gradients in O(L) time. However, you can be fancier, and 2.5G is still a lot of memory! In particular, you can implement multi-level scans that require O(sqrt(L)) memory and O(L) time, which is better!

For the 2.5GB of unnecessary activations, we could probably do a much fancier checkpointed scan, possibly by rolling it ourselves. There’s something in Equinox (checkpointed_while) and another thing in Flax (remat_scan) that could help. I have a version of the latter implemented in stanford-crfm/haliax#60 that seems to eliminate the 2.5GB, but it is too aggressive and throws out the sqrt(L) activations, which we should keep.

We should also improve the API (in Haliax) and add config (in Levanter), as well as benchmark.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant