You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a [NumLayers, Batch, Pos, Embed] matrix that is, I believe, the cached activations from a layers.fold(x) call in llama. This is probably generally a good tensor to have in the non-limited memory case, since it's not too huge and it makes recomputing the gradients in O(L) time. However, you can be fancier, and 2.5G is still a lot of memory! In particular, you can implement multi-level scans that require O(sqrt(L)) memory and O(L) time, which is better!
For the 2.5GB of unnecessary activations, we could probably do a much fancier checkpointed scan, possibly by rolling it ourselves. There’s something in Equinox (checkpointed_while) and another thing in Flax (remat_scan) that could help. I have a version of the latter implemented in stanford-crfm/haliax#60 that seems to eliminate the 2.5GB, but it is too aggressive and throws out the sqrt(L) activations, which we should keep.
We should also improve the API (in Haliax) and add config (in Levanter), as well as benchmark.
The text was updated successfully, but these errors were encountered:
Tasks:
Inspired by #424, we should reinvestigate fancier gradient checkpointing.
This is a
[NumLayers, Batch, Pos, Embed]
matrix that is, I believe, the cached activations from alayers.fold(x)
call in llama. This is probably generally a good tensor to have in the non-limited memory case, since it's not too huge and it makes recomputing the gradients in O(L) time. However, you can be fancier, and 2.5G is still a lot of memory! In particular, you can implement multi-level scans that require O(sqrt(L)) memory and O(L) time, which is better!For the 2.5GB of unnecessary activations, we could probably do a much fancier checkpointed scan, possibly by rolling it ourselves. There’s something in Equinox (checkpointed_while) and another thing in Flax (remat_scan) that could help. I have a version of the latter implemented in stanford-crfm/haliax#60 that seems to eliminate the 2.5GB, but it is too aggressive and throws out the sqrt(L) activations, which we should keep.
We should also improve the API (in Haliax) and add config (in Levanter), as well as benchmark.
The text was updated successfully, but these errors were encountered: