Skip to content

Commit aa082fa

Browse files
committed
fix: handle non-zero storage_offset in sage attention for ring compatibility
1 parent c8c8401 commit aa082fa

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,12 @@ def _sage_attention_forward_op(
14511451
if enable_gqa:
14521452
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
14531453

1454+
# In ring attention, key/value views from chunked AsyncCollectiveTensor may have invalid data_ptr.
1455+
if key.storage_offset() > 0:
1456+
key = key.clone()
1457+
if value.storage_offset() > 0:
1458+
value = value.clone()
1459+
14541460
out = sageattn(
14551461
q=query,
14561462
k=key,
@@ -1489,6 +1495,12 @@ def _sage_attention_hub_forward_op(
14891495
if enable_gqa:
14901496
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
14911497

1498+
# In ring attention, key/value views from chunked AsyncCollectiveTensor may have invalid data_ptr.
1499+
if key.storage_offset() > 0:
1500+
key = key.clone()
1501+
if value.storage_offset() > 0:
1502+
value = value.clone()
1503+
14921504
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
14931505
out = func(
14941506
q=query,

0 commit comments

Comments
 (0)