Skip to content

Commit 08307ab

Browse files
timodonnellclaude
andcommitted
gh#9: diffusion_pair_source flag + freeze_trunk helper
Implements the bottlenecked-conditioning experiment from issue #9: swap the diffusion module's pair input from the trunk's final pair representation z (B, N_tok, N_tok, 128) to the distogram-head logits (B, N_tok, N_tok, 64). Freeze the trunk; train only the diffusion module from this lower-rank signal. HelicoConfig: - New diffusion_pair_source: "z" (default, legacy) | "distogram_logits". DiffusionConditioning (the only place the swap is needed — by the time z reaches the AtomAttentionEncoder it's already z_cond, the post-conditioning 128-d tensor): - Parallel pair_norm_dist + pair_proj_dist sized for n_distogram_bins + d_pair input. Always present so checkpoints round-trip; only used when config.diffusion_pair_source == "distogram_logits". Helico.forward / Helico.predict: - Run distogram_head before diffusion when in distogram mode; pass detached logits to diffusion as z_trunk arg. detach() so the trunk graph isn't pinned through the diffusion backward when the trunk is frozen (memory hygiene from the issue's compute estimate). train.py: - TrainConfig fields diffusion_pair_source + freeze_trunk; CLI args --diffusion-pair-source / --freeze-trunk. - New _freeze_trunk(model) helper: requires_grad=False on every param outside model.diffusion.*. Optimizer is built only over requires_grad=True params so AdamW state doesn't grow uselessly. - For the trainer, freeze runs before DDP wrapping so DDP sees the correct mask. modal/train.py: - HELICO_TRAIN_DIFFUSION_PAIR_SOURCE / HELICO_TRAIN_FREEZE_TRUNK env vars threaded through. Tests (tests/test_diffusion_pair_source.py, 4 new tests): - Default "z" mode leaves pair_proj_dist with no gradient. - distogram mode: pair_proj_dist gets gradient, pair_proj does not. - _freeze_trunk: every non-diffusion param has requires_grad=False AND zero gradient after backward. - Distogram-head output is independent of which mode the diffusion module reads from (sanity that the swap is downstream of the head). Smoketest: 32-token synthetic batch, distogram mode, freeze_trunk: 0 trunk params with nonzero grad, 227 diffusion params with grad, finite loss. 173-test suite green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ef0fa65 commit 08307ab

6 files changed

Lines changed: 235 additions & 13 deletions

File tree

modal/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
HELICO_TRAIN_VAL_EVERY=0 # 0 disables; e.g. 500 runs val every 500 steps
1414
HELICO_TRAIN_VAL_SAMPLES=32
1515
HELICO_TRAIN_N_DIFFUSION_SAMPLES=8 # Diffusion noise samples per trunk forward (gh#6)
16+
HELICO_TRAIN_DIFFUSION_PAIR_SOURCE=z # "z" or "distogram_logits" (gh#9)
17+
HELICO_TRAIN_FREEZE_TRUNK=0 # 1 = freeze trunk, train only diffusion (gh#9)
1618
HELICO_TRAIN_RESUME= # /ckpts/<run>/step_<N>.pt to resume
1719
HELICO_TRAIN_PROTENIX_INIT=1 # warm-start from Protenix v1 weights
1820
HELICO_TRAIN_CUTOFF=2021-09-30 # train = release_date < this (AF3/Protenix/OF3 shared cutoff)
@@ -112,6 +114,8 @@ def _env_float(name: str, default: float) -> float:
112114
"val_every": _env_int("HELICO_TRAIN_VAL_EVERY", 0),
113115
"val_samples": _env_int("HELICO_TRAIN_VAL_SAMPLES", 32),
114116
"n_diffusion_samples": _env_int("HELICO_TRAIN_N_DIFFUSION_SAMPLES", 8),
117+
"diffusion_pair_source": os.environ.get("HELICO_TRAIN_DIFFUSION_PAIR_SOURCE", "z"),
118+
"freeze_trunk": os.environ.get("HELICO_TRAIN_FREEZE_TRUNK", "0") == "1",
115119
"resume_from": os.environ.get("HELICO_TRAIN_RESUME", ""),
116120
"protenix_init": os.environ.get("HELICO_TRAIN_PROTENIX_INIT", "1") == "1",
117121
"train_cutoff": os.environ.get("HELICO_TRAIN_CUTOFF", "2021-09-30"),
@@ -190,6 +194,7 @@ def train_remote(args: dict) -> dict:
190194
"--val-every", str(args["val_every"]),
191195
"--val-samples", str(args["val_samples"]),
192196
"--n-diffusion-samples", str(args["n_diffusion_samples"]),
197+
"--diffusion-pair-source", args["diffusion_pair_source"],
193198
"--checkpoint-dir", str(run_ckpt_dir),
194199
"--train-cutoff", args["train_cutoff"],
195200
"--val-cutoff-start", args["val_cutoff_start"],
@@ -199,6 +204,8 @@ def train_remote(args: dict) -> dict:
199204
base_cli += ["--msa-dir", str(msa_dir)]
200205
if resume_from:
201206
base_cli += ["--resume", resume_from]
207+
if args.get("freeze_trunk"):
208+
base_cli += ["--freeze-trunk"]
202209

