Flash Attention

Creator
Creator
Seonglae ChoSeonglae Cho
Created
Created
2023 Jun 29 14:42
Editor
Edited
Edited
2024 Nov 22 20:51
Refs
Refs
CUDA PTX

Attention operation having a memory bottleneck between HBM and SRAM

HBM
is large in memory, but slow in processing, meanwhile SRAM is smaller in memory, but faster in operations. So flash attention reduces memory IO of GPU although it requires slight increment of computation.
You will just not materialize the attention matrix by build small matrices and keep the statistics to compute softmax along the way. After that, each small part of matrix is computed in SRAM
  • Comping bound computation: matrix multiplication
  • Memory bound computation: softmax, activation, batch/layer normalization
GPU architecture focuses on computing bound computation, flash attention force it to optimize memory bound computation.
Accelerating Attention using Tiling and Recomputation
Accelerating Attention using Tiling and Recomputation

Tiling

  1. Separate computation of attention into the block size optimized for SRAM size
  1. Load once from HBM to SRAM for each block
  1. Apply all computation in SRAM and save the result to HBM
  1. Repeat this per block and aggregate results by kernel fusion

Recomputation

Flash attention do not save and reuse the result of intermediate matrix to reduce IO like
Gradient checkpointing
.

Conclusion

Flash Attention loads keys, queries, and values once, fuses the operations of the attention mechanism, and writes them back while traditional attention loads keys, queries, and values from HBM to GPU on-chip SRAM, performs a single step of the attention mechanism, writes it back to HBM, and repeats this for every single attention step.
https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention
Flash Attention usages
 
 
It is very effective method but flash attention requires CUDA implementation and compilation per device.
 

Paper

Flash Attention derived and coded from first principles with Triton (Python)
In this video, I'll be deriving and coding Flash Attention from scratch. I'll be deriving every operation we do in Flash Attention using only pen and "paper". Moreover, I'll explain CUDA and Triton from zero, so no prior knowledge of CUDA is required. To code the backwards pass, I'll first explain how the autograd system works in PyTorch and then derive the Jacobian of the matrix multiplication and the Softmax operation and use it to code the backwards pass. All the code will be written in Python with Triton, but no prior knowledge of Triton is required. I'll also explain the CUDA programming model from zero. Chapters 00:00:00 - Introduction 00:03:10 - Multi-Head Attention 00:09:06 - Why Flash Attention 00:12:50 - Safe Softmax 00:27:03 - Online Softmax 00:39:44 - Online Softmax (Proof) 00:47:26 - Block Matrix Multiplication 01:28:38 - Flash Attention forward (by hand) 01:44:01 - Flash Attention forward (paper) 01:50:53 - Intro to CUDA with examples 02:26:28 - Tensor Layouts 02:40:48 - Intro to Triton with examples 02:54:26 - Flash Attention forward (coding) 04:22:11 - LogSumExp trick in Flash Attention 2 04:32:53 - Derivatives, gradients, Jacobians 04:45:54 - Autograd 05:00:00 - Jacobian of the MatMul operation 05:16:14 - Jacobian through the Softmax 05:47:33 - Flash Attention backwards (paper) 06:13:11 - Flash Attention backwards (coding) 07:21:10 - Triton Autotuning 07:23:29 - Triton tricks: software pipelining 07:33:38 - Running the code This video won't only teach you one of the most influential algorithms in deep learning history; it'll also give you the knowledge you need to solve any new problem that involves writing CUDA or Triton kernels. Moreover, it'll give you the mathematical foundations to derive backwards passes! As usual, the code is available on GitHub: https://github.com/hkproj/triton-flash-attention 🚀Join Writer 🚀 If you're a ML researcher who wants to do research at the hottest AI startup in Silicon Valley, consider applying to Writer and help us make GPUs go brrrrrrrrr. Join Writer: https://writer.com/company/careers/
Flash Attention derived and coded from first principles with Triton (Python)

Windows

Device support capacity

RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800
 
 
 

Recommendations