Skip to content

Commit c5b2a1c

Browse files
committed
test(attention): cover merged causal- and attn_mask-bias flash branches
Address CodeRabbit review on PR #8842: - Narrow the use_flash_attention docstring in SABlock and CrossAttentionBlock so it reflects the actual implementation: pure causal masking keeps the fast path via is_causal=True; only an additive bias (rel_pos_embedding, or causal/attn_mask merged with another bias) forces SDPA to fall back to the memory-efficient or cuDNN backend. - Extend the numerical-equivalence tests to cover the new merged-bias paths: causal=True + rel_pos_embedding for both blocks, and attn_mask + rel_pos_embedding for SABlock. All cases assert assert_allclose(out_flash, out_ref, atol=1e-4) on 2D and 3D inputs. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent b7d4786 commit c5b2a1c

4 files changed

Lines changed: 86 additions & 6 deletions

File tree

monai/networks/blocks/crossattention.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ def __init__(
6363
attention_dtype: cast attention operations to this dtype.
6464
use_flash_attention: if True, dispatch attention through
6565
``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend;
66-
the true flash kernel is used only when no attention bias is present. When combined
67-
with ``rel_pos_embedding`` or ``causal``, PyTorch will fall back to the
68-
memory-efficient or cuDNN SDPA backend.
66+
the true flash kernel is used when no custom additive attention bias is passed.
67+
Pure ``causal`` masking (with no ``rel_pos_embedding``) keeps the fast path via
68+
``is_causal=True``. When an additive bias is required (for example,
69+
``rel_pos_embedding``, or ``causal`` merged with another bias), PyTorch falls
70+
back to the memory-efficient or cuDNN SDPA backend.
6971
"""
7072

7173
super().__init__()

monai/networks/blocks/selfattention.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ def __init__(
6565
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
6666
use_flash_attention: if True, dispatch attention through
6767
``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend;
68-
the true flash kernel is used only when no attention bias is present. When combined
69-
with ``rel_pos_embedding``, ``causal``, or ``attn_mask``, PyTorch will fall back to
70-
the memory-efficient or cuDNN SDPA backend.
68+
the true flash kernel is used when no custom additive attention bias is passed.
69+
Pure ``causal`` masking (with no ``rel_pos_embedding`` or ``attn_mask``) keeps the
70+
fast path via ``is_causal=True``. When an additive bias is required (for example,
71+
``rel_pos_embedding``, or ``causal``/``attn_mask`` merged with another bias),
72+
PyTorch falls back to the memory-efficient or cuDNN SDPA backend.
7173
7274
"""
7375

tests/networks/blocks/test_crossattention.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,31 @@ def test_rel_pos_embedding_with_flash_attention(self):
9292
out_ref = block_ref(test_data)
9393
assert_allclose(out_flash, out_ref, atol=1e-4)
9494

95+
@skipUnless(has_einops, "Requires einops")
96+
def test_causal_rel_pos_with_flash_attention(self):
97+
# Exercise the merged causal-bias branch: causal=True together with
98+
# rel_pos_embedding builds an additive bias and disables is_causal.
99+
for input_size in [(16, 32), (8, 8, 8)]:
100+
seq_len = int(np.prod(input_size))
101+
input_param = {
102+
"hidden_size": 128,
103+
"num_heads": 4,
104+
"dropout_rate": 0.0,
105+
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
106+
"input_size": input_size,
107+
"causal": True,
108+
"sequence_length": seq_len,
109+
}
110+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
111+
block_flash = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device)
112+
block_ref = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device)
113+
block_ref.load_state_dict(block_flash.state_dict())
114+
test_data = torch.randn(2, seq_len, 128).to(device)
115+
with eval_mode(block_flash), eval_mode(block_ref):
116+
out_flash = block_flash(test_data)
117+
out_ref = block_ref(test_data)
118+
assert_allclose(out_flash, out_ref, atol=1e-4)
119+
95120
@skipUnless(has_einops, "Requires einops")
96121
def test_attention_dim_not_multiple_of_heads(self):
97122
with self.assertRaises(ValueError):

tests/networks/blocks/test_selfattention.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,57 @@ def test_rel_pos_embedding_with_flash_attention(self):
9090
out_ref = block_ref(test_data)
9191
assert_allclose(out_flash, out_ref, atol=1e-4)
9292

93+
@skipUnless(has_einops, "Requires einops")
94+
def test_causal_rel_pos_with_flash_attention(self):
95+
# Exercise the merged causal-bias branch: causal=True together with
96+
# rel_pos_embedding builds an additive bias and disables is_causal,
97+
# so flash and reference paths must still match numerically.
98+
for input_size in [(16, 32), (8, 8, 8)]:
99+
seq_len = int(np.prod(input_size))
100+
input_param = {
101+
"hidden_size": 128,
102+
"num_heads": 4,
103+
"dropout_rate": 0.0,
104+
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
105+
"input_size": input_size,
106+
"causal": True,
107+
"sequence_length": seq_len,
108+
}
109+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
110+
block_flash = SABlock(**input_param, use_flash_attention=True).to(device)
111+
block_ref = SABlock(**input_param, use_flash_attention=False).to(device)
112+
block_ref.load_state_dict(block_flash.state_dict())
113+
test_data = torch.randn(2, seq_len, 128).to(device)
114+
with eval_mode(block_flash), eval_mode(block_ref):
115+
out_flash = block_flash(test_data)
116+
out_ref = block_ref(test_data)
117+
assert_allclose(out_flash, out_ref, atol=1e-4)
118+
119+
@skipUnless(has_einops, "Requires einops")
120+
def test_attn_mask_rel_pos_with_flash_attention(self):
121+
# Exercise the user-attn-mask + rel_pos branch: the user mask is
122+
# merged into the additive bias passed via SDPA's attn_mask argument.
123+
for input_size in [(16, 32), (8, 8, 8)]:
124+
seq_len = int(np.prod(input_size))
125+
input_param = {
126+
"hidden_size": 128,
127+
"num_heads": 4,
128+
"dropout_rate": 0.0,
129+
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
130+
"input_size": input_size,
131+
}
132+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
133+
block_flash = SABlock(**input_param, use_flash_attention=True).to(device)
134+
block_ref = SABlock(**input_param, use_flash_attention=False).to(device)
135+
block_ref.load_state_dict(block_flash.state_dict())
136+
test_data = torch.randn(2, seq_len, 128).to(device)
137+
attn_mask = torch.ones(2, seq_len, dtype=torch.bool, device=device)
138+
attn_mask[:, seq_len // 2 :] = False # mask out the second half
139+
with eval_mode(block_flash), eval_mode(block_ref):
140+
out_flash = block_flash(test_data, attn_mask=attn_mask)
141+
out_ref = block_ref(test_data, attn_mask=attn_mask)
142+
assert_allclose(out_flash, out_ref, atol=1e-4)
143+
93144
def test_save_attn_with_flash_attention(self):
94145
with self.assertRaises(ValueError):
95146
SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True)

0 commit comments

Comments
 (0)