Skip to content

Commit ee4b022

Browse files
authored
Merge pull request #392 from modelscope/sage_attention
Sage attention
2 parents 3db824c + da8e1fe commit ee4b022

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

diffsynth/models/wan_video_dit.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
except ModuleNotFoundError:
1919
FLASH_ATTN_2_AVAILABLE = False
2020

21+
try:
22+
from sageattention import sageattn
23+
SAGE_ATTN_AVAILABLE = True
24+
except ModuleNotFoundError:
25+
SAGE_ATTN_AVAILABLE = False
26+
2127
import warnings
2228

2329

@@ -127,6 +133,12 @@ def half(x):
127133
causal=causal,
128134
window_size=window_size,
129135
deterministic=deterministic).unflatten(0, (b, lq))
136+
elif SAGE_ATTN_AVAILABLE:
137+
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
138+
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
139+
v = v.unsqueeze(0).transpose(1, 2).to(dtype)
140+
x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
141+
x = x.transpose(1, 2).contiguous()
130142
else:
131143
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
132144
k = k.unsqueeze(0).transpose(1, 2).to(dtype)

examples/wanvideo/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ cd DiffSynth-Studio
1010
pip install -e .
1111
```
1212

13+
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
14+
15+
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
16+
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
17+
* [Sage Attention](https://github.com/thu-ml/SageAttention)
18+
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
19+
1320
## Inference
1421

1522
### Wan-Video-1.3B-T2V

0 commit comments

Comments
 (0)