Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
use_flash_attention: if True, dispatch attention through
``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend;
the true flash kernel is used only when no attention bias is present. When combined
with ``rel_pos_embedding`` or ``causal``, PyTorch will fall back to the
memory-efficient or cuDNN SDPA backend.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
"""

super().__init__()
Expand All @@ -88,9 +91,6 @@ def __init__(
"to True. save_attn can only be used if use_flash_attention is False"
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
Expand Down Expand Up @@ -155,8 +155,31 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None):
k = k.to(self.attention_dtype)

if self.use_flash_attention:
# Additive bias path mirrors SABlock: null bias preserves the true
# flash kernel fast path; any of rel_pos_embedding / causal forces
# fallback to the efficient or cuDNN SDPA backend.
bias: torch.Tensor | None = None
lq, lk = q.shape[-2], k.shape[-2]

if self.rel_positional_embedding is not None:
zero_logits = torch.zeros(q.shape[0], self.num_heads, lq, lk, dtype=q.dtype, device=q.device)
bias = self.rel_positional_embedding(x, zero_logits, q)

is_causal_arg = self.causal
if self.causal and bias is not None:
causal_bias = torch.zeros(lq, lk, dtype=q.dtype, device=q.device)
causal_bias.masked_fill_(self.causal_mask[0, 0, :lq, :lk] == 0, float("-inf"))
bias = bias + causal_bias
is_causal_arg = False

x = torch.nn.functional.scaled_dot_product_attention(
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
query=q,
key=k,
value=v,
attn_mask=bias,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=is_causal_arg,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand Down
40 changes: 33 additions & 7 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ def __init__(
attention_dtype: cast attention operations to this dtype.
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
use_flash_attention: if True, dispatch attention through
``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend;
the true flash kernel is used only when no attention bias is present. When combined
with ``rel_pos_embedding``, ``causal``, or ``attn_mask``, PyTorch will fall back to
the memory-efficient or cuDNN SDPA backend.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

"""

Expand Down Expand Up @@ -94,9 +97,6 @@ def __init__(
"to True. save_attn can only be used if use_flash_attention is False."
)

if use_flash_attention and rel_pos_embedding is not None:
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj: nn.Linear | nn.Identity
Expand Down Expand Up @@ -174,14 +174,40 @@ def forward(self, x, attn_mask: torch.Tensor | None = None):
k = k.to(self.attention_dtype)

if self.use_flash_attention:
# Build an additive attention bias when we have to combine
# rel_pos_embedding, a causal mask, or a user attn_mask. A null bias
# preserves the no-mask fast path so PyTorch can still pick the true
# flash kernel when available.
bias: torch.Tensor | None = None
lq, lk = q.shape[-2], k.shape[-2]

if self.rel_positional_embedding is not None:
zero_logits = torch.zeros(q.shape[0], self.num_heads, lq, lk, dtype=q.dtype, device=q.device)
bias = self.rel_positional_embedding(x, zero_logits, q)

is_causal_arg = self.causal
if self.causal and (bias is not None or attn_mask is not None):
causal_bias = torch.zeros(lq, lk, dtype=q.dtype, device=q.device)
causal_bias.masked_fill_(self.causal_mask[0, 0, :lq, :lk] == 0, float("-inf"))
bias = causal_bias if bias is None else bias + causal_bias
is_causal_arg = False

if attn_mask is not None:
if self.causal:
raise ValueError("Causal attention does not support attention masks.")
mask_bias = torch.zeros_like(attn_mask, dtype=q.dtype)
mask_bias.masked_fill_(attn_mask == 0, float("-inf"))
mask_bias = mask_bias.unsqueeze(1).unsqueeze(2)
bias = mask_bias if bias is None else bias + mask_bias

x = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=attn_mask,
attn_mask=bias,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
is_causal=is_causal_arg,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand Down
32 changes: 22 additions & 10 deletions tests/networks/blocks/test_crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
[
{
**{k: v for k, v in params.items() if k not in ["rel_pos_embedding_val"]},
"rel_pos_embedding": params["rel_pos_embedding_val"] if not params["use_flash_attention"] else None,
"rel_pos_embedding": params["rel_pos_embedding_val"],
},
(2, 512, params["hidden_size"]),
(2, 512, params["hidden_size"]),
Expand Down Expand Up @@ -69,16 +69,28 @@ def test_save_attn_with_flash_attention(self):
hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True
)

@skipUnless(has_einops, "Requires einops")
def test_rel_pos_embedding_with_flash_attention(self):
with self.assertRaises(ValueError):
CrossAttentionBlock(
hidden_size=128,
num_heads=3,
dropout_rate=0.1,
use_flash_attention=True,
save_attn=False,
rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
)
# rel_pos_embedding combined with use_flash_attention now dispatches
# via SDPA with an additive bias. Must match the explicit path.
for input_size in [(16, 32), (8, 8, 8)]:
input_param = {
"hidden_size": 128,
"num_heads": 4,
"dropout_rate": 0.0,
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
"input_size": input_size,
}
device = "cuda:0" if torch.cuda.is_available() else "cpu"
block_flash = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device)
block_ref = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device)
block_ref.load_state_dict(block_flash.state_dict())
seq_len = int(np.prod(input_size))
test_data = torch.randn(2, seq_len, 128).to(device)
with eval_mode(block_flash), eval_mode(block_ref):
out_flash = block_flash(test_data)
out_ref = block_ref(test_data)
assert_allclose(out_flash, out_ref, atol=1e-4)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
@skipUnless(has_einops, "Requires einops")
def test_attention_dim_not_multiple_of_heads(self):
Expand Down
32 changes: 22 additions & 10 deletions tests/networks/blocks/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"input_size": input_size,
"include_fc": include_fc,
"use_combined_linear": use_combined_linear,
"use_flash_attention": True if rel_pos_embedding is None else False,
"use_flash_attention": True,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
Expand All @@ -67,16 +67,28 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)

@skipUnless(has_einops, "Requires einops")
def test_rel_pos_embedding_with_flash_attention(self):
with self.assertRaises(ValueError):
SABlock(
hidden_size=128,
num_heads=3,
dropout_rate=0.1,
use_flash_attention=True,
save_attn=False,
rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
)
# rel_pos_embedding is now allowed with use_flash_attention; SDPA picks
# a fused backend that supports an additive attention bias. The two
# code paths must be numerically equivalent for the same weights.
for input_size in [(16, 32), (8, 8, 8)]:
input_param = {
"hidden_size": 128,
"num_heads": 4,
"dropout_rate": 0.0,
"rel_pos_embedding": RelPosEmbedding.DECOMPOSED,
"input_size": input_size,
}
device = "cuda:0" if torch.cuda.is_available() else "cpu"
block_flash = SABlock(**input_param, use_flash_attention=True).to(device)
block_ref = SABlock(**input_param, use_flash_attention=False).to(device)
block_ref.load_state_dict(block_flash.state_dict())
test_data = torch.randn(2, int(np.prod(input_size)), 128).to(device)
with eval_mode(block_flash), eval_mode(block_ref):
out_flash = block_flash(test_data)
out_ref = block_ref(test_data)
assert_allclose(out_flash, out_ref, atol=1e-4)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def test_save_attn_with_flash_attention(self):
with self.assertRaises(ValueError):
Expand Down
Loading