Transformers from scratch

Notes · the transformer architecture · Mar 2023

The transformer (Vaswani et al. 2017) is a stack of identical layers that each read and write a shared residual stream: a sequence of vectors, one per token, that every layer refines. Write the input as a matrix $X \in \mathbb{R}^{n \times d}$ for a sequence of $n$ tokens and model width $d$. Every layer is a map $\mathbb{R}^{n \times d} \to \mathbb{R}^{n \times d}$, so the shape is invariant and layers compose freely. Two sublayers do all the work: attention, which moves information between positions, and a position-wise feed-forward network, which does computation within a position. Residual connections and normalization exist only to make a deep stack of those two sublayers trainable.

The reason this architecture displaced recurrent networks is not accuracy in principle but hardware utilization in practice. An RNN has a sequential dependency through its hidden state and underuses a parallel processor; attention has no such dependency and turns the bulk of the computation into large dense matrix multiplications that saturate a GPU. The cost it pays is quadratic interaction in the sequence length, the tension that drives most of the architecture's later evolution and the efficiency work cited throughout this note.

Scaled dot-product attention

Each token vector is linearly projected into a query, a key, and a value with learned matrices $W^Q, W^K \in \mathbb{R}^{d \times d_k}$ and $W^V \in \mathbb{R}^{d \times d_v}$, giving $Q = XW^Q$, $K = XW^K$, $V = XW^V$. The output is

$$\operatorname{Attention}(Q,K,V) = \operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V.$$

Read it right to left: $QK^\top \in \mathbb{R}^{n \times n}$ holds every query-key dot product, the softmax turns each row into weights that are nonnegative and sum to one, and multiplying by $V$ returns, for each position, a convex combination of all value vectors. A token's new representation is therefore a weighted average of what every other token offers, with the weights set by content similarity.

One easily-forgotten detail is why this uses a dot product rather than the additive, MLP-based attention of earlier models. The two are close in quality; scaled dot-product attention won because the whole score matrix is a single matrix multiply, the one operation GPUs execute most efficiently, so it is dramatically faster at scale for the same result.

The $1/\sqrt{d_k}$ factor is not cosmetic. If the entries of $q$ and $k$ are independent with mean zero and unit variance, then $q \cdot k = \sum_{i=1}^{d_k} q_i k_i$ has mean zero and variance $d_k$. Without rescaling, the logits grow like $\sqrt{d_k}$, the softmax saturates toward a one-hot vector, and its Jacobian $\operatorname{diag}(p) - p p^\top$ collapses toward zero, so gradients vanish. Dividing by $\sqrt{d_k}$ restores unit-variance logits and keeps the map trainable. The softmax itself is always evaluated in shift-invariant form, $\operatorname{softmax}(z)_i = e^{\,z_i - m}/\sum_k e^{\,z_k - m}$ with $m = \max_j z_j$. The reason is concrete floating-point behavior: a 32-bit float tops out near $3.4 \times 10^{38}$, which $e^{x}$ exceeds once $x$ passes about $88.7$ (around $709$ for 64-bit), so an unshifted large logit overflows to $+\infty$, and the subsequent $\infty/\infty$ in the normalization evaluates to $\mathrm{NaN}$ that then propagates through every downstream operation and ruins the run. Subtracting the row max makes the largest exponent $e^{0} = 1$ and every other term at most $1$, so overflow becomes impossible; the only rounding error left is the harmless underflow of very negative terms to $0$, which is exactly the weight they should have.

Multi-head attention

One attention map encodes only a single notion of relevance. Multi-head attention runs $h$ maps in parallel, each in a $d_k = d/h$ dimensional subspace with its own projections, then concatenates and mixes them:

$$\operatorname{head}_i = \operatorname{Attention}(XW_i^Q, XW_i^K, XW_i^V), \qquad \operatorname{MHA}(X) = \operatorname{Concat}(\operatorname{head}_1, \dots, \operatorname{head}_h)\,W^O.$$

With $d_k = d_v = d/h$, the matrices $W^Q, W^K, W^V, W^O$ are each effectively $d \times d$, so attention contributes $4d^2$ parameters per layer. Heads demonstrably specialize, into induction heads that copy a previously seen token, heads that resolve coreference, and heads that attend to delimiters, giving the layer several independent views of context at the same total compute as one wide head.

Shrinking the cache: MQA, GQA, and MLA

