@@ -342,6 +342,111 @@ def fast_setattr(self, name: str, value: Any) -> None:
342342 """Fast attribute set for non-parameter fields."""
343343 self .__dict__ [name ] = value
344344
345+ def _use_varlen_sdpa (
346+ self ,
347+ attn_mask_type : str ,
348+ attention_mask : Optional [Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]],
349+ window_size : Optional [Tuple [int , int ]],
350+ core_attention_bias_type : str ,
351+ alibi_slopes : Optional [torch .Tensor ],
352+ fp8 : bool ,
353+ ) -> bool :
354+ """Whether PyTorch SDPA can replace unfused attention without materializing masks."""
355+ if self .attention_type != "self" :
356+ return False
357+ if attn_mask_type != "padding_causal" :
358+ return False
359+ if window_size not in [None , (- 1 , 0 ), (- 1 , - 1 )]:
360+ return False
361+ if attn_mask_type == "padding_causal" and attention_mask is None :
362+ return False
363+ if isinstance (attention_mask , tuple ):
364+ return False
365+ return (
366+ core_attention_bias_type == "no_bias"
367+ and self .attention_dropout .p == 0.0
368+ and alibi_slopes is None
369+ and self .softmax_type == "vanilla"
370+ and not self .return_max_logit
371+ and not fp8
372+ )
373+
374+ def _format_context (
375+ self ,
376+ context_layer : torch .Tensor ,
377+ q_format : str ,
378+ max_seqlen_q : int ,
379+ batch_size : int ,
380+ cu_seqlens_q : Optional [torch .Tensor ],
381+ ) -> torch .Tensor :
382+ """Convert context from [b, h, sq, d] to the requested output layout."""
383+ if q_format == "sbhd" :
384+ context_layer = context_layer .permute (2 , 0 , 1 , 3 ).contiguous ()
385+ return context_layer .view (max_seqlen_q , batch_size , - 1 )
386+ if q_format == "bshd" :
387+ context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
388+ return context_layer .view (batch_size , max_seqlen_q , - 1 )
389+ if q_format == "thd" :
390+ context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
391+ context_layer = ConvertBSHDtoTHD .apply (context_layer , cu_seqlens_q )
392+ return context_layer .view (context_layer .shape [0 ], - 1 )
393+ raise ValueError (f"Unsupported q_format = { q_format } !" )
394+
395+ def _forward_varlen_sdpa (
396+ self ,
397+ query_layer : torch .Tensor ,
398+ key_layer : torch .Tensor ,
399+ value_layer : torch .Tensor ,
400+ q_format : str ,
401+ batch_size : int ,
402+ max_seqlen_q : int ,
403+ cu_seqlens_q : Optional [torch .Tensor ],
404+ attention_mask : Optional [torch .Tensor ],
405+ scale : float ,
406+ ) -> torch .Tensor :
407+ """Run causal self-attention without expanding padding masks to [b, 1, sq, sk]."""
408+ context_layer = torch .zeros (
409+ batch_size ,
410+ query_layer .size (2 ),
411+ max_seqlen_q ,
412+ value_layer .size (3 ),
413+ dtype = query_layer .dtype ,
414+ device = query_layer .device ,
415+ )
416+
417+ if attention_mask is not None :
418+ seqlens_q = attention_mask .logical_not ()[:, 0 , 0 , :].sum (dim = 1 )
419+ else :
420+ seqlens_q = torch .full (
421+ (batch_size ,), max_seqlen_q , dtype = torch .int64 , device = query_layer .device
422+ )
423+
424+ dropout_p = self .attention_dropout .p if self .training else 0.0
425+ with self .attention_dropout_ctx ():
426+ for batch_id in range (batch_size ):
427+ seqlen_q = int (seqlens_q [batch_id ].item ())
428+ if seqlen_q == 0 :
429+ continue
430+ query = query_layer [:seqlen_q , batch_id ].permute (1 , 0 , 2 ).unsqueeze (0 )
431+ key = key_layer [:seqlen_q , batch_id ].permute (1 , 0 , 2 ).unsqueeze (0 )
432+ value = value_layer [:seqlen_q , batch_id ].permute (1 , 0 , 2 ).unsqueeze (0 )
433+ context_layer [batch_id , :, :seqlen_q , :] = F .scaled_dot_product_attention (
434+ query ,
435+ key ,
436+ value ,
437+ dropout_p = dropout_p ,
438+ is_causal = True ,
439+ scale = scale ,
440+ ).squeeze (0 )
441+
442+ return self ._format_context (
443+ context_layer ,
444+ q_format ,
445+ max_seqlen_q ,
446+ batch_size ,
447+ cu_seqlens_q ,
448+ )
449+
345450 def forward (
346451 self ,
347452 _alibi_cache : Dict [str , Any ],
@@ -434,22 +539,6 @@ def forward(
434539 max_seqlen_kv ,
435540 self .attention_type ,
436541 )
437- attn_mask_type , attention_mask , actual_seqlens_q , actual_seqlens_kv = (
438- dpa_utils .get_full_mask (
439- max_seqlen_q ,
440- max_seqlen_kv ,
441- attn_mask_type = attn_mask_type ,
442- attention_mask = attention_mask ,
443- window_size = window_size ,
444- attention_type = self .attention_type ,
445- bottom_right_alignment = (
446- attn_mask_type not in ["causal" , "padding_causal" ]
447- if bottom_right_diagonal is None
448- else bottom_right_diagonal
449- ),
450- )
451- )
452-
453542 apply_qk_layer_scaling = self .apply_qk_layer_scaling and key_layer .dtype == torch .float16
454543
455544 # [b, h, sq, sk]
@@ -471,6 +560,46 @@ def forward(
471560 int (query_layer .shape [2 ] / value_layer .shape [2 ]), dim = 2
472561 )
473562
563+ scale = self .softmax_scale
564+ if apply_qk_layer_scaling :
565+ scale /= self .layer_number
566+
567+ if self ._use_varlen_sdpa (
568+ attn_mask_type ,
569+ attention_mask ,
570+ window_size ,
571+ core_attention_bias_type ,
572+ alibi_slopes ,
573+ fp8 ,
574+ ):
575+ return self ._forward_varlen_sdpa (
576+ query_layer ,
577+ key_layer ,
578+ value_layer ,
579+ q_format ,
580+ batch_size ,
581+ max_seqlen_q ,
582+ cu_seqlens_q ,
583+ attention_mask ,
584+ self .softmax_scale ,
585+ )
586+
587+ attn_mask_type , attention_mask , actual_seqlens_q , actual_seqlens_kv = (
588+ dpa_utils .get_full_mask (
589+ max_seqlen_q ,
590+ max_seqlen_kv ,
591+ attn_mask_type = attn_mask_type ,
592+ attention_mask = attention_mask ,
593+ window_size = window_size ,
594+ attention_type = self .attention_type ,
595+ bottom_right_alignment = (
596+ attn_mask_type not in ["causal" , "padding_causal" ]
597+ if bottom_right_diagonal is None
598+ else bottom_right_diagonal
599+ ),
600+ )
601+ )
602+
474603 # preallocting result tensor: [b * h, sq, sk]
475604 matmul_result = torch .empty (
476605 output_size [0 ] * output_size [1 ],
@@ -480,10 +609,6 @@ def forward(
480609 device = torch .cuda .current_device (),
481610 )
482611
483- scale = self .softmax_scale
484- if apply_qk_layer_scaling :
485- scale /= self .layer_number
486-
487612 if fp8 :
488613 # get fp8 recipe for DPA
489614 fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
0 commit comments