Skip to content

Latest commit

 

History

History
39 lines (27 loc) · 2.01 KB

README.md

File metadata and controls

39 lines (27 loc) · 2.01 KB

mamba-tiny

Tiny implementation of Mamba in PyTorch.

Featuring:

  • Equivalent numerical output as official implementation for both forward and backward pass
  • Simplified, readable, annotated code
  • An alternative to using parallel scan (not available in pytorch as of current) via cumsum, inspired by heisen_sequence

Does NOT include:

  • Recurrent mode of the network intended for inference. The demo code (sentence generation) effectively runs the network as if it were the forward pass during training, which is much slower than the recurrent mode.
  • Kernel fusion. This repo does not make any attempt to perform kernel fusion of the selective scan operations with the other dense operations. So all the internal states of the model would be explicitly materialized, so memory usage may be a bottleneck.
  • Proper parameter initialization (though this could be added without sacrificing readability)

Demo

See demo.ipynb for examples of prompt completions.

from model import Mamba
from transformers import AutoTokenizer

model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

generate(model, tokenizer, 'Mamba is the')

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

150 meters... 🫢 scary!

References

The Mamba architecture was introduced by Albert Gu and Tri Dao. The official implementation is here: https://github.com/state-spaces/mamba/tree/main

Related works using parallel scans in log-space: