This notebook explores attention mechanisms in detail.
The Problem: How can tokens learn relationships with other tokens?
The Solution: Attention! Let each token "look at" every other token.
Token "cat" learns:
- "is" is relevant (object)
- "a" is relevant (article)
- "," is less relevant (punctuation)
- Sentence start "." is irrelevant
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Step 1: Compatibility (Q * K^T)
- Does Query match this Key?
- Higher score = better match
Step 2: Scaling (/ sqrt(d_k))
- Prevent values from getting too extreme
- Stabilizes softmax gradients
Step 3: Softmax
- Convert scores to probabilities (0-1)
- All probabilities sum to 1
Step 4: Weighting (multiply by V)
- Get actual values using attention weights
- High attention weight → high influence
Text: "I like cats"
Query for "cats": [0.2, -0.1, 0.5]
Keys: "I" ← [0.1, 0.2, 0.3]
"like" ← [0.3, 0.4, 0.2]
"cats" ← [0.2, -0.1, 0.5]
Scores = Q * [K_I, K_like, K_cats]^T
= [0.03, 0.20, 0.50] (unscaled)
= [0.10, 0.68, 0.95] (after softmax)
Weighted sum = 0.10 * V_I + 0.68 * V_like + 0.95 * V_cats
= "like' information mostly preserved"
Different heads learn different patterns:
Head 1: Focuses on direct object
"cats" ← [100% to "like"]
Head 2: Focuses on subject
"cats" ← [100% to "I"]
Head 3: Focuses on syntax
"cats" ← [50% to "noun", 30% to article]
Input: [64-dim embedding] × 3 tokens
Split into 8 heads:
Head 1-8: [8-dim embedding] × 3 tokens each
Attention on each head:
Head 1 output: [8-dim] × 3
Head 2 output: [8-dim] × 3
...
Head 8 output: [8-dim] × 3
Concatenate:
[64-dim] × 3 (back to original!)
Each position now has info from all heads.
Without masking:
Training: Predicting "Y" can look at "Z" (the answer!)
Testing: Generating "Y", but "Z" doesn't exist yet!
Result: Mismatch → bad inference
Mask out future tokens:
Position 0: can attend to [0] (only itself)
Position 1: can attend to [0, 1] (past + self)
Position 2: can attend to [0, 1, 2] (past + self)
Position 3: can attend to [0, 1, 2, 3] (past + self)
Attention matrix before softmax:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
(1 = can attend, 0 = cannot attend)
After masking future positions:
[[exp(s00), -∞, -∞, -∞ ],
[exp(s10), exp(s11), -∞, -∞ ],
[exp(s20), exp(s21), exp(s22), -∞ ],
[exp(s30), exp(s31), exp(s32), exp(s33)]]
Softmax ignores -∞ values!
Result: Can't attend to future.
const ScaledDotProductAttention = require("../03_attention/scaled_dot_product");
const MultiHeadAttention = require("../03_attention/multi_head_attention");
const CausalMask = require("../03_attention/causal_mask");
// Single head
const attention = new ScaledDotProductAttention((d_k = 64));
const result = attention.forward(Q, K, V);
console.log(result.attention_weights);
// Multi-head (8 heads, 64 dimensions)
const mha = new MultiHeadAttention(64, 8);
const output = mha.forward(Q, K, V);
// Causal mask
const mask = CausalMask.create_mask(4);
CausalMask.visualize_mask(4);
CausalMask.print_heatmap(attention_weights, tokens);- Self-Attention: Q, K, V from same sequence → tokens learn relationships
- Multi-Head: Different heads = different relationship types
- Causal Mask: Prevents looking at future (critical for generation!)
- Learned Projection: W_Q, W_K, W_V are trainable
- Soft Attention: Uses probabilities (softmax), not hard selection
Attention is the cornerstone of modern LLMs! 🧠