Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
use_flash_attention: if True, dispatch attention through
``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend;
the true flash kernel is used when no custom additive attention bias is passed.
Pure ``causal`` masking (with no ``rel_pos_embedding``) keeps the fast path via
``is_causal=True``. When an additive bias is required (for example,
``rel_pos_embedding``, or ``causal`` merged with another bias), PyTorch falls
back to the memory-efficient or cuDNN SDPA backend.
"""

super().__init__()
Expand All @@ -88,9 +93,6 @@ def __init__(
"to True. save_attn can only be used if use_flash_attention is False"
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
Expand Down Expand Up @@ -155,8 +157,31 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None):
k = k.to(self.attention_dtype)

if self.use_flash_attention:
# Additive bias path mirrors SABlock: null bias preserves the true
# flash kernel fast path; any of rel_pos_embedding / causal forces
# fallback to the efficient or cuDNN SDPA backend.
bias: torch.Tensor | None = None
lq, lk = q.shape[-2], k.shape[-2]

if self.rel_positional_embedding is not None:
zero_logits = torch.zeros(q.shape[0], self.num_heads, lq, lk, dtype=q.dtype, device=q.device)
bias = self.rel_positional_embedding(x, zero_logits, q)

is_causal_arg = self.causal
if self.causal and bias is not None:
causal_bias = torch.zeros(lq, lk, dtype=q.dtype, device=q.device)
causal_bias.masked_fill_(self.causal_mask[0, 0, :lq, :lk] == 0, float("-inf"))
bias = bias + causal_bias
is_causal_arg = False

x = torch.nn.functional.scaled_dot_product_attention(
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
query=q,
key=k,
value=v,
attn_mask=bias,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=is_causal_arg,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand Down
42 changes: 35 additions & 7 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ def __init__(
attention_dtype: cast attention operations to this dtype.
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
use_flash_attention: if True, dispatch attention through
``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend;
the true flash kernel is used when no custom additive attention bias is passed.
Pure ``causal`` masking (with no ``rel_pos_embedding`` or ``attn_mask``) keeps the
fast path via ``is_causal=True``. When an additive bias is required (for example,
``rel_pos_embedding``, or ``causal``/``attn_mask`` merged with another bias),
PyTorch falls back to the memory-efficient or cuDNN SDPA backend.

"""

Expand Down Expand Up @@ -94,9 +99,6 @@ def __init__(
"to True. save_attn can only be used if use_flash_attention is False."
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj: nn.Linear | nn.Identity
Expand Down Expand Up @@ -174,14 +176,40 @@ def forward(self, x, attn_mask: torch.Tensor | None = None):
k = k.to(self.attention_dtype)

if self.use_flash_attention:
# Build an additive attention bias when we have to combine
# rel_pos_embedding, a causal mask, or a user attn_mask. A null bias
# preserves the no-mask fast path so PyTorch can still pick the true
# flash kernel when available.
bias: torch.Tensor | None = None
lq, lk = q.shape[-2], k.shape[-2]

if self.rel_positional_embedding is not None:
zero_logits = torch.zeros(q.shape[0], self.num_heads, lq, lk, dtype=q.dtype, device=q.device)
bias = self.rel_positional_embedding(x, zero_logits, q)

is_causal_arg = self.causal
if self.causal and (bias is not None or attn_mask is not None):
causal_bias = torch.zeros(lq, lk, dtype=q.dtype, device=q.device)
causal_bias.masked_fill_(self.causal_mask[0, 0, :lq, :lk] == 0, float("-inf"))
bias = causal_bias if bias is None else bias + causal_bias
is_causal_arg = False

if attn_mask is not None:
if self.causal:
raise ValueError("Causal attention does not support attention masks.")
mask_bias = torch.zeros_like(attn_mask, dtype=q.dtype)
mask_bias.masked_fill_(attn_mask == 0, float("-inf"))
mask_bias = mask_bias.unsqueeze(1).unsqueeze(2)
bias = mask_bias if bias is None else bias + mask_bias

x = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=attn_mask,
attn_mask=bias,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
is_causal=is_causal_arg,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand Down
57 changes: 47 additions & 10 deletions tests/networks/blocks/test_crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
[
{
**{k: v for k, v in params.items() if k not in ["rel_pos_embedding_val"]},
"rel_pos_embedding": params["rel_pos_embedding_val"] if not params["use_flash_attention"] else None,
"rel_pos_embedding": params["rel_pos_embedding_val"],
},
(2, 512, params["hidden_size"]),
(2, 512, params["hidden_size"]),
Expand Down Expand Up @@ -69,16 +69,53 @@ def test_save_attn_with_flash_attention(self):
hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True
)

@skipUnless(has_einops, "Requires einops")
def test_rel_pos_embedding_with_flash_attention(self):
with self.assertRaises(ValueError):
CrossAttentionBlock(
hidden_size=128,
num_heads=3,
dropout_rate=0.1,
use_flash_attention=True,
save_attn=False,
rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
)
# rel_pos_embedding combined with use_flash_attention now dispatches
# via SDPA with an additive bias. Must match the explicit path.
for input_size in [(16, 32), (8, 8, 8)]:
input_param = {
"hidden_size": 128,
"num_heads": 4,
"dropout_rate": 0.0,
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
"input_size": input_size,
}
device = "cuda:0" if torch.cuda.is_available() else "cpu"
block_flash = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device)
block_ref = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device)
block_ref.load_state_dict(block_flash.state_dict())
seq_len = int(np.prod(input_size))
test_data = torch.randn(2, seq_len, 128).to(device)
with eval_mode(block_flash), eval_mode(block_ref):
out_flash = block_flash(test_data)
out_ref = block_ref(test_data)
assert_allclose(out_flash, out_ref, atol=1e-4)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
@skipUnless(has_einops, "Requires einops")
def test_causal_rel_pos_with_flash_attention(self):
# Exercise the merged causal-bias branch: causal=True together with
# rel_pos_embedding builds an additive bias and disables is_causal.
for input_size in [(16, 32), (8, 8, 8)]:
seq_len = int(np.prod(input_size))
input_param = {
"hidden_size": 128,
"num_heads": 4,
"dropout_rate": 0.0,
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
"input_size": input_size,
"causal": True,
"sequence_length": seq_len,
}
device = "cuda:0" if torch.cuda.is_available() else "cpu"
block_flash = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device)
block_ref = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device)
block_ref.load_state_dict(block_flash.state_dict())
test_data = torch.randn(2, seq_len, 128).to(device)
with eval_mode(block_flash), eval_mode(block_ref):
out_flash = block_flash(test_data)
out_ref = block_ref(test_data)
assert_allclose(out_flash, out_ref, atol=1e-4)