At inference the keys and values of all past tokens are cached (below), and that cache, not the arithmetic, is the bottleneck. Multi-query attention (Shazeer 2019) keeps $h$ query heads but a single shared key/value head, shrinking the cache by a factor of $h$. Grouped-query attention (Ainslie et al. 2023) interpolates with $g$ key/value heads, $1 < g < h$, and is the modern default because it keeps near-full quality at close to MQA throughput; it is what Mistral 7B (2023) and the Llama line (Touvron et al. 2023) use. Multi-head latent attention (DeepSeek-V3 2024) goes further, jointly compressing keys and values into a low-rank latent that is the only thing cached, cutting cache memory dramatically while retaining multi-head expressivity.

Position: sinusoids, RoPE, and ALiBi

Attention is permutation-equivariant, so order has to be supplied. The original model added fixed sinusoids. Modern models prefer rotary position embeddings (Su et al. 2021), which rotate $q$ and $k$ in 2D coordinate pairs by an angle proportional to absolute position. Writing $R_m$ for the rotation at position $m$,

$$\langle R_m q,\; R_n k \rangle = \langle q,\; R_{n-m} k \rangle,$$

so the logit depends only on the relative offset $n-m$, which is why RoPE extrapolates to longer sequences far better than learned absolute embeddings, especially with interpolation. A different take, ALiBi (Press et al. 2021), adds a linear distance penalty to the logits and also extrapolates, with no learned positional parameters at all.

Normalization and the residual stream

Each sublayer is wrapped as $x \leftarrow x + \operatorname{Sublayer}(\operatorname{Norm}(x))$. The point of the residual is gradient flow: differentiating a sum gives an identity term plus the sublayer's Jacobian, so the gradient always has a path back that is never multiplied toward zero, which is what lets very deep stacks train at all (the same fix as in ResNets, where otherwise accuracy degrades as layers are added). It also reframes each layer's job as a small additive edit to the stream rather than a full rewrite, an easier function to learn. Pre-norm, normalizing inside the residual branch, matters because in post-norm the repeated residual additions make the activation norm grow with depth, so gradients drift toward exploding or vanishing and need a warmup schedule to control; pre-norm keeps the stream's scale bounded and preserves a clean identity path from output back to input.

LayerNorm standardizes each token vector and rescales it, $\operatorname{LN}(x) = \gamma \odot \dfrac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$. Most recent models instead use RMSNorm (Zhang & Sennrich 2019), which drops mean-centering and bias, $\operatorname{RMSNorm}(x) = \dfrac{x}{\sqrt{\tfrac{1}{d}\sum_i x_i^2 + \epsilon}} \odot \gamma$. It is cheaper and just as effective, and why it can drop the mean is itself instructive: the stabilizing benefit of normalization comes mostly from rescaling activations to a controlled magnitude, not from re-centering them, so removing the mean subtraction (and the learned bias) costs nothing while saving compute that runs twice per layer across hundreds of layers.

The feed-forward network

The position-wise FFN holds most of the parameters and raw capacity. Classically it is two linear layers with a nonlinearity, $\operatorname{FFN}(x) = W_2\,\phi(W_1 x + b_1) + b_2$, with inner width about $4d$, roughly $8d^2$ parameters per layer.

Modern models use the gated SwiGLU variant (Shazeer 2020), which is worth stating exactly because it is easy to half-remember. Begin with a gated linear unit (GLU): take two independent linear projections of the input and multiply them elementwise, so one branch acts as a learned gate on the other. SwiGLU is the GLU whose active branch uses the Swish activation (also called SiLU), $\operatorname{Swish}(z) = z\,\sigma(z)$, a smooth, non-monotonic relative of ReLU. Written out, $\operatorname{FFN}(x) = \big(\operatorname{Swish}(xW_1) \odot xW_3\big)W_2$, with three weight matrices instead of two. Because of that third matrix, the inner width is usually scaled to about $\tfrac{2}{3}\cdot 4d = \tfrac{8}{3}d$ so the parameter count stays comparable to a classic $4d$ FFN. A block is then about $12d^2$ parameters, roughly one third attention and two thirds FFN.

Causal masking and the objective

For autoregressive language modeling, position $i$ must not see positions $> i$. A causal mask sets those logits to $-\infty$ before the softmax, which works precisely because $e^{-\infty} = 0$, so the forbidden positions receive exactly zero weight and drop out of the convex combination with no renormalization or special-casing. Training minimizes the average next-token cross-entropy, and because the mask makes every position a valid prediction, one forward pass over a length-$n$ sequence yields $n$ supervised signals at once, which is what makes pretraining so sample-efficient in wall-clock terms (Brown et al. 2020).