203210
if n_gpus > 1:
204211
cmd = [

src/helico/model/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ class HelicoConfig:
7777
# (amortize the expensive trunk — AF3 SI §3.7.1 Fig 2c).
7878
n_diffusion_samples: int = 8
7979

80+
# --- Pair source for diffusion conditioning (gh#9) ---
81+
# "z" (default): use the trunk's final pair representation
82+
# z_trunk : (B, N_tok, N_tok, d_pair)
83+
# "distogram_logits": substitute the trunk's distogram-head output
84+
# (B, N_tok, N_tok, n_distogram_bins). Forces an
85+
# information bottleneck at the trunk → diffusion
86+
# interface; intended for gh#9-style experiments
87+
# where we freeze the trunk and only retrain the
88+
# diffusion module from this lower-rank signal.
89+
# The diffusion module gains a parallel set of input projections sized
90+
# for the alternate channel count; only one path is active per forward.
91+
diffusion_pair_source: str = "z"
92+
8093
# --- Atom feature dims (from AF3 SI §2.8 Table 5) ---
8194
n_elements: int = UNK_ELEM_IDX + 1 # Number of element types + 1 UNK
8295
n_token_types: int = NUM_TOKEN_TYPES

src/helico/model/diffusion.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ class AtomAttentionEncoder(nn.Module):
288288
def __init__(self, config, has_coords: bool = True,
289289
c_token_override: int | None = None):
290290
super().__init__()
291+
self.config = config
291292
c = config
292293
c_atom = c.c_atom
293294
c_atompair = c.c_atompair
@@ -309,7 +310,12 @@ def __init__(self, config, has_coords: bool = True,
309310
# Noisy coords projection
310311
self.noisy_pos_proj = linear_no_bias(3, c_atom)
311312

312-
# Trunk s,z injection (zero-init → no-op at start of training)
313+
# Trunk s,z injection (zero-init → no-op at start of training).
314+
# NOTE: in DiffusionModule, this z_trunk arg is actually z_cond
315+
# (DiffusionConditioning output, always c_z), not the raw trunk
316+
# pair tensor. So gh#9's distogram swap only needs to happen
317+
# inside DiffusionConditioning — by the time z reaches the atom
318+
# encoder it's already been projected back to c_z.
313319
self.trunk_s_norm = LayerNorm(c_s)
314320
self.trunk_s_proj = linear_no_bias(c_s, c_atom, zeros_init=True)
315321
self.trunk_z_norm = LayerNorm(c_z)
@@ -413,7 +419,9 @@ def forward(
413419
p = p + self.pair_valid_proj(v_lm)
414420
p = p * pad_mask.unsqueeze(0).unsqueeze(-1).to(diff.dtype)
415421

416-
# 5. Trunk pair injection (windowed gather of token-pair z into atom-pair)
422+
# 5. Trunk pair injection (windowed gather of token-pair z into atom-pair).
423+
# ``z_trunk`` here is z_cond (post DiffusionConditioning, always c_z
424+
# channels) — see note in __init__.
417425
if self.has_coords and z_trunk is not None:
418426
z_trunk_proj = self.trunk_z_proj(self.trunk_z_norm(z_trunk))
419427
z_windowed = self._gather_trunk_pair_windowed(
@@ -536,6 +544,7 @@ class DiffusionConditioning(nn.Module):
536544

537545
def __init__(self, config):
538546
super().__init__()
547+
self.config = config
539548
c = config
540549
c_s = c.d_single
541550
c_z = c.d_pair
@@ -547,6 +556,13 @@ def __init__(self, config):
547556
self.pair_transition_1 = Transition(c_z, factor=2)
548557
self.pair_transition_2 = Transition(c_z, factor=2)
549558

559+
# gh#9: parallel pair input for diffusion_pair_source="distogram_logits".
560+
# Input is concat(distogram_logits, relpe) — c.n_distogram_bins + c_z.
561+
# Always present so checkpoints from "z" mode round-trip; only
562+
# active when config.diffusion_pair_source != "z".
563+
self.pair_norm_dist = nn.LayerNorm(c.n_distogram_bins + c_z)
564+
self.pair_proj_dist = linear_no_bias(c.n_distogram_bins + c_z, c_z)
565+
550566
# Single path
551567
self.fourier = FourierEmbedding(c.c_noise_embedding)
552568
self.s_inputs_dim = c_s + 65 # s_inputs dim = c_s (from atom encoder) + 65
@@ -566,9 +582,16 @@ def forward(self, s_trunk: torch.Tensor, z_trunk: torch.Tensor,
566582
"""
567583
sigma_data = 16.0 # EDM constant (σ_data)
568584

569-
# Pair conditioning: concat(z_trunk, relpe) → norm → linear → 2x Transition
585+
# Pair conditioning: concat(z_trunk, relpe) → norm → linear → 2x Transition.
586+
# gh#9: in "distogram_logits" mode, z_trunk is actually the distogram
587+
# logits (B, N, N, n_distogram_bins) and we use the parallel
588+
# pair_proj_dist sized for that channel count.
570589
relpe = self.relpe(**relpe_feats)
571-
z = self.pair_proj(self.pair_norm(torch.cat([z_trunk, relpe], dim=-1)))
590+
z_in = torch.cat([z_trunk, relpe], dim=-1)
591+
if self.config.diffusion_pair_source == "distogram_logits":
592+
z = self.pair_proj_dist(self.pair_norm_dist(z_in))
593+
else:
594+
z = self.pair_proj(self.pair_norm(z_in))
572595
z = z + self.pair_transition_1(z)
573596
z = z + self.pair_transition_2(z)
574597

src/helico/model/helico.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,24 @@ def forward(
138138

139139
results = {"single": s, "pair": z}
140140

141-
# 4. Diffusion — s_inputs is already (B, N_tok, 449 = d_single + 65)
141+
# 4a. Distogram (always computed; needs to be available *before*
142+
# diffusion when diffusion_pair_source == "distogram_logits" so the
143+
# diffusion module can read from it. distogram_head is itself a
144+
# single Linear (z → 64-bin logits, symmetrized).
145+
distogram_logits = self.distogram_head(z)
146+
results["distogram_logits"] = distogram_logits
147+
148+
# 4b. Diffusion — s_inputs is already (B, N_tok, 449 = d_single + 65)
142149
# n_diffusion_samples > 1 amortizes the expensive trunk over several
143150
# denoising passes per batch entry (gh#6). Outputs are (B*N_d, ...).
151+
# gh#9: when configured, swap z_trunk for the trunk's distogram
152+
# output (information bottleneck). detach() ensures the trunk
153+
# graph isn't pinned through the diffusion backward when the
154+
# trunk is frozen — saves activation memory.
155+
if self.config.diffusion_pair_source == "distogram_logits":
156+
z_for_diffusion = distogram_logits.detach()
157+
else:
158+
z_for_diffusion = z
144159
n_d = max(1, self.config.n_diffusion_samples)
145160
x_denoised, gt_coords, sigma = self.diffusion.forward_training(
146161
gt_coords=batch["atom_coords"],
@@ -150,7 +165,7 @@ def forward(
150165
atom_to_token=batch["atom_to_token"],
151166
atom_mask=atom_mask,
152167
s_trunk=s,
153-
z_trunk=z,
168+
z_trunk=z_for_diffusion,
154169
s_inputs=s_inputs,
155170
relpe_feats=relpe_feats,
156171
n_samples=n_d,
@@ -163,10 +178,6 @@ def forward(
163178
atom_mask_d = atom_mask.repeat_interleave(n_d, dim=0) if n_d > 1 else atom_mask
164179
results["diffusion_loss"] = diffusion_loss(x_denoised, gt_coords, sigma, atom_mask_d)
165180

166-
# 5. Distogram (from trunk pair)
167-
distogram_logits = self.distogram_head(z)
168-
results["distogram_logits"] = distogram_logits
169-
170181
# 6. Confidence head (uses pred_coords from diffusion). Use only
171182
# the first denoising sample per batch entry — the head expects
172183
# (B, N_atoms, 3), not (B*N_d, ...).
@@ -285,6 +296,13 @@ def _expand(t):
285296
return t.unsqueeze(1).expand(-1, n_samples, *[-1] * (t.dim() - 1)).reshape(B * n_samples, *t.shape[1:])
286297

287298
ref_space_uid = batch.get("ref_space_uid")
299+
# gh#9: same swap as in forward — at inference, when the diffusion
300+
# module is configured to read from the distogram, run the head
301+
# before sampling and feed those logits in place of z.
302+
if self.config.diffusion_pair_source == "distogram_logits":
303+
z_for_diffusion = self.distogram_head(z)
304+
else:
305+
z_for_diffusion = z
288306
t_diffusion_start = _sync_time()
289307
batched_coords = self.diffusion.sample(
290308
ref_pos=_expand(batch["ref_coords"]),
@@ -293,7 +311,7 @@ def _expand(t):
293311
atom_to_token=_expand(batch["atom_to_token"]),
294312
atom_mask=_expand(atom_mask),
295313
s_trunk=_expand(s),
296-
z_trunk=_expand(z),
314+
z_trunk=_expand(z_for_diffusion),
297315
s_inputs=_expand(s_inputs),
298316
relpe_feats={k: _expand(v) for k, v in relpe_feats.items()},
299317
ref_space_uid=_expand(ref_space_uid) if ref_space_uid is not None else None,

src/helico/train.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ class TrainConfig:
9797
# DDP
9898
distributed: bool = False
9999

100+
# gh#9: pair source for diffusion conditioning ("z" or "distogram_logits"),
101+
# plus a flag to freeze the trunk so only the diffusion module trains.
102+
diffusion_pair_source: str = "z"
103+
freeze_trunk: bool = False
104+
100105
def get_torch_dtype(self) -> torch.dtype:
101106
return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[self.dtype]
102107

@@ -147,6 +152,29 @@ def get_lr(step: int, config: TrainConfig, stage_lr: float | None = None) -> flo
147152
# EMA
148153
# ============================================================================
149154

155+
def _freeze_trunk(model: nn.Module) -> tuple[int, int]:
156+
"""Freeze every parameter outside ``model.diffusion`` (gh#9).
157+
158+
The convention here is "trunk" = everything except the diffusion module:
159+
input embedder, trunk linears, MSA module, pairformer, distogram head,
160+
confidence head, template embedder. The diffusion module's two new
161+
distogram-input projections (``pair_proj_dist``, ``trunk_z_proj_dist``)
162+
live under ``model.diffusion.*`` and stay trainable automatically.
163+
164+
Returns (n_frozen, n_trainable) parameter counts.
165+
"""
166+
base = model.module if hasattr(model, "module") else model
167+
n_frozen = n_trainable = 0
168+
for name, param in base.named_parameters():
169+
if name.startswith("diffusion."):
170+
param.requires_grad = True
171+
n_trainable += param.numel()
172+
else:
173+
param.requires_grad = False
174+
n_frozen += param.numel()
175+
return n_frozen, n_trainable
176+
177+
150178
class EMAModel:
151179
"""Exponential Moving Average of model weights."""
152180

@@ -390,16 +418,25 @@ def train(
390418

391419
model = model.to(device)
392420

421+
# gh#9: freeze the trunk so only the diffusion module trains. Done
422+
# before DDP wrapping so DDP sees the right requires_grad mask.
423+
if config.freeze_trunk:
424+
n_frozen, n_trainable = _freeze_trunk(model)
425+
logger.info(
426+
f"freeze_trunk=True: froze {n_frozen:,} params, "
427+
f"{n_trainable:,} remain trainable"
428+
)
429+
393430
if config.distributed:
394431
# find_unused_parameters=True: conditionally-used sub-modules (e.g.
395432
# MSA paths when no MSA is present) don't receive gradients on
396433
# every batch. Without this flag, DDP's all-reduce deadlocks with
397434
# "Expected to have finished reduction" on step 1.
398435
model = DDP(model, device_ids=[device], find_unused_parameters=True)
399436

400-
# Optimizer
437+
# Optimizer (skip frozen params so AdamW state doesn't grow uselessly).
401438
optimizer = torch.optim.AdamW(
402-
model.parameters(),
439+
[p for p in model.parameters() if p.requires_grad],
403440
lr=config.lr,
404441
weight_decay=config.weight_decay,
405442
betas=(0.9, 0.999),
@@ -741,6 +778,11 @@ def main():
741778
parser.add_argument("--n-diffusion-token-blocks", type=int, default=24, help="Number of diffusion token transformer blocks")
742779
parser.add_argument("--n-diffusion-samples", type=int, default=8,
743780
help="Diffusion noise samples per trunk forward (gh#6). 1 = legacy.")
781+
parser.add_argument("--diffusion-pair-source", type=str, default="z",
782+
choices=["z", "distogram_logits"],
783+
help="Pair conditioning source for the diffusion module (gh#9).")
784+
parser.add_argument("--freeze-trunk", action="store_true",
785+
help="Freeze the trunk (gh#9). Only the diffusion module trains.")
744786
parser.add_argument("--crop-size", type=int, default=384, help="Initial crop size")
745787
parser.add_argument("--batch-size", type=int, default=1, help="Batch size per GPU")
746788
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
@@ -798,12 +840,15 @@ def main():
798840
val_samples=args.val_samples,
799841
checkpoint_dir=args.checkpoint_dir,
800842
distributed=args.distributed,
843+
diffusion_pair_source=args.diffusion_pair_source,
844+
freeze_trunk=args.freeze_trunk,
801845
)
802846

803847
model_config = HelicoConfig(
804848
n_pairformer_blocks=args.n_blocks,
805849
n_diffusion_token_blocks=args.n_diffusion_token_blocks,
806850
n_diffusion_samples=args.n_diffusion_samples,
851+
diffusion_pair_source=args.diffusion_pair_source,
807852
)
808853

809854
model = Helico(model_config)

0 commit comments

Comments
 (0)