@@ -246,6 +246,7 @@ class AttentionBackendName(str, Enum):
246246 _NATIVE_FLASH = "_native_flash"
247247 _NATIVE_MATH = "_native_math"
248248 _NATIVE_NPU = "_native_npu"
249+ _NATIVE_NEURON = "_native_neuron"
249250 _NATIVE_XLA = "_native_xla"
250251
251252 # `sageattention`
@@ -576,6 +577,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
576577 f"NPU Attention backend '{ backend .value } ' is not usable because of missing package or the version is too old. Please install `torch_npu`."
577578 )
578579
580+ elif backend == AttentionBackendName ._NATIVE_NEURON :
581+ pass # No extra dependency check needed; torch_neuronx overrides the ATen op at import time.
582+
579583 elif backend == AttentionBackendName ._NATIVE_XLA :
580584 if not _CAN_USE_XLA_ATTN :
581585 raise RuntimeError (
@@ -3218,6 +3222,126 @@ def _native_npu_attention(
32183222 return out
32193223
32203224
3225+ def _neuron_attention_forward_op (
3226+ ctx : torch .autograd .function .FunctionCtx ,
3227+ query : torch .Tensor ,
3228+ key : torch .Tensor ,
3229+ value : torch .Tensor ,
3230+ attn_mask = None ,
3231+ dropout_p : float = 0.0 ,
3232+ is_causal : bool = False ,
3233+ scale = None ,
3234+ enable_gqa : bool = False ,
3235+ return_lse : bool = False ,
3236+ _save_ctx : bool = True ,
3237+ _parallel_config = None ,
3238+ ):
3239+ """Forward op for Neuron ring attention using _scaled_dot_product_fused_attention_overrideable.
3240+
3241+ Saves query, key, value, out, lse, philox_seed, philox_offset for backward.
3242+ Follows the same pattern as _cudnn_attention_forward_op.
3243+ """
3244+ import math
3245+
3246+ q , k , v = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
3247+ if scale is None :
3248+ scale = 1.0 / math .sqrt (q .shape [- 1 ])
3249+
3250+ result = torch .ops .aten ._scaled_dot_product_fused_attention_overrideable (
3251+ q , k , v ,
3252+ attn_bias = attn_mask ,
3253+ dropout_p = dropout_p ,
3254+ is_causal = is_causal ,
3255+ return_debug_mask = False ,
3256+ scale = scale ,
3257+ )
3258+ out_bhsd , lse , cum_seq_q , cum_seq_k , max_q , max_k , philox_seed , philox_offset , _ = result
3259+
3260+ if _save_ctx :
3261+ ctx .save_for_backward (q , k , v , out_bhsd , lse , philox_seed , philox_offset )
3262+ ctx .attn_mask = attn_mask
3263+ ctx .dropout_p = dropout_p
3264+ ctx .is_causal = is_causal
3265+ ctx .scale = scale
3266+ ctx .max_q = max_q
3267+ ctx .max_k = max_k
3268+
3269+ out = out_bhsd .permute (0 , 2 , 1 , 3 ) # [B, S, H, D]
3270+ # [B, H, S] → [B, S, H, 1] for broadcasting in ring accumulation against out [B, S, H, D]
3271+ lse_out = lse .permute (0 , 2 , 1 ).unsqueeze (- 1 )
3272+ return (out , lse_out ) if return_lse else out
3273+
3274+
3275+ def _neuron_attention_backward_op (
3276+ ctx : torch .autograd .function .FunctionCtx ,
3277+ grad_out : torch .Tensor ,
3278+ * args ,
3279+ ** kwargs ,
3280+ ):
3281+ """Backward op for Neuron ring attention using _scaled_dot_product_fused_attention_overrideable_backward."""
3282+ q , k , v , out_bhsd , lse , philox_seed , philox_offset = ctx .saved_tensors
3283+
3284+ grad_out_bhsd = grad_out .permute (0 , 2 , 1 , 3 ) # [B, S, H, D] → [B, H, S, D]
3285+ grad_input_mask = [True , True , True , False ] # grad for q, k, v; not attn_bias
3286+
3287+ attn_bias = ctx .attn_mask if ctx .attn_mask is not None else torch .zeros ((1 ,), dtype = q .dtype , device = q .device )
3288+ cum_seq_q = cum_seq_k = torch .zeros ((1 ,), dtype = torch .int32 , device = q .device )
3289+
3290+ grad_q , grad_k , grad_v , _ = torch .ops .aten ._scaled_dot_product_fused_attention_overrideable_backward (
3291+ grad_out_bhsd , q , k , v ,
3292+ attn_bias ,
3293+ grad_input_mask ,
3294+ out_bhsd ,
3295+ lse ,
3296+ cum_seq_q , cum_seq_k ,
3297+ ctx .max_q , ctx .max_k ,
3298+ ctx .dropout_p ,
3299+ ctx .is_causal ,
3300+ philox_seed , philox_offset ,
3301+ scale = ctx .scale ,
3302+ )
3303+ # [B, H, S, D] → [B, S, H, D]
3304+ return grad_q .permute (0 , 2 , 1 , 3 ), grad_k .permute (0 , 2 , 1 , 3 ), grad_v .permute (0 , 2 , 1 , 3 )
3305+
3306+
3307+ @_AttentionBackendRegistry .register (
3308+ AttentionBackendName ._NATIVE_NEURON ,
3309+ constraints = [],
3310+ supports_context_parallel = True ,
3311+ )
3312+ def _native_neuron_attention (
3313+ query : torch .Tensor ,
3314+ key : torch .Tensor ,
3315+ value : torch .Tensor ,
3316+ attn_mask = None ,
3317+ dropout_p : float = 0.0 ,
3318+ is_causal : bool = False ,
3319+ scale = None ,
3320+ enable_gqa : bool = False ,
3321+ return_lse : bool = False ,
3322+ _parallel_config = None ,
3323+ ) -> torch .Tensor :
3324+ if _parallel_config is not None :
3325+ return _templated_context_parallel_attention (
3326+ query , key , value ,
3327+ attn_mask = attn_mask ,
3328+ dropout_p = dropout_p ,
3329+ is_causal = is_causal ,
3330+ scale = scale ,
3331+ enable_gqa = enable_gqa ,
3332+ return_lse = return_lse ,
3333+ forward_op = _neuron_attention_forward_op ,
3334+ backward_op = _neuron_attention_backward_op ,
3335+ _parallel_config = _parallel_config ,
3336+ )
3337+ # Non-ring path
3338+ return _neuron_attention_forward_op (
3339+ None , query , key , value ,
3340+ attn_mask = attn_mask , dropout_p = dropout_p , is_causal = is_causal ,
3341+ scale = scale , return_lse = return_lse , _save_ctx = False ,
3342+ )
3343+
3344+
32213345# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
32223346@_AttentionBackendRegistry .register (
32233347 AttentionBackendName ._NATIVE_XLA ,
0 commit comments