Score modification
def score_mod(score: f32[], b: i32[], h: i32[], q_idx: i32[], kv_idx: i32[]) for b in range(batch_size): for h in range(num_heads): for q_idx in range(sequence_length): for kv_idx in range(sequence_length): modified_scores[b, h, q_idx, kv_idx] = score_mod(scores[b, h, q_idx, kv_idx], b, h, q_idx, kv_idx) return score
ALiBi bias
Similar to Relative Positional Encoding but per-head factor that is typically precomputed and has beneficial properties for length extrapolation at inference
alibi_bias= generate_alibi_bias()# [num_heads] defalibi(score, b, h, q_idx, kv_idx): bias= alibi_bias[h]* (q_idx- kv_idx) return score+ bias
Soft-capping
softcap = 20 def soft_cap(score, b, h, q_idx, kv_idx): score = score / softcap score = torch.tanh(score) score = score * softcap return score
FlexAttention is currently available in PyTorch nightly releases, we plan to release it as a prototype feature in 2.5.0