Flash Attention

Creator
Creator
Seonglae ChoSeonglae Cho
Created
Created
2023 Jun 29 14:42
Editor
Edited
Edited
2024 Jul 6 3:11
Refs
Refs
CUDA PTX

Attention operation having a memory bottleneck between HBM and SRAM

HBM
is large in memory, but slow in processing, meanwhile SRAM is smaller in memory, but faster in operations. So flash attention reduces memory IO of GPU although it requires slight increment of computation.
You will just not materialize the attention matrix by build small matrices and keep the statistics to compute softmax along the way. After that, each small part of matrix is computed in SRAM
  • Comping bound computation: matrix multiplication
  • Memory bound computation: softmax, activation, batch/layer normalization
GPU architecture focuses on computing bound computation, flash attention force it to optimize memory bound computation.
Accelerating Attention using Tiling and Recomputation
Accelerating Attention using Tiling and Recomputation

Tiling

  1. Separate computation of attention into the block size optimized for SRAM size
  1. Load once from HBM to SRAM for each block
  1. Apply all computation in SRAM and save the result to HBM
  1. Repeat this per block and aggregate results by kernel fusion

Recomputation

Flash attention do not save and reuse the result of intermediate matrix to reduce IO like
Gradient checkpointing
.

Conclusion

Flash Attention loads keys, queries, and values once, fuses the operations of the attention mechanism, and writes them back while traditional attention loads keys, queries, and values from HBM to GPU on-chip SRAM, performs a single step of the attention mechanism, writes it back to HBM, and repeats this for every single attention step.
https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention
Flash Attention usages
 
 
It is very effective method but flash attention requires CUDA implementation and compilation per device.
 

Paper

Windows

Device support capacity

RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800
 
 
 

Recommendations