Glossary/Attention Mechanism
AI Fundamentals

Attention Mechanism

How AI decides what to focus on.


Definition

The attention mechanism is a technique that allows neural networks to dynamically focus on the most relevant parts of an input when producing an output. In Transformers, self-attention enables every token to attend to every other token simultaneously, computing context-aware representations based on learned Query, Key, and Value matrices.

Intuition: what attention does

Imagine reading: "The trophy didn't fit in the suitcase because it was too big." To understand what 'it' refers to, your brain attends to 'trophy' — not 'suitcase'. The attention mechanism captures this. For every token in a sequence, attention computes a weighted sum over all other tokens — where weights represent how strongly each token should influence the current token's representation.

Key Intuition

High attention weight = strong influence. Near-zero weight = ignored. The model learns which tokens to attend to — not from hand-coded rules, but purely from training data.

This mechanism replaced recurrent networks entirely. Instead of passing information step-by-step like a chain, attention connects every pair of positions directly — giving every token a global view of the entire sequence in a single operation.

Query, Key, and Value: the math

Each input token embedding x is linearly projected into three vectors using learned weight matrices W_Q, W_K, W_V:

Q = what am I looking for? K = what do I contain? V = what information do I provide?

Attention scores are computed by dot-producting each token's Query with every other token's Key. We scale by √d_k to prevent the dot products from growing so large that softmax saturates (producing near-zero gradients):

The scaled dot-product attention formula from "Attention Is All You Need" (Vaswani et al., 2017)

Where Q ∈ ℝⁿˣᵈ, K ∈ ℝᵐˣᵈ, V ∈ ℝᵐˣᵈ, and d_k is the key dimension. The softmax normalizes scores into weights that sum to 1. The output for each token is a weighted sum of all Value vectors.

Why √d_k?

As d_k grows, dot products grow in magnitude, pushing softmax into regions where gradients are tiny. Dividing by √d_k keeps the variance of the dot products around 1 regardless of dimension size.

Implementing attention in Python

Here's a clean NumPy implementation of scaled dot-product attention — the core operation inside every Transformer:

Scaled dot-product attention — the core of every Transformer block

import numpy as np

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    e_x = np.exp(x - x_max)          # numerically stable
    return e_x / e_x.sum(axis=axis, keepdims=True)

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Args:
        Q: queries  shape (seq_len, d_k)
        K: keys     shape (seq_len, d_k)
        V: values   shape (seq_len, d_v)
        mask: optional causal mask (seq_len, seq_len)

    Returns:
        output: shape (seq_len, d_v)
        weights: attention weights (seq_len, seq_len)
    """
    d_k = Q.shape[-1]

    # Compute raw attention scores: (seq_len, seq_len)
    scores = Q @ K.T / np.sqrt(d_k)

    # Apply causal mask (for decoder / autoregressive models)
    if mask is not None:
        scores = np.where(mask == 0, -1e9, scores)

    # Softmax to get attention weights
    weights = softmax(scores, axis=-1)

    # Weighted sum of values
    output = weights @ V
    return output, weights


# --- Example usage ---
np.random.seed(42)
seq_len, d_k, d_v = 5, 64, 64

Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)

output, weights = scaled_dot_product_attention(Q, K, V)

print("Output shape:          ", output.shape)   # (5, 64)
print("Attention weights shape:", weights.shape)  # (5, 5)
print("Weights row 0 (sum=1): ", weights[0].sum().round(4))  # 1.0

PyTorch version — production quality with batching

import torch
import torch.nn.functional as F

def attention(Q, K, V, mask=None):
    """
    Q, K: (batch, heads, seq_len, d_k)
    V:    (batch, heads, seq_len, d_v)
    """
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / d_k ** 0.5

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights


# PyTorch also has a built-in (Flash Attention optimized):
# F.scaled_dot_product_attention(Q, K, V, is_causal=True)

Multi-head attention

Running a single attention operation captures one type of relationship. Multi-head attention runs h parallel attention heads simultaneously, each with its own Q/K/V projections:

Where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

Different heads learn to capture different relationship types — one head might capture syntactic dependencies, another semantic similarity, another positional patterns. GPT-3 uses 96 heads across 175B parameters. This parallel multi-perspective processing is what enables Transformers to simultaneously capture many different kinds of linguistic and conceptual structure.

ModelLayersHeadsd_modelParameters
GPT-2 Small1212768117M
GPT-3969612,288175B
LLaMA 3 8B32324,0968B
LLaMA 3 70B80648,19270B

Causal masking in decoder-only models

In autoregressive models (GPT, LLaMA, Claude), the model generates tokens left-to-right. To prevent 'cheating' during training, a causal mask is applied — future positions are set to −∞ before softmax, making those weights effectively zero:

Causal mask for autoregressive training

import torch

def make_causal_mask(seq_len):
    """
    Creates a lower-triangular boolean mask.
    True = allowed to attend. False = masked (future token).

    For seq_len=4:
    [[T, F, F, F],
     [T, T, F, F],
     [T, T, T, F],
     [T, T, T, T]]
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    return mask

