@@ -277,8 +277,8 @@ class _HubKernelConfig:
277277 repo_id = "kernels-community/flash-attn2" ,
278278 function_attr = "flash_attn_func" ,
279279 revision = None ,
280- wrapped_forward_attr = "_wrapped_flash_attn_forward" ,
281- wrapped_backward_attr = "_wrapped_flash_attn_backward" ,
280+ wrapped_forward_attr = "flash_attn_interface. _wrapped_flash_attn_forward" ,
281+ wrapped_backward_attr = "flash_attn_interface. _wrapped_flash_attn_backward" ,
282282 ),
283283 AttentionBackendName .FLASH_VARLEN_HUB : _HubKernelConfig (
284284 repo_id = "kernels-community/flash-attn2" , function_attr = "flash_attn_varlen_func" , revision = None
@@ -602,27 +602,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
602602
603603
604604# ===== Helpers for downloading kernels =====
605+ def _resolve_kernel_attr (module , attr_path : str ):
606+ target = module
607+ for attr in attr_path .split ("." ):
608+ if not hasattr (target , attr ):
609+ raise AttributeError (f"Kernel module '{ module .__name__ } ' does not define attribute path '{ attr_path } '." )
610+ target = getattr (target , attr )
611+ return target
612+
613+
605614def _maybe_download_kernel_for_backend (backend : AttentionBackendName ) -> None :
606615 if backend not in _HUB_KERNELS_REGISTRY :
607616 return
608617 config = _HUB_KERNELS_REGISTRY [backend ]
609618
610- if config .kernel_fn is not None :
619+ needs_kernel = config .kernel_fn is None
620+ needs_wrapped_forward = config .wrapped_forward_attr is not None and config .wrapped_forward_fn is None
621+ needs_wrapped_backward = config .wrapped_backward_attr is not None and config .wrapped_backward_fn is None
622+
623+ if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward ):
611624 return
612625
613626 try :
614627 from kernels import get_kernel
615628
616629 kernel_module = get_kernel (config .repo_id , revision = config .revision )
617- kernel_func = getattr (kernel_module , config .function_attr )
618- # Cache the downloaded kernel function in the config object
619- config .kernel_fn = kernel_func
630+ if needs_kernel :
631+ config .kernel_fn = _resolve_kernel_attr (kernel_module , config .function_attr )
620632
621- if config . wrapped_forward_attr is not None and config . wrapped_forward_attr is not None :
622- wrapped_forward_fn = getattr (kernel_module , config .wrapped_forward_attr )
623- wrapped_backward_fn = getattr ( kernel_module , config . wrapped_backward_attr )
624- config . wrapped_forward_fn = wrapped_forward_fn
625- config .wrapped_backward_fn = wrapped_backward_fn
633+ if needs_wrapped_forward :
634+ config . wrapped_forward_fn = _resolve_kernel_attr (kernel_module , config .wrapped_forward_attr )
635+
636+ if needs_wrapped_backward :
637+ config .wrapped_backward_fn = _resolve_kernel_attr ( kernel_module , config . wrapped_backward_attr )
626638
627639 except Exception as e :
628640 logger .error (f"An error occurred while fetching kernel '{ config .repo_id } ' from the Hub: { e } " )
@@ -1307,6 +1319,7 @@ def _sage_attention_hub_forward_op(
13071319
13081320 return (out , lse ) if return_lse else out
13091321
1322+
13101323# ===== Context parallel =====
13111324
13121325
0 commit comments