You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
gh#9: diffusion_pair_source flag + freeze_trunk helper
Implements the bottlenecked-conditioning experiment from issue #9:
swap the diffusion module's pair input from the trunk's final pair
representation z (B, N_tok, N_tok, 128) to the distogram-head logits
(B, N_tok, N_tok, 64). Freeze the trunk; train only the diffusion
module from this lower-rank signal.
HelicoConfig:
- New diffusion_pair_source: "z" (default, legacy) | "distogram_logits".
DiffusionConditioning (the only place the swap is needed — by the
time z reaches the AtomAttentionEncoder it's already z_cond, the
post-conditioning 128-d tensor):
- Parallel pair_norm_dist + pair_proj_dist sized for n_distogram_bins
+ d_pair input. Always present so checkpoints round-trip; only used
when config.diffusion_pair_source == "distogram_logits".
Helico.forward / Helico.predict:
- Run distogram_head before diffusion when in distogram mode; pass
detached logits to diffusion as z_trunk arg. detach() so the trunk
graph isn't pinned through the diffusion backward when the trunk is
frozen (memory hygiene from the issue's compute estimate).
train.py:
- TrainConfig fields diffusion_pair_source + freeze_trunk; CLI args
--diffusion-pair-source / --freeze-trunk.
- New _freeze_trunk(model) helper: requires_grad=False on every param
outside model.diffusion.*. Optimizer is built only over
requires_grad=True params so AdamW state doesn't grow uselessly.
- For the trainer, freeze runs before DDP wrapping so DDP sees the
correct mask.
modal/train.py:
- HELICO_TRAIN_DIFFUSION_PAIR_SOURCE / HELICO_TRAIN_FREEZE_TRUNK env
vars threaded through.
Tests (tests/test_diffusion_pair_source.py, 4 new tests):
- Default "z" mode leaves pair_proj_dist with no gradient.
- distogram mode: pair_proj_dist gets gradient, pair_proj does not.
- _freeze_trunk: every non-diffusion param has requires_grad=False
AND zero gradient after backward.
- Distogram-head output is independent of which mode the diffusion
module reads from (sanity that the swap is downstream of the head).
Smoketest: 32-token synthetic batch, distogram mode, freeze_trunk:
0 trunk params with nonzero grad, 227 diffusion params with grad,
finite loss. 173-test suite green.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
0 commit comments