Commit c5b2a1c
committed
test(attention): cover merged causal- and attn_mask-bias flash branches
Address CodeRabbit review on PR #8842:
- Narrow the use_flash_attention docstring in SABlock and
CrossAttentionBlock so it reflects the actual implementation: pure
causal masking keeps the fast path via is_causal=True; only an
additive bias (rel_pos_embedding, or causal/attn_mask merged with
another bias) forces SDPA to fall back to the memory-efficient or
cuDNN backend.
- Extend the numerical-equivalence tests to cover the new merged-bias
paths: causal=True + rel_pos_embedding for both blocks, and
attn_mask + rel_pos_embedding for SABlock. All cases assert
assert_allclose(out_flash, out_ref, atol=1e-4) on 2D and 3D inputs.
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>1 parent b7d4786 commit c5b2a1c
4 files changed
Lines changed: 86 additions & 6 deletions
File tree
- monai/networks/blocks
- tests/networks/blocks
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
66 | | - | |
67 | | - | |
68 | | - | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
69 | 71 | | |
70 | 72 | | |
71 | 73 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
65 | 65 | | |
66 | 66 | | |
67 | 67 | | |
68 | | - | |
69 | | - | |
70 | | - | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
71 | 73 | | |
72 | 74 | | |
73 | 75 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
92 | 92 | | |
93 | 93 | | |
94 | 94 | | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
95 | 120 | | |
96 | 121 | | |
97 | 122 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
93 | 144 | | |
94 | 145 | | |
95 | 146 | | |
| |||
0 commit comments