Skip to content

Add Ulysses attention#376

Merged
copybara-service[bot] merged 1 commit intomainfrom
ulysses-attention-benchmark
Apr 17, 2026
Merged

Add Ulysses attention#376
copybara-service[bot] merged 1 commit intomainfrom
ulysses-attention-benchmark

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented Apr 13, 2026

Summary

This PR adds Ulysses attention support for WAN TPU inference in MaxDiffusion and documents how to enable it.

Design Doc: https://docs.google.com/document/d/1_hrPGaIwj84iF8vFJrcdKdmwfKJPvW6O2Sy5ftLVn60/edit?usp=sharing&resourcekey=0-p0zkvHa_NJDwHPqLwNxNCg

What Changed

  • added a TPU Ulysses attention path for WAN that performs sequence-to-head all_to_all before local splash attention and restores the original layout afterward
  • refactored the TPU flash/Ulysses block-size resolution logic so both paths use the same helper
  • added fail fast with a ValueError when the attention head count is not divisible by the context shard count
  • added tests
  • updated the README to document Ulysses support for WAN inference, including the required attention="ulysses" and ici_context_parallelism>1 override pattern

Performance

TPU v6e

Wan2.2 I2V

Setup:

  • model: Wan-AI/Wan2.2-I2V-A14B-Diffusers
  • hardware: 8x TPU v6 lite
  • parallelism: dp=2, cp=4, fsdp=1, tp=1
  • timing config: 40 inference steps, 81 frames, 720x1280
Global Batch Size Flash Ulysses Delta
1 285.56s 251.45s -11.9%
2 533.67s 491.22s -8.0%

Wan2.2 T2V

Setup:

  • model: Wan-AI/Wan2.2-T2V-A14B-Diffusers
  • hardware: 8x TPU v6e
  • parallelism: dp=2, cp=4, fsdp=1, tp=1
  • timing config: 40 inference steps, 81 frames, 720x1280
Global Batch Size Flash Ulysses Delta
1 275.54s 246.90s -10.39%
2 535.40s 480.24s -10.30%

TPU v7x

Wan2.2 I2V

Setup:

  • model: Wan-AI/Wan2.2-I2V-A14B-Diffusers
  • hardware: TPU v7-8 (8 chips)
  • parallelism: ici_context_parallelism=4, ici_data_parallelism=2
  • timing config: 40 inference steps, 81 frames, 720x1280
  • flash block sizes: block_q=2048, block_kv=2048, block_kv_compute=1024
Global Batch Size Flash Ulysses Delta
1 209s 199s -5%
2 414s 394s -5%
4 829s 780s -6%

@github-actions
Copy link
Copy Markdown

@csgoogle csgoogle changed the title working code Add Ulysses attention Apr 15, 2026
@csgoogle csgoogle marked this pull request as ready for review April 15, 2026 09:18
@csgoogle csgoogle requested a review from entrpn as a code owner April 15, 2026 09:18
Comment thread src/maxdiffusion/models/attention_flax.py
entrpn
entrpn previously approved these changes Apr 15, 2026
Perseus14
Perseus14 previously approved these changes Apr 16, 2026
@Perseus14
Copy link
Copy Markdown
Collaborator

@csgoogle Please squash your commits

@csgoogle csgoogle dismissed stale reviews from Perseus14 and entrpn via 292fd84 April 16, 2026 17:05
@csgoogle csgoogle force-pushed the ulysses-attention-benchmark branch from 656e150 to 1b3bbe2 Compare April 16, 2026 17:31
Perseus14
Perseus14 previously approved these changes Apr 16, 2026
@csgoogle csgoogle force-pushed the ulysses-attention-benchmark branch 2 times, most recently from 673b2a9 to 5f75432 Compare April 16, 2026 17:48
@csgoogle csgoogle force-pushed the ulysses-attention-benchmark branch from 5f75432 to a4e0ae7 Compare April 16, 2026 17:54
@csgoogle
Copy link
Copy Markdown
Collaborator Author

@csgoogle Please squash your commits

done

@copybara-service copybara-service Bot merged commit 702cadd into main Apr 17, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants