|
| 1 | +"""Modular norm normalization for GPT-2 weights. |
| 2 | +
|
| 3 | +Uses the modula package to compute per-layer target norms from the recursive |
| 4 | +modular norm formula, then normalizes weights via spectral normalization |
| 5 | +(Linear layers) or row normalization (Embedding layers) after each optimizer step. |
| 6 | +
|
| 7 | +Reference: https://github.com/jxbz/modula |
| 8 | +""" |
| 9 | + |
| 10 | +from dataclasses import dataclass |
| 11 | + |
| 12 | +import torch |
| 13 | +from modula.abstract import CompositeModule, TupleModule |
| 14 | +from modula.atom import Embedding as ModulaEmbedding |
| 15 | +from modula.atom import Linear as ModulaLinear |
| 16 | +from modula.compound import GPT as ModulaGPT |
| 17 | + |
| 18 | + |
| 19 | +def _extract_atoms_and_norms( |
| 20 | + module, target_norm: float = 1.0 |
| 21 | +) -> list[tuple[ModulaLinear | ModulaEmbedding, float]]: |
| 22 | + """Walk the modula architecture tree and extract (atom, target_norm) pairs. |
| 23 | +
|
| 24 | + The target norm for each atom is computed recursively from the architecture's |
| 25 | + mass and sensitivity values, following modula's normalization formula. |
| 26 | + """ |
| 27 | + if isinstance(module, (ModulaLinear, ModulaEmbedding)): |
| 28 | + return [(module, target_norm)] |
| 29 | + |
| 30 | + if isinstance(module, CompositeModule): |
| 31 | + m0, m1 = module.children |
| 32 | + if module.mass > 0: |
| 33 | + t0 = m0.mass / module.mass * target_norm / m1.sensitivity |
| 34 | + t1 = m1.mass / module.mass * target_norm |
| 35 | + return _extract_atoms_and_norms(m0, t0) + _extract_atoms_and_norms(m1, t1) |
| 36 | + return [] |
| 37 | + |
| 38 | + if isinstance(module, TupleModule): |
| 39 | + if module.mass > 0: |
| 40 | + result: list[tuple[ModulaLinear | ModulaEmbedding, float]] = [] |
| 41 | + for child in module.children: |
| 42 | + t = child.mass / module.mass * target_norm |
| 43 | + result.extend(_extract_atoms_and_norms(child, t)) |
| 44 | + return result |
| 45 | + return [] |
| 46 | + |
| 47 | + return [] |
| 48 | + |
| 49 | + |
| 50 | +@dataclass |
| 51 | +class _EmbeddingEntry: |
| 52 | + name: str |
| 53 | + target_norm: float |
| 54 | + |
| 55 | + |
| 56 | +@dataclass |
| 57 | +class _LinearEntry: |
| 58 | + name: str |
| 59 | + target_norm: float |
| 60 | + transpose: bool # True for Conv1D params stored as (in, out) |
| 61 | + u: torch.Tensor # Power iteration vector, shape (in_features,) |
| 62 | + |
| 63 | + |
| 64 | +@dataclass |
| 65 | +class _CAttnEntry: |
| 66 | + """Entry for the combined c_attn weight that holds Q, K, V projections.""" |
| 67 | + |
| 68 | + name: str |
| 69 | + target_norms: list[float] # [Q, K, V] |
| 70 | + u_vectors: list[torch.Tensor] # Power iteration vectors for Q, K, V |
| 71 | + n_embd: int |
| 72 | + |
| 73 | + |
| 74 | +class ModulaNormalizer: |
| 75 | + """Normalizes GPT-2 weights in the modular norm after each optimizer step. |
| 76 | +
|
| 77 | + Builds a modula architecture matching the GPT-2 config, computes the |
| 78 | + per-layer target norms from the recursive modular norm formula, and |
| 79 | + normalizes weights via spectral normalization (Linear) or row normalization |
| 80 | + (Embedding) after each optimizer step. |
| 81 | +
|
| 82 | + Supports both in-place (no-grad) normalization for training and |
| 83 | + differentiable normalization for the MAGIC backward pass. |
| 84 | + """ |
| 85 | + |
| 86 | + def __init__(self, model_config, device: str | torch.device): |
| 87 | + n_embd = model_config.n_embd |
| 88 | + n_layer = model_config.n_layer |
| 89 | + n_head = model_config.n_head |
| 90 | + vocab_size = model_config.vocab_size |
| 91 | + n_positions = model_config.n_positions |
| 92 | + |
| 93 | + self.n_embd = n_embd |
| 94 | + self.n_layer = n_layer |
| 95 | + |
| 96 | + # Build modula GPT architecture and extract target norms |
| 97 | + gpt = ModulaGPT( |
| 98 | + vocab_size=vocab_size, |
| 99 | + context=n_positions, |
| 100 | + num_heads=n_head, |
| 101 | + d_embed=n_embd, |
| 102 | + d_query=n_embd // n_head, |
| 103 | + d_value=n_embd // n_head, |
| 104 | + num_blocks=n_layer, |
| 105 | + ) |
| 106 | + |
| 107 | + # Initialize to create power iteration vectors on the atoms |
| 108 | + gpt.initialize(device) |
| 109 | + |
| 110 | + atoms_and_norms = _extract_atoms_and_norms(gpt, target_norm=1.0) |
| 111 | + expected = 2 + n_layer * 6 + 1 |
| 112 | + assert ( |
| 113 | + len(atoms_and_norms) == expected |
| 114 | + ), f"Expected {expected} atoms, got {len(atoms_and_norms)}" |
| 115 | + |
| 116 | + # Build entries mapping modula atoms to HF parameter names |
| 117 | + self._entries: list[_EmbeddingEntry | _LinearEntry | _CAttnEntry] = [] |
| 118 | + idx = 0 |
| 119 | + |
| 120 | + # Token embedding |
| 121 | + _, tn = atoms_and_norms[idx] |
| 122 | + idx += 1 |
| 123 | + self._entries.append(_EmbeddingEntry("transformer.wte.weight", tn)) |
| 124 | + |
| 125 | + # Position embedding |
| 126 | + _, tn = atoms_and_norms[idx] |
| 127 | + idx += 1 |
| 128 | + self._entries.append(_EmbeddingEntry("transformer.wpe.weight", tn)) |
| 129 | + |
| 130 | + # Transformer blocks |
| 131 | + for i in range(n_layer): |
| 132 | + # Q, K, V from combined c_attn |
| 133 | + qkv_norms = [] |
| 134 | + qkv_us = [] |
| 135 | + for _ in range(3): |
| 136 | + atom, tn = atoms_and_norms[idx] |
| 137 | + idx += 1 |
| 138 | + qkv_norms.append(tn) |
| 139 | + assert isinstance(atom, ModulaLinear) |
| 140 | + qkv_us.append(atom.u.to(device)) |
| 141 | + self._entries.append( |
| 142 | + _CAttnEntry( |
| 143 | + f"transformer.h.{i}.attn.c_attn.weight", |
| 144 | + qkv_norms, |
| 145 | + qkv_us, |
| 146 | + n_embd, |
| 147 | + ) |
| 148 | + ) |
| 149 | + |
| 150 | + # Attention output projection (Conv1D: stored as (in, out)) |
| 151 | + atom, tn = atoms_and_norms[idx] |
| 152 | + idx += 1 |
| 153 | + assert isinstance(atom, ModulaLinear) |
| 154 | + self._entries.append( |
| 155 | + _LinearEntry( |
| 156 | + f"transformer.h.{i}.attn.c_proj.weight", |
| 157 | + tn, |
| 158 | + transpose=True, |
| 159 | + u=atom.u.to(device), |
| 160 | + ) |
| 161 | + ) |
| 162 | + |
| 163 | + # MLP up projection (Conv1D) |
| 164 | + atom, tn = atoms_and_norms[idx] |
| 165 | + idx += 1 |
| 166 | + assert isinstance(atom, ModulaLinear) |
| 167 | + self._entries.append( |
| 168 | + _LinearEntry( |
| 169 | + f"transformer.h.{i}.mlp.c_fc.weight", |
| 170 | + tn, |
| 171 | + transpose=True, |
| 172 | + u=atom.u.to(device), |
| 173 | + ) |
| 174 | + ) |
| 175 | + |
| 176 | + # MLP down projection (Conv1D) |
| 177 | + atom, tn = atoms_and_norms[idx] |
| 178 | + idx += 1 |
| 179 | + assert isinstance(atom, ModulaLinear) |
| 180 | + self._entries.append( |
| 181 | + _LinearEntry( |
| 182 | + f"transformer.h.{i}.mlp.c_proj.weight", |
| 183 | + tn, |
| 184 | + transpose=True, |
| 185 | + u=atom.u.to(device), |
| 186 | + ) |
| 187 | + ) |
| 188 | + |
| 189 | + # lm_head: skip — in GPT-2 it is tied to transformer.wte.weight, |
| 190 | + # which is already normalized as an Embedding above. Keeping the |
| 191 | + # architecture identical between baseline and modula runs. |
| 192 | + idx += 1 # consume the atom but don't create an entry |
| 193 | + |
| 194 | + @torch.no_grad() |
| 195 | + def warmup(self, params: dict[str, torch.Tensor], n_steps: int = 10): |
| 196 | + """Run power iteration steps to converge u vectors without scaling weights. |
| 197 | +
|
| 198 | + Should be called once on the initial weights before the first normalize(). |
| 199 | + """ |
| 200 | + for entry in self._entries: |
| 201 | + if entry.name not in params: |
| 202 | + continue |
| 203 | + |
| 204 | + weight = params[entry.name] |
| 205 | + |
| 206 | + if isinstance(entry, _LinearEntry): |
| 207 | + wt = weight.t() if entry.transpose else weight |
| 208 | + for _ in range(n_steps): |
| 209 | + _power_iter_step(wt, entry.u) |
| 210 | + |
| 211 | + elif isinstance(entry, _CAttnEntry): |
| 212 | + d = entry.n_embd |
| 213 | + for j in range(3): |
| 214 | + part = weight[:, j * d : (j + 1) * d] |
| 215 | + wt = part.t() |
| 216 | + for _ in range(n_steps): |
| 217 | + _power_iter_step(wt, entry.u_vectors[j]) |
| 218 | + |
| 219 | + def normalize( |
| 220 | + self, params: dict[str, torch.Tensor], trace: bool = False |
| 221 | + ) -> dict[str, torch.Tensor]: |
| 222 | + """Normalize weights in the modular norm. |
| 223 | +
|
| 224 | + Args: |
| 225 | + params: Dict mapping parameter names to tensors. |
| 226 | + trace: If True, returns new dict with differentiable normalized tensors. |
| 227 | + If False, normalizes in-place and returns the same dict. |
| 228 | + """ |
| 229 | + if trace: |
| 230 | + return self._normalize_traced(params) |
| 231 | + |
| 232 | + self._normalize_inplace(params) |
| 233 | + return params |
| 234 | + |
| 235 | + @torch.no_grad() |
| 236 | + def _normalize_inplace(self, params: dict[str, torch.Tensor]): |
| 237 | + for entry in self._entries: |
| 238 | + if entry.name not in params: |
| 239 | + continue |
| 240 | + |
| 241 | + weight = params[entry.name] |
| 242 | + |
| 243 | + if isinstance(entry, _EmbeddingEntry): |
| 244 | + norms = weight.norm(dim=1, keepdim=True).clamp_(min=1e-8) |
| 245 | + weight.mul_(entry.target_norm / norms) |
| 246 | + |
| 247 | + elif isinstance(entry, _LinearEntry): |
| 248 | + wt = weight.t() if entry.transpose else weight |
| 249 | + sigma = _power_iter_step(wt, entry.u) |
| 250 | + weight.mul_(entry.target_norm / sigma) |
| 251 | + |
| 252 | + elif isinstance(entry, _CAttnEntry): |
| 253 | + d = entry.n_embd |
| 254 | + for j in range(3): |
| 255 | + part = weight[:, j * d : (j + 1) * d] |
| 256 | + wt = part.t() |
| 257 | + sigma = _power_iter_step(wt, entry.u_vectors[j]) |
| 258 | + part.mul_(entry.target_norms[j] / sigma) |
| 259 | + |
| 260 | + def _normalize_traced( |
| 261 | + self, params: dict[str, torch.Tensor] |
| 262 | + ) -> dict[str, torch.Tensor]: |
| 263 | + new_params = dict(params) |
| 264 | + |
| 265 | + for entry in self._entries: |
| 266 | + if entry.name not in params: |
| 267 | + continue |
| 268 | + |
| 269 | + weight = params[entry.name] |
| 270 | + |
| 271 | + if isinstance(entry, _EmbeddingEntry): |
| 272 | + norms = weight.norm(dim=1, keepdim=True).clamp(min=1e-8) |
| 273 | + new_params[entry.name] = weight * (entry.target_norm / norms) |
| 274 | + |
| 275 | + elif isinstance(entry, _LinearEntry): |
| 276 | + wt = weight.t() if entry.transpose else weight |
| 277 | + sigma = _spectral_norm_diff(wt, entry.u.detach()) |
| 278 | + new_params[entry.name] = weight * (entry.target_norm / sigma) |
| 279 | + |
| 280 | + elif isinstance(entry, _CAttnEntry): |
| 281 | + d = entry.n_embd |
| 282 | + parts = [] |
| 283 | + for j in range(3): |
| 284 | + part = weight[:, j * d : (j + 1) * d] |
| 285 | + wt = part.t() |
| 286 | + sigma = _spectral_norm_diff(wt, entry.u_vectors[j].detach()) |
| 287 | + parts.append(part * (entry.target_norms[j] / sigma)) |
| 288 | + new_params[entry.name] = torch.cat(parts, dim=1) |
| 289 | + |
| 290 | + return new_params |
| 291 | + |
| 292 | + |
| 293 | +@torch.no_grad() |
| 294 | +def _power_iter_step(weight: torch.Tensor, u: torch.Tensor) -> float: |
| 295 | + """One step of power iteration. Updates u in-place. Returns approx spectral norm.""" |
| 296 | + v = torch.mv(weight, u) |
| 297 | + v /= v.norm() |
| 298 | + torch.mv(weight.t(), v, out=u) |
| 299 | + return u.norm().item() |
| 300 | + |
| 301 | + |
| 302 | +def _spectral_norm_diff(weight: torch.Tensor, u_detached: torch.Tensor) -> torch.Tensor: |
| 303 | + """Differentiable spectral norm using a detached power iteration vector.""" |
| 304 | + v = torch.mv(weight, u_detached) |
| 305 | + v = v / v.norm().clamp(min=1e-8) |
| 306 | + u = torch.mv(weight.t(), v) |
| 307 | + return u.norm().clamp(min=1e-8) |
0 commit comments