Skip to content

SwinUNETR: optional flash attention (scaled_dot_product_attention) in WindowAttention#8977

Open
aymuos15 wants to merge 2 commits into
Project-MONAI:devfrom
aymuos15:perf/swinunetr-flash-attention
Open

SwinUNETR: optional flash attention (scaled_dot_product_attention) in WindowAttention#8977
aymuos15 wants to merge 2 commits into
Project-MONAI:devfrom
aymuos15:perf/swinunetr-flash-attention

Conversation

@aymuos15

@aymuos15 aymuos15 commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Fixes #8973 .

Description

SwinUNETR's WindowAttention builds the full (nWindows*heads, N, N) score matrix by hand before softmax. This adds an opt-in use_flash_attention flag (default False, so existing behaviour is unchanged) that routes attention through torch.nn.functional.scaled_dot_product_attention, folding the relative position bias and, for shifted windows, the attention mask into one additive attn_mask cast to the query dtype. This mirrors the flash-attention option already in MONAI's SelfAttention, CrossAttention and CABlock.

Measured on the SwinUNETR encoder (SwinViT) forward, inference, single GPU, best-of-5. Float32 output matches the default path to within 3e-6 and is bit-exact in float64, verified across 2D and 3D, batch sizes 1 to 4, and non-cubic inputs.

ROI dtype default flash speedup
96^3 fp32 12.59 ms 8.34 ms 1.51x
96^3 bf16 10.71 ms 5.46 ms 1.96x
128^3 fp32 34.27 ms 21.74 ms 1.58x
128^3 bf16 29.45 ms 14.12 ms 2.09x
160^3 fp32 59.05 ms 37.66 ms 1.57x
160^3 bf16 51.25 ms 25.33 ms 2.02x

The flag is threaded through SwinTransformer, BasicLayer and SwinTransformerBlock to WindowAttention, exactly as use_v2 and use_checkpoint are. No parameters or buffers change, so pretrained weights load unchanged. The flash path is used only when autograd is disabled and the module is not scripted, so training and TorchScript keep the original path byte-for-byte; this is a deliberate choice to leave training numerics untouched.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • In-line docstrings updated.

@coderabbitai

coderabbitai Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 22073760-8de3-421c-b0b0-4905b5eba1a8

📥 Commits

Reviewing files that changed from the base of the PR and between 30e09af and 1713e89.

📒 Files selected for processing (2)
  • monai/networks/nets/swin_unetr.py
  • tests/networks/nets/test_swin_unetr.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/networks/nets/test_swin_unetr.py
  • monai/networks/nets/swin_unetr.py

📝 Walkthrough

Walkthrough

This PR adds an opt-in use_flash_attention flag to SwinUNETR and its internal Swin attention stack. WindowAttention.forward now uses scaled_dot_product_attention during inference when gradients are disabled and scripting is off, combining relative position bias with the shifted-window mask. A new test loads matching weights into flash and non-flash models and compares their outputs.

Estimated code review effort: 3 (Moderate) | ~20 minutes

Sequence Diagram(s)

sequenceDiagram
  participant SwinUNETR
  participant WindowAttention
  participant SDPA as scaled_dot_product_attention

  SwinUNETR->>WindowAttention: forward(x, mask) [use_flash_attention=True, no_grad]
  WindowAttention->>WindowAttention: compute relative position bias
  WindowAttention->>WindowAttention: combine bias + mask
  WindowAttention->>SDPA: scaled_dot_product_attention(q, k, v, attn_bias)
  SDPA-->>WindowAttention: attention output
  WindowAttention->>WindowAttention: proj + dropout
  WindowAttention-->>SwinUNETR: return output
Loading

Related issues

Related issues: #8973

Suggested labels

enhancement, networks

Suggested reviewers

None specified.

Poem

Flash path folds the bias in,
Old and new now match again.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title is concise and accurately summarizes the main change: optional flash attention for SwinUNETR WindowAttention.
Description check ✅ Passed The description follows the template well, includes the issue reference, a clear summary, and completed change-type checkboxes.
Linked Issues check ✅ Passed The changes satisfy #8973 by adding opt-in SDPA flash attention, threading the flag through the stack, and adding parity tests.
Out of Scope Changes check ✅ Passed The PR stays focused on SwinUNETR flash attention and its test coverage, with no明显 unrelated code changes.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
monai/networks/nets/swin_unetr.py (1)

545-548: 🚀 Performance & Scalability | 🔵 Trivial | 💤 Low value

Shifted-window branch materializes a full (b, heads, n, n) bias; also prefer reshape over view.

Two small points:

  • The expand(...).reshape(b, heads, n, n) allocates a tensor the size of the full score matrix — the exact thing the PR aims to avoid. Non-shifted windows stay cheap (broadcast), so the memory win is limited to unshifted blocks. Worth confirming against your benchmark.
  • mask.view(1, nw, 1, n, n) requires mask contiguous; the non-flash path uses unsqueeze which doesn't. Use reshape to be safe.
