---
date: '2025-08-21'
description: Self-attention from first principles, formal properties, and efficiency bounds.
id: attention
modified: 2026-06-06 00:04:40 GMT-04:00
tags:
  - ml
title: attention primer
created: '2025-08-21'
published: '2025-08-21'
pageLayout: default
slug: lectures/2/afp
permalink: https://aarnphm.xyz/lectures/2/afp.md
generator:
  quartz: v4.6.0
  hostedProvider: Cloudflare
  baseUrl: aarnphm.xyz
full: https://aarnphm.xyz/llms-full.txt
---
Self-attention is a routing mechanism. We map queries to keys to determine value weighting. The architecture requires tracking precise memory ratios per token to keep inference within bounds.

## preliminaries

Let $X\in\mathbb{R}^{n\times d}$ be the input sequence. Project using $W_Q,W_K,W_V$. For $h$ heads, head dimension is $d_h$.

For $z\in\mathbb{R}^m$, the softmax and logsumexp (LSE) functions are:

$$
\operatorname{softmax}(z)_i = \frac{e^{z_i}}{\sum_j e^{z_j}},\quad \operatorname{LSE}(z)=\log\sum_j e^{z_j},\quad \nabla\!\operatorname{LSE}(z)=\operatorname{softmax}(z).
$$

Temperature $T>0$ scales the logits as $z/T$. Stability requires shifting by the maximum before exponentiation: $z\leftarrow z-\max(z)$. \[@dabah2025temperaturescalingconformalprediction\]

[[thoughts/RoPE|RoPE]] (rotary position embeddings) adds relative position by rotating $(q,k)$ pairs in the complex plane before the dot product. \[@su2023roformerenhancedtransformerrotary\]

[[thoughts/Kullback-Leibler divergence]] controls total variation distance (Pinsker’s inequality) and is defined as:

$$
D_{\mathrm{KL}}(P\|Q)=\sum_i P_i\log\frac{P_i}{Q_i}
$$

## scaled dot-product self-attention

Let $Q=XW_Q,\;K=XW_K,\;V=XW_V$.

$$
\mathrm{Attn}(Q,K,V)=\sigma\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V,\quad \sigma=\text{row-softmax}.
$$

> \[!proposition\] Proposition 1. Permutation Equivariance
>
> Without positional encodings, for any permutation matrix $P\in\mathbb{R}^{n\times n}$:
>
> $$
> \mathrm{Attn}(PQ,PK,PV)=P\,\mathrm{Attn}(Q,K,V).
> $$

Proof. $PQ(PK)^\top=PQK^\top P^\top$. Row-wise softmax satisfies $\sigma_{\text{row}}(PZP^\top)=P\,\sigma_{\text{row}}(Z)\,P^\top$ because $P$ is orthogonal ($P^\top P = I$) and permuting both rows and columns preserves row-wise sums. Right-multiply by $PV$ to yield $P \sigma_{\text{row}}(Z) P^\top P V = P \mathrm{Attn}(Q,K,V)$.

> \[!proposition\] Proposition 2. Variance Scaling
>
> Assume $q,k\in\mathbb{R}^d$ are statistically independent at initialization. Let $\mathbb{E}[q]=\mathbb{E}[k]=0$, $\mathrm{Cov}(q)=\Sigma_q$, $\mathrm{Cov}(k)=\Sigma_k$. For $S=q^\top k$ and $Z=S/\sqrt{d}$:
>
> $$
> \mathrm{Var}(S)=\operatorname{tr}(\Sigma_q\Sigma_k),\quad \mathrm{Var}(Z)=\frac{1}{d}\operatorname{tr}(\Sigma_q\Sigma_k).
> $$
>
> Under isotropy $\Sigma_q=\Sigma_k=I_d$, $\mathrm{Var}(S)=d$ and $\mathrm{Var}(Z)=1$. \[@vaswani2023attentionneed\]

Proof. Use $\mathbb{E}[S]=0$. By iterated expectation (relying strictly on the independence of $q$ and $k$), $\mathbb{E}[S^2] = \mathbb{E}_q[\mathbb{E}_k[q^\top k k^\top q \mid q]] = \mathbb{E}_q[q^\top \Sigma_k q]$. Since this is a scalar, we apply the trace trick: $\mathbb{E}_q[\operatorname{tr}(q^\top \Sigma_k q)] = \operatorname{tr}(\Sigma_k \mathbb{E}[q q^\top]) = \operatorname{tr}(\Sigma_k \Sigma_q)$. Scale the result by $1/d$.

