@@ -365,6 +365,111 @@ def fast_setattr(self, name: str, value: Any) -> None:
365365 """Fast attribute set for non-parameter fields."""
366366 self .__dict__ [name ] = value
367367
368+ def _use_varlen_sdpa (
369+ self ,
370+ attn_mask_type : str ,
371+ attention_mask : Optional [Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]],
372+ window_size : Optional [Tuple [int , int ]],
373+ core_attention_bias_type : str ,
374+ alibi_slopes : Optional [torch .Tensor ],
375+ fp8 : bool ,
376+ ) -> bool :
377+ """Whether PyTorch SDPA can replace unfused attention without materializing masks."""
378+ if self .attention_type != "self" :
379+ return False
380+ if attn_mask_type != "padding_causal" :
381+ return False
382+ if window_size not in [None , (- 1 , 0 ), (- 1 , - 1 )]:
383+ return False
384+ if attn_mask_type == "padding_causal" and attention_mask is None :
385+ return False
386+ if isinstance (attention_mask , tuple ):
387+ return False
388+ return (
389+ core_attention_bias_type == "no_bias"
390+ and self .attention_dropout .p == 0.0
391+ and alibi_slopes is None
392+ and self .softmax_type == "vanilla"
393+ and not self .return_max_logit
394+ and not fp8
395+ )
396+
397+ def _format_context (
398+ self ,
399+ context_layer : torch .Tensor ,
400+ q_format : str ,
401+ max_seqlen_q : int ,
402+ batch_size : int ,
403+ cu_seqlens_q : Optional [torch .Tensor ],
404+ ) -> torch .Tensor :
405+ """Convert context from [b, h, sq, d] to the requested output layout."""
406+ if q_format == "sbhd" :
407+ context_layer = context_layer .permute (2 , 0 , 1 , 3 ).contiguous ()
408+ return context_layer .view (max_seqlen_q , batch_size , - 1 )
409+ if q_format == "bshd" :
410+ context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
411+ return context_layer .view (batch_size , max_seqlen_q , - 1 )
412+ if q_format == "thd" :
413+ context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
414+ context_layer = ConvertBSHDtoTHD .apply (context_layer , cu_seqlens_q )
415+ return context_layer .view (context_layer .shape [0 ], - 1 )
416+ raise ValueError (f"Unsupported q_format = { q_format } !" )
417+
418+ def _forward_varlen_sdpa (
419+ self ,
420+ query_layer : torch .Tensor ,
421+ key_layer : torch .Tensor ,
422+ value_layer : torch .Tensor ,
423+ q_format : str ,
424+ batch_size : int ,
425+ max_seqlen_q : int ,
426+ cu_seqlens_q : Optional [torch .Tensor ],
427+ attention_mask : Optional [torch .Tensor ],
428+ scale : float ,
429+ ) -> torch .Tensor :
430+ """Run causal self-attention without expanding padding masks to [b, 1, sq, sk]."""
431+ context_layer = torch .zeros (
432+ batch_size ,
433+ query_layer .size (2 ),
434+ max_seqlen_q ,
435+ value_layer .size (3 ),
436+ dtype = query_layer .dtype ,
437+ device = query_layer .device ,
438+ )
439+
440+ if attention_mask is not None :
441+ seqlens_q = attention_mask .logical_not ()[:, 0 , 0 , :].sum (dim = 1 )
442+ else :
443+ seqlens_q = torch .full (
444+ (batch_size ,), max_seqlen_q , dtype = torch .int64 , device = query_layer .device
445+ )
446+
447+ dropout_p = self .attention_dropout .p if self .training else 0.0
448+ with self .attention_dropout_ctx ():
449+ for batch_id in range (batch_size ):
450+ seqlen_q = int (seqlens_q [batch_id ].item ())
451+ if seqlen_q == 0 :
452+ continue
453+ query = query_layer [:seqlen_q , batch_id ].permute (1 , 0 , 2 ).unsqueeze (0 )
454+ key = key_layer [:seqlen_q , batch_id ].permute (1 , 0 , 2 ).unsqueeze (0 )
455+ value = value_layer [:seqlen_q , batch_id ].permute (1 , 0 , 2 ).unsqueeze (0 )
456+ context_layer [batch_id , :, :seqlen_q , :] = F .scaled_dot_product_attention (
457+ query ,
458+ key ,
459+ value ,
460+ dropout_p = dropout_p ,
461+ is_causal = True ,
462+ scale = scale ,
463+ ).squeeze (0 )
464+
465+ return self ._format_context (
466+ context_layer ,
467+ q_format ,
468+ max_seqlen_q ,
469+ batch_size ,
470+ cu_seqlens_q ,
471+ )
472+
368473 def forward (
369474 self ,
370475 _alibi_cache : Dict [str , Any ],
@@ -457,22 +562,6 @@ def forward(
457562 max_seqlen_kv ,
458563 self .attention_type ,
459564 )
460- attn_mask_type , attention_mask , actual_seqlens_q , actual_seqlens_kv = (
461- dpa_utils .get_full_mask (
462- max_seqlen_q ,
463- max_seqlen_kv ,
464- attn_mask_type = attn_mask_type ,
465- attention_mask = attention_mask ,
466- window_size = window_size ,
467- attention_type = self .attention_type ,
468- bottom_right_alignment = (
469- attn_mask_type not in ["causal" , "padding_causal" ]
470- if bottom_right_diagonal is None
471- else bottom_right_diagonal
472- ),
473- )
474- )
475-
476565 apply_qk_layer_scaling = self .apply_qk_layer_scaling and key_layer .dtype == torch .float16
477566
478567 # [b, h, sq, sk]
@@ -494,6 +583,46 @@ def forward(
494583 int (query_layer .shape [2 ] / value_layer .shape [2 ]), dim = 2
495584 )
496585
586+ scale = self .softmax_scale
587+ if apply_qk_layer_scaling :
588+ scale /= self .layer_number
589+
590+ if self ._use_varlen_sdpa (
591+ attn_mask_type ,
592+ attention_mask ,
593+ window_size ,
594+ core_attention_bias_type ,
595+ alibi_slopes ,
596+ fp8 ,
597+ ):
598+ return self ._forward_varlen_sdpa (
599+ query_layer ,
600+ key_layer ,
601+ value_layer ,
602+ q_format ,
603+ batch_size ,
604+ max_seqlen_q ,
605+ cu_seqlens_q ,
606+ attention_mask ,
607+ self .softmax_scale ,
608+ )
609+
610+ attn_mask_type , attention_mask , actual_seqlens_q , actual_seqlens_kv = (
611+ dpa_utils .get_full_mask (
612+ max_seqlen_q ,
613+ max_seqlen_kv ,
614+ attn_mask_type = attn_mask_type ,
615+ attention_mask = attention_mask ,
616+ window_size = window_size ,
617+ attention_type = self .attention_type ,
618+ bottom_right_alignment = (
619+ attn_mask_type not in ["causal" , "padding_causal" ]
620+ if bottom_right_diagonal is None
621+ else bottom_right_diagonal
622+ ),
623+ )
624+ )
625+
497626 # preallocting result tensor: [b * h, sq, sk]
498627 matmul_result = torch .empty (
499628 output_size [0 ] * output_size [1 ],
@@ -503,10 +632,6 @@ def forward(
503632 device = torch .cuda .current_device (),
504633 )
505634
506- scale = self .softmax_scale
507- if apply_qk_layer_scaling :
508- scale /= self .layer_number
509-
510635 if fp8 :
511636 # get fp8 recipe for DPA
512637 fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
0 commit comments