SwinUNETR: optional flash attention (scaled_dot_product_attention) in WindowAttention#8977
SwinUNETR: optional flash attention (scaled_dot_product_attention) in WindowAttention#8977aymuos15 wants to merge 2 commits into
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughThis PR adds an opt-in Estimated code review effort: 3 (Moderate) | ~20 minutes Sequence Diagram(s)sequenceDiagram
participant SwinUNETR
participant WindowAttention
participant SDPA as scaled_dot_product_attention
SwinUNETR->>WindowAttention: forward(x, mask) [use_flash_attention=True, no_grad]
WindowAttention->>WindowAttention: compute relative position bias
WindowAttention->>WindowAttention: combine bias + mask
WindowAttention->>SDPA: scaled_dot_product_attention(q, k, v, attn_bias)
SDPA-->>WindowAttention: attention output
WindowAttention->>WindowAttention: proj + dropout
WindowAttention-->>SwinUNETR: return output
Related issuesRelated issues: Suggested labels
Suggested reviewersNone specified. Poem Flash path folds the bias in, 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
monai/networks/nets/swin_unetr.py (1)
545-548: 🚀 Performance & Scalability | 🔵 Trivial | 💤 Low valueShifted-window branch materializes a full
(b, heads, n, n)bias; also preferreshapeoverview.Two small points:
- The
expand(...).reshape(b, heads, n, n)allocates a tensor the size of the full score matrix — the exact thing the PR aims to avoid. Non-shifted windows stay cheap (broadcast), so the memory win is limited to unshifted blocks. Worth confirming against your benchmark.mask.view(1, nw, 1, n, n)requiresmaskcontiguous; the non-flash path usesunsqueezewhich doesn't. Usereshapeto be safe.♻️ Defensive reshape
- bias = relative_position_bias.view(1, 1, self.num_heads, n, n) + mask.view(1, nw, 1, n, n) + bias = relative_position_bias.view(1, 1, self.num_heads, n, n) + mask.reshape(1, nw, 1, n, n)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/nets/swin_unetr.py` around lines 545 - 548, The shifted-window bias path in the attention logic still materializes a full (b, heads, n, n) tensor and uses mask.view in the `WindowAttention`/relative-position-bias branch, so update that branch to preserve broadcasting instead of expanding into the full score matrix and replace the contiguous-only `view` on `mask` with `reshape` (or equivalent safe reshaping) to match the non-flash path behavior.tests/networks/nets/test_swin_unetr.py (1)
115-115: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueTolerance far looser than the stated parity target.
PR objective claims float32 parity to ~
3e-6, but this assertsatol=rtol=1e-3. A regression a couple orders of magnitude short of the goal would still pass. Consider tightening to match the objective, or a float64 bit-exact check.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/networks/nets/test_swin_unetr.py` at line 115, The parity assertion in the Swin UNETR test is too loose relative to the stated float32 target, so tighten the tolerance in the test that uses assert_allclose to reflect the intended ~3e-6 parity goal. Update the check in test_swin_unetr.py near the existing assert_allclose call so it either uses a much stricter atol/rtol or switches to a float64 bit-exact comparison if that is the intended standard, keeping the test aligned with the PR objective.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/networks/nets/swin_unetr.py`:
- Line 543: The flash-attention branch in swinuNetr is only reached when
gradients are disabled, so `model.eval()` alone will still take the slow path.
Update the `SwinUNETR`/`forward`-path documentation to call out that
`use_flash_attention` requires `torch.no_grad()` or `torch.inference_mode()`,
and consider adding a one-time warning in the `use_flash_attention` guard when
`torch.is_grad_enabled()` is true so users can spot the mismatch.
In `@tests/networks/nets/test_swin_unetr.py`:
- Line 114: The loop in the test that iterates over ref and out should use zip
with strict=True to make the equal-length assumption explicit and guard against
future mismatches. Update the zip call inside the relevant test in
test_swin_unetr.py so the iteration over ref and out is strict, preserving the
existing comparison logic while documenting intent.
---
Nitpick comments:
In `@monai/networks/nets/swin_unetr.py`:
- Around line 545-548: The shifted-window bias path in the attention logic still
materializes a full (b, heads, n, n) tensor and uses mask.view in the
`WindowAttention`/relative-position-bias branch, so update that branch to
preserve broadcasting instead of expanding into the full score matrix and
replace the contiguous-only `view` on `mask` with `reshape` (or equivalent safe
reshaping) to match the non-flash path behavior.
In `@tests/networks/nets/test_swin_unetr.py`:
- Line 115: The parity assertion in the Swin UNETR test is too loose relative to
the stated float32 target, so tighten the tolerance in the test that uses
assert_allclose to reflect the intended ~3e-6 parity goal. Update the check in
test_swin_unetr.py near the existing assert_allclose call so it either uses a
much stricter atol/rtol or switches to a float64 bit-exact comparison if that is
the intended standard, keeping the test aligned with the PR objective.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 320211db-4262-434b-a821-dcac487b9efc
📒 Files selected for processing (2)
monai/networks/nets/swin_unetr.pytests/networks/nets/test_swin_unetr.py
Add a use_flash_attention flag (default False) threaded from SwinUNETR through SwinTransformer, BasicLayer and SwinTransformerBlock to WindowAttention. When enabled and autograd is disabled (and not scripting), attention is computed with torch.nn.functional.scaled_dot_product_attention, folding the relative position bias and the shifted-window mask into a single additive attn_mask cast to the query dtype. The fused kernel avoids materializing the score matrix. Output matches the default path; training, scripting and the default path are unchanged. This mirrors the flash-attention option already in MONAI's SelfAttention, CrossAttention and CABlock. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Assert the SwinViT encoder features with use_flash_attention=True match the default path, in float64 for a tight bit-level comparison. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
30e09af to
1713e89
Compare
Fixes #8973 .
Description
SwinUNETR's WindowAttention builds the full (nWindows*heads, N, N) score matrix by hand before softmax. This adds an opt-in use_flash_attention flag (default False, so existing behaviour is unchanged) that routes attention through torch.nn.functional.scaled_dot_product_attention, folding the relative position bias and, for shifted windows, the attention mask into one additive attn_mask cast to the query dtype. This mirrors the flash-attention option already in MONAI's SelfAttention, CrossAttention and CABlock.
Measured on the SwinUNETR encoder (SwinViT) forward, inference, single GPU, best-of-5. Float32 output matches the default path to within 3e-6 and is bit-exact in float64, verified across 2D and 3D, batch sizes 1 to 4, and non-cubic inputs.
The flag is threaded through SwinTransformer, BasicLayer and SwinTransformerBlock to WindowAttention, exactly as use_v2 and use_checkpoint are. No parameters or buffers change, so pretrained weights load unchanged. The flash path is used only when autograd is disabled and the module is not scripted, so training and TorchScript keep the original path byte-for-byte; this is a deliberate choice to leave training numerics untouched.
Types of changes