♻️ Defensive reshape
-                bias = relative_position_bias.view(1, 1, self.num_heads, n, n) + mask.view(1, nw, 1, n, n)
+                bias = relative_position_bias.view(1, 1, self.num_heads, n, n) + mask.reshape(1, nw, 1, n, n)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/swin_unetr.py` around lines 545 - 548, The shifted-window
bias path in the attention logic still materializes a full (b, heads, n, n)
tensor and uses mask.view in the `WindowAttention`/relative-position-bias
branch, so update that branch to preserve broadcasting instead of expanding into
the full score matrix and replace the contiguous-only `view` on `mask` with
`reshape` (or equivalent safe reshaping) to match the non-flash path behavior.
tests/networks/nets/test_swin_unetr.py (1)

115-115: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Tolerance far looser than the stated parity target.

PR objective claims float32 parity to ~3e-6, but this asserts atol=rtol=1e-3. A regression a couple orders of magnitude short of the goal would still pass. Consider tightening to match the objective, or a float64 bit-exact check.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/networks/nets/test_swin_unetr.py` at line 115, The parity assertion in
the Swin UNETR test is too loose relative to the stated float32 target, so
tighten the tolerance in the test that uses assert_allclose to reflect the
intended ~3e-6 parity goal. Update the check in test_swin_unetr.py near the
existing assert_allclose call so it either uses a much stricter atol/rtol or
switches to a float64 bit-exact comparison if that is the intended standard,
keeping the test aligned with the PR objective.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/networks/nets/swin_unetr.py`:
- Line 543: The flash-attention branch in swinuNetr is only reached when
gradients are disabled, so `model.eval()` alone will still take the slow path.
Update the `SwinUNETR`/`forward`-path documentation to call out that
`use_flash_attention` requires `torch.no_grad()` or `torch.inference_mode()`,
and consider adding a one-time warning in the `use_flash_attention` guard when
`torch.is_grad_enabled()` is true so users can spot the mismatch.

In `@tests/networks/nets/test_swin_unetr.py`:
- Line 114: The loop in the test that iterates over ref and out should use zip
with strict=True to make the equal-length assumption explicit and guard against
future mismatches. Update the zip call inside the relevant test in
test_swin_unetr.py so the iteration over ref and out is strict, preserving the
existing comparison logic while documenting intent.

---

Nitpick comments:
In `@monai/networks/nets/swin_unetr.py`:
- Around line 545-548: The shifted-window bias path in the attention logic still
materializes a full (b, heads, n, n) tensor and uses mask.view in the
`WindowAttention`/relative-position-bias branch, so update that branch to
preserve broadcasting instead of expanding into the full score matrix and
replace the contiguous-only `view` on `mask` with `reshape` (or equivalent safe
reshaping) to match the non-flash path behavior.

In `@tests/networks/nets/test_swin_unetr.py`:
- Line 115: The parity assertion in the Swin UNETR test is too loose relative to
the stated float32 target, so tighten the tolerance in the test that uses
assert_allclose to reflect the intended ~3e-6 parity goal. Update the check in
test_swin_unetr.py near the existing assert_allclose call so it either uses a
much stricter atol/rtol or switches to a float64 bit-exact comparison if that is
the intended standard, keeping the test aligned with the PR objective.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 320211db-4262-434b-a821-dcac487b9efc

📥 Commits

Reviewing files that changed from the base of the PR and between f1dcac4 and 30e09af.

📒 Files selected for processing (2)
  • monai/networks/nets/swin_unetr.py
  • tests/networks/nets/test_swin_unetr.py

Comment thread monai/networks/nets/swin_unetr.py
Comment thread tests/networks/nets/test_swin_unetr.py Outdated
aymuos15 added 2 commits July 2, 2026 22:19
Add a use_flash_attention flag (default False) threaded from SwinUNETR
through SwinTransformer, BasicLayer and SwinTransformerBlock to
WindowAttention. When enabled and autograd is disabled (and not
scripting), attention is computed with
torch.nn.functional.scaled_dot_product_attention, folding the relative
position bias and the shifted-window mask into a single additive
attn_mask cast to the query dtype. The fused kernel avoids materializing
the score matrix. Output matches the default path; training, scripting
and the default path are unchanged. This mirrors the flash-attention
option already in MONAI's SelfAttention, CrossAttention and CABlock.

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Assert the SwinViT encoder features with use_flash_attention=True match
the default path, in float64 for a tight bit-level comparison.

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
@aymuos15 aymuos15 force-pushed the perf/swinunetr-flash-attention branch from 30e09af to 1713e89 Compare July 2, 2026 21:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SwinUNETR WindowAttention: add flash attention (scaled_dot_product_attention)

1 participant