@skipUnless(has_einops, "Requires einops")
def test_attention_dim_not_multiple_of_heads(self):
Expand Down
83 changes: 73 additions & 10 deletions tests/networks/blocks/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"input_size": input_size,
"include_fc": include_fc,
"use_combined_linear": use_combined_linear,
"use_flash_attention": True if rel_pos_embedding is None else False,
"use_flash_attention": True,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
Expand All @@ -67,16 +67,79 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)

@skipUnless(has_einops, "Requires einops")
def test_rel_pos_embedding_with_flash_attention(self):
with self.assertRaises(ValueError):
SABlock(
hidden_size=128,
num_heads=3,
dropout_rate=0.1,
use_flash_attention=True,
save_attn=False,
rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
)
# rel_pos_embedding is now allowed with use_flash_attention; SDPA picks
# a fused backend that supports an additive attention bias. The two
# code paths must be numerically equivalent for the same weights.
for input_size in [(16, 32), (8, 8, 8)]:
input_param = {
"hidden_size": 128,
"num_heads": 4,
"dropout_rate": 0.0,
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
"input_size": input_size,
}
device = "cuda:0" if torch.cuda.is_available() else "cpu"
block_flash = SABlock(**input_param, use_flash_attention=True).to(device)
block_ref = SABlock(**input_param, use_flash_attention=False).to(device)
block_ref.load_state_dict(block_flash.state_dict())
test_data = torch.randn(2, int(np.prod(input_size)), 128).to(device)
with eval_mode(block_flash), eval_mode(block_ref):
out_flash = block_flash(test_data)
out_ref = block_ref(test_data)
assert_allclose(out_flash, out_ref, atol=1e-4)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

@skipUnless(has_einops, "Requires einops")
def test_causal_rel_pos_with_flash_attention(self):
# Exercise the merged causal-bias branch: causal=True together with
# rel_pos_embedding builds an additive bias and disables is_causal,
# so flash and reference paths must still match numerically.
for input_size in [(16, 32), (8, 8, 8)]:
seq_len = int(np.prod(input_size))
input_param = {
"hidden_size": 128,
"num_heads": 4,
"dropout_rate": 0.0,
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
"input_size": input_size,
"causal": True,
"sequence_length": seq_len,
}
device = "cuda:0" if torch.cuda.is_available() else "cpu"
block_flash = SABlock(**input_param, use_flash_attention=True).to(device)
block_ref = SABlock(**input_param, use_flash_attention=False).to(device)
block_ref.load_state_dict(block_flash.state_dict())
test_data = torch.randn(2, seq_len, 128).to(device)
with eval_mode(block_flash), eval_mode(block_ref):
out_flash = block_flash(test_data)
out_ref = block_ref(test_data)
assert_allclose(out_flash, out_ref, atol=1e-4)

@skipUnless(has_einops, "Requires einops")
def test_attn_mask_rel_pos_with_flash_attention(self):
# Exercise the user-attn-mask + rel_pos branch: the user mask is
# merged into the additive bias passed via SDPA's attn_mask argument.
for input_size in [(16, 32), (8, 8, 8)]:
seq_len = int(np.prod(input_size))
input_param = {
"hidden_size": 128,
"num_heads": 4,
"dropout_rate": 0.0,
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
"input_size": input_size,
}
device = "cuda:0" if torch.cuda.is_available() else "cpu"
block_flash = SABlock(**input_param, use_flash_attention=True).to(device)
block_ref = SABlock(**input_param, use_flash_attention=False).to(device)
block_ref.load_state_dict(block_flash.state_dict())
test_data = torch.randn(2, seq_len, 128).to(device)
attn_mask = torch.ones(2, seq_len, dtype=torch.bool, device=device)
attn_mask[:, seq_len // 2 :] = False # mask out the second half
with eval_mode(block_flash), eval_mode(block_ref):
out_flash = block_flash(test_data, attn_mask=attn_mask)
out_ref = block_ref(test_data, attn_mask=attn_mask)
assert_allclose(out_flash, out_ref, atol=1e-4)

def test_save_attn_with_flash_attention(self):
with self.assertRaises(ValueError):
Expand Down
Loading