@@ -352,6 +352,8 @@ class _HubKernelConfig:
352352 AttentionBackendName .FLASH_VARLEN_HUB : _HubKernelConfig (
353353 repo_id = "kernels-community/flash-attn2" ,
354354 function_attr = "flash_attn_varlen_func" ,
355+ wrapped_forward_attr = "flash_attn_interface._wrapped_flash_attn_varlen_forward" ,
356+ wrapped_backward_attr = "flash_attn_interface._wrapped_flash_attn_varlen_backward" ,
355357 version = 1 ,
356358 ),
357359 AttentionBackendName .SAGE_HUB : _HubKernelConfig (
@@ -636,6 +638,13 @@ def _prepare_for_flash_attn_or_sage_varlen(
636638 return _prepare_for_flash_attn_or_sage_varlen_with_mask (batch_size , seq_len_q , attn_mask , device )
637639
638640
641+ def _unpad_to_padded (packed : torch .Tensor , indices : torch .Tensor , batch_size : int , seq_len : int ) -> torch .Tensor :
642+ """scatter a packed `(nnz, ...)` tensor back to padded `(batch_size, seq_len, ...)`."""
643+ output = torch .zeros (batch_size * seq_len , * packed .shape [1 :], dtype = packed .dtype , device = packed .device )
644+ output [indices ] = packed
645+ return output .view (batch_size , seq_len , * packed .shape [1 :])
646+
647+
639648def _normalize_attn_mask (attn_mask : torch .Tensor , batch_size : int , seq_len_k : int ) -> torch .Tensor :
640649 """
641650 Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
@@ -1292,6 +1301,178 @@ def _flash_attention_hub_backward_op(
12921301 return grad_query , grad_key , grad_value
12931302
12941303
1304+ def _flash_varlen_attention_hub_forward_op (
1305+ ctx : torch .autograd .function .FunctionCtx ,
1306+ query : torch .Tensor ,
1307+ key : torch .Tensor ,
1308+ value : torch .Tensor ,
1309+ attn_mask : torch .Tensor | None = None ,
1310+ dropout_p : float = 0.0 ,
1311+ is_causal : bool = False ,
1312+ scale : float | None = None ,
1313+ enable_gqa : bool = False ,
1314+ return_lse : bool = False ,
1315+ _save_ctx : bool = True ,
1316+ _parallel_config : "ParallelConfig" | None = None ,
1317+ * ,
1318+ window_size : tuple [int , int ] = (- 1 , - 1 ),
1319+ ):
1320+ if enable_gqa :
1321+ raise ValueError ("`enable_gqa` is not yet supported for flash-attn varlen hub kernels." )
1322+
1323+ config = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_VARLEN_HUB ]
1324+ wrapped_forward_fn = config .wrapped_forward_fn
1325+ wrapped_backward_fn = config .wrapped_backward_fn
1326+ if wrapped_forward_fn is None or wrapped_backward_fn is None :
1327+ raise RuntimeError (
1328+ "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and "
1329+ "`_wrapped_flash_attn_varlen_backward` for context parallel execution."
1330+ )
1331+
1332+ if scale is None :
1333+ scale = query .shape [- 1 ] ** (- 0.5 )
1334+
1335+ softcap = 0.0
1336+ alibi_slopes = None
1337+ deterministic = False
1338+ grad_enabled = any (x .requires_grad for x in (query , key , value ))
1339+
1340+ if grad_enabled or (_parallel_config is not None and _parallel_config .context_parallel_config ._world_size > 1 ):
1341+ dropout_p = dropout_p if dropout_p > 0 else 1e-30
1342+
1343+ batch_size , seq_len_q , num_heads , _ = query .shape
1344+ _ , seq_len_kv , _ , _ = key .shape
1345+
1346+ if attn_mask is not None :
1347+ attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
1348+ (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (_ , max_seqlen_k ) = (
1349+ _prepare_for_flash_attn_or_sage_varlen_with_mask (batch_size , seq_len_q , attn_mask , query .device )
1350+ )
1351+ indices_k = attn_mask .flatten ().nonzero (as_tuple = False ).flatten ()
1352+ query_packed = query .flatten (0 , 1 )
1353+ key_packed = key .reshape (- 1 , * key .shape [2 :])[indices_k ]
1354+ value_packed = value .reshape (- 1 , * value .shape [2 :])[indices_k ]
1355+ max_seqlen_q = seq_len_q
1356+ else :
1357+ (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
1358+ _prepare_for_flash_attn_or_sage_varlen_without_mask (batch_size , seq_len_q , seq_len_kv , query .device )
1359+ )
1360+ query_packed = query .flatten (0 , 1 )
1361+ key_packed = key .flatten (0 , 1 )
1362+ value_packed = value .flatten (0 , 1 )
1363+ seqlens_k = None
1364+
1365+ with torch .set_grad_enabled (grad_enabled ):
1366+ out_packed , lse , _ , rng_state = wrapped_forward_fn (
1367+ query_packed ,
1368+ key_packed ,
1369+ value_packed ,
1370+ cu_seqlens_q ,
1371+ cu_seqlens_k ,
1372+ max_seqlen_q ,
1373+ max_seqlen_k ,
1374+ dropout_p ,
1375+ scale ,
1376+ is_causal ,
1377+ window_size [0 ],
1378+ window_size [1 ],
1379+ softcap ,
1380+ alibi_slopes ,
1381+ return_lse ,
1382+ )
1383+
1384+ out = out_packed .view (batch_size , seq_len_q , * out_packed .shape [1 :])
1385+
1386+ if _save_ctx :
1387+ ctx .save_for_backward (
1388+ query_packed , key_packed , value_packed , out_packed , lse , rng_state , cu_seqlens_q , cu_seqlens_k
1389+ )
1390+ ctx .seqlens_k = seqlens_k # None if unmasked
1391+ ctx .indices_k = indices_k if attn_mask is not None else None
1392+ ctx .max_seqlen_q = max_seqlen_q
1393+ ctx .max_seqlen_k = max_seqlen_k
1394+ ctx .batch_size = batch_size
1395+ ctx .seq_len_q = seq_len_q
1396+ ctx .seq_len_kv = seq_len_kv
1397+ ctx .num_heads = num_heads
1398+ ctx .dropout_p = dropout_p
1399+ ctx .scale = scale
1400+ ctx .is_causal = is_causal
1401+ ctx .window_size = window_size
1402+ ctx .softcap = softcap
1403+ ctx .alibi_slopes = alibi_slopes
1404+ ctx .deterministic = deterministic
1405+
1406+ # (num_heads, batch_size * seq_len_q) -> (batch_size, seq_len_q, num_heads)
1407+ lse_sp = lse .view (num_heads , batch_size , seq_len_q ).permute (1 , 2 , 0 ).contiguous ()
1408+
1409+ return (out , lse_sp ) if return_lse else out
1410+
1411+
1412+ def _flash_varlen_attention_hub_backward_op (
1413+ ctx : torch .autograd .function .FunctionCtx ,
1414+ grad_out : torch .Tensor ,
1415+ * args ,
1416+ ** kwargs ,
1417+ ):
1418+ config = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_VARLEN_HUB ]
1419+ wrapped_backward_fn = config .wrapped_backward_fn
1420+ if wrapped_backward_fn is None :
1421+ raise RuntimeError (
1422+ "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` "
1423+ "for context parallel execution."
1424+ )
1425+
1426+ query_packed , key_packed , value_packed , out_packed , lse , rng_state , cu_seqlens_q , cu_seqlens_k = ctx .saved_tensors
1427+
1428+ grad_out_packed = grad_out .flatten (0 , 1 )
1429+ grad_query , grad_key , grad_value = (
1430+ torch .empty_like (query_packed ),
1431+ torch .empty_like (key_packed ),
1432+ torch .empty_like (value_packed ),
1433+ )
1434+
1435+ _ = wrapped_backward_fn (
1436+ grad_out_packed ,
1437+ query_packed ,
1438+ key_packed ,
1439+ value_packed ,
1440+ out_packed ,
1441+ lse ,
1442+ grad_query ,
1443+ grad_key ,
1444+ grad_value ,
1445+ cu_seqlens_q ,
1446+ cu_seqlens_k ,
1447+ ctx .max_seqlen_q ,
1448+ ctx .max_seqlen_k ,
1449+ ctx .dropout_p ,
1450+ ctx .scale ,
1451+ ctx .is_causal ,
1452+ ctx .window_size [0 ],
1453+ ctx .window_size [1 ],
1454+ ctx .softcap ,
1455+ ctx .alibi_slopes ,
1456+ ctx .deterministic ,
1457+ rng_state ,
1458+ )
1459+
1460+ grad_query = grad_query .view (ctx .batch_size , ctx .seq_len_q , * grad_query .shape [1 :])
1461+
1462+ if ctx .seqlens_k is not None :
1463+ grad_key = _unpad_to_padded (grad_key , ctx .indices_k , ctx .batch_size , ctx .seq_len_kv )
1464+ grad_value = _unpad_to_padded (grad_value , ctx .indices_k , ctx .batch_size , ctx .seq_len_kv )
1465+ else :
1466+ grad_key = grad_key .view (ctx .batch_size , ctx .seq_len_kv , * grad_key .shape [1 :])
1467+ grad_value = grad_value .view (ctx .batch_size , ctx .seq_len_kv , * grad_value .shape [1 :])
1468+
1469+ grad_query = grad_query [..., : grad_out .shape [- 1 ]]
1470+ grad_key = grad_key [..., : grad_out .shape [- 1 ]]
1471+ grad_value = grad_value [..., : grad_out .shape [- 1 ]]
1472+
1473+ return grad_query , grad_key , grad_value
1474+
1475+
12951476def _flash_attention_3_hub_forward_op (
12961477 ctx : torch .autograd .function .FunctionCtx ,
12971478 query : torch .Tensor ,
@@ -2557,7 +2738,7 @@ def _flash_attention_hub(
25572738@_AttentionBackendRegistry .register (
25582739 AttentionBackendName .FLASH_VARLEN_HUB ,
25592740 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
2560- supports_context_parallel = False ,
2741+ supports_context_parallel = True ,
25612742)
25622743def _flash_varlen_attention_hub (
25632744 query : torch .Tensor ,
@@ -2571,46 +2752,69 @@ def _flash_varlen_attention_hub(
25712752 return_lse : bool = False ,
25722753 _parallel_config : "ParallelConfig" | None = None ,
25732754) -> torch .Tensor :
2755+ if _parallel_config is not None and _parallel_config .context_parallel_config .ring_degree > 1 :
2756+ raise NotImplementedError ("`ring_degree > 1` is not yet supported for the FLASH_VARLEN_HUB backend." )
2757+
2758+ lse = None
25742759 batch_size , seq_len_q , _ , _ = query .shape
25752760 _ , seq_len_kv , _ , _ = key .shape
25762761
2577- if attn_mask is not None :
2578- attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
2579-
2580- (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
2581- _prepare_for_flash_attn_or_sage_varlen (
2582- batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
2583- )
2584- )
2762+ if _parallel_config is None :
2763+ if attn_mask is not None :
2764+ attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
2765+ (_ , _ ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
2766+ _prepare_for_flash_attn_or_sage_varlen_with_mask (batch_size , seq_len_q , attn_mask , query .device )
2767+ )
2768+ indices_k = attn_mask .flatten ().nonzero (as_tuple = False ).flatten ()
2769+ key_packed = key .reshape (- 1 , * key .shape [2 :])[indices_k ]
2770+ value_packed = value .reshape (- 1 , * value .shape [2 :])[indices_k ]
2771+ else :
2772+ (_ , _ ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
2773+ _prepare_for_flash_attn_or_sage_varlen_without_mask (batch_size , seq_len_q , seq_len_kv , query .device )
2774+ )
2775+ key_packed = key .flatten (0 , 1 )
2776+ value_packed = value .flatten (0 , 1 )
25852777
2586- key_valid , value_valid = [], []
2587- for b in range (batch_size ):
2588- valid_len = seqlens_k [b ]
2589- key_valid .append (key [b , :valid_len ])
2590- value_valid .append (value [b , :valid_len ])
2778+ query_packed = query .flatten (0 , 1 )
25912779
2592- query_packed = query .flatten (0 , 1 )
2593- key_packed = torch .cat (key_valid , dim = 0 )
2594- value_packed = torch .cat (value_valid , dim = 0 )
2595-
2596- func = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_VARLEN_HUB ].kernel_fn
2597- out = func (
2598- q = query_packed ,
2599- k = key_packed ,
2600- v = value_packed ,
2601- cu_seqlens_q = cu_seqlens_q ,
2602- cu_seqlens_k = cu_seqlens_k ,
2603- max_seqlen_q = max_seqlen_q ,
2604- max_seqlen_k = max_seqlen_k ,
2605- dropout_p = dropout_p ,
2606- softmax_scale = scale ,
2607- causal = is_causal ,
2608- window_size = window_size ,
2609- return_attn_probs = return_lse ,
2610- )
2611- out = out .unflatten (0 , (batch_size , - 1 ))
2780+ func = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_VARLEN_HUB ].kernel_fn
2781+ out = func (
2782+ q = query_packed ,
2783+ k = key_packed ,
2784+ v = value_packed ,
2785+ cu_seqlens_q = cu_seqlens_q ,
2786+ cu_seqlens_k = cu_seqlens_k ,
2787+ max_seqlen_q = max_seqlen_q ,
2788+ max_seqlen_k = max_seqlen_k ,
2789+ dropout_p = dropout_p ,
2790+ softmax_scale = scale ,
2791+ causal = is_causal ,
2792+ window_size = window_size ,
2793+ return_attn_probs = return_lse ,
2794+ )
2795+ if return_lse :
2796+ out , lse , * _ = out
2797+ out = out .unflatten (0 , (batch_size , - 1 ))
2798+ else :
2799+ forward_op = functools .partial (_flash_varlen_attention_hub_forward_op , window_size = window_size )
2800+ out = _templated_context_parallel_attention (
2801+ query ,
2802+ key ,
2803+ value ,
2804+ attn_mask ,
2805+ dropout_p ,
2806+ is_causal ,
2807+ scale ,
2808+ False ,
2809+ return_lse ,
2810+ forward_op = forward_op ,
2811+ backward_op = _flash_varlen_attention_hub_backward_op ,
2812+ _parallel_config = _parallel_config ,
2813+ )
2814+ if return_lse :
2815+ out , lse = out
26122816
2613- return out
2817+ return ( out , lse ) if return_lse else out
26142818
26152819
26162820@_AttentionBackendRegistry .register (
0 commit comments