Attention operation having a memory bottleneck between HBM and SRAM
HBM is large in memory, but slow in processing, meanwhile SRAM is smaller in memory, but faster in operations. So flash attention reduces memory IO of GPU although it requires slight increment of computation.
You will just not materialize the attention matrix by build small matrices and keep the statistics to compute softmax along the way. After that, each small part of matrix is computed in SRAM
- Comping bound computation: matrix multiplication
- Memory bound computation: softmax, activation, batch/layer normalization
GPU architecture focuses on computing bound computation, flash attention force it to optimize memory bound computation.
Tiling
- Separate computation of attention into the block size optimized for SRAM size
- Load once from HBM to SRAM for each block
- Apply all computation in SRAM and save the result to HBM
- Repeat this per block and aggregate results by kernel fusion
Recomputation
Flash attention do not save and reuse the result of intermediate matrix to reduce IO like Gradient checkpointing.
Conclusion
Flash Attention loads keys, queries, and values once, fuses the operations of the attention mechanism, and writes them back while traditional attention loads keys, queries, and values from HBM to GPU on-chip SRAM, performs a single step of the attention mechanism, writes it back to HBM, and repeats this for every single attention step.
Flash Attention usages
It is very effective method but flash attention requires CUDA implementation and compilation per device.
Paper
Windows
Device support capacity
RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800