Feb. 9, 2024, 12:39 a.m.

[Mamba](https://arxiv.org/abs/2312.00752) is a state space model with data-dependent coefficients. It is originally trained with [associative scan](https://en.wikipedia.org/wiki/Prefix_sum), with currently is [not supported directly by pytorch](https://github.com/pytorch/pytorch/issues/95408), hence why the authors wrote custom cuda kernels for it (which has the additional benefit of kernel fusion). To simplify this, someone wrote a [minimal version of mamba](https://github.com/johnma2006/mamba-minimal) in one file, where the associative scan operation is replaced by a for loop, which sacrifices efficiency for simplicity of implementation.

However, I think there is a way …

