@@ -329,7 +329,11 @@ class _HubKernelConfig:
329329_HUB_KERNELS_REGISTRY : dict ["AttentionBackendName" , _HubKernelConfig ] = {
330330 # TODO: temporary revision for now. Remove when merged upstream into `main`.
331331 AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
332- repo_id = "kernels-community/flash-attn3" , function_attr = "flash_attn_func" , revision = "fake-ops-return-probs"
332+ repo_id = "kernels-community/flash-attn3" ,
333+ function_attr = "flash_attn_func" ,
334+ revision = "fake-ops-return-probs" ,
335+ wrapped_forward_attr = "flash_attn_interface._flash_attn_forward" ,
336+ wrapped_backward_attr = "flash_attn_interface._flash_attn_backward" ,
333337 ),
334338 AttentionBackendName ._FLASH_3_VARLEN_HUB : _HubKernelConfig (
335339 repo_id = "kernels-community/flash-attn3" ,
@@ -1290,36 +1294,62 @@ def _flash_attention_3_hub_forward_op(
12901294 if enable_gqa :
12911295 raise ValueError ("`enable_gqa` is not yet supported for flash-attn 3 hub kernels." )
12921296
1293- func = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ].kernel_fn
1294- out = func (
1295- q = query ,
1296- k = key ,
1297- v = value ,
1298- softmax_scale = scale ,
1297+ config = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ]
1298+ wrapped_forward_fn = config .wrapped_forward_fn
1299+ if wrapped_forward_fn is None :
1300+ raise RuntimeError (
1301+ "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
1302+ "for context parallel execution."
1303+ )
1304+
1305+ if scale is None :
1306+ scale = query .shape [- 1 ] ** (- 0.5 )
1307+
1308+ out , softmax_lse , * _ = wrapped_forward_fn (
1309+ query ,
1310+ key ,
1311+ value ,
1312+ None ,
1313+ None , # k_new, v_new
1314+ None , # qv
1315+ None , # out
1316+ None ,
1317+ None ,
1318+ None , # cu_seqlens_q/k/k_new
1319+ None ,
1320+ None , # seqused_q/k
1321+ None ,
1322+ None , # max_seqlen_q/k
1323+ None ,
1324+ None ,
1325+ None , # page_table, kv_batch_idx, leftpad_k
1326+ None ,
1327+ None ,
1328+ None , # rotary_cos/sin, seqlens_rotary
1329+ None ,
1330+ None ,
1331+ None , # q_descale, k_descale, v_descale
1332+ scale ,
12991333 causal = is_causal ,
1300- qv = None ,
1301- q_descale = None ,
1302- k_descale = None ,
1303- v_descale = None ,
1304- window_size = window_size ,
1334+ window_size_left = window_size [0 ],
1335+ window_size_right = window_size [1 ],
1336+ attention_chunk = 0 ,
13051337 softcap = softcap ,
13061338 num_splits = num_splits ,
13071339 pack_gqa = pack_gqa ,
1308- deterministic = deterministic ,
13091340 sm_margin = sm_margin ,
1310- return_attn_probs = return_lse ,
13111341 )
13121342
1313- lse = None
1314- if return_lse :
1315- out , lse = out
1316- lse = lse .permute (0 , 2 , 1 ).contiguous ()
1343+ lse = softmax_lse .permute (0 , 2 , 1 ).contiguous () if return_lse else None
13171344
13181345 if _save_ctx :
1319- ctx .save_for_backward (query , key , value )
1346+ ctx .save_for_backward (query , key , value , out , softmax_lse )
13201347 ctx .scale = scale
13211348 ctx .is_causal = is_causal
1322- ctx ._hub_kernel = func
1349+ ctx .window_size = window_size
1350+ ctx .softcap = softcap
1351+ ctx .deterministic = deterministic
1352+ ctx .sm_margin = sm_margin
13231353
13241354 return (out , lse ) if return_lse else out
13251355
@@ -1328,55 +1358,50 @@ def _flash_attention_3_hub_backward_op(
13281358 ctx : torch .autograd .function .FunctionCtx ,
13291359 grad_out : torch .Tensor ,
13301360 * args ,
1331- window_size : tuple [int , int ] = (- 1 , - 1 ),
1332- softcap : float = 0.0 ,
1333- num_splits : int = 1 ,
1334- pack_gqa : bool | None = None ,
1335- deterministic : bool = False ,
1336- sm_margin : int = 0 ,
1361+ ** kwargs ,
13371362):
1338- query , key , value = ctx .saved_tensors
1339- kernel_fn = ctx ._hub_kernel
1340- # NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
1341- # primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
1342- # therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
1343- # `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
1344- # the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
1345- # in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
1346- with torch .enable_grad ():
1347- query_r = query .detach ().requires_grad_ (True )
1348- key_r = key .detach ().requires_grad_ (True )
1349- value_r = value .detach ().requires_grad_ (True )
1350-
1351- out = kernel_fn (
1352- q = query_r ,
1353- k = key_r ,
1354- v = value_r ,
1355- softmax_scale = ctx .scale ,
1356- causal = ctx .is_causal ,
1357- qv = None ,
1358- q_descale = None ,
1359- k_descale = None ,
1360- v_descale = None ,
1361- window_size = window_size ,
1362- softcap = softcap ,
1363- num_splits = num_splits ,
1364- pack_gqa = pack_gqa ,
1365- deterministic = deterministic ,
1366- sm_margin = sm_margin ,
1367- return_attn_probs = False ,
1368- )
1369- if isinstance (out , tuple ):
1370- out = out [0 ]
1371-
1372- grad_query , grad_key , grad_value = torch .autograd .grad (
1373- out ,
1374- (query_r , key_r , value_r ),
1375- grad_out ,
1376- retain_graph = False ,
1377- allow_unused = False ,
1363+ config = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ]
1364+ wrapped_backward_fn = config .wrapped_backward_fn
1365+ if wrapped_backward_fn is None :
1366+ raise RuntimeError (
1367+ "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
1368+ "for context parallel execution."
13781369 )
13791370
1371+ query , key , value , out , softmax_lse = ctx .saved_tensors
1372+ grad_query = torch .empty_like (query )
1373+ grad_key = torch .empty_like (key )
1374+ grad_value = torch .empty_like (value )
1375+
1376+ wrapped_backward_fn (
1377+ grad_out ,
1378+ query ,
1379+ key ,
1380+ value ,
1381+ out ,
1382+ softmax_lse ,
1383+ None ,
1384+ None , # cu_seqlens_q, cu_seqlens_k
1385+ None ,
1386+ None , # seqused_q, seqused_k
1387+ None ,
1388+ None , # max_seqlen_q, max_seqlen_k
1389+ grad_query ,
1390+ grad_key ,
1391+ grad_value ,
1392+ ctx .scale ,
1393+ ctx .is_causal ,
1394+ ctx .window_size [0 ],
1395+ ctx .window_size [1 ],
1396+ ctx .softcap ,
1397+ ctx .deterministic ,
1398+ ctx .sm_margin ,
1399+ )
1400+
1401+ grad_query = grad_query [..., : grad_out .shape [- 1 ]]
1402+ grad_key = grad_key [..., : grad_out .shape [- 1 ]]
1403+ grad_value = grad_value [..., : grad_out .shape [- 1 ]]
1404+
13801405 return grad_query , grad_key , grad_value
13811406
13821407
0 commit comments