Skip to content

[ROCm] Fix gpt_oss test suite: AITER attention fallback, SonicMoE guard, and distributed test fixes#46160

Open
Abdennacer-Badaoui wants to merge 4 commits into
huggingface:mainfrom
Abdennacer-Badaoui:fix-gpt-oss
Open

[ROCm] Fix gpt_oss test suite: AITER attention fallback, SonicMoE guard, and distributed test fixes#46160
Abdennacer-Badaoui wants to merge 4 commits into
huggingface:mainfrom
Abdennacer-Badaoui:fix-gpt-oss

Conversation

@Abdennacer-Badaoui
Copy link
Copy Markdown
Member

@Abdennacer-Badaoui Abdennacer-Badaoui commented May 22, 2026

Description:

Hub kernels like kernels-community/vllm-flash-attn3 ship CUDA-only wheels. On ROCm, loading them raises ValueError: Cannot find a build variant, which caused GptOssModelTest::test_eager_matches_fa2_generate (among many others) to crash.

This PR adds a ROCm-specific fallback in _lazy_imports: when a model routes to a hub kernel and we're on ROCm with AITER installed, we use AMD's AITER Triton MHA kernel instead of attempting to fetch the hub wheel. AITER supports the same interface including learnable attention sinks (s_aux) used by gpt_oss. The CUDA path is completely unchanged.

Changes:

  • integrations/aiter_flash_attention.py — thin wrappers around aiter.ops.triton.attention.mha.flash_attn_func / flash_attn_varlen_func, mapping s_aux -> sink

  • modeling_flash_attention_utils.py — ROCm+AITER branch in the kernel fallback section of _lazy_imports; registers the attention function under the hub kernel name so model forward resolves it correctly

  • utils/import_utils.py — adds is_aiter_available()

  • docker/transformers-pytorch-amd-gpu/Dockerfile — pins the AITER wheel matching the ROCm 7.2 base image

  • tests/models/gpt_oss/test_modeling_gpt_oss.py — skips the hub-kernel availability check on ROCm+AITER in test_default_flash_implementation_auto_correction

  • SonicMoE ROCm guard : AMD MI300X GPUs return torch.cuda.get_device_capability() >= (9, 0) via the CUDA compatibility layer, causing sonicmoe to be included in expert implementations on ROCm even though kernels-community/sonic-moe has no ROCm build. This PR adds is_rocm_platform() guards in _load_sonicmoe_kernel() and test_modeling_common.py to fully disable SonicMoE on ROCm.

Tested on: ROCm 7.2.2 / MI300X / Python 3.10 / torch 2.10.0+rocm7.2.2
This fixes ~140 failing gpt_oss tests on ROCm.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Abdennacer-Badaoui Abdennacer-Badaoui requested a review from vasqu May 22, 2026 11:30
@Abdennacer-Badaoui
Copy link
Copy Markdown
Member Author

I aslo fixed distributed worker for current TP API

  • Replace deprecated tp_plan="auto" kwarg with distributed_config=DistributedConfig(tp_size=WORLD_SIZE); the old kwarg was silently forwarded to the model __init__ and raised a TypeError
  • Fix skip_special_tokens=False -> True in distributed_worker to match the load_and_forward path used to generate the fixtures; the mismatch caused spurious output comparison failures on both NVIDIA and ROCm

@Abdennacer-Badaoui Abdennacer-Badaoui marked this pull request as draft May 22, 2026 16:07
@Abdennacer-Badaoui Abdennacer-Badaoui marked this pull request as ready for review May 23, 2026 14:41
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gpt_oss

@Abdennacer-Badaoui Abdennacer-Badaoui changed the title [ROCm] Use AITER as Triton backend for hub attention kernels with no ROCm build [ROCm] Fix gpt_oss test suite: AITER attention fallback, SonicMoE guard, and distributed test fixes May 23, 2026
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=46160&sha=089a66

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants