FlexAttention

Creator
Creator
Seonglae Cho
Created
Created
2024 Aug 13 13:45
Editor
Edited
Edited
2024 Aug 13 13:51
Refs
Refs

Score modification

notion image
def score_mod(score: f32[], b: i32[], h: i32[], q_idx: i32[], kv_idx: i32[]) for b in range(batch_size): for h in range(num_heads): for q_idx in range(sequence_length): for kv_idx in range(sequence_length): modified_scores[b, h, q_idx, kv_idx] = score_mod(scores[b, h, q_idx, kv_idx], b, h, q_idx, kv_idx) return score

ALiBi bias

Similar to
Relative Positional Encoding
but per-head factor that is typically precomputed and has beneficial properties for length extrapolation at inference
alibi_bias= generate_alibi_bias()# [num_heads] defalibi(score, b, h, q_idx, kv_idx): bias= alibi_bias[h]* (q_idx- kv_idx) return score+ bias
 

Soft-capping

softcap = 20 def soft_cap(score, b, h, q_idx, kv_idx): score = score / softcap score = torch.tanh(score) score = score * softcap return score
FlexAttention is currently available in PyTorch nightly releases, we plan to release it as a prototype feature in 2.5.0
 
 
 
 

Recommendations