Mixed precision training uses lower-precision floating-point formats (FP16, BF16, FP8) for computations while keeping FP32 for critical operations like gradient accumulation and weight updates. This halves memory usage and doubles throughput on modern GPUs (NVIDIA Tensor Cores are optimised for FP16/BF16/FP8). BF16 (Brain Float) is the default for modern LLM training — wider dynamic range than FP16, less prone to overflow. FP8 (used in NVIDIA H100 and the Llama GRPO notebook) halves BF16 memory again, enabling 2× larger models or batch sizes.
Real-life analogy: The recipe shortcut
A recipe calls for precise gram-level measurements for safety-critical ingredients (salt, yeast) but accepts rough cup measures for bulk ingredients (flour, water). Mixed precision training does the same: keep exact FP32 precision for gradients and master weights (the critical ingredients), but use fast approximate FP16/BF16 for forward passes and activations (the bulk computation). You get 2× speed with negligible quality loss.
Floating point formats compared
| Format | Bits | Range | Precision | Memory (1M params) | Use |
|---|---|---|---|---|---|
| FP32 | 32 | ±3.4×10³⁸ | High (24-bit mantissa) | 4 GB | Master weights, gradients |
| FP16 | 16 | ±65504 | Medium (10-bit mantissa) | 2 GB | Activations, compute (overflow risk!) |
| BF16 | 16 | ±3.4×10³⁸ | Lower (7-bit mantissa) | 2 GB | LLM training standard (same range as FP32) |
| FP8 (E4M3) | 8 | ±448 | Low (3-bit mantissa) | 1 GB | H100 forward pass (Llama GRPO notebook) |
| INT8 | 8 | −128..127 | Integer only | 1 GB | Inference quantization (not training) |
Mixed precision training with PyTorch AMP and Unsloth
import torch
from torch.cuda.amp import autocast, GradScaler
# ── PyTorch Automatic Mixed Precision (AMP) ──
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
scaler = GradScaler() # Scales loss to prevent FP16 underflow
for X, y in dataloader:
X, y = X.cuda(), y.cuda()
optimizer.zero_grad()
with autocast(dtype=torch.float16): # Forward pass in FP16
output = model(X)
loss = criterion(output, y)
scaler.scale(loss).backward() # Backward in FP16 (scaled)
scaler.unscale_(optimizer) # Unscale before clip
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer) # Update in FP32 (master weights)
scaler.update()
# ── BF16 (recommended for A100/H100 — no scaler needed) ──
with autocast(dtype=torch.bfloat16):
output = model(X)
loss = criterion(output, y)
loss.backward() # BF16 gradient: same range as FP32, no overflow
optimizer.step()
# ── Unsloth: FP8 training for maximum efficiency (from Llama GRPO notebook) ──
# Unsloth automatically uses FP8 on H100s for the fastest possible training
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Llama-3.2-1B-Instruct-FP8-Block", # FP8 quantised model
max_seq_length = 8192,
load_in_4bit = False,
# Unsloth handles FP8 weight loading and BF16 compute automatically
)
# "FP8-Block" models store weights in FP8 (half the VRAM of BF16)
# but compute in BF16 precision — best of both worlds
# ── Hugging Face Transformers: torch_dtype ──
from transformers import AutoModelForCausalLM
model_bf16 = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct",
torch_dtype=torch.bfloat16, # Load in BF16: 2GB instead of 4GB
device_map="auto"
)
# ── Memory comparison for a 7B model ──
params = 7_000_000_000
print(f"FP32 (4 bytes): {params*4/1e9:.0f} GB") # 28 GB
print(f"BF16 (2 bytes): {params*2/1e9:.0f} GB") # 14 GB
print(f"FP8 (1 byte): {params*1/1e9:.0f} GB") # 7 GB
print(f"INT4 (0.5 byte):{params*0.5/1e9:.0f} GB") # 3.5 GB (quantization)Loss scaling and why BF16 is preferred for LLMs
FP16 overflow problem: FP16 max value = 65504. During training, gradients can easily exceed this, causing NaN. Solution: loss scaling (multiply loss by 1024 before backward, divide after). BF16 fix: Same 16-bit memory as FP16 but same exponent range as FP32 (±3.4×10³⁸). Never overflows on typical gradient values. No loss scaling needed. Used by default in all modern LLM training (GPT-4, Claude, Llama, Mistral). On H100 GPUs, FP8 takes this further — same dynamic range as FP8 hardware ops enable near-BF16 accuracy at half the memory.
Practice questions
- A 7B LLM in FP32 requires 28GB VRAM. In BF16? (Answer: 14GB — BF16 uses 2 bytes per parameter vs 4 bytes for FP32. BF16 is the standard for LLM inference and fine-tuning on modern GPUs.)
- Why is BF16 preferred over FP16 for LLM training despite both being 16-bit? (Answer: BF16 has the same exponent range as FP32 (8-bit exponent) — no overflow on large gradient values. FP16 has only a 5-bit exponent, max value 65504, causing frequent overflows during training. BF16 requires no loss scaling; FP16 requires GradScaler.)
- What is the GradScaler in PyTorch AMP and why is it needed for FP16? (Answer: GradScaler multiplies the loss by a large scale factor before backward pass to prevent FP16 underflow (very small gradients rounding to zero). Before the optimizer step, it divides the scaled gradients back. If inf/NaN is detected, it reduces the scale factor. Not needed for BF16.)
- What does FP8-Block mean in the Unsloth model name "Llama-3.2-1B-Instruct-FP8-Block"? (Answer: Weights are stored in FP8 format with block-wise quantization (groups of weights share a scaling factor). This halves VRAM vs BF16. Computations still happen in BF16/FP16 precision — the FP8 storage is dequantized for compute. Enables loading 2× larger models on the same GPU.)
- When would you NOT use mixed precision training? (Answer: When training very small models where FP32 precision matters (numerical methods, scientific computing). Also for final evaluation of trained models where exact reproducibility is needed. And when debugging — FP32 gives cleaner error signals since no precision loss.)
On LumiChats
LumiChats runs on BF16-precision models — every inference uses BF16 compute for 2× memory efficiency vs FP32. The FP8 support mentioned in the Llama GRPO notebook represents the next generation of efficiency: models that fit in half the VRAM of BF16 while maintaining near-identical quality.
Try it free