Inference: prefill, decode, and the KV cache

Generation has two regimes. Prefill processes the whole prompt in parallel, like training. Decode emits one token at a time, and re-reading the entire prefix every step would be quadratic waste, so each layer caches the keys and values of past tokens and computes only $q,k,v$ for the new token. The KV cache holds about $2 \cdot L \cdot n \cdot d_{kv}$ scalars (layers $L$, length $n$, key/value width $d_{kv}$), which is why long contexts and large batches are memory-bound and why MQA, GQA, and MLA matter so much.

Decode moves many weights and cache bytes per few FLOPs, so it has low arithmetic intensity and is bandwidth-bound; prefill is compute-bound. FlashAttention (Dao et al. 2022; FlashAttention-2, 2023; FlashAttention-3, 2024) attacks the prefill memory wall by never materializing the full $n \times n$ score matrix: it tiles $Q,K,V$ and runs a numerically stable online softmax over blocks, reducing attention memory from $O(n^2)$ to $O(n)$ and cutting slow high-bandwidth-memory traffic, which makes it both faster and far more memory-frugal at long context.

Scaling laws and the modern recipe

How to spend compute is governed by empirical scaling laws. Early work (Kaplan et al. 2020) found smooth power-law relationships between loss and scale; the Chinchilla analysis (Hoffmann et al. 2022) then showed many models were undertrained and that, for a fixed budget, tokens and parameters should grow together at roughly twenty tokens per parameter. The open Llama and Mistral models are practical instances of this recipe combined with the components above (pre-norm RMSNorm, RoPE, GQA, SwiGLU).

Training memory and numerical precision

Two practical points come up constantly and are worth knowing precisely. First, large models train in bf16 rather than fp16 because bf16 keeps the full 8-bit exponent of fp32, so it has the same dynamic range and rarely overflows or underflows; fp16's 5-bit exponent has a narrow range that forces a loss-scaling trick to stop small gradients from flushing to zero. Its mantissa is shorter, so bf16 is less precise per number, but for training that range matters more than precision. Second, it is activations, not parameters, that usually dominate training memory, because every intermediate must be retained for the backward pass. Gradient (activation) checkpointing keeps only a few activations and recomputes the rest during the backward pass, trading roughly one extra forward pass of compute for a large memory saving, which is often what makes a long sequence or a large batch fit on the device at all.

Beyond dense attention: long context, MoE, and state-space models

The quadratic cost has produced three durable directions. Sparse and windowed attention, such as Longformer (2020), restricts which key positions each query may see. The basic pattern is a sliding window: position $i$ attends only to the $w$ neighbors in $[\,i - w/2,\; i + w/2\,]$, which drops the cost from $O(n^2)$ to $O(n\,w)$, linear in length for a fixed $w$; stacking such layers grows the effective receptive field the way convolutions do, since a token reaches a little further each layer. To widen reach without enlarging $w$, dilated (strided) windows skip positions, attending to every $k$-th neighbor so a width-$w$ window spans $k\,w$ tokens. A few global tokens, say a classification token or sentence boundaries, are then allowed to attend to, and be attended by, everything, restoring the long-range links a purely local window drops. Longformer combines exactly these, local windows for most tokens plus a handful of global ones, for linear time and memory on long documents. Mixture of experts decouples parameters from per-token compute by routing each token to a few expert FFNs (Fedus et al. 2021; Mixtral 2024; DeepSeek-V3 2024), so a model can have hundreds of billions of parameters while activating only a fraction per token. And state-space models like Mamba (2023) replace attention with a selective linear recurrence that is linear in sequence length, trading some in-context recall for cheaper long-range modeling. Hybrids that interleave attention and SSM layers are an active area.

Step by step

  1. Project the residual stream into queries, keys, and values per head.
  2. Form scaled scores $QK^\top/\sqrt{d_k}$ and apply the causal mask.
  3. Softmax each row (shift by the row max for stability) to get attention weights.
  4. Multiply weights by $V$, concatenate heads, and project with $W^O$.
  5. Add back to the residual stream, then apply the pre-norm FFN sublayer the same way.
import numpy as np

def softmax(x, axis=-1):
    x = x - x.max(axis=axis, keepdims=True)          # shift for numerical stability
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scores = Q @ K.swapaxes(-1, -2) / np.sqrt(d_k)   # (..., q, k)
    if mask is not None:
        scores = np.where(mask, scores, -1e9)        # causal / padding mask
    weights = softmax(scores, axis=-1)
    return weights @ V, weights

