---
abstract: The reason for Attention comparing to LSTM is that its ability to encode additional positional data into the inputs, in which it helps with longer context length and better memory retrieval. Note that most LLMs are decoder-only, given its superior benchmark in zero-shot tasks.
date: '2024-02-07'
description: and posteriori information retrieval.
id: Attention
modified: 2026-06-08 14:59:16 GMT-04:00
seealso:
  - '[[lectures/2/convexity|emperical finding]]'
  - '[[lectures/2/afp|attention from first principles]]'
socials:
  efficient: https://www.youtube.com/watch?v=Y-o545eYjXM
tags:
  - technical
  - llm
  - ml
title: Attention
transclude:
  title: false
created: '2024-02-07'
published: '2024-02-07'
pageLayout: default
slug: thoughts/Attention
permalink: https://aarnphm.xyz/thoughts/Attention.md
generator:
  quartz: v4.6.0
  hostedProvider: Cloudflare
  baseUrl: aarnphm.xyz
full: https://aarnphm.xyz/llms-full.txt
---
Attention operates on a sequence of query $Q$, key $K$ and value $V$ vector. Attention matrix of a sequence then computed as \[@vaswani2023attentionneed\]:

$$
A(Q, K, V) = \operatorname{softmax}(\frac{Q \cdot K^{T}}{\sqrt{d}})V \space \space \text{ for } Q_{L \times d}, K_{L \times d}, V_{L \times d}
$$

First introduced in @vaswani2023attentionneed. One can think of attention for QKV as:

- Q: what I’m looking for
- K: what information do I have
- V: what information do I need to share to each other.

> \[!note\]+ equivalent
>
> We can probably arrange the attention function (composed of multiple [[thoughts/induction heads|attention-heads]]) according to @elhage2021mathematical:
>
> $$
> \text{Attn}^{\vec{l,h}}(X_{\leq i}^{l-1}) = \sum_{j \leq i}a^{l,h}_{i,j} x^{l-1}_j W^{l,h}_{V} W_{O}^{l,h}
> $$
>
> where the <mark>learnable</mark> weight matrices $W_{V}^{l,h} \in \mathbb{R}^{d \times d_h}$ and $W_{O}^{l,h} \in \mathbb{R}^{d_h \times d}$, $d_h$ is the dimension per head, are combined OV matrix

```jsx imports={Zoomable,AttentionCircuits}
<Zoomable label="QK/OV circuit decomposition">
  <AttentionCircuits caption="The same attention layer two ways: textbook softmax on the left, the Anthropic circuit decomposition on the right. Hover a circuit name to see which weights it covers: $QK$ sets where to read, $OV$ sets what to write." />
</Zoomable>
```

## Multi-head Attention

Allows the model to jointly attend to information from different representation subspaces at different positions:

$$
\begin{aligned}
\text{MHA}(Q,K,V) &= \operatorname{concat}(\text{head}_1, \ldots, \text{head}_h)\, W_O\\
\text{head}_i &= \operatorname{softmax}\!\left(\frac{Q W_{Q,i}\,(K W_{K,i})^{\top}}{\sqrt{d_h}}\right)\, V W_{V,i}\\
& W_O \in \mathbb{R}^{(h d_h) \times d_{\text{model}}},\; W_{Q,i},W_{K,i} \in \mathbb{R}^{d_{\text{model}} \times d_h},\; W_{V,i} \in \mathbb{R}^{d_{\text{model}} \times d_h}
\end{aligned}
$$

Each head can specialise on a distinct relational pattern in the same context window.

```jsx imports={Zoomable,MultiHeadAttention}
<Zoomable label="multi-head attention diagram">
  <MultiHeadAttention
    caption="Slide $h$ to rebalance the $d_m$ budget across heads. Each head's softmax is an independent normaliser; $h$ heads $\neq$ one wider head."
    heads={4}
  />
</Zoomable>
```

