Skip to content

Integrate torchax custom attention kernel into ulysses#392

Merged
copybara-service[bot] merged 1 commit intomainfrom
torchax_attention
Apr 30, 2026
Merged

Integrate torchax custom attention kernel into ulysses#392
copybara-service[bot] merged 1 commit intomainfrom
torchax_attention

Conversation

@eltsai
Copy link
Copy Markdown
Collaborator

@eltsai eltsai commented Apr 27, 2026

Adding torchax path's custom kernel into ulysses (triggered when attention=ulysses_custom)

Inference time:

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):      80.6s
  Compile:               186.9s
  ────────────────────────────────────────
  Inference:             167.2s
==================================================

@eltsai eltsai requested a review from entrpn as a code owner April 27, 2026 05:21
@github-actions
Copy link
Copy Markdown

Comment on lines +689 to +693
bq = 2048
bkv = 2048
bkv_compute = 1024
bkv_compute_in = 256
heads_per_tile = 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating this to

bq = 4864
bkv = 1024
bkv_compute = 1024
bkv_compute_in = 1024
heads_per_tile = 1

and using this command gave me the following latency

  Load (checkpoint):     297.0s
  Compile:               219.8s
  ───────────────────────────────
  Inference:             147.4s

Comment thread src/maxdiffusion/kernels/custom_splash_attention.py
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

Comment thread src/maxdiffusion/kernels/custom_splash_attention.py
Comment thread src/maxdiffusion/kernels/custom_splash_attention.py
Comment thread src/maxdiffusion/models/attention_flax.py Outdated
@eltsai eltsai force-pushed the torchax_attention branch from 56c76b8 to daf4a31 Compare April 27, 2026 20:21
@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 27, 2026

Updated stats:

Accelerator Sharding E2E time log Video
v7x-8 dp2-context4-tp1 139.3s log Video
v7x-16 dp2-context8-tp1 70.2s log Video

entrpn
entrpn previously approved these changes Apr 27, 2026
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eltsai I have left a few comments on the code. I am hoping we can reuse the user defined configs like use_base2_exp, use_experimental_scheduler and flash_block_sizes for the custom kernel.

Let me know if we can incorporate these changes now or attempt them later

cc: @entrpn

Comment thread src/maxdiffusion/models/custom_splash_attention.py Outdated
Comment thread src/maxdiffusion/models/custom_splash_attention.py Outdated
Comment thread src/maxdiffusion/models/attention_flax.py Outdated
Comment thread src/maxdiffusion/models/attention_flax.py
@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 29, 2026

Updated the PR based on @Perseus14 's comments: passing use_base2_exp, use_experimental_scheduler and flash_block_sizes instead of hard coded.

Also moved the custom kernel file under src/maxdiffusion/kernels

Updated the profiling code to log perf with more granularity:

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):     130.9s
  Compile:               137.5s
  ────────────────────────────────────────
  Inference:             137.4s
  Conditioning:            2.4s
  Denoise Total:         121.8s
  VAE Decode:             13.3s
==================================================

@eltsai
Copy link
Copy Markdown
Collaborator Author

eltsai commented Apr 29, 2026

Updated perf logging code. It was not logging each component's time correctly. In WanPipeline2_2.__call__, we start the denoise_total timer, run the loop, and stop it. Because of jax's asynchronous execution, the timer stops before the TPU has actually finished executing all 40 steps.

New run:

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):     128.6s
  Compile:               137.6s
  ────────────────────────────────────────
  Inference:             136.9s
  Conditioning:            2.2s
  Denoise Total:         130.6s
  VAE Decode:              4.0s
==================================================

Also added the two new flags from @rishabhmanoj

disable_bounds_checks=True,
skip_device_barrier=True,

which give us about ~1 sec of gain

@eltsai eltsai force-pushed the torchax_attention branch from 3f0f13d to 82033df Compare April 29, 2026 20:59
@eltsai eltsai force-pushed the torchax_attention branch from 82033df to 589c3d5 Compare April 29, 2026 21:04
@copybara-service copybara-service Bot merged commit 3ef0fdd into main Apr 30, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants