Skip to content

Commit 51d298b

Browse files
committed
Avoid full mask allocation in unfused padding causal attention
1 parent cabc6b6 commit 51d298b

2 files changed

Lines changed: 216 additions & 20 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from transformer_engine.pytorch.attention.dot_product_attention import (
2727
_attention_backends,
2828
)
29+
from transformer_engine.pytorch.attention.dot_product_attention import backends as dpa_backends
30+
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
2931
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
3032
FlashAttentionUtils,
3133
check_set_window_size,
@@ -667,6 +669,75 @@ def test_dpa_mask(dtype, model_configs, model):
667669
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
668670

669671

672+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.")
673+
def test_unfused_thd_padding_causal_uses_sdpa_without_full_mask(monkeypatch):
674+
"""Unfused THD padding_causal should avoid materializing a full quadratic mask."""
675+
reset_rng_states()
676+
batch_size = 2
677+
num_heads = 2
678+
head_dim = 16
679+
seqlens = torch.tensor([3, 5], dtype=torch.int32, device="cuda")
680+
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
681+
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
682+
total_seqlen = int(cu_seqlens[-1].item())
683+
max_seqlen = int(seqlens.max().item())
684+
685+
query = torch.randn(
686+
total_seqlen, num_heads, head_dim, dtype=torch.float16, device="cuda", requires_grad=True
687+
)
688+
key = torch.randn_like(query, requires_grad=True)
689+
value = torch.randn_like(query, requires_grad=True)
690+
softmax_scale = head_dim**-0.5
691+
692+
expected = []
693+
with torch.no_grad():
694+
for batch_id in range(batch_size):
695+
start = int(cu_seqlens[batch_id].item())
696+
end = int(cu_seqlens[batch_id + 1].item())
697+
q = query[start:end].permute(1, 0, 2).unsqueeze(0)
698+
k = key[start:end].permute(1, 0, 2).unsqueeze(0)
699+
v = value[start:end].permute(1, 0, 2).unsqueeze(0)
700+
expected.append(
701+
torch.nn.functional.scaled_dot_product_attention(
702+
q, k, v, dropout_p=0.0, is_causal=True, scale=softmax_scale
703+
)
704+
.squeeze(0)
705+
.permute(1, 0, 2)
706+
.reshape(end - start, -1)
707+
)
708+
expected = torch.cat(expected, dim=0)
709+
710+
def fail_get_full_mask(*args, **kwargs):
711+
raise AssertionError("get_full_mask should not be called for this path")
712+
713+
monkeypatch.setattr(dpa_utils, "get_full_mask", fail_get_full_mask)
714+
715+
attention = dpa_backends.UnfusedDotProductAttention(
716+
softmax_scale=softmax_scale,
717+
attention_type="self",
718+
attention_dropout=0.0,
719+
).eval()
720+
output = attention(
721+
{},
722+
query,
723+
key,
724+
value,
725+
qkv_layout="thd_thd_thd",
726+
cu_seqlens_q=cu_seqlens,
727+
cu_seqlens_kv=cu_seqlens,
728+
max_seqlen_q=max_seqlen,
729+
max_seqlen_kv=max_seqlen,
730+
attn_mask_type="padding_causal",
731+
window_size=(-1, 0),
732+
)
733+
734+
torch.testing.assert_close(output, expected, rtol=1e-3, atol=1e-3)
735+
output.float().sum().backward()
736+
assert query.grad is not None
737+
assert key.grad is not None
738+
assert value.grad is not None
739+
740+
670741
model_configs_bias = {
671742
# test: ModelConfig(b, sq, hq, dqk)
672743
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 145 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,111 @@ def fast_setattr(self, name: str, value: Any) -> None:
342342
"""Fast attribute set for non-parameter fields."""
343343
self.__dict__[name] = value
344344

345+
def _use_varlen_sdpa(
346+
self,
347+
attn_mask_type: str,
348+
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
349+
window_size: Optional[Tuple[int, int]],
350+
core_attention_bias_type: str,
351+
alibi_slopes: Optional[torch.Tensor],
352+
fp8: bool,
353+
) -> bool:
354+
"""Whether PyTorch SDPA can replace unfused attention without materializing masks."""
355+
if self.attention_type != "self":
356+
return False
357+
if attn_mask_type != "padding_causal":
358+
return False
359+
if window_size not in [None, (-1, 0), (-1, -1)]:
360+
return False
361+
if attn_mask_type == "padding_causal" and attention_mask is None:
362+
return False
363+
if isinstance(attention_mask, tuple):
364+
return False
365+
return (
366+
core_attention_bias_type == "no_bias"
367+
and self.attention_dropout.p == 0.0
368+
and alibi_slopes is None
369+
and self.softmax_type == "vanilla"
370+
and not self.return_max_logit
371+
and not fp8
372+
)
373+
374+
def _format_context(
375+
self,
376+
context_layer: torch.Tensor,
377+
q_format: str,
378+
max_seqlen_q: int,
379+
batch_size: int,
380+
cu_seqlens_q: Optional[torch.Tensor],
381+
) -> torch.Tensor:
382+
"""Convert context from [b, h, sq, d] to the requested output layout."""
383+
if q_format == "sbhd":
384+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
385+
return context_layer.view(max_seqlen_q, batch_size, -1)
386+
if q_format == "bshd":
387+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
388+
return context_layer.view(batch_size, max_seqlen_q, -1)
389+
if q_format == "thd":
390+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
391+
context_layer = ConvertBSHDtoTHD.apply(context_layer, cu_seqlens_q)
392+
return context_layer.view(context_layer.shape[0], -1)
393+
raise ValueError(f"Unsupported q_format = {q_format}!")
394+
395+
def _forward_varlen_sdpa(
396+
self,
397+
query_layer: torch.Tensor,
398+
key_layer: torch.Tensor,
399+
value_layer: torch.Tensor,
400+
q_format: str,
401+
batch_size: int,
402+
max_seqlen_q: int,
403+
cu_seqlens_q: Optional[torch.Tensor],
404+
attention_mask: Optional[torch.Tensor],
405+
scale: float,
406+
) -> torch.Tensor:
407+
"""Run causal self-attention without expanding padding masks to [b, 1, sq, sk]."""
408+
context_layer = torch.zeros(
409+
batch_size,
410+
query_layer.size(2),
411+
max_seqlen_q,
412+
value_layer.size(3),
413+
dtype=query_layer.dtype,
414+
device=query_layer.device,
415+
)
416+
417+
if attention_mask is not None:
418+
seqlens_q = attention_mask.logical_not()[:, 0, 0, :].sum(dim=1)
419+
else:
420+
seqlens_q = torch.full(
421+
(batch_size,), max_seqlen_q, dtype=torch.int64, device=query_layer.device
422+
)
423+
424+
dropout_p = self.attention_dropout.p if self.training else 0.0
425+
with self.attention_dropout_ctx():
426+
for batch_id in range(batch_size):
427+
seqlen_q = int(seqlens_q[batch_id].item())
428+
if seqlen_q == 0:
429+
continue
430+
query = query_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
431+
key = key_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
432+
value = value_layer[:seqlen_q, batch_id].permute(1, 0, 2).unsqueeze(0)
433+
context_layer[batch_id, :, :seqlen_q, :] = F.scaled_dot_product_attention(
434+
query,
435+
key,
436+
value,
437+
dropout_p=dropout_p,
438+
is_causal=True,
439+
scale=scale,
440+
).squeeze(0)
441+
442+
return self._format_context(
443+
context_layer,
444+
q_format,
445+
max_seqlen_q,
446+
batch_size,
447+
cu_seqlens_q,
448+
)
449+
345450
def forward(
346451
self,
347452
_alibi_cache: Dict[str, Any],
@@ -434,22 +539,6 @@ def forward(
434539
max_seqlen_kv,
435540
self.attention_type,
436541
)
437-
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
438-
dpa_utils.get_full_mask(
439-
max_seqlen_q,
440-
max_seqlen_kv,
441-
attn_mask_type=attn_mask_type,
442-
attention_mask=attention_mask,
443-
window_size=window_size,
444-
attention_type=self.attention_type,
445-
bottom_right_alignment=(
446-
attn_mask_type not in ["causal", "padding_causal"]
447-
if bottom_right_diagonal is None
448-
else bottom_right_diagonal
449-
),
450-
)
451-
)
452-
453542
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
454543

455544
# [b, h, sq, sk]
@@ -471,6 +560,46 @@ def forward(
471560
int(query_layer.shape[2] / value_layer.shape[2]), dim=2
472561
)
473562

563+
scale = self.softmax_scale
564+
if apply_qk_layer_scaling:
565+
scale /= self.layer_number
566+
567+
if self._use_varlen_sdpa(
568+
attn_mask_type,
569+
attention_mask,
570+
window_size,
571+
core_attention_bias_type,
572+
alibi_slopes,
573+
fp8,
574+
):
575+
return self._forward_varlen_sdpa(
576+
query_layer,
577+
key_layer,
578+
value_layer,
579+
q_format,
580+
batch_size,
581+
max_seqlen_q,
582+
cu_seqlens_q,
583+
attention_mask,
584+
self.softmax_scale,
585+
)
586+
587+
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
588+
dpa_utils.get_full_mask(
589+
max_seqlen_q,
590+
max_seqlen_kv,
591+
attn_mask_type=attn_mask_type,
592+
attention_mask=attention_mask,
593+
window_size=window_size,
594+
attention_type=self.attention_type,
595+
bottom_right_alignment=(
596+
attn_mask_type not in ["causal", "padding_causal"]
597+
if bottom_right_diagonal is None
598+
else bottom_right_diagonal
599+
),
600+
)
601+
)
602+
474603
# preallocting result tensor: [b * h, sq, sk]
475604
matmul_result = torch.empty(
476605
output_size[0] * output_size[1],
@@ -480,10 +609,6 @@ def forward(
480609
device=torch.cuda.current_device(),
481610
)
482611

483-
scale = self.softmax_scale
484-
if apply_qk_layer_scaling:
485-
scale /= self.layer_number
486-
487612
if fp8:
488613
# get fp8 recipe for DPA
489614
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()

0 commit comments

Comments
 (0)