Harsh Menon

Attention Optimizations

The attention operator is a key component of transformer models (Attention is all you need) that are used extensively in natural language processing and vision models. The operator takes in three tensors - query $Q$, key $K$ and value $V$ tensors and performs the following computation. $$\begin{aligned} S &= Q K^{T} \newline \ P &= \text{softmax}(V) \newline \ O &= P V \end{aligned}$$ where $K \in \mathbb{R}^{N \times d}$, $Q, V \in \mathbb{R}^{S \times d}$. In the case where $S = N$, this operator is referred to as self-attention. Typically, there is an upper-triangular mask applied.

Flash Attention

This optimization was introduced by FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness . The key ideas here are to fuse the three operators together and then tile them so that we don't materialize the $S \times N$ softmax intermediate result $P$. This requires keeping tracking of additional parameters (the sum and max) since we are also tiling the softmax operator (which the authors refer to as algebraic aggregation).

Multi-Query Attention

This optimization was introduced by Fast Transformer Decoding: One Write-Head is All You Need. The key idea here is to share the keys and values across all the attention heads in the model, thereby greatly reducing the memory bandwidth requirements during incremental decoding.

Grouped Query Attention

This was an optimization introduced by GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints on top of multi-query attention introduced above.

Sliding Window Attention