def multi_head_attention(X, Wq, Wk, Wv, Wo, num_heads):
    seq, d_model = X.shape
    d_head = d_model // num_heads
    Q, K, V = X @ Wq, X @ Wk, X @ Wv
    split = lambda t: t.reshape(seq, num_heads, d_head).transpose(1, 0, 2)
    out, _ = scaled_dot_product_attention(split(Q), split(K), split(V))
    out = out.transpose(1, 0, 2).reshape(seq, d_model)   # concat heads
    return out @ Wo

Complexity (time and space)

Per layer, attention costs $O(n^2 d)$ time, from the $n \times n$ scores and their product with $V$, and the FFN costs $O(n d^2)$. The crossover is at $n \approx d$: short sequences are FFN-bound, long sequences attention-bound, which is why context length is the expensive axis. Naive attention uses $O(n^2)$ activation memory per head, the term FlashAttention removes by tiling and recomputation. At inference the KV cache adds $O(L n d_{kv})$ persistent memory, and each decode step is $O(n d)$ work but bandwidth-bound in practice.

Worked example

Attention preserves the sequence shape, and each token's attention weights form a valid distribution. Three tokens of width four through a two-head layer:

import numpy as np
np.random.seed(0)

X = np.random.randn(3, 4)                    # 3 tokens, model dim 4
Wq, Wk, Wv, Wo = [np.random.randn(4, 4) for _ in range(4)]

out = multi_head_attention(X, Wq, Wk, Wv, Wo, num_heads=2)
print(out.shape)                            # (3, 4)  -- shape in == shape out

_, weights = scaled_dot_product_attention(X, X, X)
print(weights.sum(axis=-1))                 # [1. 1. 1.]  -- each row is a distribution

Follow-up questions

  • Derive why scores are divided by $\sqrt{d_k}$. For zero-mean unit-variance independent $q_i,k_i$, $\operatorname{Var}(q\cdot k)=d_k$; dividing by $\sqrt{d_k}$ normalizes the logit variance to one, preventing softmax saturation and vanishing gradients.
  • Compare MQA, GQA, and MLA. MQA shares one key/value head (smallest cache, some quality loss); GQA uses $g$ groups for a quality/throughput middle ground; MLA caches a low-rank latent of keys and values, shrinking the cache while keeping multi-head expressivity.
  • Why is decode bandwidth-bound but prefill compute-bound? Decode processes one token against cached state, moving many weights and KV bytes per few FLOPs (low arithmetic intensity); prefill runs dense matmuls over the whole prompt.
  • How does FlashAttention avoid $O(n^2)$ memory? It tiles Q, K, V and keeps a running online softmax over blocks, so the full score matrix is never stored; memory drops to $O(n)$ and slow memory traffic falls.
  • What do scaling laws prescribe? Loss falls as a power law in compute, data, and parameters; Chinchilla says scale tokens and parameters together (about 20 tokens per parameter) for a fixed budget.
  • When would you reach past dense attention? For very long context, windowed/sparse attention or a state-space model (Mamba); to grow capacity cheaply, mixture-of-experts routing.

References

  1. Vaswani et al., Attention Is All You Need (2017).
  2. Shazeer, Fast Transformer Decoding (Multi-Query Attention, 2019).
  3. Zhang & Sennrich, Root Mean Square Layer Normalization (RMSNorm, 2019).
  4. Kaplan et al., Scaling Laws for Neural Language Models (2020).
  5. Brown et al., Language Models are Few-Shot Learners (GPT-3, 2020).
  6. Shazeer, GLU Variants Improve Transformer (SwiGLU, 2020).
  7. Beltagy et al., Longformer (2020).
  8. Su et al., RoFormer: Rotary Position Embedding (2021).
  9. Press et al., Train Short, Test Long (ALiBi, 2021).
  10. Fedus et al., Switch Transformers (2021).
  11. Hoffmann et al., Training Compute-Optimal LLMs (Chinchilla, 2022).
  12. Dao et al., FlashAttention (2022).
  13. Touvron et al., LLaMA (2023).
  14. Ainslie et al., GQA: Grouped-Query Attention (2023).
  15. Dao, FlashAttention-2 (2023).
  16. Jiang et al., Mistral 7B (2023).
  17. Gu & Dao, Mamba: Selective State Spaces (2023).
  18. Jiang et al., Mixtral of Experts (2024).
  19. Shah et al., FlashAttention-3 (2024).
  20. DeepSeek-AI, DeepSeek-V3 Technical Report (Multi-head Latent Attention, 2024).