Skip to content

Padded query positions produce NaN in attention (symmetric pair_mask bug) #2

@timodonnell

Description

@timodonnell

Summary

When the collate function pads token sequences (e.g. so all batches reach a cuDNN-flash-attn-friendly multiple of 64 tokens), training goes NaN within ~10 steps. The root cause is that pair_mask is built as the symmetric outer product of token_mask — so padded query positions have ALL keys masked, softmax over all -inf produces NaN, and NaN propagates through autograd to corrupt the weights.

This currently doesn't bite us because we always train with batch_size=1 and skip padding entirely (the per-batch max equals the actual sample's n_tokens, so no fake positions exist). But it blocks several real workflows:

  • Multi-sample batching: any batch_size > 1 will have shorter samples padded to the longest in the batch — those padded positions can be 10s-100s of fake tokens.
  • Padding for kernel constraints: cuDNN flash-attn (used by cuequivariance's attention_pair_bias) sometimes rejects awkward sequence lengths with "No valid execution plans built". Padding to a kernel-friendly multiple is the standard fix, but currently triggers this NaN bug. We worked around it by wrapping val forwards in try/except (commit `85759f0`) so a crashing sample is skipped rather than killing the run, but that's tactical.
  • DDP load balancing: any plan to bucket-batch by length and pad within buckets is similarly blocked.

Reproduction

Until the bug is fixed, the simplest repro is:

  1. In `src/helico/data.py`, change `collate_fn` to round `max_tokens` up to a multiple of 64:
    ```python
    max_tokens = ((max_tokens + 63) // 64) * 64
    ```
  2. Run `HELICO_TRAIN_GPU=H100:1 HELICO_TRAIN_MAX_STEPS=20 HELICO_TRAIN_CROP=256 modal run modal/train.py`
  3. Step 0 logs sensible loss; by step 10 every metric is NaN. `eval_metrics: pred coords contain NaN/inf` is logged.

(See smoketest 8 in the 2026-04-22 session — commit `40594c7` introduced the padding, `85759f0` reverted it.)

Where to look

  • `src/helico/model.py` — search for `pair_mask` construction. Currently around `pair_mask = (mask.unsqueeze(-1) & mask.unsqueeze(-2)).float()` in `Helico.forward`. This is the symmetric mask.
  • `src/helico/model.py` — `attention_pair_bias` call sites in the pairformer and diffusion module pass that pair_mask through.
  • The cuequivariance kernel itself (`cuequivariance_ops_torch/attention_pair_bias_torch.py:752`) calls `torch.nn.functional.scaled_dot_product_attention` and assumes the `attn_mask` follows standard SDPA conventions.

Suggested fix

Use an asymmetric mask for attention: only mask keys, not queries. Padded query positions then produce some output (which gets discarded downstream by the output mask), and softmax always has at least one valid key to attend to.

Concretely:

  • Replace the symmetric `pair_mask = mask⊗mask` at attention call sites with `key_mask = mask.unsqueeze(-2).expand_as(pair)` (or however cuequivariance wants the bias additive mask shaped).
  • Verify the math: padded queries' outputs aren't fed into anything that contributes to the loss (loss is masked by `token_mask` / `atom_mask`).
  • Add a unit test: collate two structures of different sizes with deliberate padding, run a forward + backward, assert no NaN in any param gradient.

Additional context

  • cuequivariance hardcodes its SDPA backend list to `[CUDNN_ATTENTION, FLASH_ATTENTION, EFFICIENT_ATTENTION]` with `set_priority=True` and omits MATH. So when cuDNN rejects a shape and the others also reject, there's no fallback. We can't easily change that from outside without monkey-patching (which `AGENTS.md` rules out).
  • The val crash that originally led us here is in commits `727b52a`, `08c81fd`, `177e9a6`, `a3d5ada` — multiple workaround attempts, all unsuccessful; the only thing that worked was the try/except skip in `85759f0`.

🤖 Generated with Claude Code

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions