Glossary/KL Divergence in LLM Training — Preventing Policy Collapse
Model Training & Optimization

KL Divergence in LLM Training — Preventing Policy Collapse

The mathematical guardrail that keeps fine-tuned LLMs from forgetting everything they know.


Definition

KL (Kullback-Leibler) Divergence measures how much one probability distribution differs from a reference distribution. In LLM training, it appears in three critical roles: (1) RLHF/GRPO KL penalty — prevents the fine-tuned policy from drifting too far from the pretrained model, preventing reward hacking and catastrophic forgetting. (2) VAE regularisation — regularises the latent space toward a Gaussian prior. (3) Knowledge distillation — measures how much the student's distribution differs from the teacher's. Understanding KL divergence is essential for anyone working with LLM alignment.

KL divergence — the mathematics

KL divergence from P to Q: expected log-ratio under P. D_KL(P||Q) ≥ 0 always. = 0 only when P = Q exactly. NOT symmetric: D_KL(P||Q) ≠ D_KL(Q||P). Measures 'extra bits' needed to encode samples from P using the optimal code for Q.

KL divergence in RLHF, distillation, and VAE

import torch
import torch.nn.functional as F
import numpy as np

# ── 1. Basic KL Divergence computation ──
P = torch.tensor([0.3, 0.5, 0.2])   # True distribution (e.g., student output)
Q = torch.tensor([0.35, 0.4, 0.25]) # Reference distribution (e.g., teacher output)

# Manual computation
kl_manual = (P * (P.log() - Q.log())).sum()
print(f"KL(P||Q) manual: {kl_manual:.4f}")

# PyTorch (input must be log-probabilities for kl_div)
kl_torch = F.kl_div(Q.log(), P, reduction='sum')
print(f"KL(P||Q) torch:  {kl_torch:.4f}")

# ── 2. KL penalty in RLHF/GRPO (prevents policy collapse) ──
# RLHF objective: maximise reward while staying close to reference policy
# Total loss = -reward + β × KL(π_current || π_reference)

def rlhf_loss(reward, logprobs_current, logprobs_reference, beta=0.05):
    """
    reward:           scalar reward from reward model
    logprobs_current: log P(response | π_θ)  [current fine-tuned policy]
    logprobs_reference: log P(response | π_ref) [frozen reference: pretrained model]
    beta:             KL penalty weight (0.01-0.1 typical)
    """
    kl_penalty = (logprobs_current - logprobs_reference).mean()
    # KL(π_current || π_ref) ≈ logP_current - logP_ref (token-level average)
    loss = -reward + beta * kl_penalty
    return loss, kl_penalty.item()

# Example: reward model gives 2.5, but KL is very high (policy drifted)
reward = torch.tensor(2.5)
logprobs_cur = torch.tensor([-0.5, -1.2, -0.8, -0.3])   # Fine-tuned model probs
logprobs_ref = torch.tensor([-0.3, -0.5, -0.4, -0.2])   # Reference model probs (pretrained)

loss, kl = rlhf_loss(reward, logprobs_cur, logprobs_ref, beta=0.05)
print(f"RLHF loss: {loss:.3f} (reward={reward}, KL={kl:.3f})")

# ── 3. KL in knowledge distillation ──
# Soft targets: temperature scaling softens the teacher's distribution
def distillation_loss(student_logits, teacher_logits, hard_labels, T=3.0, alpha=0.5):
    """
    T:     Temperature (higher = softer distributions)
    alpha: Mix of soft (KL) and hard (CE) loss
    """
    # Soft targets from teacher
    soft_teacher = F.softmax(teacher_logits / T, dim=-1)
    soft_student = F.log_softmax(student_logits / T, dim=-1)

    # KL divergence loss (soft): match teacher distribution shape
    kl_soft = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T**2)

    # Cross-entropy loss (hard): match ground truth labels
    ce_hard  = F.cross_entropy(student_logits, hard_labels)

    return alpha * kl_soft + (1 - alpha) * ce_hard

# ── 4. KL in VAE (regularises latent space) ──
def vae_kl_loss(mu, logvar):
    """
    KL(N(μ,σ²) || N(0,1)) — keeps latent distribution close to standard Gaussian
    = -0.5 × Σ(1 + log(σ²) - μ² - σ²)
    """
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

