diff --git a/demos/Realtime_Training_Telemetry_Demo.ipynb b/demos/Realtime_Training_Telemetry_Demo.ipynb new file mode 100644 index 000000000..d77d67274 --- /dev/null +++ b/demos/Realtime_Training_Telemetry_Demo.ipynb @@ -0,0 +1,464 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Real-time Training Telemetry Demo\n", + "\n", + "This notebook demonstrates the real-time extraction and visualization of mechanistic metrics during a model's training loop using TransformerLens.\n", + "\n", + "By leveraging dynamic dictionary logging alongside the `ActivationCache`, we can isolate the training window where localized phase transitions\u2014such as the formation of induction heads\u2014begin to emerge.\n", + "\n", + "**A note on scaling**: While this 2-layer toy model allows for high-granularity tracking with minimal computational overhead, achieving similar resolution in larger architectures is non-trivial. It requires highly targeted caching and direct manipulation of the telemetry bridge to surface this level of detail without memory exhaustion.\n", + "\n", + "**Compute requirements**: A standard CPU is entirely sufficient for this demonstration. The 500-step training loop will execute rapidly in standard local or cloud-based notebook environments without the need for hardware acceleration." + ], + "metadata": { + "id": "ctqPlNKhJFgZ" + } + }, + { + "cell_type": "markdown", + "source": [ + "Initializes the workspace, configures Plotly renderers, and defines the toy model architecture.\n", + "\n", + "**Visualization Context** `Plotly Renderers`:\n", + "\n", + "`Plotly` generates interactive, JavaScript-based visualizations. Google Colab handles these DOM interactions differently than local Jupyter or VS Code environments. We detect the active environment to set the appropriate `plotly.io` renderer (`\"colab\"` vs `\"notebook_connected\"`), ensuring the dynamic telemetry plots render correctly without blank output blocks.\n", + "\n", + "**Architectural Rationale:**\n", + "\n", + "\n", + "* **2 Layers (`n_layers=2`):** The theoretical minimum depth required for induction circuits. Layer 0 creates \"previous token\" representations, and Layer 1 queries these to predict the next token based on earlier context.\n", + "\n", + "* **2 Heads (`n_heads=2`):** Provides just enough capacity for heads to specialize (e.g., dedicating one head to induction) without creating excessive noise in the telemetry visualizations.\n", + "\n", + "* **GELU Activation (`act_fn=\"gelu\"`):** Selected over ReLU to mirror the smooth non-linearities of modern production LLMs, ensuring the activation dynamics remain representative of real-world architectures.\n", + "\n", + "* **Miniaturized Dimensions:** `d_model=64`, `d_vocab=64`, and `n_ctx=32` are intentionally bottlenecked to force rapid convergence, reliably inducing the phase transition within a brief 500-step training window." + ], + "metadata": { + "id": "tQ_xnziNJlsF" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "# Detect execution environment\n", + "try:\n", + " import google.colab # noqa: F401\n", + " IN_COLAB = True\n", + " print(\"Environment: Google Colab\")\n", + "except ImportError:\n", + " IN_COLAB = False\n", + " print(\"Environment: Local / Standard Jupyter\")\n", + "\n", + "# Environment-specific dependency management\n", + "if IN_COLAB:\n", + " %pip install -q transformer_lens\n", + " %pip install -q circuitsvis\n", + "\n", + "import plotly.io as pio\n", + "import plotly.graph_objects as go\n", + "from plotly.subplots import make_subplots\n", + "\n", + "# Configure Plotly renderer for correct JavaScript execution\n", + "if IN_COLAB:\n", + " pio.renderers.default = \"colab\"\n", + "else:\n", + " pio.renderers.default = \"notebook_connected\"\n", + "print(f\"Plotly Renderer: {pio.renderers.default}\")\n", + "\n", + "# Must be imported after Colab pip install\n", + "from transformer_lens import HookedTransformer, HookedTransformerConfig # noqa: E402\n", + "\n", + "# Configuration\n", + "torch.manual_seed(42)\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"\ud83d\ude80 Running on {device}\")\n", + "\n", + "cfg = HookedTransformerConfig(\n", + " n_layers=2,\n", + " d_model=64,\n", + " d_head=32,\n", + " n_heads=2,\n", + " d_mlp=256,\n", + " d_vocab=64,\n", + " n_ctx=32,\n", + " act_fn=\"gelu\",\n", + " normalization_type=\"LN\",\n", + " seed=42,\n", + ")\n", + "model = HookedTransformer(cfg).to(device)" + ], + "metadata": { + "id": "GyA-sxLOLRVA", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "1cc176e2-7638-4106-b063-edce512c6f6d" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Environment: Google Colab\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m968.6/968.6 kB\u001b[0m \u001b[31m23.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m56.4/56.4 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for transformers-stream-generator (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hPlotly Renderer: colab\n", + "\ud83d\ude80 Running on cpu\n", + "Moving model to device: cpu\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Attention Telemetry Extraction\n", + "\n", + "This bridge extracts mechanistic metrics directly from the `ActivationCache` during the forward pass.\n", + "\n", + "* **Head Coherence:** Measures attention focus sharpness via normalized entropy. A score of $1.0$ indicates perfect focus on a single token, while $0.0$ indicates uniformly distributed attention.\n", + "\n", + "* **Head Agreement:** Evaluates intra-layer behavioral similarity among attention heads.\n", + "* **Variance Normalization:** The agreement metric normalizes against a `0.005` variance constant. This serves as an expected baseline for inter-head variance at the point of specialization in 2-layer models.\n", + "\n", + "**Note:** This constant is a localized architectural assumption. Recalibrating this threshold will likely be necessary when porting the telemetry bridge to larger, higher-dimensional models." + ], + "metadata": { + "id": "0hMNO0WpKXhR" + } + }, + { + "cell_type": "code", + "source": [ + "class AttentionTelemetry:\n", + " \"\"\"Lightweight bridge extracting mechanistic metrics from ActivationCache.\"\"\"\n", + "\n", + " @staticmethod\n", + " def compute_metrics(cache, layer_idx):\n", + " \"\"\"Computes attention coherence and agreement for a given layer.\n", + "\n", + " Args:\n", + " cache (ActivationCache): The cached activations from the forward pass.\n", + " layer_idx (int): The index of the layer to analyze.\n", + "\n", + " Returns:\n", + " dict: A dictionary containing layer_idx, head_coherence, and head_agreement.\n", + "\n", + " Notes on v_max (0.005):\n", + " The agreement normalization constant is derived from the expected inter-head\n", + " attention variance at convergence in 2-layer induction head toy models.\n", + " At convergence, heads specialize (low variance); pre-convergence variance\n", + " peaks near 0.005. This value is task- and architecture-specific; adjust\n", + " if adapting to larger models or different tasks.\n", + " \"\"\"\n", + "\n", + " pattern_name = f\"blocks.{layer_idx}.attn.hook_pattern\"\n", + "\n", + " # Shape: [batch, heads, seq, seq]\n", + " attn_probs = cache[pattern_name]\n", + "\n", + " # 1. Head Coherence: 1.0 - (Entropy / Max_Entropy)\n", + " probs = attn_probs + 1e-9\n", + " entropy = -torch.sum(probs * torch.log(probs), dim=-1) # [batch, heads, seq]\n", + " head_coherence = 1.0 - (entropy.mean(dim=[0, 2]) / np.log(attn_probs.shape[-1]))\n", + "\n", + " # 2. Head Agreement: 1.0 - clip(Variance / v_max)\n", + " mean_var = torch.var(attn_probs, dim=1).mean() # Variance across heads\n", + " head_agreement = 1.0 - torch.clamp(mean_var / 0.005, 0.0, 1.0)\n", + "\n", + " return {\n", + " \"layer_idx\": layer_idx,\n", + " \"head_coherence\": head_coherence.mean().item(),\n", + " \"head_agreement\": head_agreement.item()\n", + " }" + ], + "metadata": { + "id": "fd_0TLTUUNf6" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Synthetic Data Generation & Training Loop\n", + "\n", + "This cell generates a highly constrained dataset to force circuit formation and executes the training loop, capturing telemetry at specified intervals.\n", + "\n", + "**A Note on Synthetic Data:**\n", + "\n", + "The repeated sequence generator (`[A, B, C, ..., A, B, C]`) used below is strictly illustrative. It is engineered specifically as a shortcut to force the rapid emergence of in-context look-back circuits (induction heads). It is not meant to serve as an educational standard for model training.\n", + "\n", + "**Transitioning to Real Data:**\n", + "\n", + "Applying this telemetry extraction to real-world datasets requires rigorous attention to detail regarding:\n", + "\n", + "* **Data Quality:** Unstructured noise in the input distribution will severely obscure the mechanistic signals (like coherence and agreement) you are attempting to isolate.\n", + "\n", + "* **Data Type & Tokenization:** Real-world text requires careful handling of padding, EOS/BOS tokens, and sequence packing, all of which dynamically alter attention patterns and can skew your baseline metrics.\n", + "\n", + "* **Ingestion Methodology:** Managing data loaders, batching, and ensuring that telemetry logging steps align with representative samples is critical to preventing metric distortion.\n", + "\n", + "**Performance Optimization:** To prevent the telemetry capture from bottlenecking the training process, `model.run_with_cache` is exclusively executed at logging intervals. Standard forward passes bypass the cache entirely." + ], + "metadata": { + "id": "p2TYMU23Kt2K" + } + }, + { + "cell_type": "code", + "source": [ + "def generate_induction_data(batch_size, seq_len, vocab_size, device=\"cpu\"):\n", + " half_len = seq_len // 2\n", + " first_half = torch.randint(0, vocab_size, (batch_size, half_len))\n", + " data = torch.cat([first_half, first_half], dim=1)\n", + " return data.to(device)" + ], + "metadata": { + "id": "t6NIpzbKK0_K" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# The Impact of Tokenization on Telemetry\n", + "\n", + "When moving from synthetic generators to real-world text, understanding the mechanistic role of special tokens is vital to preventing metric distortion.\n", + "\n", + "**BOS (Begin of Sequence) as an Attention Sink:**\n", + "\n", + "In many autoregressive models, attention heads route their focus to the first token (BOS) when they do not have a strong contextual match elsewhere. This is known as an \"attention sink.\" If unaccounted for, your telemetry will show artificially high **Head Coherence** (because the head is sharply focused on a single token). However, this represents a resting state rather than active circuit engagement.\n", + "\n", + "**EOS (End of Sequence) & Padding:**\n", + "\n", + "Real data requires batching sequences of variable lengths, necessitating padding tokens and EOS markers to denote where a document ends.\n", + "\n", + "* **Masking Failures:** If attention masks are not perfectly aligned with your telemetry extraction, heads might attend to padding tokens, introducing garbage data that drastically skews your **Agreement** metrics.\n", + "\n", + "* **Context Resets:** The transition between unrelated documents (separated by an EOS token) disrupts the contiguous context window. This resets the look-back mechanisms that induction circuits rely on, causing momentary drops in otherwise stable telemetry" + ], + "metadata": { + "id": "M8sQ0pCuLZnB" + } + }, + { + "cell_type": "markdown", + "source": [ + "# The Training & Real-Time Telemetry Loop\n", + "\n", + "Executing the training sequence while dynamically tracking phase transitions.\n", + "\n", + "**Compute Efficiency (Selective Caching):**\n", + "\n", + "To prevent the extraction process from suffocating the CPU/GPU memory bandwidth, we employ selective caching. Standard forward passes operate normally; `model.run_with_cache` is exclusively invoked at defined logging intervals to extract the telemetry state without severely bottlenecking the training step.\n", + "\n", + "**Rendering Optimization (The Real-Time UI):**\n", + "\n", + "To achieve real-time visualization without crashing the browser's DOM or throttling the PyTorch loop:\n", + "\n", + "1. **Memory Pre-allocation:** We pre-allocate `NaN` arrays for the telemetry traces, completely bypassing costly array reallocation during the loop.\n", + "\n", + "2. **In-Place Mutation:** Instead of generating hundreds of static Plotly objects, we mutate the figure's trace data directly and use `IPython.display.clear_output` to cleanly redraw the frame in the exact same output block.\n", + "\n", + "3. **Static Fallback:** *If you wish to bypass real-time rendering to maximize training speed, simply comment out the `clear_output(wait=True)` and `fig.show()` lines inside the loop, and call `fig.show()` once at the very end of the cell.\n", + "\n", + "**Mechanistic Observation (The Phase Transition):**\n", + "\n", + "Watch the dual-plot for the localized phase transition: a distinct window where the model suddenly \"discovers\" the induction algorithm. This is marked by a violent crash in the loss curve and a simultaneous, sharp spike in the last layer's Attention Coherence." + ], + "metadata": { + "id": "TGucpV41Ltxc" + } + }, + { + "cell_type": "code", + "source": [ + "from IPython.display import clear_output\n", + "import numpy as np # noqa: F811\n", + "import torch\n", + "\n", + "# --- 1. Self-Contained Synthetic Data Generator ---\n", + "def get_batch(batch_size=16, seq_len=model.cfg.n_ctx, vocab_size=model.cfg.d_vocab):\n", + " \"\"\"Generates repeated sequences [A, B, C, A, B, C] to force induction circuitry.\"\"\"\n", + " half_len = seq_len // 2\n", + " random_tokens = torch.randint(0, vocab_size, (batch_size, half_len), device=device)\n", + " return torch.cat([random_tokens, random_tokens], dim=1)\n", + "\n", + "# --- 2. Pre-allocate Memory for Real-Time Plotting ---\n", + "TOTAL_STEPS = 500\n", + "LOG_INTERVAL = 10\n", + "num_logging_steps = TOTAL_STEPS // LOG_INTERVAL\n", + "\n", + "logged_steps = np.arange(0, TOTAL_STEPS, LOG_INTERVAL)\n", + "\n", + "loss_data = np.full(num_logging_steps, np.nan)\n", + "coherence_data = np.full(num_logging_steps, np.nan)\n", + "heatmap_data = np.full((model.cfg.n_layers, num_logging_steps), np.nan)\n", + "\n", + "# --- 3. Initialize the Plotly Figure ---\n", + "fig = make_subplots(\n", + " rows=2, cols=1,\n", + " shared_xaxes=True,\n", + " vertical_spacing=0.1,\n", + " subplot_titles=(\n", + " \"Phase Transition: Loss Crash vs. Circuit Formation\",\n", + " \"Attention Coherence Heatmap by Layer Depth\"\n", + " ),\n", + " specs=[[{\"secondary_y\": True}], [{\"type\": \"heatmap\"}]]\n", + ")\n", + "\n", + "# Trace 0: Loss Curve\n", + "fig.add_trace(\n", + " go.Scatter(x=logged_steps, y=loss_data, name=\"Loss (CE)\", line=dict(color='gray', dash='dash')),\n", + " row=1, col=1, secondary_y=False\n", + ")\n", + "\n", + "# Trace 1: Last Layer Coherence Curve\n", + "last_layer_idx = model.cfg.n_layers - 1\n", + "fig.add_trace(\n", + " go.Scatter(x=logged_steps, y=coherence_data, name=f\"Layer {last_layer_idx} Coherence\", line=dict(color='#1f77b4', width=2.5)),\n", + " row=1, col=1, secondary_y=True\n", + ")\n", + "\n", + "# Trace 2: Layer Heatmap\n", + "fig.add_trace(\n", + " go.Heatmap(\n", + " z=heatmap_data, x=logged_steps, y=[f\"L{i}\" for i in range(model.cfg.n_layers)],\n", + " colorscale='Magma', zmin=0.0, zmax=1.0,\n", + " colorbar=dict(title=\"Coherence (0-1)\", orientation='h', y=-0.25, len=0.5)\n", + " ),\n", + " row=2, col=1\n", + ")\n", + "\n", + "fig.update_layout(height=700, template=\"plotly_white\", margin=dict(t=50, b=50))\n", + "fig.update_yaxes(title_text=\"Cross Entropy Loss\", secondary_y=False, row=1, col=1)\n", + "fig.update_yaxes(title_text=\"Coherence\", secondary_y=True, range=[0, 1.1], row=1, col=1)\n", + "fig.update_xaxes(range=[0, TOTAL_STEPS])\n", + "# --- 4. The Training & Telemetry Loop ---\n", + "model.train()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n", + "\n", + "log_idx = 0\n", + "for step in range(TOTAL_STEPS):\n", + " batch = get_batch()\n", + "\n", + " # Selective Caching Strategy\n", + " if step % LOG_INTERVAL == 0:\n", + " loss, cache = model.run_with_cache(batch, return_type=\"loss\")\n", + "\n", + " # Update baseline data arrays\n", + " loss_data[log_idx] = loss.item()\n", + "\n", + " # Extract mechanistic metrics using the static method from Cell 3\n", + " for layer in range(model.cfg.n_layers):\n", + " layer_metrics = AttentionTelemetry.compute_metrics(cache, layer)\n", + " heatmap_data[layer, log_idx] = layer_metrics['head_coherence']\n", + "\n", + " # Specifically grab the last layer for the line graph\n", + " if layer == last_layer_idx:\n", + " coherence_data[log_idx] = layer_metrics['head_coherence']\n", + "\n", + " # Mutate Plotly traces in-place\n", + " fig.data[0].x = logged_steps\n", + " fig.data[0].y = loss_data\n", + "\n", + " fig.data[1].x = logged_steps\n", + " fig.data[1].y = coherence_data\n", + "\n", + " fig.data[2].x = logged_steps\n", + " fig.data[2].z = heatmap_data\n", + "\n", + " # Redraw the UI\n", + " clear_output(wait=True)\n", + " fig.show()\n", + "\n", + " log_idx += 1\n", + " else:\n", + " # Standard forward pass (bypassing the cache for speed)\n", + " loss = model(batch, return_type=\"loss\")\n", + "\n", + " # Standard PyTorch Optimization\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 717 + }, + "outputId": "c1d1d8a9-a3e8-4bec-8336-a0829f7c0872", + "id": "Yj90yewlLvuh" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {} + } + ] + } + ] +} \ No newline at end of file