Skip to content

Commit 38b4dd5

Browse files
committed
[attention backends] Add Neuron backend for context parallel
1 parent e9c092d commit 38b4dd5

1 file changed

Lines changed: 124 additions & 0 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ class AttentionBackendName(str, Enum):
246246
_NATIVE_FLASH = "_native_flash"
247247
_NATIVE_MATH = "_native_math"
248248
_NATIVE_NPU = "_native_npu"
249+
_NATIVE_NEURON = "_native_neuron"
249250
_NATIVE_XLA = "_native_xla"
250251

251252
# `sageattention`
@@ -576,6 +577,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
576577
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
577578
)
578579

580+
elif backend == AttentionBackendName._NATIVE_NEURON:
581+
pass # No extra dependency check needed; torch_neuronx overrides the ATen op at import time.
582+
579583
elif backend == AttentionBackendName._NATIVE_XLA:
580584
if not _CAN_USE_XLA_ATTN:
581585
raise RuntimeError(
@@ -3218,6 +3222,126 @@ def _native_npu_attention(
32183222
return out
32193223

32203224

3225+
def _neuron_attention_forward_op(
3226+
ctx: torch.autograd.function.FunctionCtx,
3227+
query: torch.Tensor,
3228+
key: torch.Tensor,
3229+
value: torch.Tensor,
3230+
attn_mask=None,
3231+
dropout_p: float = 0.0,
3232+
is_causal: bool = False,
3233+
scale=None,
3234+
enable_gqa: bool = False,
3235+
return_lse: bool = False,
3236+
_save_ctx: bool = True,
3237+
_parallel_config=None,
3238+
):
3239+
"""Forward op for Neuron ring attention using _scaled_dot_product_fused_attention_overrideable.
3240+
3241+
Saves query, key, value, out, lse, philox_seed, philox_offset for backward.
3242+
Follows the same pattern as _cudnn_attention_forward_op.
3243+
"""
3244+
import math
3245+
3246+
q, k, v = (x.permute(0, 2, 1, 3) for x in (query, key, value))
3247+
if scale is None:
3248+
scale = 1.0 / math.sqrt(q.shape[-1])
3249+
3250+
result = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
3251+
q, k, v,
3252+
attn_bias=attn_mask,
3253+
dropout_p=dropout_p,
3254+
is_causal=is_causal,
3255+
return_debug_mask=False,
3256+
scale=scale,
3257+
)
3258+
out_bhsd, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, _ = result
3259+
3260+
if _save_ctx:
3261+
ctx.save_for_backward(q, k, v, out_bhsd, lse, philox_seed, philox_offset)
3262+
ctx.attn_mask = attn_mask
3263+
ctx.dropout_p = dropout_p
3264+
ctx.is_causal = is_causal
3265+
ctx.scale = scale
3266+
ctx.max_q = max_q
3267+
ctx.max_k = max_k
3268+
3269+
out = out_bhsd.permute(0, 2, 1, 3) # [B, S, H, D]
3270+
# [B, H, S] → [B, S, H, 1] for broadcasting in ring accumulation against out [B, S, H, D]
3271+
lse_out = lse.permute(0, 2, 1).unsqueeze(-1)
3272+
return (out, lse_out) if return_lse else out
3273+
3274+
3275+
def _neuron_attention_backward_op(
3276+
ctx: torch.autograd.function.FunctionCtx,
3277+
grad_out: torch.Tensor,
3278+
*args,
3279+
**kwargs,
3280+
):
3281+
"""Backward op for Neuron ring attention using _scaled_dot_product_fused_attention_overrideable_backward."""
3282+
q, k, v, out_bhsd, lse, philox_seed, philox_offset = ctx.saved_tensors
3283+
3284+
grad_out_bhsd = grad_out.permute(0, 2, 1, 3) # [B, S, H, D] → [B, H, S, D]
3285+
grad_input_mask = [True, True, True, False] # grad for q, k, v; not attn_bias
3286+
3287+
attn_bias = ctx.attn_mask if ctx.attn_mask is not None else torch.zeros((1,), dtype=q.dtype, device=q.device)
3288+
cum_seq_q = cum_seq_k = torch.zeros((1,), dtype=torch.int32, device=q.device)
3289+
3290+
grad_q, grad_k, grad_v, _ = torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
3291+
grad_out_bhsd, q, k, v,
3292+
attn_bias,
3293+
grad_input_mask,
3294+
out_bhsd,
3295+
lse,
3296+
cum_seq_q, cum_seq_k,
3297+
ctx.max_q, ctx.max_k,
3298+
ctx.dropout_p,
3299+
ctx.is_causal,
3300+
philox_seed, philox_offset,
3301+
scale=ctx.scale,
3302+
)
3303+
# [B, H, S, D] → [B, S, H, D]
3304+
return grad_q.permute(0, 2, 1, 3), grad_k.permute(0, 2, 1, 3), grad_v.permute(0, 2, 1, 3)
3305+
3306+
3307+
@_AttentionBackendRegistry.register(
3308+
AttentionBackendName._NATIVE_NEURON,
3309+
constraints=[],
3310+
supports_context_parallel=True,
3311+
)
3312+
def _native_neuron_attention(
3313+
query: torch.Tensor,
3314+
key: torch.Tensor,
3315+
value: torch.Tensor,
3316+
attn_mask=None,
3317+
dropout_p: float = 0.0,
3318+
is_causal: bool = False,
3319+
scale=None,
3320+
enable_gqa: bool = False,
3321+
return_lse: bool = False,
3322+
_parallel_config=None,
3323+
) -> torch.Tensor:
3324+
if _parallel_config is not None:
3325+
return _templated_context_parallel_attention(
3326+
query, key, value,
3327+
attn_mask=attn_mask,
3328+
dropout_p=dropout_p,
3329+
is_causal=is_causal,
3330+
scale=scale,
3331+
enable_gqa=enable_gqa,
3332+
return_lse=return_lse,
3333+
forward_op=_neuron_attention_forward_op,
3334+
backward_op=_neuron_attention_backward_op,
3335+
_parallel_config=_parallel_config,
3336+
)
3337+
# Non-ring path
3338+
return _neuron_attention_forward_op(
3339+
None, query, key, value,
3340+
attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal,
3341+
scale=scale, return_lse=return_lse, _save_ctx=False,
3342+
)
3343+
3344+
32213345
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
32223346
@_AttentionBackendRegistry.register(
32233347
AttentionBackendName._NATIVE_XLA,

0 commit comments

Comments
 (0)