Attention operation having a memory bottleneck between HBM and SRAM
HBM is large in memory, but slow in processing, meanwhile SRAM is smaller in memory, but faster in operations. So flash attention reduces memory IO of GPU although it requires slight increment of computation.
You will just not materialize the attention matrix by build small matrices and keep the statistics to compute softmax along the way. After that, each small part of matrix is computed in SRAM
- Comping bound computation: matrix multiplication
- Memory bound computation: softmax, activation, batch/layer normalization
GPU architecture focuses on computing bound computation, flash attention force it to optimize memory bound computation.

Tiling
- Separate computation of attention into the block size optimized for SRAM size
- Load once from HBM to SRAM for each block
- Apply all computation in SRAM and save the result to HBM
- Repeat this per block and aggregate results by kernel fusion
Recomputation
Flash attention do not save and reuse the result of intermediate matrix to reduce IO like Gradient checkpointing.
Conclusion
Flash Attention loads keys, queries, and values once, fuses the operations of the attention mechanism, and writes them back while traditional attention loads keys, queries, and values from HBM to GPU on-chip SRAM, performs a single step of the attention mechanism, writes it back to HBM, and repeats this for every single attention step.

Flash Attention usages
It is very effective method but flash attention requires CUDA implementation and compilation per device.
Paper
Flash Attention derived and coded from first principles with Triton (Python)
In this video, I'll be deriving and coding Flash Attention from scratch.
I'll be deriving every operation we do in Flash Attention using only pen and "paper". Moreover, I'll explain CUDA and Triton from zero, so no prior knowledge of CUDA is required. To code the backwards pass, I'll first explain how the autograd system works in PyTorch and then derive the Jacobian of the matrix multiplication and the Softmax operation and use it to code the backwards pass.
All the code will be written in Python with Triton, but no prior knowledge of Triton is required. I'll also explain the CUDA programming model from zero.
Chapters
00:00:00 - Introduction
00:03:10 - Multi-Head Attention
00:09:06 - Why Flash Attention
00:12:50 - Safe Softmax
00:27:03 - Online Softmax
00:39:44 - Online Softmax (Proof)
00:47:26 - Block Matrix Multiplication
01:28:38 - Flash Attention forward (by hand)
01:44:01 - Flash Attention forward (paper)
01:50:53 - Intro to CUDA with examples
02:26:28 - Tensor Layouts
02:40:48 - Intro to Triton with examples
02:54:26 - Flash Attention forward (coding)
04:22:11 - LogSumExp trick in Flash Attention 2
04:32:53 - Derivatives, gradients, Jacobians
04:45:54 - Autograd
05:00:00 - Jacobian of the MatMul operation
05:16:14 - Jacobian through the Softmax
05:47:33 - Flash Attention backwards (paper)
06:13:11 - Flash Attention backwards (coding)
07:21:10 - Triton Autotuning
07:23:29 - Triton tricks: software pipelining
07:33:38 - Running the code
This video won't only teach you one of the most influential algorithms in deep learning history; it'll also give you the knowledge you need to solve any new problem that involves writing CUDA or Triton kernels. Moreover, it'll give you the mathematical foundations to derive backwards passes!
As usual, the code is available on GitHub: https://github.com/hkproj/triton-flash-attention
🚀Join Writer 🚀
If you're a ML researcher who wants to do research at the hottest AI startup in Silicon Valley, consider applying to Writer and help us make GPUs go brrrrrrrrr. Join Writer: https://writer.com/company/careers/
https://www.youtube.com/watch?v=zy8ChVd_oTM

Windows
Device support capacity
RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness | Wonbeom Jang
optimize transformer on gpu device
https://www.wonbeomjang.kr/blog/2023/fastattention/
Flash Attention implementation is wrong which brokes the model
4
We reverse-engineered Flash Attention 4
Asynchrony, fast approximate exponents, and 10x more efficient softmax.
https://modal.com/blog/reverse-engineer-flash-attention-4

Lecture 80: How FlashAttention 4 Works
Speaker: Charles Frye
The source code (in CuTe) for FlashAttention4 on Blackwell GPUs has recently been released for the forward pass. The following blog: https://modal.com/blog/reverse-engineer-flash-attention-4 goes over their findings when reading through the source code, and changes between FA1,2,3 and now 4!
https://www.youtube.com/watch?v=VPslgC9piIw


Seonglae Cho