Skip to content

Commit 82f0e0e

Browse files
author
Rahul Mangalampalli
committed
Avoid full mask allocation in unfused padding causal attention
1 parent 3fffa55 commit 82f0e0e

2 files changed

Lines changed: 216 additions & 20 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from transformer_engine.pytorch.attention.dot_product_attention import (
2727
_attention_backends,
2828
)
29+
from transformer_engine.pytorch.attention.dot_product_attention import backends as dpa_backends
30+
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
2931
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
3032
FlashAttentionUtils,
3133
check_set_window_size,
@@ -647,6 +649,75 @@ def test_dpa_mask(dtype, model_configs, model):
647649
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
648650

649651

652+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.")
653+
def test_unfused_thd_padding_causal_uses_sdpa_without_full_mask(monkeypatch):
654+
"""Unfused THD padding_causal should avoid materializing a full quadratic mask."""
655+
reset_rng_states()
656+
batch_size = 2
657+
num_heads = 2
658+
head_dim = 16
659+
seqlens = torch.tensor([3, 5], dtype=torch.int32, device="cuda")
660+
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
661+
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
662+
total_seqlen = int(cu_seqlens[-1].item())
663+
max_seqlen = int(seqlens.max().item())
664+
665+
query = torch.randn(
666+
total_seqlen, num_heads, head_dim, dtype=torch.float16, device="cuda", requires_grad=True
667+
)
668+
key = torch.randn_like(query, requires_grad=True)
669+
value = torch.randn_like(query, requires_grad=True)
670+
softmax_scale = head_dim**-0.5
671+
672+
expected = []
673+
with torch.no_grad():
674+
for batch_id in range(batch_size):
675+
start = int(cu_seqlens[batch_id].item())
676+
end = int(cu_seqlens[batch_id + 1].item())
677+
q = query[start:end].permute(1, 0, 2).unsqueeze(0)
678+
k = key[start:end].permute(1, 0, 2).unsqueeze(0)
679+
v = value[start:end].permute(1, 0, 2).unsqueeze(0)
680+
expected.append(
681+
torch.nn.functional.scaled_dot_product_attention(
682+
q, k, v, dropout_p=0.0, is_causal=True, scale=softmax_scale
683+
)
684+
.squeeze(0)
685+
.permute(1, 0, 2)
686+
.reshape(end - start, -1)
687+
)
688+
expected = torch.cat(expected, dim=0)
689+
690+
def fail_get_full_mask(*args, **kwargs):
691+
raise AssertionError("get_full_mask should not be called for this path")
692+
693+
monkeypatch.setattr(dpa_utils, "get_full_mask", fail_get_full_mask)
694+
695+
attention = dpa_backends.UnfusedDotProductAttention(
696+
softmax_scale=softmax_scale,
697+
attention_type="self",
698+
attention_dropout=0.0,
699+
).eval()
700+
output = attention(
701+
{},
702+
query,
703+
key,
704+
value,
705+
qkv_layout="thd_thd_thd",
706+
cu_seqlens_q=cu_seqlens,
707+
cu_seqlens_kv=cu_seqlens,
708+
max_seqlen_q=max_seqlen,
709+
max_seqlen_kv=max_seqlen,
710+
attn_mask_type="padding_causal",
711+
window_size=(-1, 0),
712+
)
713+
714+
torch.testing.assert_close(output, expected, rtol=1e-3, atol=1e-3)
715+
output.float().sum().backward()
716+
assert query.grad is not None
717+
assert key.grad is not None
718+
assert value.grad is not None
719+
720+
650721
model_configs_bias = {
651722
# test: ModelConfig(b, sq, hq, dqk)
652723
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 145 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,111 @@ def fast_setattr(self, name: str, value: Any) -> None:
365365
"""Fast attribute set for non-parameter fields."""
366366
self.__dict__[name] = value
367367

368+
def _use_varlen_sdpa(
369+
self,
370+
attn_mask_type: str,
371+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
372+
window_size: Optional[Tuple[int, int]],
373+
core_attention_bias_type: str,
374+
alibi_slopes: Optional[torch.Tensor],
375+
fp8: bool,
376+
) -> bool:
377+
"""Whether PyTorch SDPA can replace unfused attention without materializing masks."""
378+
if self.attention_type != "self":
379+
return False
380+
if attn_mask_type != "padding_causal":
381+
return False
382+
if window_size not in [None, (-1, 0), (-1, -1)]:
383+
return False
384+
if attn_mask_type == "padding_causal" and attention_mask is None:
385+
return False
386+
if isinstance(attention_mask, tuple):
387+
return False
388+
return (
389+
core_attention_bias_type == "no_bias"
390+
and self.attention_dropout.p == 0.0
391+
and alibi_slopes is None
392+
and self.softmax_type == "vanilla"
393+
and not self.return_max_logit
394+
and not fp8
395+
)
396+
397+
def _format_context(
398+
self,
399+
context_layer: torch.Tensor,
400+
q_format: str,
401+
max_seqlen_q: int,
402+
batch_size: int,
403+
cu_seqlens_q: Optional[torch.Tensor],
404+
) -> torch.Tensor:
405+
"""Convert context from [b, h, sq, d] to the requested output layout."""
406+
if q_format == "sbhd":
407+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
408+
return context_layer.view(max_seqlen_q, batch_size, -1)
409+
if q_format == "bshd":
410+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
411+
return context_layer.view(batch_size, max_seqlen_q, -1)
412+
if q_format == "thd":
413+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
414+
context_layer = ConvertBSHDtoTHD.apply(context_layer, cu_seqlens_q)
415+
return context_layer.view(context_layer.shape[0], -1)
416+
raise ValueError(f"Unsupported q_format = {q_format}!")
417+
418+
def _forward_varlen_sdpa(
419+
self,
420+
query_layer: torch.Tensor,
421+
key_layer: torch.Tensor,
422+
value_layer: torch.Tensor,
423+
q_format: str,
424+
batch_size: int,
425+
max_seqlen_q: int,
426+
cu_seqlens_q: Optional[torch.Tensor],
427+
attention_mask: Optional[torch.Tensor],
428+
scale: float,
429+
) -> torch.Tensor:
430+
"""Run causal self-attention without expanding padding masks to [b, 1, sq, sk]."""
431+
context_layer = torch.zeros(
432+
batch_size,
433+
query_layer.size(2),
434+
max_seqlen_q,
435+
value_layer.size(3),
436+
dtype=query_layer.dtype,
437+
device=query_layer.device,
438+
)
439+
440+
if attention_mask is not None:
441+
seqlens_q = attention_mask.logical_not()[:, 0, 0, :].sum(dim=1)
442+
else:
443+
seqlens_q = torch.full(
444+
(batch_size,), max_seqlen_q, dtype=torch.int64, device=query_layer.device
445+
)
446+
447+
dropout_p = self.attention_dropout.p if self.training else 0.0
448+
with self.attention_dropout_ctx():
449+
for batch_id in range(batch_size):
450+
seqlen_q = int(seqlens_q[batch_id].item())
451+
if seqlen_q == 0:
452+
continue
453+
query = query_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
454+
key = key_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
455+
value = value_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
456+
context_layer[batch_id, :, :seqlen_q, :] = F.scaled_dot_product_attention(
457+
query,
458+
key,
459+
value,
460+
dropout_p=dropout_p,
461+
is_causal=True,
462+
scale=scale,
463+
).squeeze(0)
464+
465+
return self._format_context(
466+
context_layer,
467+
q_format,
468+
max_seqlen_q,
469+
batch_size,
470+
cu_seqlens_q,
471+
)
472+
368473
def forward(
369474
self,
370475
_alibi_cache: Dict[str, Any],
@@ -457,22 +562,6 @@ def forward(
457562
max_seqlen_kv,
458563
self.attention_type,
459564
)
460-
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
461-
dpa_utils.get_full_mask(
462-
max_seqlen_q,
463-
max_seqlen_kv,
464-
attn_mask_type=attn_mask_type,
465-
attention_mask=attention_mask,
466-
window_size=window_size,
467-
attention_type=self.attention_type,
468-
bottom_right_alignment=(
469-
attn_mask_type not in ["causal", "padding_causal"]
470-
if bottom_right_diagonal is None
471-
else bottom_right_diagonal
472-
),
473-
)
474-
)
475-
476565
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
477566

478567
# [b, h, sq, sk]
@@ -494,6 +583,46 @@ def forward(
494583
int(query_layer.shape[2] / value_layer.shape[2]), dim=2
495584
)
496585

586+
scale = self.softmax_scale
587+
if apply_qk_layer_scaling:
588+
scale /= self.layer_number
589+
590+
if self._use_varlen_sdpa(
591+
attn_mask_type,
592+
attention_mask,
593+
window_size,
594+
core_attention_bias_type,
595+
alibi_slopes,
596+
fp8,
597+
):
598+
return self._forward_varlen_sdpa(
599+
query_layer,
600+
key_layer,
601+
value_layer,
602+
q_format,
603+
batch_size,
604+
max_seqlen_q,
605+
cu_seqlens_q,
606+
attention_mask,
607+
self.softmax_scale,
608+
)
609+
610+
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
611+
dpa_utils.get_full_mask(
612+
max_seqlen_q,
613+
max_seqlen_kv,
614+
attn_mask_type=attn_mask_type,
615+
attention_mask=attention_mask,
616+
window_size=window_size,
617+
attention_type=self.attention_type,
618+
bottom_right_alignment=(
619+
attn_mask_type not in ["causal", "padding_causal"]
620+
if bottom_right_diagonal is None
621+
else bottom_right_diagonal
622+
),
623+
)
624+
)
625+
497626
# preallocting result tensor: [b * h, sq, sk]
498627
matmul_result = torch.empty(
499628
output_size[0] * output_size[1],
@@ -503,10 +632,6 @@ def forward(
503632
device=torch.cuda.current_device(),
504633
)
505634

506-
scale = self.softmax_scale
507-
if apply_qk_layer_scaling:
508-
scale /= self.layer_number
509-
510635
if fp8:
511636
# get fp8 recipe for DPA
512637
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()

0 commit comments

Comments
 (0)