_Note_: In real models, $q$ and $k$ are derived from the same residual stream, so this independence assumption only loosely holds post-initialization.

## kernel regression

> \[!proposition\] Proposition 3. RBF Weights under Strict Normalization
>
> Fix $q\in\mathbb{R}^d$ and keys $\{k_j\}_{j=1}^n$. If key norms are strictly uniform such that $\|k_j\|=\beta$ for all $j$, then scaled dot-product attention at temperature $T$ exactly mirrors Nadaraya-Watson Gaussian kernel regression:
>
> $$
> w_j(q)=\frac{\exp(-\|q-k_j\|^2/(2\sigma^2))}{\sum_{\ell}\exp(-\|q-k_\ell\|^2/(2\sigma^2))}\quad\text{with}\quad \sigma^2=T\sqrt{d}.
> $$
>
> [^tweet]

[^tweet]: <blockquote class="twitter-tweet" data-lang="fr" data-dnt="true"><p lang="en" dir="ltr">If you&#39;re an “ML Engineer” and you think attention means “the model focuses on relevant parts of the input,” you’re missing the actual math that makes Transformers work.<br><br>Concept 22: Attention Is All You Need, a different perspective.<br><br>The core attention equation:<br><br>Attention(q,…</p>&mdash; chastronomic (@chastronomic) <a href="https://x.com/chastronomic/status/1995604876823593374?ref_src=twsrc%5Etfw">1 décembre 2025</a></blockquote>



Proof. Expand squared distance: $\|q-k_j\|^2 = \|q\|^2+\|k_j\|^2-2\langle q,k_j\rangle$. The $\|q\|^2$ term is constant across all $j$ and cancels in the softmax fraction. The $\|k_j\|^2$ term only cancels if it is a uniform constant $\beta$. If uniform, the kernel weight $\exp(-\|q-k_j\|^2/(2\sigma^2))$ isolates to $\exp(\langle q,k_j\rangle/\sigma^2)$. Equating this to the attention exponent $\langle q,k_j\rangle/(T\sqrt{d})$ forces $\sigma^2=T\sqrt{d}$. Without uniform key norms, standard attention deviates from true Gaussian regression.

## architectural variants

Multi-Head Attention (MHA) concatenates outputs from $h$ independent routing spaces:

$$
\text{MHA}(X)=W_O [O_1;\ldots;O_h],\quad O_i=\sigma\!\left(\frac{Q_i K_i^\top}{\sqrt{d_h}}\right)V_i.
$$

Memory bounds during decoding run into KV cache limits. [[thoughts/GQA|Grouped-Query Attention]] (GQA) shares keys and values across $G$ groups, while Multi-Query Attention (MQA) shares one set across all heads ($G=1$). \[@ainslie2023gqatraininggeneralizedmultiquery; @shazeer2019fasttransformerdecodingwritehead\]

> \[!proposition\] Proposition 4. GQA Robustness Bound
>
> Replacing exact $K_i$ with a group-shared $K_g$ shifts logits by $\delta z= q(K_g-K_i)^\top/\sqrt{d_h}$. The softmax function is $(\lambda/2)$-Lipschitz, where $\lambda$ is inverse temperature.
>
> $$
> \|\operatorname{softmax}(z+\delta z)-\operatorname{softmax}(z)\|_2\le \frac{\lambda}{2}\,\|\delta z\|_2.
> $$

Proof. The Jacobian of softmax is $J = \lambda(\operatorname{diag}(\sigma) - \sigma\sigma^\top)$. The maximum eigenvalue (spectral norm) of this matrix is at most $\lambda/2$.

### Multi-Head Latent Attention (MLA)

DeepSeek’s [[thoughts/MLA|MLA]] caches a compressed latent $c^{KV}_t\in\mathbb{R}^{d_c}$ instead of raw heads. To maintain spatial information, [[thoughts/RoPE|RoPE]] requires a decoupled key vector $k_t^R \in\mathbb{R}^{d_h^R}$.

During inference, up-projections $W_U^K$ and $W_U^V$ map the latent back to full dimension. Because matrix multiplication is associative, these projections are absorbed into $W_Q$ and $W_O$:
$W_Q' = W_{UQ} (W_U^K)^\top$ and $W_O' = W_U^V W_O$.