mask = make_causal_mask(6)
print(mask.int())
# tensor([[1, 0, 0, 0, 0, 0],
#         [1, 1, 0, 0, 0, 0],
#         [1, 1, 1, 0, 0, 0],
#         [1, 1, 1, 1, 0, 0],
#         [1, 1, 1, 1, 1, 0],
#         [1, 1, 1, 1, 1, 1]])

Encoder vs Decoder

Encoder-only models like BERT use bidirectional (full) attention — every token can see every other. Decoder-only models like GPT use causal (masked) attention — each token can only see itself and prior tokens. This is what makes decoders suitable for text generation.

Efficient attention variants

Vanilla attention is O(n²) in memory and compute — a 128K-token context would require 128K × 128K = 16 billion attention weights. Several techniques address this:

  • FlashAttention (2022, 2023) — Restructures computation to minimize GPU memory reads by tiling. Same mathematical result as standard attention, but 2–4× faster and O(n) memory instead of O(n²). Used in virtually every modern LLM training run.
  • Grouped Query Attention (GQA) — Shares K/V heads across multiple Q heads, reducing KV-cache memory during inference. Used in LLaMA 3, Mistral, and most production models.
  • Sliding Window Attention (Mistral) — Each token attends only to a local window of W tokens. Reduces compute from O(n²) to O(n·W). Extended via sparse global attention for very long contexts.
  • Multi-Query Attention (MQA) — All Q heads share a single K/V head — extreme memory savings at the cost of some quality.

Practice questions

  1. What is multi-head attention and why is it better than single-head attention? (Answer: Multi-head attention runs h attention functions in parallel on d_model/h-dimensional projections of Q, K, V. Each head can specialise in different relationship types: head 1 may capture syntactic relationships, head 2 semantic similarity, head 3 coreference. Single-head attention is forced to average all relationship types into one representation. Multi-head concatenates all h outputs (each d_v dimensional) and projects back to d_model — capturing diverse interaction patterns simultaneously. BERT-base: 12 heads, GPT-3: 96 heads.)
  2. What is cross-attention (used in encoder-decoder transformers) and how does it differ from self-attention? (Answer: Self-attention: Q, K, V all come from the same sequence. Each token attends to all tokens in the same sequence. Cross-attention: Q comes from the decoder sequence; K and V come from the encoder's output. The decoder queries the encoder's representations. This allows the decoder to 'look at' the full input context while generating each output token. Used in T5, BART, and the original Transformer for machine translation — each output word attends to relevant input words.)
  3. What is the computational complexity of self-attention and why does it limit context length? (Answer: Self-attention: O(n² · d) where n = sequence length, d = model dimension. For n=128K and d=4096: 128,000² × 4,096 ≈ 67 trillion operations per layer. The quadratic term in n is the bottleneck — doubling context length quadruples attention computation. Memory: storing the n×n attention matrix requires O(n²) memory — 16GB just for the attention weights at 128K context. FlashAttention (Dao et al.) reduces memory to O(n) using tiled computation but cannot reduce the O(n²) compute.)
  4. What does attention weight visualisation reveal about what transformers 'pay attention to'? (Answer: Attention weights visualise which tokens the model focuses on when computing each token's representation. Research findings: (1) Some heads track syntactic structure (subject-verb agreement). (2) Others track coreference (pronoun → antecedent). (3) In early layers: local syntactic patterns. In later layers: semantic relationships. Caveat: attention ≠ explanation — high attention weight does not necessarily mean causal importance. Gradient-based attribution methods (Integrated Gradients, SHAP) are more reliable for explanation.)
  5. What is Flash Attention and why is it critical for production LLM serving? (Answer: FlashAttention (Dao et al. 2022): IO-aware attention algorithm that tiles the attention computation across GPU SRAM blocks, avoiding writing the full n×n attention matrix to HBM (GPU RAM). Standard attention: 4 reads/writes to HBM. FlashAttention: 1 read/write. Speedup: 2–4× faster. Memory: O(n) instead of O(n²). Critical for production because: long-context serving (128K) is only practical with FlashAttention — without it, a 128K context sequence requires 64GB just for the attention matrix. FlashAttention-2 and -3 further improve throughput by better GPU thread scheduling.)

Try LumiChats for ₹69

39+ AI models. Study Mode with page-locked answers. Agent Mode with code execution. Pay only on days you use it.

Get Started — ₹69/day

Related Terms

4 terms