The Transformer (Vaswani et al., 2017) is the dominant deep learning architecture for sequence processing. It replaces recurrence with self-attention: every token in the sequence directly attends to every other token simultaneously, enabling full parallelisation and direct capture of long-range dependencies. The architecture has two variants: encoder-only (BERT) for understanding tasks, decoder-only (GPT, Claude) for generation, and encoder-decoder (T5, translation models). Every major AI system — GPT-4, Claude, Gemini, Llama, Stable Diffusion — is based on transformers.
Self-attention — the core operation
Scaled dot-product attention. Q (queries), K (keys), V (values) are linear projections of the input. QKᵀ/√dₖ = attention scores (how much each token attends to each other). Softmax normalises to weights summing to 1. Multiply by V = weighted sum of values. √dₖ prevents dot products from growing too large (which would saturate softmax).
Self-attention and multi-head attention from scratch
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
"""Single-head scaled dot-product self-attention."""
def __init__(self, d_model: int):
super().__init__()
self.d_k = d_model
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
# x: (batch, seq_len, d_model)
Q = self.W_Q(x) # (batch, seq, d_model)
K = self.W_K(x) # (batch, seq, d_model) — same x = SELF-attention
V = self.W_V(x) # (batch, seq, d_model)
# Attention scores: (batch, seq, seq) — each token attends to every other
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
# Causal mask: prevent attending to future tokens (decoder/GPT)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1) # (batch, seq, seq) — sums to 1
# Weighted sum of values
out = weights @ V # (batch, seq, d_model)
return self.W_O(out), weights
class MultiHeadAttention(nn.Module):
"""Multi-head attention: h attention heads in parallel."""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_k = d_model // n_heads # Each head gets d_model/n_heads dims
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def split_heads(self, x):
# x: (batch, seq, d_model) → (batch, heads, seq, d_k)
B, S, D = x.shape
return x.view(B, S, self.n_heads, self.d_k).transpose(1, 2)
def forward(self, x, mask=None):
B, S, D = x.shape
Q = self.split_heads(self.W_Q(x)) # (B, heads, S, d_k)
K = self.split_heads(self.W_K(x))
V = self.split_heads(self.W_V(x))
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
out = (weights @ V) # (B, heads, S, d_k)
out = out.transpose(1, 2).reshape(B, S, D) # Re-merge heads
return self.W_O(out)
# Test multi-head attention
mha = MultiHeadAttention(d_model=512, n_heads=8) # 8 heads, each 64-dim
x = torch.randn(2, 20, 512) # Batch 2, seq 20, d_model 512
out = mha(x)
print(f"MHA input: {x.shape}") # [2, 20, 512]
print(f"MHA output: {out.shape}") # [2, 20, 512]Positional encoding and the full transformer block
Positional encoding and transformer block
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding — adds position information to embeddings."""
def __init__(self, d_model: int, max_seq_len: int = 5000):
super().__init__()
pe = torch.zeros(max_seq_len, d_model)
pos = torch.arange(0, max_seq_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div) # Even dimensions
pe[:, 1::2] = torch.cos(pos * div) # Odd dimensions
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_seq, d_model)
def forward(self, x):
# Add position to token embeddings (broadcast over batch)
return x + self.pe[:, :x.size(1)]
class TransformerBlock(nn.Module):
"""One Transformer encoder block = MHA + FFN + residual connections + layer norm."""
def __init__(self, d_model: int, n_heads: int, ff_dim: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, ff_dim), nn.GELU(), # GELU — modern default
nn.Dropout(dropout),
nn.Linear(ff_dim, d_model)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Sub-layer 1: Multi-Head Attention + residual + layer norm
attn_out = self.attention(x, mask)
x = self.norm1(x + self.dropout(attn_out)) # Add & Norm
# Sub-layer 2: Feed-Forward Network + residual + layer norm
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out)) # Add & Norm
return x
# Stack transformer blocks into an encoder
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model=512, n_heads=8,
ff_dim=2048, n_layers=6, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_enc = PositionalEncoding(d_model)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, ff_dim, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
x = self.pos_enc(self.embedding(x)) # Token + position
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
# BERT-base equivalent
bert_encoder = TransformerEncoder(
vocab_size=30522, d_model=768, n_heads=12, ff_dim=3072, n_layers=12
)
tokens = torch.randint(0, 30522, (2, 128)) # Batch 2, seq 128
output = bert_encoder(tokens)
print(f"Encoder output: {output.shape}") # [2, 128, 768]
# Total parameter count
params = sum(p.numel() for p in bert_encoder.parameters())
print(f"Parameters: {params:,}") # ~110M — close to actual BERT-base (110M)| Architecture | Layers used | Key capability | Examples |
|---|---|---|---|
| Encoder-only | Encoder blocks only | Bidirectional context — understands text | BERT, RoBERTa, DistilBERT |
| Decoder-only | Decoder blocks (causal mask) | Auto-regressive text generation | GPT-4, Claude, Llama, Gemini |
| Encoder-Decoder | Both encoder + decoder | Sequence-to-sequence tasks | T5, BART, MarianMT, Whisper |
Practice questions
- Why is √d_k used to scale the dot product in attention? (Answer: For large d_k (e.g., 64), the dot products QKᵀ can become very large in magnitude. Feeding large values to softmax creates near-one-hot distributions with tiny gradients (saturation). Dividing by √d_k keeps values in a reasonable range before softmax, maintaining useful gradient signal.)
- Multi-head attention uses 8 heads of 64-dim each. Why multiple heads instead of one 512-dim head? (Answer: Different heads learn to attend to different types of relationships simultaneously: one head may focus on syntactic dependencies, another on coreference, another on topic. Single large attention computes one combined representation — multiple smaller heads capture diverse relationship patterns in parallel.)
- Encoder-only vs Decoder-only transformers — what structural difference enables generation? (Answer: Decoder uses a causal (triangular) mask: token i can only attend to tokens 0...i (no future tokens). This enables auto-regressive generation — predict next token, append, predict again. Encoder has no mask — each token attends to all others in both directions for full context understanding.)
- Why does the Transformer need positional encoding? (Answer: Self-attention is permutation-invariant — "the cat sat on the mat" and "the mat sat on the cat" produce the same attention scores. Positional encoding injects position information into token embeddings so the model knows WHERE each token is, enabling it to understand word order and sentence structure.)
- What are residual connections (Add & Norm) in transformers and why are they critical? (Answer: Residual connections: output = LayerNorm(x + sublayer(x)). They allow gradients to flow directly from output to early layers without going through the full network — solving the vanishing gradient problem for very deep transformers (100+ layers in GPT-4). Without residuals, 12+ layer transformers would not train effectively.)
On LumiChats
Claude is a decoder-only transformer — this architecture is the foundation of every response you receive. The self-attention mechanism described here is exactly what processes your message: each word attends to every other word, computing weighted representations before generating the response token by token.
Try it free