Skip to content

Commit 08990d6

Browse files
timodonnellclaude
andcommitted
gh#9: stage_config CLI override fix + smart-init pair_proj_dist + strict=False
Three fixes from the LR-sweep instability (loss in the 1000s, grad norm in the 30k-150k range): 1. StageConfig silently overrode CLI --lr. Training loop called get_lr(step, config, stage.get("lr")) and "stage_lr or config.lr" meant the hardcoded stage_config[lr]=1e-3 always won. So every distinct --lr in the LR sweep ended up running at lr=1e-3. Switch to get_lr(step, config) — config.lr from CLI is now the only knob. Crop-size staging unaffected. 2. pair_proj_dist starts at default-random init, so distogram-mode training feeds garbage into the diffusion module out of the gate. Smart-init: copy pair_proj's relpe-column weights into pair_proj_dist's relpe columns (so the relpe contribution matches z-mode at step 0), zero the distogram-input columns. Diffusion module starts in a near-legacy state and learns to use the distogram signal gradually. Smoketest: diff_loss drops from 1000s to ~13 with this init. 3. load_checkpoint used strict=True, so loading a legacy "z"-mode checkpoint into a distogram-mode model failed on the missing pair_proj_dist keys. Switch to strict=False; return the missing keys list. train() runs the smart init when pair_proj_dist was missing from the loaded state (i.e., resuming from a "z" seed or starting fresh) — skipped when resuming an already-distogram run. Tests still green (39/39 in test_diffusion_pair_source + test_data). load_checkpoint signature changed to return (step, missing_keys) tuple — callers in this file already updated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a745e40 commit 08990d6

1 file changed

Lines changed: 76 additions & 8 deletions

File tree

src/helico/train.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,39 @@ def get_lr(step: int, config: TrainConfig, stage_lr: float | None = None) -> flo
152152
# EMA
153153
# ============================================================================
154154

155+
def _init_distogram_proj_from_z(model: nn.Module) -> None:
156+
"""gh#9: warm-start pair_proj_dist from pair_proj's relpe weights.
157+
158+
Without this, training a distogram-mode model from a legacy "z"-mode
159+
checkpoint starts with random ``pair_proj_dist`` weights — the diffusion
160+
module sees a totally fresh input distribution and explodes (grad norms
161+
in the 30k-150k range observed during the gh#9 LR sweep).
162+
163+
Smart init:
164+
- distogram-channel weights → 0 (distogram contributes nothing initially)
165+
- relpe-channel weights → copy of legacy pair_proj's relpe weights
166+
(relpe contribution matches z-mode at step 0)
167+
168+
The diffusion module then starts in a near-legacy state and gradually
169+
learns to use the distogram signal as training proceeds. Idempotent if
170+
pair_proj_dist already has nonzero distogram-channel weights from a
171+
prior training run.
172+
"""
173+
base = model.module if hasattr(model, "module") else model
174+
cond = base.diffusion.conditioning
175+
n_distogram_bins = base.config.n_distogram_bins
176+
c_z = base.config.d_pair
177+
with torch.no_grad():
178+
cond.pair_proj_dist.weight.data[:, :n_distogram_bins].zero_()
179+
cond.pair_proj_dist.weight.data[:, n_distogram_bins:].copy_(
180+
cond.pair_proj.weight.data[:, c_z:]
181+
)
182+
logger.info(
183+
f"_init_distogram_proj_from_z: zeroed {n_distogram_bins} distogram "
184+
f"input channels; copied {c_z} relpe channels from pair_proj"
185+
)
186+
187+
155188
def _freeze_trunk(model: nn.Module) -> tuple[int, int]:
156189
"""Freeze every parameter outside ``model.diffusion`` (gh#9).
157190
@@ -232,13 +265,28 @@ def load_checkpoint(
232265
model: nn.Module,
233266
optimizer: torch.optim.Optimizer | None = None,
234267
ema: EMAModel | None = None,
235-
) -> int:
268+
) -> tuple[int, list[str]]:
269+
"""Load a checkpoint. Returns ``(start_step, missing_keys)``.
270+
271+
``missing_keys`` is the list of state-dict keys the model expected
272+
that the checkpoint didn't have. Used by gh#9's smart-init: when
273+
pair_proj_dist is in missing_keys, we know we're loading a legacy
274+
"z"-mode seed and need to warm-start the new layers.
275+
"""
236276
state = torch.load(path, map_location="cpu", weights_only=False)
237277

238-
if isinstance(model, DDP):
239-
model.module.load_state_dict(state["model_state_dict"])
240-
else:
241-
model.load_state_dict(state["model_state_dict"])
278+
# strict=False so legacy "z"-mode checkpoints (which lack pair_proj_dist
279+
# / pair_norm_dist) can be loaded into a distogram-mode model.
280+
target = model.module if isinstance(model, DDP) else model
281+
missing, unexpected = target.load_state_dict(
282+
state["model_state_dict"], strict=False,
283+
)
284+
if missing:
285+
logger.info(f"Loaded checkpoint with {len(missing)} missing keys "
286+
f"(first: {missing[:3]})")
287+
if unexpected:
288+
logger.warning(f"Loaded checkpoint had {len(unexpected)} unexpected keys "
289+
f"(first: {unexpected[:3]})")
242290

243291
if optimizer is not None and "optimizer_state_dict" in state:
244292
optimizer.load_state_dict(state["optimizer_state_dict"])
@@ -251,7 +299,7 @@ def load_checkpoint(
251299

252300
step = state.get("step", 0)
253301
logger.info(f"Loaded checkpoint from {path} at step {step}")
254-
return step
302+
return step, list(missing)
255303

256304

257305
# ============================================================================
@@ -460,8 +508,24 @@ def train(
460508

461509
# Resume
462510
start_step = 0
511+
missing_keys: list[str] = []
463512
if resume_path:
464-
start_step = load_checkpoint(resume_path, model, optimizer, ema)
513+
start_step, missing_keys = load_checkpoint(resume_path, model, optimizer, ema)
514+
515+
# gh#9 smart init: when starting a distogram-mode run from a legacy
516+
# "z"-mode checkpoint (or freshly without resume), warm-start
517+
# pair_proj_dist from pair_proj's relpe weights so training doesn't
518+
# explode out of the gate. Skip when resuming a distogram run — the
519+
# checkpoint already has the trained pair_proj_dist.
520+
if config.diffusion_pair_source == "distogram_logits":
521+
loaded_dist = (
522+
resume_path is not None
523+
and not any("pair_proj_dist" in k for k in missing_keys)
524+
)
525+
if not loaded_dist:
526+
_init_distogram_proj_from_z(model)
527+
else:
528+
logger.info("pair_proj_dist loaded from checkpoint — skipping smart init")
465529

466530
# W&B (rank 0 only). Enabled by HELICO_WANDB_ENABLE=1 in env.
467531
wandb_run = None
@@ -526,7 +590,11 @@ def train(
526590

527591
# Update stage-specific settings
528592
stage = stage_config.get_stage(step)
529-
current_lr = get_lr(step, config, stage.get("lr"))
593+
# Always use config.lr (from CLI / TrainConfig) — the per-stage
594+
# ``lr`` field in StageConfig is legacy and silently overrode
595+
# the CLI before, making LR sweeps meaningless. Crop-size
596+
# staging below is unaffected.
597+
current_lr = get_lr(step, config)
530598
for param_group in optimizer.param_groups:
531599
param_group["lr"] = current_lr
532600

0 commit comments

Comments
 (0)