Summary
When the collate function pads token sequences (e.g. so all batches reach a cuDNN-flash-attn-friendly multiple of 64 tokens), training goes NaN within ~10 steps. The root cause is that pair_mask is built as the symmetric outer product of token_mask — so padded query positions have ALL keys masked, softmax over all -inf produces NaN, and NaN propagates through autograd to corrupt the weights.
This currently doesn't bite us because we always train with batch_size=1 and skip padding entirely (the per-batch max equals the actual sample's n_tokens, so no fake positions exist). But it blocks several real workflows:
- Multi-sample batching: any
batch_size > 1 will have shorter samples padded to the longest in the batch — those padded positions can be 10s-100s of fake tokens.
- Padding for kernel constraints: cuDNN flash-attn (used by cuequivariance's
attention_pair_bias) sometimes rejects awkward sequence lengths with "No valid execution plans built". Padding to a kernel-friendly multiple is the standard fix, but currently triggers this NaN bug. We worked around it by wrapping val forwards in try/except (commit `85759f0`) so a crashing sample is skipped rather than killing the run, but that's tactical.
- DDP load balancing: any plan to bucket-batch by length and pad within buckets is similarly blocked.
Reproduction
Until the bug is fixed, the simplest repro is:
- In `src/helico/data.py`, change `collate_fn` to round `max_tokens` up to a multiple of 64:
```python
max_tokens = ((max_tokens + 63) // 64) * 64
```
- Run `HELICO_TRAIN_GPU=H100:1 HELICO_TRAIN_MAX_STEPS=20 HELICO_TRAIN_CROP=256 modal run modal/train.py`
- Step 0 logs sensible loss; by step 10 every metric is NaN. `eval_metrics: pred coords contain NaN/inf` is logged.
(See smoketest 8 in the 2026-04-22 session — commit `40594c7` introduced the padding, `85759f0` reverted it.)
Where to look
- `src/helico/model.py` — search for `pair_mask` construction. Currently around `pair_mask = (mask.unsqueeze(-1) & mask.unsqueeze(-2)).float()` in `Helico.forward`. This is the symmetric mask.
- `src/helico/model.py` — `attention_pair_bias` call sites in the pairformer and diffusion module pass that pair_mask through.
- The cuequivariance kernel itself (`cuequivariance_ops_torch/attention_pair_bias_torch.py:752`) calls `torch.nn.functional.scaled_dot_product_attention` and assumes the `attn_mask` follows standard SDPA conventions.
Suggested fix
Use an asymmetric mask for attention: only mask keys, not queries. Padded query positions then produce some output (which gets discarded downstream by the output mask), and softmax always has at least one valid key to attend to.
Concretely:
- Replace the symmetric `pair_mask = mask⊗mask` at attention call sites with `key_mask = mask.unsqueeze(-2).expand_as(pair)` (or however cuequivariance wants the bias additive mask shaped).
- Verify the math: padded queries' outputs aren't fed into anything that contributes to the loss (loss is masked by `token_mask` / `atom_mask`).
- Add a unit test: collate two structures of different sizes with deliberate padding, run a forward + backward, assert no NaN in any param gradient.
Additional context
- cuequivariance hardcodes its SDPA backend list to `[CUDNN_ATTENTION, FLASH_ATTENTION, EFFICIENT_ATTENTION]` with `set_priority=True` and omits MATH. So when cuDNN rejects a shape and the others also reject, there's no fallback. We can't easily change that from outside without monkey-patching (which `AGENTS.md` rules out).
- The val crash that originally led us here is in commits `727b52a`, `08c81fd`, `177e9a6`, `a3d5ada` — multiple workaround attempts, all unsuccessful; the only thing that worked was the try/except skip in `85759f0`.
🤖 Generated with Claude Code
Summary
When the collate function pads token sequences (e.g. so all batches reach a cuDNN-flash-attn-friendly multiple of 64 tokens), training goes NaN within ~10 steps. The root cause is that
pair_maskis built as the symmetric outer product oftoken_mask— so padded query positions have ALL keys masked, softmax over all-infproduces NaN, and NaN propagates through autograd to corrupt the weights.This currently doesn't bite us because we always train with
batch_size=1and skip padding entirely (the per-batch max equals the actual sample'sn_tokens, so no fake positions exist). But it blocks several real workflows:batch_size > 1will have shorter samples padded to the longest in the batch — those padded positions can be 10s-100s of fake tokens.attention_pair_bias) sometimes rejects awkward sequence lengths with "No valid execution plans built". Padding to a kernel-friendly multiple is the standard fix, but currently triggers this NaN bug. We worked around it by wrapping val forwards in try/except (commit `85759f0`) so a crashing sample is skipped rather than killing the run, but that's tactical.Reproduction
Until the bug is fixed, the simplest repro is:
```python
max_tokens = ((max_tokens + 63) // 64) * 64
```
(See smoketest 8 in the 2026-04-22 session — commit `40594c7` introduced the padding, `85759f0` reverted it.)
Where to look
Suggested fix
Use an asymmetric mask for attention: only mask keys, not queries. Padded query positions then produce some output (which gets discarded downstream by the output mask), and softmax always has at least one valid key to attend to.
Concretely:
Additional context
🤖 Generated with Claude Code