Skip to content

Fix batched supervised chi periodicity loss#573

Open
taivu1998 wants to merge 1 commit into
aqlaboratory:mainfrom
taivu1998:tdv/issue-381-supervised-chi-batch
Open

Fix batched supervised chi periodicity loss#573
taivu1998 wants to merge 1 commit into
aqlaboratory:mainfrom
taivu1998:tdv/issue-381-supervised-chi-batch

Conversation

@taivu1998
Copy link
Copy Markdown

Summary

Fixes #381.

This PR fixes batched chi-periodicity handling in supervised_chi_loss(). The previous einsum used to construct chi_pi_periodic omitted the ellipsis from the output equation:

"...ij,jk->ik"

For batched aatype tensors, that reduces leading batch dimensions and produces a [N, 4] periodicity mask instead of [*batch_dims, N, 4]. As a result, side-chain chi periodicity could be shared across different examples at the same residue index.

Changes

  • Replace the one-hot/einsum construction with direct residue-constant table lookup:
    • chi_pi_periodic = table[aatype, ...]
  • Preserve dtype/device behavior by constructing the table from angles_sin_cos.new_tensor(...).
  • Correct the local supervised_chi_loss() docstring for chi-only ground-truth tensors.
  • Add a CPU regression test that mixes ASP and ARG at the same batch/residue position:
    • ASP chi2 is pi-periodic, so a pi-shifted prediction should have zero chi loss.
    • ARG chi2 is not pi-periodic, so the same prediction should retain squared error.

Validation

  • PYTHONPATH=/private/tmp/openfold-issue-381-test-stubs:$PYTHONPATH /private/tmp/openfold-issue-381-venv/bin/python -m pytest tests/test_loss.py -k "supervised_chi_loss" -q
    • 1 passed, 1 skipped, 21 deselected
    • The skipped test is the existing AlphaFold/JAX parity test, skipped because AlphaFold/JAX is not installed in this local environment.
  • python -m py_compile openfold/utils/loss.py tests/test_loss.py
  • git diff --check
  • Direct behavioral check:
    • old mask shape: (1, 4), old loss: 0.0
    • fixed mask shape: (2, 1, 4), fixed loss: 1.9997999668121338

Note: the local macOS environment does not have OpenFold's compiled attn_core_inplace_cuda extension. For the focused pytest selection, I used a temporary import stub outside the repository so the loss test module could collect; the stub was not executed by the supervised-chi tests.

@taivu1998 taivu1998 marked this pull request as ready for review May 11, 2026 03:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorrect (batched) einsum in supervised_chi_loss()?

1 participant