Skip to content

Commit ee49a8d

Browse files
luciaquirkeclaude
andcommitted
Add modular norm normalization for MAGIC attribution
Normalize weight updates in the modular norm after each optimizer step to improve metasmoothness of the training trajectory. Uses the modula package (pip install modula) to compute per-layer target norms from the recursive modular norm formula. - bergson/magic/modula_norm.py: ModulaNormalizer class that builds a modula GPT architecture matching GPT-2, extracts target norms per atom, and applies spectral norm (Linear) or row norm (Embedding) normalization. Supports differentiable normalization for the MAGIC backward attribution pass. - bergson/magic/trainer.py: Accept optional normalizer, apply after each optimizer step (including traced steps for attribution). - bergson/magic/cli.py: Add use_modula flag to MagicConfig. - examples/modula/: Reproducible experiment comparing baseline vs modula MAGIC attribution on GPT-2 / WikiText-2. Usage: bash examples/modula/run_experiment.sh Requires: pip install modula Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 05b5677 commit ee49a8d

6 files changed

Lines changed: 449 additions & 3 deletions

File tree

bergson/magic/cli.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class MagicConfig(AttributionConfig, TrainingConfig):
6666
per_token: bool = False
6767
"""Whether to compute attribution scores per token (instead of per sequence)."""
6868

69+
use_modula: bool = False
70+
"""Normalize weight updates in the modular norm (modula package).
71+
Improves metasmoothness of the training trajectory for better attribution."""
72+
6973
def __post_init__(self):
7074
assert not self.fsdp, "PyTorch FSDP is not currently supported for MAGIC."
7175

@@ -155,6 +159,13 @@ def prepare_trainer(
155159
)
156160
model.to(f"cuda:{rank}") # type: ignore[reportArgumentType]
157161

162+
normalizer = None
163+
use_modula = getattr(cfg, "use_modula", False)
164+
if use_modula:
165+
from .modula_norm import ModulaNormalizer
166+
167+
normalizer = ModulaNormalizer(model.config, device=f"cuda:{rank}")
168+
158169
if target_modules:
159170
# Only train the PEFT adapter parameters
160171
model.requires_grad_(False)
@@ -200,7 +211,7 @@ def prepare_trainer(
200211
case other:
201212
raise ValueError(f"Unsupported optimizer: {other}")
202213

203-
trainer, fwd_state = Trainer.initialize(model, opt)
214+
trainer, fwd_state = Trainer.initialize(model, opt, normalizer=normalizer)
204215
return trainer, fwd_state, model
205216

206217

bergson/magic/modula_norm.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
"""Modular norm normalization for GPT-2 weights.
2+
3+
Uses the modula package to compute per-layer target norms from the recursive
4+
modular norm formula, then normalizes weights via spectral normalization
5+
(Linear layers) or row normalization (Embedding layers) after each optimizer step.
6+
7+
Reference: https://github.com/jxbz/modula
8+
"""
9+
10+
from dataclasses import dataclass
11+
12+
import torch
13+
from modula.abstract import CompositeModule, TupleModule
14+
from modula.atom import Embedding as ModulaEmbedding
15+
from modula.atom import Linear as ModulaLinear
16+
from modula.compound import GPT as ModulaGPT
17+
18+
19+
def _extract_atoms_and_norms(
20+
module, target_norm: float = 1.0
21+
) -> list[tuple[ModulaLinear | ModulaEmbedding, float]]:
22+
"""Walk the modula architecture tree and extract (atom, target_norm) pairs.
23+
24+
The target norm for each atom is computed recursively from the architecture's
25+
mass and sensitivity values, following modula's normalization formula.
26+
"""
27+
if isinstance(module, (ModulaLinear, ModulaEmbedding)):
28+
return [(module, target_norm)]
29+
30+
if isinstance(module, CompositeModule):
31+
m0, m1 = module.children
32+
if module.mass > 0:
33+
t0 = m0.mass / module.mass * target_norm / m1.sensitivity
34+
t1 = m1.mass / module.mass * target_norm
35+
return _extract_atoms_and_norms(m0, t0) + _extract_atoms_and_norms(m1, t1)
36+
return []
37+
38+
if isinstance(module, TupleModule):
39+
if module.mass > 0:
40+
result: list[tuple[ModulaLinear | ModulaEmbedding, float]] = []
41+
for child in module.children:
42+
t = child.mass / module.mass * target_norm
43+
result.extend(_extract_atoms_and_norms(child, t))
44+
return result
45+
return []
46+
47+
return []
48+
49+
50+
@dataclass
51+
class _EmbeddingEntry:
52+
name: str
53+
target_norm: float
54+
55+
56+
@dataclass
57+
class _LinearEntry:
58+
name: str
59+
target_norm: float
60+
transpose: bool # True for Conv1D params stored as (in, out)
61+
u: torch.Tensor # Power iteration vector, shape (in_features,)
62+
63+
64+
@dataclass
65+
class _CAttnEntry:
66+
"""Entry for the combined c_attn weight that holds Q, K, V projections."""
67+
68+
name: str
69+
target_norms: list[float] # [Q, K, V]
70+
u_vectors: list[torch.Tensor] # Power iteration vectors for Q, K, V
71+
n_embd: int
72+
73+
74+
class ModulaNormalizer:
75+
"""Normalizes GPT-2 weights in the modular norm after each optimizer step.
76+
77+
Builds a modula architecture matching the GPT-2 config, computes the
78+
per-layer target norms from the recursive modular norm formula, and
79+
normalizes weights via spectral normalization (Linear) or row normalization
80+
(Embedding) after each optimizer step.
81+
82+
Supports both in-place (no-grad) normalization for training and
83+
differentiable normalization for the MAGIC backward pass.
84+
"""
85+
86+
def __init__(self, model_config, device: str | torch.device):
87+
n_embd = model_config.n_embd
88+
n_layer = model_config.n_layer
89+
n_head = model_config.n_head
90+
vocab_size = model_config.vocab_size
91+
n_positions = model_config.n_positions
92+
93+
self.n_embd = n_embd
94+
self.n_layer = n_layer
95+
96+
# Build modula GPT architecture and extract target norms
97+
gpt = ModulaGPT(
98+
vocab_size=vocab_size,
99+
context=n_positions,
100+
num_heads=n_head,
101+
d_embed=n_embd,
102+
d_query=n_embd // n_head,
103+
d_value=n_embd // n_head,
104+
num_blocks=n_layer,
105+
)
106+
107+
# Initialize to create power iteration vectors on the atoms
108+
gpt.initialize(device)
109+
110+
atoms_and_norms = _extract_atoms_and_norms(gpt, target_norm=1.0)
111+
expected = 2 + n_layer * 6 + 1
112+
assert (
113+
len(atoms_and_norms) == expected
114+
), f"Expected {expected} atoms, got {len(atoms_and_norms)}"
115+
116+
# Build entries mapping modula atoms to HF parameter names
117+
self._entries: list[_EmbeddingEntry | _LinearEntry | _CAttnEntry] = []
118+
idx = 0
119+
120+
# Token embedding
121+
_, tn = atoms_and_norms[idx]
122+
idx += 1
123+
self._entries.append(_EmbeddingEntry("transformer.wte.weight", tn))
124+
125+
# Position embedding
126+
_, tn = atoms_and_norms[idx]
127+
idx += 1
128+
self._entries.append(_EmbeddingEntry("transformer.wpe.weight", tn))
129+
130+
# Transformer blocks
131+
for i in range(n_layer):
132+
# Q, K, V from combined c_attn
133+
qkv_norms = []
134+
qkv_us = []
135+
for _ in range(3):
136+
atom, tn = atoms_and_norms[idx]
137+
idx += 1
138+
qkv_norms.append(tn)
139+
assert isinstance(atom, ModulaLinear)
140+
qkv_us.append(atom.u.to(device))
141+
self._entries.append(
142+
_CAttnEntry(
143+
f"transformer.h.{i}.attn.c_attn.weight",
144+
qkv_norms,
145+
qkv_us,
146+
n_embd,
147+
)
148+
)
149+
150+
# Attention output projection (Conv1D: stored as (in, out))
151+
atom, tn = atoms_and_norms[idx]
152+
idx += 1
153+
assert isinstance(atom, ModulaLinear)
154+
self._entries.append(
155+
_LinearEntry(
156+
f"transformer.h.{i}.attn.c_proj.weight",
157+
tn,
158+
transpose=True,
159+
u=atom.u.to(device),
160+
)
161+
)
162+
163+
# MLP up projection (Conv1D)
164+
atom, tn = atoms_and_norms[idx]
165+
idx += 1
166+
assert isinstance(atom, ModulaLinear)
167+
self._entries.append(
168+
_LinearEntry(
169+
f"transformer.h.{i}.mlp.c_fc.weight",
170+
tn,
171+
transpose=True,
172+
u=atom.u.to(device),
173+
)
174+
)
175+
176+
# MLP down projection (Conv1D)
177+
atom, tn = atoms_and_norms[idx]
178+
idx += 1
179+
assert isinstance(atom, ModulaLinear)
180+
self._entries.append(
181+
_LinearEntry(
182+
f"transformer.h.{i}.mlp.c_proj.weight",
183+
tn,
184+
transpose=True,
185+
u=atom.u.to(device),
186+
)
187+
)
188+
189+
# lm_head: skip — in GPT-2 it is tied to transformer.wte.weight,
190+
# which is already normalized as an Embedding above. Keeping the
191+
# architecture identical between baseline and modula runs.
192+
idx += 1 # consume the atom but don't create an entry
193+
194+
@torch.no_grad()
195+
def warmup(self, params: dict[str, torch.Tensor], n_steps: int = 10):
196+
"""Run power iteration steps to converge u vectors without scaling weights.
197+
198+
Should be called once on the initial weights before the first normalize().
199+
"""
200+
for entry in self._entries:
201+
if entry.name not in params:
202+
continue
203+
204+
weight = params[entry.name]
205+
206+
if isinstance(entry, _LinearEntry):
207+
wt = weight.t() if entry.transpose else weight
208+
for _ in range(n_steps):
209+
_power_iter_step(wt, entry.u)
210+
211+
elif isinstance(entry, _CAttnEntry):
212+
d = entry.n_embd
213+
for j in range(3):
214+
part = weight[:, j * d : (j + 1) * d]
215+
wt = part.t()
216+
for _ in range(n_steps):
217+
_power_iter_step(wt, entry.u_vectors[j])
218+
219+
def normalize(
220+
self, params: dict[str, torch.Tensor], trace: bool = False
221+
) -> dict[str, torch.Tensor]:
222+
"""Normalize weights in the modular norm.
223+
224+
Args:
225+
params: Dict mapping parameter names to tensors.
226+
trace: If True, returns new dict with differentiable normalized tensors.
227+
If False, normalizes in-place and returns the same dict.
228+
"""
229+
if trace:
230+
return self._normalize_traced(params)
231+
232+
self._normalize_inplace(params)
233+
return params
234+
235+
@torch.no_grad()
236+
def _normalize_inplace(self, params: dict[str, torch.Tensor]):
237+
for entry in self._entries:
238+
if entry.name not in params:
239+
continue
240+
241+
weight = params[entry.name]
242+
243+
if isinstance(entry, _EmbeddingEntry):
244+
norms = weight.norm(dim=1, keepdim=True).clamp_(min=1e-8)
245+
weight.mul_(entry.target_norm / norms)
246+
247+
elif isinstance(entry, _LinearEntry):
248+
wt = weight.t() if entry.transpose else weight
249+
sigma = _power_iter_step(wt, entry.u)
250+
weight.mul_(entry.target_norm / sigma)
251+
252+
elif isinstance(entry, _CAttnEntry):
253+
d = entry.n_embd
254+
for j in range(3):
255+
part = weight[:, j * d : (j + 1) * d]
256+
wt = part.t()
257+
sigma = _power_iter_step(wt, entry.u_vectors[j])
258+
part.mul_(entry.target_norms[j] / sigma)
259+
260+
def _normalize_traced(
261+
self, params: dict[str, torch.Tensor]
262+
) -> dict[str, torch.Tensor]:
263+
new_params = dict(params)
264+
265+
for entry in self._entries:
266+
if entry.name not in params:
267+
continue
268+
269+
weight = params[entry.name]
270+
271+
if isinstance(entry, _EmbeddingEntry):
272+
norms = weight.norm(dim=1, keepdim=True).clamp(min=1e-8)
273+
new_params[entry.name] = weight * (entry.target_norm / norms)
274+
275+
elif isinstance(entry, _LinearEntry):
276+
wt = weight.t() if entry.transpose else weight
277+
sigma = _spectral_norm_diff(wt, entry.u.detach())
278+
new_params[entry.name] = weight * (entry.target_norm / sigma)
279+
280+
elif isinstance(entry, _CAttnEntry):
281+
d = entry.n_embd
282+
parts = []
283+
for j in range(3):
284+
part = weight[:, j * d : (j + 1) * d]
285+
wt = part.t()
286+
sigma = _spectral_norm_diff(wt, entry.u_vectors[j].detach())
287+
parts.append(part * (entry.target_norms[j] / sigma))
288+
new_params[entry.name] = torch.cat(parts, dim=1)
289+
290+
return new_params
291+
292+
293+
@torch.no_grad()
294+
def _power_iter_step(weight: torch.Tensor, u: torch.Tensor) -> float:
295+
"""One step of power iteration. Updates u in-place. Returns approx spectral norm."""
296+
v = torch.mv(weight, u)
297+
v /= v.norm()
298+
torch.mv(weight.t(), v, out=u)
299+
return u.norm().item()
300+
301+
302+
def _spectral_norm_diff(weight: torch.Tensor, u_detached: torch.Tensor) -> torch.Tensor:
303+
"""Differentiable spectral norm using a detached power iteration vector."""
304+
v = torch.mv(weight, u_detached)
305+
v = v / v.norm().clamp(min=1e-8)
306+
u = torch.mv(weight.t(), v)
307+
return u.norm().clamp(min=1e-8)

bergson/magic/trainer.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def initialize(
221221
cls,
222222
model: nn.Module,
223223
optimizer: GradientTransformation,
224+
normalizer=None,
224225
) -> tuple["Trainer", TrainerState]:
225226
"""Convenience method for initializing the trainer and state."""
226227
# Create new tensor objects for the parameters and buffers to ensure that they
@@ -236,10 +237,21 @@ def initialize(
236237
buffers = shallow_copy(dict(model.named_buffers(remove_duplicate=False)))
237238
opt_state = optimizer.init(params)
238239

240+
# Warm up power iteration vectors on the actual weights, then normalize
241+
# so training starts on the modular norm constraint surface
242+
if normalizer is not None:
243+
normalizer.warmup(params)
244+
normalizer.normalize(params, trace=False)
245+
239246
state = TrainerState(params, opt_state, buffers)
240-
return cls(model, optimizer), state
247+
return cls(model, optimizer, normalizer), state
241248

242-
def __init__(self, model: nn.Module, optimizer: GradientTransformation):
249+
def __init__(
250+
self,
251+
model: nn.Module,
252+
optimizer: GradientTransformation,
253+
normalizer=None,
254+
):
243255
# Move only trainable parameters to the meta device, leaving frozen params
244256
# on device so they don't need to be managed by TrainerState.
245257
for mod in model.modules():
@@ -251,6 +263,7 @@ def __init__(self, model: nn.Module, optimizer: GradientTransformation):
251263

252264
self.model = model
253265
self.optimizer = optimizer
266+
self.normalizer = normalizer
254267

255268
def step(
256269
self,
@@ -300,6 +313,10 @@ def step(
300313
grads, state.opt_state, inplace=inplace, params=state.params
301314
)
302315
new_params = torchopt.apply_updates(state.params, updates, inplace=inplace)
316+
317+
if self.normalizer is not None:
318+
new_params = self.normalizer.normalize(new_params, trace=trace)
319+
303320
state = TrainerState(
304321
new_params,
305322
new_state,

0 commit comments

Comments
 (0)