One may focus on positional offsets (e.g., “next token” dependencies) while another emphasises semantic alignment (e.g., subject $\leftrightarrow$ predicate links). The concatenation and final projection $W^O$ then recombine the perspectives into the [[thoughts/Transformers|transformer]] [[thoughts/mechanistic interpretability#residual stream|residual stream]].

> \[!motivation\]+ why split the model into heads?
>
> Each head learns a slightly different relational probe over the same sequence. One head might focus on syntactic structure,
> another on long-distance coreference.
>
> By projecting $Q$, $K$, and $V$ into lower dimensional spaces, we allow those probes to specialise without paying
> the quadratic cost of a single massive head. Empirically this improves data efficiency because the model can
> reuse a single context to answer multiple “questions” about it in parallel, rather than re-reading the sequence each time.

> \[!example\]- two offset heads vs one big head
>
> Consider a length-$L$ sequence with two heads: head 1 attends to the next token $(+1)$ and head 2 attends to the previous token $(-1)$. Let $\beta \gg 0$ so each head’s softmax is nearly an argmax on its offset.
>
> $$
> S^{(+1)}_{ij} = \beta\,[j=i+1], \quad S^{(-1)}_{ij} = \beta\,[j=i-1], \quad P^{(\pm1)} = \operatorname{softmax}_j S^{(\pm1)}.
> $$
>
> With values $V \in \mathbb{R}^{L\times d_h}$ and output projection blocks $W_O^{(1)}, W_O^{(2)}$, the MHA output is
>
> $$
> Y_{\text{MHA}} = \big(P^{(+1)} V\big) W_O^{(1)} + \big(P^{(-1)} V\big) W_O^{(2)}.
> $$
>
> A single head with scores $S = S^{(+1)} + S^{(-1)}$ yields $P=\operatorname{softmax}(S)$ and output $Y_{\text{SH}} = (PV)\,\tilde W_O$
>
> Because softmax is non-additive, $Y_{\text{SH}}$ cannot match $Y_{\text{MHA}}$ for all inputs even if $\tilde W_O$ is chosen {{sidenotes[adversarially]: one normaliser is coupled where two normalisers are separable. This structural independence of normalisers increases expressivity \[@cordonnier2019relationshipselfattentionconvolution; @yun2019universaltransformers\]}}
>
> <div class="notebook-runtime" data-notebook-runtime="notebook-runtime-ea8puf"></div>
>
> <div class="notebook-code-cell" data-notebook-cell-frame="code-cell-1" id="code-cell-1" data-notebook-language="python">
>
> <div class="notebook-runtime-cell" data-notebook-cell="code-cell-1" data-notebook-execution-count=""><span class="notebook-execution-prompt" data-notebook-execution-label="code-cell-1" aria-live="polite">In [ ]:</span></div>
>
> <div class="notebook-cell-actions" data-notebook-cell-actions="code-cell-1">
> <span class="notebook-language-badge notebook-language-badge-python" data-notebook-language="python" title="Python cell"><span class="notebook-language-icon" aria-hidden="true"><svg class="notebook-language-svg notebook-python-icon" viewBox="0 0 111 112" aria-hidden="true" focusable="false"><path fill="#3776ab" d="M54.918785.00091927421C50.335132.02221727 45.957846.41313697 42.106285 1.0946693 30.760069 3.0991731 28.700036 7.2947714 28.700035 15.032169v10.21875h26.8125v3.40625h-36.875c-7.792459 0-14.6157588 4.683717-16.7499998 13.59375-2.46181998 10.212966-2.57101508 16.586023 0 27.25 1.9059283 7.937852 6.4575432 13.593748 14.2499998 13.59375h9.21875v-12.25c0-8.849902 7.657144-16.656248 16.75-16.65625h26.78125c7.454951 0 13.406253-6.138164 13.40625-13.625v-25.53125c0-7.2663386-6.12998-12.7247771-13.40625-13.9374997C64.281548.32794397 59.502438-.02037903 54.918785.00091927421zM40.418785 8.2196694c2.769547 0 5.03125 2.2986456 5.03125 5.1249996-.000002 2.816336-2.261703 5.09375-5.03125 5.09375-2.779476-.000001-5.03125-2.277415-5.03125-5.09375-.000001-2.826353 2.251774-5.1249996 5.03125-5.1249996z"/><path fill="#ffd43b" d="M85.637535 28.657169v11.90625c0 9.230755-7.825895 16.999999-16.75 17h-26.78125c-7.335833 0-13.406249 6.278483-13.40625 13.625v25.531247c0 7.266344 6.318588 11.540324 13.40625 13.625004 8.487331 2.49561 16.626237 2.94663 26.78125 0 6.750155-1.95439 13.406253-5.88761 13.40625-13.625004V86.500919h-26.78125v-3.40625h40.187504c7.792461 0 10.696251-5.435408 13.406241-13.59375 2.79933-8.398886 2.68022-16.475776 0-27.25-1.92578-7.757441-5.60387-13.59375-13.406241-13.59375zm-15.0625 64.65625c2.779478.000003 5.03125 2.277417 5.03125 5.093747-.000002 2.826354-2.251775 5.125004-5.03125 5.125004-2.76955 0-5.03125-2.29865-5.03125-5.125004.000002-2.81633 2.261697-5.093747 5.03125-5.093747z"/></svg></span><span class="notebook-language-label">Python cell</span></span>
> <button type="button" class="notebook-icon-button" data-notebook-run-cell="code-cell-1" aria-label="Run code-cell-1" title="Run code-cell-1"><svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="M8 5v14l11-7z"/></svg></button>
> <button type="button" class="notebook-icon-button" data-notebook-edit-cell="code-cell-1" aria-label="Edit code-cell-1" title="Edit code-cell-1"><svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="m4 16.5-.5 4 4-.5L19 8.5 15.5 5z"/><path d="m14 6.5 3.5 3.5"/></svg></button>
> <button type="button" class="notebook-icon-button" data-notebook-save-cell="code-cell-1" aria-label="Save code-cell-1 locally" title="Save code-cell-1 locally" hidden><svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="M5 4h11l3 3v13H5z"/><path d="M8 4v6h8V4"/><path d="M8 20v-6h8v6"/></svg></button>
> <button type="button" class="notebook-icon-button" data-notebook-revert-cell="code-cell-1" aria-label="Revert code-cell-1 local edit" title="Revert code-cell-1 local edit" hidden><svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="M9 14 4 9l5-5"/><path d="M4 9h10.5a5.5 5.5 0 0 1 0 11H11"/></svg></button>
> <button type="button" class="notebook-icon-button" data-notebook-vim-cell="code-cell-1" aria-label="Enable Vim mode" title="Enable Vim mode" hidden><svg class="notebook-vim-icon" viewBox="0 0 602 734" aria-hidden="true" focusable="false"><g transform="translate(2 3)"><path class="notebook-vim-icon-left" d="M0 155.5704 155-1l-.000003 728L0 572.237919z"/><path class="notebook-vim-icon-right" d="M443.060403 156.982405 600-1l-3.181208 728L442 572.219941z" transform="translate(521 363.5) scale(-1 1) translate(-521 -363.5)"/><path class="notebook-vim-icon-cross" d="M154.986294 0 558 615.189696 445.224605 728 42 114.172017z"/></g></svg></button>
> <span class="notebook-local-source-status" data-notebook-local-source-status="code-cell-1" hidden></span>
> </div>
>
> <div class="notebook-source-editor" data-notebook-source-editor="code-cell-1" hidden></div>
>
> ```python shell
> import torch, math
> torch.set_printoptions(precision=3, sci_mode=False)
> L, d_h = 6, 4
> beta = 10.0  # high temperature -> near-argmax
>
> # Build (+1) and (-1) score matrices
> S_p1 = torch.full((L,L), -float('inf'))
> S_m1 = torch.full((L,L), -float('inf'))
> for i in range(L-1): S_p1[i, i+1] = beta
> for i in range(1, L):  S_m1[i, i-1] = beta
>
> softmax = lambda S: (S - S.max(dim=-1, keepdim=True).values).softmax(dim=-1)
> P1, P2 = softmax(S_p1), softmax(S_m1)
>
> V = torch.randn(L, d_h)
> WO1, WO2 = torch.randn(d_h, d_h), torch.randn(d_h, d_h)
>
> Y_mha = P1@V@WO1 + P2@V@WO2
>
> # Single-head surrogate: add scores and use one normaliser
> P = softmax(S_p1 + S_m1)
> WOt = torch.randn(d_h, d_h)
> Y_sh = P@V@WOt
>
> print('||Y_mha - Y_sh||_F =', torch.linalg.norm(Y_mha - Y_sh).item())
> ```
>
> <div class="notebook-runtime-output" data-notebook-output="code-cell-1" hidden></div>
>
> </div>
>
> On random seeds this norm is typically O(1). No choice of a single post-projection can remove the coupling induced by the single softmax normaliser; two heads give two independent distributions you can recombine downstream. Empirically, heads do specialise and can be pruned selectively \[@voita2019analyzingmha; @michel2019sixteenheads\].

```jsx imports={Zoomable,OffsetHeadsToy}
<Zoomable label="two-head softmax demo">
  <OffsetHeadsToy
    caption="Two heads ($P^{(+1)}$, $P^{(-1)}$) against the single-head surrogate ($P = \operatorname{softmax}(S^{(+1)} + S^{(-1)})$). Raise $\beta$ to sharpen each diagonal; reseed $V$ and the projections and $\|Y_{\text{MHA}} - Y_{\text{SH}}\|_F$ stays stubbornly $O(1)$ — no single-head projection closes the gap."
    length={6}
  />
</Zoomable>
```

> \[!math\] 1. softmax factorisation barrier
>
> Let $S_i = QW_{Q,i}(KW_{K,i})^\top/\sqrt{d_h}$ and $P_i = \operatorname{softmax}(S_i)$.
>
> If a single-head self-attention with some score matrix $S$ and post-projection $\tilde W_O$ reproduced an $h$-head layer for all $Q,K,V$, then we would need $\operatorname{softmax}(S) V\tilde W_O = \sum_{i=1}^h P_i V W_{V,i} W_O^{(i)}$ for all $V$.
>
> This forces $\operatorname{softmax}(S) = \sum_i P_i M_i$ for some fixed matrices $M_i$ independent of inputs.
>
> But since softmax is not additive and $P_i$ depend on disjoint parameter sets, there exist inputs making $\sum_i P_i M_i$ violate row-stochasticity or attention symmetry constraints unless $h=1$ or all $S_i$ are affinely dependent.
>
> Hence in one layer, multi-head is strictly more expressive due to independent normalisers. See also \[@cordonnier2019relationshipselfattentionconvolution; @yun2019universaltransformers\].

> \[!note\] Parameter and compute at fixed `d_{\text{model}}`
>
> - Params (packed projections): $W_Q,W_K,W_V,W_O \in \mathbb{R}^{d_m\times d_m}$ $\Rightarrow$ about $4d_m^2$ weights, essentially independent of head count $h$ (implementation splits the columns into $h$ groups).
> - FLOPs (naive): $\Theta(L^2 d_m)$ per layer per sequence; choosing $h$ changes per-head tile sizes and kernel efficiency, not asymptotics.
> - KV cache: per token per layer stores $K,V$ of size $2d_m$ (in bytes: $2d_m$ times dtype size). Changing $h$ does not change the sum dimension, but affects the layout and can impact IO-bound kernels in practice.

```jsx imports={Zoomable,AttentionCostCalculator}
<Zoomable label="attention cost calculator">
  <AttentionCostCalculator caption="Dial $d_{\text{model}}$, layers, $h$, sequence length and batch (click any value to type it), plus the KV-cache dtype; weights stay bf16. The cache is $L \times B \times N \times 2d_{\text{model}} \times \operatorname{bytes}$, so quantising it (int8, int4) shrinks the salmon bar against the fixed weight bars while $h$ alone stays free." />
</Zoomable>
```

## optimization

> \[!note\] Note
>
> Exact kernels cut memory traffic; sparse/local reduce edges; linear/approximate trade exactness for O(L) time.

### Exact, IO-aware kernels

- [[thoughts/flash attention|Flash Attention]]: tile/fuse softmax, read each tile once from HBM; FA-3 exploits FP8 and hardware copy engines for higher throughput \[@dao2022flashattentionfastmemoryefficientexact; @dao2023flashattention2fasterattentionbetter; @shah2024flashattention3fastaccurateattention\].
- Flash-Decoding (++): specialised kernels for decode that parallelise across KV blocks and fuse reductions.

### Sparse/local attention

- Longformer: sliding window with optional global tokens for linear scaling \[@beltagy2020longformerlongdocumenttransformer\].
- BigBird: block-sparse (window + random + global) with theoretical guarantees \[@zaheer2021bigbirdtransformerslonger\].

### Linear/approximate attention

- Reformer: LSH buckets for sub-quadratic attention + reversible layers \[@kitaev2020reformerefficienttransformer\].
- Linformer: low-rank projection along sequence dimension \[@wang2020linformerselfattentionlinearcomplexity\].
- Performer: FAVOR+ kernel features approximate softmax for linear time \[@choromanski2022rethinkingattentionperformers\].
- Nystromformer: landmark-based Nystrom approximation \[@xiong2021nystromformernystrombasedalgorithmapproximating\].

### Multi-device and prefix-aware inference

- [[thoughts/ring attention|Ring]]/Striped Attention: partition long sequences across devices, overlapping compute and communication \[@liu2023ringattentionblockwisetransformers; @brandon2023stripedattentionfasterring\].
- [[thoughts/cascade attention|Cascade]]/Tree-aware kernels: exploit shared prefixes and tree layouts to reuse KV IO \[@zheng2024sglangefficientexecutionstructured; @shyam2025treeattentiontopologyawaredecoding\].

## cheatsheet

| Method              | Type         |      Complexity (seq) | Key idea                                    | Typical win                    |
| ------------------- | ------------ | --------------------: | ------------------------------------------- | ------------------------------ |
| FlashAttention-3    | exact kernel |              $O(L^2)$ | tiled IO-minimal attention; FP8/TMA overlap | large train speedups on Hopper |
| Flash-Decoding / ++ | exact decode |      $O(L)$ per token | block-parallel KV, fused reductions         | multi-x decode on long ctx     |
| Longformer          | sparse       | $\approx O(L\cdot w)$ | local window + global tokens                | linear scaling for long docs   |
| BigBird             | sparse       | $\approx O(L\cdot w)$ | window + random + global blocks             | theory + strong practice       |
| Reformer            | approx       |        $O(L \log{L})$ | LSH attention; reversible layers            | memory/time reductions         |
| Linformer           | approx       |         $O(L\cdot k)$ | low-rank K/V along L                        | linear time/space              |
| Performer           | approx       |         $O(L\cdot d)$ | FAVOR+ random features                      | linear attention               |
| Nystromformer       | approx       |         $O(L\cdot m)$ | landmark Nystrom approximation              | fewer tokens, good quality     |
| MQA/GQA             | arch/infer   |     $O(H_k\; d_h)$ KV | share K/V across heads/groups               | KV/bandwidth savings           |
| Ring/Striped        | parallel     |              $O(L^2)$ | pipeline across devices                     | million-token context          |
| Cascade/Tree-aware  | kernel       |              $O(L^2)$ | KV reuse on shared prefixes                 | big wins on shared prompts     |

<script type="application/json" data-notebook-runtime-data>{"id":"notebook-runtime-ea8puf","sourcePath":"thoughts/Attention.md","language":"python","indexUrl":"https://cdn.jsdelivr.net/pyodide/v0.29.4/full/","cells":[{"id":"code-cell-1","source":"import torch, math\ntorch.set_printoptions(precision=3, sci_mode=False)\nL, d_h = 6, 4\nbeta = 10.0  # high temperature -\u003e near-argmax\n\n# Build (+1) and (-1) score matrices\nS_p1 = torch.full((L,L), -float('inf'))\nS_m1 = torch.full((L,L), -float('inf'))\nfor i in range(L-1): S_p1[i, i+1] = beta\nfor i in range(1, L):  S_m1[i, i-1] = beta\n\nsoftmax = lambda S: (S - S.max(dim=-1, keepdim=True).values).softmax(dim=-1)\nP1, P2 = softmax(S_p1), softmax(S_m1)\n\nV = torch.randn(L, d_h)\nWO1, WO2 = torch.randn(d_h, d_h), torch.randn(d_h, d_h)\n\nY_mha = P1@V@WO1 + P2@V@WO2\n\n# Single-head surrogate: add scores and use one normaliser\nP = softmax(S_p1 + S_m1)\nWOt = torch.randn(d_h, d_h)\nY_sh = P@V@WOt\n\nprint('||Y_mha - Y_sh||_F =', torch.linalg.norm(Y_mha - Y_sh).item())","language":"python","executionIndex":null}],"toolbar":false,"debug":true,"vimMode":true}</script>