# ── 5. Forward KL vs Reverse KL ──
# D_KL(P||Q): "forward" — zero-avoiding (Q must cover all of P's support)
# D_KL(Q||P): "reverse" / inclusive — mass-covering (Q spreads over P's support)
# RLHF uses forward KL from current to reference policy
# Variational inference uses reverse KL (ELBO maximisation)
print("KL is asymmetric:")
print(f"  KL(P||Q) = {kl_manual:.4f}")
print(f"  KL(Q||P) = {(Q * (Q.log() - P.log())).sum():.4f}")

Why the KL penalty prevents reward hacking

Without KL penalty, RLHF training quickly degrades: the LLM learns to output repetitive phrases that happen to fool the reward model ('The answer is: ' repeated 50 times scores high because it looks like it is about to give a long answer). The KL penalty penalises any output distribution that deviates significantly from the pretrained model's distribution — forcing the fine-tuned model to stay linguistically coherent while improving alignment.

KL penalty βEffectRisk
β = 0 (no penalty)Full optimisation of rewardReward hacking, collapsed language quality
β = 0.01 (very light)Strong reward optimisation with slight constraintStill some reward hacking risk
β = 0.05-0.1 (typical)Good balance of reward + language qualityStandard RLHF setting
β = 0.5 (strong)Minimal drift from referenceReward signal too weak to change behaviour
β = ∞No change from reference (SFT only)Defeats purpose of RLHF

Practice questions

  1. D_KL(P||Q) = 0.7 vs D_KL(Q||P) = 1.2. What does this asymmetry mean? (Answer: KL divergence is not symmetric — the direction matters. D_KL(P||Q) measures the extra bits needed to encode P using Q's code. Different directions penalise different types of mismatch. In RLHF, we use D_KL(π_current || π_ref) — penalises current policy deviating from reference.)
  2. In RLHF, why is β (the KL penalty weight) a crucial hyperparameter? (Answer: Too small β: reward hacking (model exploits reward model weaknesses). Too large β: model barely changes from pretrained policy (reward signal overwhelmed). Typical β = 0.05-0.1. Often adapted during training: start small, increase if reward hacking is detected, decrease if model is not improving.)
  3. In knowledge distillation with temperature T=5, how does the teacher distribution change? (Answer: Temperature T scales logits before softmax: softmax(logits/T). Higher T → softer distribution (probabilities more uniform, smaller gaps between classes). The student learns not just the most likely class but the relative similarity structure: "this looks a bit like class 3 even though it's class 1". This dark knowledge is why distillation outperforms training from hard labels alone.)
  4. Forward KL D_KL(P||Q) vs Reverse KL D_KL(Q||P) — which is mode-seeking and which is mode-covering? (Answer: Forward KL D_KL(P||Q) is mode-covering: P is true distribution, Q is learned. Minimising forward KL forces Q to cover all modes of P (zero-avoiding — Q must not be zero where P is nonzero). Reverse KL D_KL(Q||P) is mode-seeking: Q focuses on the dominant mode of P, ignoring smaller modes. RLHF uses forward KL (current || reference) to prevent the policy from collapsing to a narrow mode.)
  5. In a VAE, what does KL(q(z|x) || p(z)) penalise? (Answer: It penalises the encoder's posterior distribution q(z|x) from deviating from the prior p(z) = N(0,I). Without KL: each input maps to a distinct region of latent space — no overlap, no continuity, cannot sample from latent space for generation. With KL: all inputs map to distributions near N(0,I) — overlapping, smooth latent space that enables generation by sampling z ~ N(0,I).)

On LumiChats

Every time Claude is fine-tuned or aligned, KL divergence is used to prevent it from forgetting its pretrained language capabilities. The β-KL penalty in the training objective is why Claude retains coherent language generation after RLHF — without it, alignment training would destroy the LLM's core ability to communicate.

Try it free

Try LumiChats for ₹69

39+ AI models. Study Mode with page-locked answers. Agent Mode with code execution. Pay only on days you use it.

Get Started — ₹69/day

Related Terms

5 terms