KL (Kullback-Leibler) Divergence measures how much one probability distribution differs from a reference distribution. In LLM training, it appears in three critical roles: (1) RLHF/GRPO KL penalty — prevents the fine-tuned policy from drifting too far from the pretrained model, preventing reward hacking and catastrophic forgetting. (2) VAE regularisation — regularises the latent space toward a Gaussian prior. (3) Knowledge distillation — measures how much the student's distribution differs from the teacher's. Understanding KL divergence is essential for anyone working with LLM alignment.
KL divergence — the mathematics
KL divergence from P to Q: expected log-ratio under P. D_KL(P||Q) ≥ 0 always. = 0 only when P = Q exactly. NOT symmetric: D_KL(P||Q) ≠ D_KL(Q||P). Measures 'extra bits' needed to encode samples from P using the optimal code for Q.
KL divergence in RLHF, distillation, and VAE
import torch
import torch.nn.functional as F
import numpy as np
# ── 1. Basic KL Divergence computation ──
P = torch.tensor([0.3, 0.5, 0.2]) # True distribution (e.g., student output)
Q = torch.tensor([0.35, 0.4, 0.25]) # Reference distribution (e.g., teacher output)
# Manual computation
kl_manual = (P * (P.log() - Q.log())).sum()
print(f"KL(P||Q) manual: {kl_manual:.4f}")
# PyTorch (input must be log-probabilities for kl_div)
kl_torch = F.kl_div(Q.log(), P, reduction='sum')
print(f"KL(P||Q) torch: {kl_torch:.4f}")
# ── 2. KL penalty in RLHF/GRPO (prevents policy collapse) ──
# RLHF objective: maximise reward while staying close to reference policy
# Total loss = -reward + β × KL(π_current || π_reference)
def rlhf_loss(reward, logprobs_current, logprobs_reference, beta=0.05):
"""
reward: scalar reward from reward model
logprobs_current: log P(response | π_θ) [current fine-tuned policy]
logprobs_reference: log P(response | π_ref) [frozen reference: pretrained model]
beta: KL penalty weight (0.01-0.1 typical)
"""
kl_penalty = (logprobs_current - logprobs_reference).mean()
# KL(π_current || π_ref) ≈ logP_current - logP_ref (token-level average)
loss = -reward + beta * kl_penalty
return loss, kl_penalty.item()
# Example: reward model gives 2.5, but KL is very high (policy drifted)
reward = torch.tensor(2.5)
logprobs_cur = torch.tensor([-0.5, -1.2, -0.8, -0.3]) # Fine-tuned model probs
logprobs_ref = torch.tensor([-0.3, -0.5, -0.4, -0.2]) # Reference model probs (pretrained)
loss, kl = rlhf_loss(reward, logprobs_cur, logprobs_ref, beta=0.05)
print(f"RLHF loss: {loss:.3f} (reward={reward}, KL={kl:.3f})")
# ── 3. KL in knowledge distillation ──
# Soft targets: temperature scaling softens the teacher's distribution
def distillation_loss(student_logits, teacher_logits, hard_labels, T=3.0, alpha=0.5):
"""
T: Temperature (higher = softer distributions)
alpha: Mix of soft (KL) and hard (CE) loss
"""
# Soft targets from teacher
soft_teacher = F.softmax(teacher_logits / T, dim=-1)
soft_student = F.log_softmax(student_logits / T, dim=-1)
# KL divergence loss (soft): match teacher distribution shape
kl_soft = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T**2)
# Cross-entropy loss (hard): match ground truth labels
ce_hard = F.cross_entropy(student_logits, hard_labels)
return alpha * kl_soft + (1 - alpha) * ce_hard
# ── 4. KL in VAE (regularises latent space) ──
def vae_kl_loss(mu, logvar):
"""
KL(N(μ,σ²) || N(0,1)) — keeps latent distribution close to standard Gaussian
= -0.5 × Σ(1 + log(σ²) - μ² - σ²)
"""
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# ── 5. Forward KL vs Reverse KL ──
# D_KL(P||Q): "forward" — zero-avoiding (Q must cover all of P's support)
# D_KL(Q||P): "reverse" / inclusive — mass-covering (Q spreads over P's support)
# RLHF uses forward KL from current to reference policy
# Variational inference uses reverse KL (ELBO maximisation)
print("KL is asymmetric:")
print(f" KL(P||Q) = {kl_manual:.4f}")
print(f" KL(Q||P) = {(Q * (Q.log() - P.log())).sum():.4f}")Why the KL penalty prevents reward hacking
Without KL penalty, RLHF training quickly degrades: the LLM learns to output repetitive phrases that happen to fool the reward model ('The answer is: ' repeated 50 times scores high because it looks like it is about to give a long answer). The KL penalty penalises any output distribution that deviates significantly from the pretrained model's distribution — forcing the fine-tuned model to stay linguistically coherent while improving alignment.
| KL penalty β | Effect | Risk |
|---|---|---|
| β = 0 (no penalty) | Full optimisation of reward | Reward hacking, collapsed language quality |
| β = 0.01 (very light) | Strong reward optimisation with slight constraint | Still some reward hacking risk |
| β = 0.05-0.1 (typical) | Good balance of reward + language quality | Standard RLHF setting |
| β = 0.5 (strong) | Minimal drift from reference | Reward signal too weak to change behaviour |
| β = ∞ | No change from reference (SFT only) | Defeats purpose of RLHF |
Practice questions
- D_KL(P||Q) = 0.7 vs D_KL(Q||P) = 1.2. What does this asymmetry mean? (Answer: KL divergence is not symmetric — the direction matters. D_KL(P||Q) measures the extra bits needed to encode P using Q's code. Different directions penalise different types of mismatch. In RLHF, we use D_KL(π_current || π_ref) — penalises current policy deviating from reference.)
- In RLHF, why is β (the KL penalty weight) a crucial hyperparameter? (Answer: Too small β: reward hacking (model exploits reward model weaknesses). Too large β: model barely changes from pretrained policy (reward signal overwhelmed). Typical β = 0.05-0.1. Often adapted during training: start small, increase if reward hacking is detected, decrease if model is not improving.)
- In knowledge distillation with temperature T=5, how does the teacher distribution change? (Answer: Temperature T scales logits before softmax: softmax(logits/T). Higher T → softer distribution (probabilities more uniform, smaller gaps between classes). The student learns not just the most likely class but the relative similarity structure: "this looks a bit like class 3 even though it's class 1". This dark knowledge is why distillation outperforms training from hard labels alone.)
- Forward KL D_KL(P||Q) vs Reverse KL D_KL(Q||P) — which is mode-seeking and which is mode-covering? (Answer: Forward KL D_KL(P||Q) is mode-covering: P is true distribution, Q is learned. Minimising forward KL forces Q to cover all modes of P (zero-avoiding — Q must not be zero where P is nonzero). Reverse KL D_KL(Q||P) is mode-seeking: Q focuses on the dominant mode of P, ignoring smaller modes. RLHF uses forward KL (current || reference) to prevent the policy from collapsing to a narrow mode.)
- In a VAE, what does KL(q(z|x) || p(z)) penalise? (Answer: It penalises the encoder's posterior distribution q(z|x) from deviating from the prior p(z) = N(0,I). Without KL: each input maps to a distinct region of latent space — no overlap, no continuity, cannot sample from latent space for generation. With KL: all inputs map to distributions near N(0,I) — overlapping, smooth latent space that enables generation by sampling z ~ N(0,I).)
On LumiChats
Every time Claude is fine-tuned or aligned, KL divergence is used to prevent it from forgetting its pretrained language capabilities. The β-KL penalty in the training objective is why Claude retains coherent language generation after RLHF — without it, alignment training would destroy the LLM's core ability to communicate.
Try it free