Skip to content

[FLYDSL] [TRITON] Attention backward mxfp8 gfx950#3094

Open
lburzawa wants to merge 3 commits intomainfrom
attn_bwd_mxfp8_gfx950
Open

[FLYDSL] [TRITON] Attention backward mxfp8 gfx950#3094
lburzawa wants to merge 3 commits intomainfrom
attn_bwd_mxfp8_gfx950

Conversation

@lburzawa
Copy link
Copy Markdown
Contributor

@lburzawa lburzawa commented May 8, 2026

Motivation

Support mxfp8 attention backward in FlyDSL on gfx950.

Technical Details

  • Main attn bwd kernel in FlyDSL
  • Bwd preprocess kernel in Triton
  • Quant kernels in Triton

Test Plan

Correctness tests for each kernel.

Test Result

Tests pass.

Submission Checklist

@lburzawa lburzawa requested a review from a team May 8, 2026 22:58
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 8, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 3094 --add-label <label>

qk = qk * sm_scale
m = qk.max(dim=-1)[0]
p = (qk - m[:, :, None]).exp()
l = p.sum(dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E741> reported by reviewdog 🐶
Ambiguous variable name: l

causal=causal,
waves_per_eu=_wpe,
)
print(f"✓ Kernel prepared")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F541> reported by reviewdog 🐶
f-string without any placeholders

Suggested change
print(f"✓ Kernel prepared")
print("✓ Kernel prepared"d")

qk = qk * sm_scale
m = qk.max(dim=-1)[0]
p = (qk - m[:, :, None]).exp()
l = p.sum(dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E741> reported by reviewdog 🐶
Ambiguous variable name: l

Comment on lines +8 to +9
import time
import numpy as np
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
time imported but unused

Suggested change
import time
import numpy as np
import numpy as np

)
non_torch_memory_before = cuda_memory_before - torch_memory_before

data = func(*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable data is assigned to but never used

Suggested change
data = func(*args, **kwargs)
func(*args, **kwargs)

@lburzawa lburzawa requested a review from vgokhale May 8, 2026 23:03
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@lburzawa lburzawa requested a review from coderfeli May 8, 2026 23:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant