Hardware-Aligned and Natively Trainable Sparse Attention
Dynamic hierarchical sparse strategy
- fine-grained token selection
- coarse-grained token compression (Contextual Compression aware sparse attention)
to preserve both global context awareness and local precision
End-to-end training-aware design with backward operators enables NSA to support both efficient deployment and end-to-end training.
Three parallel attention branches:
- NSA maintains or exceeds Full Attention models
- NSA achieves substantial speedups over Full Attention on 64k-length sequences
For a given query, preceding keys and values are processed into compressed attention for coarse-grained patterns. are dynamically constructed based on the current query and the contextual memory .
where is gate store for corresponding strategy
1. Token Compression
By aggregating sequential blocks of keys or values into block-level representation, compressed keys and values capture the information of the entire block.
2. Token selection
Blockwise selection is crucial to achieve efficient computation on modern GPUs (Arithmetic Intensity-aware). Blockwise selection follows the inherent distribution patterns of attention scores.
Optimize blockwise sparse attention for Tensor Core utilization and memory access, ensuring balanced arithmetic intensity
Importance Score Computation
Top-𝑛 Block Selection
After obtaining the selection block importance scores, We retain tokens within the top-𝑛 sparse blocks ranked by block importance scores
3. Sliding window attention
Kernel Design
While compression and sliding window attention, they introduced specialized kernel design for sparse selection attention. FlashAttention’s strategy of loading temporally continuous query blocks into SRAM, it would result in inefficient memory access since queries within a block may require disjoint KV blocks.
To address this, Different query grouping strategy: for each position on the query sequence, we load all query heads within a Grouped-query Attention group (they share the same sparse KV) into SRAM.
- Group-Centric Data Loading
- Shared KV Fetching
- Outer Loop on Grid