This strips the up-projection from the critical path entirely.

### memory limits

Elements cached per token per layer:

- **MHA**: $2 h d_h$
- **GQA**: $2 G d_h$
- **MQA**: $2 d_h$
- **MLA**: $d_c + d_h^R$

_Example:_ At $h=128$, $d_h=128$, MHA caches 32,768 elements. MLA ($d_c=512$, $d_h^R=64$) caches 576 elements, a 56:1 reduction ratio.

## reference cuda kernels

> \[!note\] implementation
>
> See reference CUDA kernels for exact score matrices and stable block logic.
>
> - Scores: ```cuda title="qk\_scores.cu" path="lectures/2/qk\_scores.cu"
>   #include <cuda_runtime.h>
>   #include <cmath>
>
>   __global__ void qk_scores_kernel(
>       const float* __restrict__ Q,  // [n, d]
>       const float* __restrict__ K,  // [n, d]
>       float* __restrict__ S,        // [n, n]
>       int n, int d, float inv_sqrt_d) {
>
>     int row = blockIdx.x * blockDim.x + threadIdx.x;  // query index i
>     int col = blockIdx.y * blockDim.y + threadIdx.y;  // key   index j
>     if (row >= n || col >= n) return;
>
>     const float* q = Q + row * d;
>     const float* k = K + col * d;
>
>     float acc = 0.f;
>     for (int t = 0; t < d; ++t) acc += q[t] * k[t];
>
>     S[row * n + col] = acc * inv_sqrt_d;
>   }
>
>   extern "C" void qk_scores(const float* Q, const float* K, float* S, int n, int d) {
>     dim3 block(16, 16);
>     dim3 grid((n + block.x - 1) / block.x, (n + block.y - 1) / block.y);
>     float inv_sqrt_d = 1.f / std::sqrtf((float)d);
>     qk_scores_kernel<<<grid, block>>>(Q, K, S, n, d, inv_sqrt_d);
>   }
>
>   ```
> - Softmax: ```cuda title="row\_softmax.cu" path="lectures/2/row\_softmax.cu"
>   #include <cuda_runtime.h>
>   #include <float.h>
>   #include <math.h>
>
>   __global__ void row_softmax_kernel(float* __restrict__ S, int n) {
>     int i = blockIdx.x * blockDim.x + threadIdx.x;
>     if (i >= n) return;
>
>     float m = -FLT_MAX;
>     for (int j = 0; j < n; ++j) m = fmaxf(m, S[i*n + j]);
>
>     float sum = 0.f;
>     for (int j = 0; j < n; ++j) {
>       float e = expf(S[i*n + j] - m);
>       S[i*n + j] = e;
>       sum += e;
>     }
>
>     float inv = 1.f / fmaxf(sum, 1e-12f);
>     for (int j = 0; j < n; ++j) S[i*n + j] *= inv;
>   }
>
>   extern "C" void row_softmax(float* S, int n) {
>     int block = 256;
>     int grid  = (n + block - 1) / block;
>     row_softmax_kernel<<<grid, block>>>(S, n);
>   }
>
>   ```
> - Apply: ```cuda title="apply\_values.cu" path="lectures/2/apply\_values.cu"
>   #include <cuda_runtime.h>
>
>   __global__ void apply_values_kernel(
>       const float* __restrict__ S,  // [n, n]
>       const float* __restrict__ V,  // [n, d_v]
>       float* __restrict__ O,        // [n, d_v]
>       int n, int d_v) {
>
>     int i = blockIdx.x * blockDim.x + threadIdx.x;  // row in S / O
>     int t = blockIdx.y * blockDim.y + threadIdx.y;  // value dim
>     if (i >= n || t >= d_v) return;
>
>     float acc = 0.f;
>     for (int j = 0; j < n; ++j) acc += S[i*n + j] * V[j*d_v + t];
>     O[i*d_v + t] = acc;
>   }
>
>   extern "C" void apply_values(const float* S, const float* V, float* O, int n, int d_v) {
>     dim3 block(16, 16);
>     dim3 grid((n + block.x - 1) / block.x, (d_v + block.y - 1) / block.y);
>     apply_values_kernel<<<grid, block>>>(S, V, O, n, d_v);
>   }
>
>
>   ```
>
> [[thoughts/flash attention|FlashAttention]] prevents intermediate materialization by streaming tiles through shared memory, calculating softmax scaling blocks directly. \[@dao2022flashattentionfastmemoryefficientexact\]
>
> Triton equivalent: ```python title="attention\_triton.py" path="lectures/2/attention\_triton.py"
> # attention_triton.py
> import triton
> import triton.language as tl
>
>
> @triton.jit
> def attn_two_pass(
>   Q,
>   K,
>   V,
>   O,
>   n,
>   d,
>   dv,
>   stride_qm,
>   stride_qd,
>   stride_km,
>   stride_kd,
>   stride_vm,
>   stride_vd,
>   stride_om,
>   stride_od,
>   BLOCK_M: tl.constexpr,
>   BLOCK_N: tl.constexpr,
>   BLOCK_D: tl.constexpr,
> ):
>   # Row block we compute
>   row_offs = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
>   d_offs = tl.arange(0, BLOCK_D)
>   n_offs = tl.arange(0, BLOCK_N)
>
>   # ----------------------------
>   # Pass 1: compute per-row max & denom for stable softmax
>   # ----------------------------
>   m_i = tl.full((BLOCK_M,), -1e9, dtype=tl.float32)
>   l_i = tl.zeros((BLOCK_M,), dtype=tl.float32)
>
>   for col_start in range(0, n, BLOCK_N):
>     # Load Q-block [M, D]
>     q = tl.load(
>       Q + row_offs[:, None] * stride_qm + d_offs[None, :] * stride_qd,
>       mask=(row_offs[:, None] < n) & (d_offs[None, :] < d),
>       other=0.0,
>     )
>     # Load K-block [N, D]
>     k = tl.load(
>       K
>       + (col_start + n_offs)[:, None] * stride_km
>       + d_offs[None, :] * stride_kd,
>       mask=((col_start + n_offs)[:, None] < n) & (d_offs[None, :] < d),
>       other=0.0,
>     )
>     # scores [M, N] = q @ k^T / sqrt(d)
>     scores = tl.dot(q, tl.trans(k)) * (1.0 / tl.sqrt(tl.float32(d)))
>
>     # Online softmax: track max and rescale accumulated sum
>     m_i_new = tl.maximum(m_i, tl.max(scores, axis=1))
>     # Rescale previous accumulator when max increases
>     alpha = tl.exp(m_i - m_i_new)
>     l_i = l_i * alpha
>     # Add current tile's contribution
>     p = tl.exp(scores - m_i_new[:, None])
>     l_i += tl.sum(p, axis=1)
>     # Update running max
>     m_i = m_i_new
>
>   # ----------------------------
>   # Pass 2: accumulate O = softmax(S) @ V
>   # ----------------------------
>   inv_l_i = 1.0 / l_i
>   Oacc = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32)
>
>   for col_start in range(0, n, BLOCK_N):
>     q = tl.load(
>       Q + row_offs[:, None] * stride_qm + d_offs[None, :] * stride_qd,
>       mask=(row_offs[:, None] < n) & (d_offs[None, :] < d),
>       other=0.0,
>     )
>     k = tl.load(
>       K
>       + (col_start + n_offs)[:, None] * stride_km
>       + d_offs[None, :] * stride_kd,
>       mask=((col_start + n_offs)[:, None] < n) & (d_offs[None, :] < d),
>       other=0.0,
>     )
>     scores = tl.dot(q, tl.trans(k)) * (1.0 / tl.sqrt(tl.float32(d)))
>
>     # Recenter using m_i computed in pass 1
>     scores = scores - m_i[:, None]
>     p = tl.exp(scores) * inv_l_i[:, None]  # [M, N]
>
>     # Load V-block [N, Dv]
>     v = tl.load(
>       V
>       + (col_start + n_offs)[:, None] * stride_vm
>       + d_offs[None, :] * stride_vd,
>       mask=((col_start + n_offs)[:, None] < n) & (d_offs[None, :] < dv),
>       other=0.0,
>     )
>
>     Oacc += tl.dot(p, v)  # [M, Dv]
>
>   # Store result
>   tl.store(
>     O + row_offs[:, None] * stride_om + d_offs[None, :] * stride_od,
>     Oacc,
>     mask=(row_offs[:, None] < n) & (d_offs[None, :] < dv),
>   )
>
> ```

