Skip to content

[AMD][RDNA4]Fix RDNA4 (gfx1201 / Wave32) CI Failures#2210

Open
zhangnju wants to merge 7 commits into
tile-ai:mainfrom
zhangnju:rdna4_ci
Open

[AMD][RDNA4]Fix RDNA4 (gfx1201 / Wave32) CI Failures#2210
zhangnju wants to merge 7 commits into
tile-ai:mainfrom
zhangnju:rdna4_ci

Conversation

@zhangnju
Copy link
Copy Markdown
Collaborator

@zhangnju zhangnju commented May 15, 2026

This PR fixes several correctness bugs that caused CI failures on AMD RDNA4 (gfx1201, Wave32) hardware.

Changes

  1. Add gfx1201 support to the Carver module
  • Extended RDNA arch class and get_arch() dispatcher to accept RDNA generation 12 (gfx12xx) in addition to generation 11 (gfx11xx).
  • Added WMMA 16×16 tensor instruction entry for generation 12.
  1. Fix AllReduce / SharedReduceWarp hardcoded Wave64 assumption (src/tl_templates/hip/reduce.h, src/backend/rocm/op/finalize_reducer.cc)
  • AllReduce and SharedReduceWarp had constexpr warpSize = 64, which is wrong on RDNA Wave32 devices.
  • Replaced with __builtin_amdgcn_wavefrontsize() so the correct wavefront size is used at runtime on both RDNA (wave32) and CDNA (wave64).
  • Also fixed FinalizeReducer::WarpSize() in the ROCm backend to call TargetGetWarpSize(target) instead of returning a hardcoded 64.
  1. Fix CumSum1D / CumSum2D hardcoded SEG=64 (src/tl_templates/hip/reduce.h)
  • CumSum1D and CumSum2D used a fixed segment size of 64 (wave64), producing incorrect results on Wave32.
  • Refactored into run_seg<T, SEG>() (inner impl) and run() (dispatcher), which selects SEG=32 or SEG=64 based on __builtin_amdgcn_wavefrontsize().
  • Extended static_assert to also allow threads == 32.
  1. Fix WMMA k_pack > 1 index calculation (tilelang/rocm/intrinsics/wmma_macro_generator.py)
  • ldmatrix_a / ldmatrix_b passed the raw local_id (ranging 0..k_pack*local_size-1) to reverse_index_map, which was designed for local_id in 0..local_size-1.
  • Fix: use local_id % local_size for the reverse map and add (local_id // local_size) * micro_size_k as an explicit K-tile offset, correctly addressing each K-tile in shared memory.
  1. Fix FP8 tensor DLPack conversion in TVM FFI adapter (tilelang/jit/adapter/tvm_ffi.py)
  • PyTorch does not support DLPack export for FP8 tensors, causing a runtime error when FP8 inputs are passed to TVM FFI kernels.
  • Workaround: view FP8 tensors as int8, convert via from_dlpack, then reinterpret with _create_view(shape, dtype=fp8_dtype_str).
  1. Fix eager JIT test for non-CUDA targets (testing/python/language/test_tilelang_language_eager_jit.py)
  • Excluded float32 input dtype from the WMMA test matrix on non-CUDA targets (ROCm WMMA does not support float32 inputs).
  • Fixed incorrect use of out_dtype(A @ B) (a torch.dtype is not callable); replaced with (A @ B).to(out_dtype).

Impact: No breaking changes. RDNA3 (gfx11) behavior is unchanged.

Summary by CodeRabbit

  • New Features

    • Added support for RDNA generation 12 (gfx12) GPU targets.
  • Improvements

    • CumSum and reduction kernels adapt to wavefront width (32/64) at runtime.
    • FP8 tensors are specially handled when invoking compiled kernels.
    • Fixed shared-memory indexing for WMMA RDNA loads.
    • Warp-size now follows the compilation target.
    • JIT tests adjust dtypes based on detected backend.
  • Tests

    • Added and strengthened tests to reject unsupported RDNA generations and validate tensor-instruction lookups.

Review Change Stack

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 15, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds gfx12 support alongside gfx11, strengthens RDNA generation rejection tests, makes HIP reductions and CumSum dispatch on runtime wavefront size, adds FP8 handling in the TVM FFI, and updates WMMA shared-memory indexing.

Changes

RDNA Generation 11/12 Support Expansion

Layer / File(s) Summary
RDNA device model generation support
tilelang/carver/arch/rdna.py, testing/python/target/test_tilelang_rocm_target.py
Import target utils for generation lookup and extend tensor-instruction table to include gfx12 alongside gfx11; RDNA init reads generation via utils and validates gfx11/gfx12; tests added to reject unsupported RDNA generations and assert tensor-instruction shapes and list type.

HIP CumSum & Warp-size Dispatch

Layer / File(s) Summary
SharedReduce / FinalizeReducer warp size
src/tl_templates/hip/reduce.h, src/backend/rocm/op/finalize_reducer.cc
Derive warp/wavefront size at runtime via __builtin_amdgcn_wavefrontsize() / TargetGetWarpSize instead of hardcoded 64; update FinalizeReducer accordingly.
AllReduce and CumSum SEG dispatch
src/tl_templates/hip/reduce.h
AllReduce::run/run_batch use runtime wavefront size; CumSum1D/2D add run_seg<T,SEG> helpers and public run<T> dispatches to SEG=32 or SEG=64 based on wavefront size; static_asserts updated to allow 32-thread waves.

TVM FFI FP8 handling

Layer / File(s) Summary
FP8 detection and argument marshalling
tilelang/jit/adapter/tvm_ffi.py
Import is_float8_dtype and build ffi_arg_list that converts FP8 PyTorch tensors to int8 via DLPack and rewraps a FP8 view before invoking the TVM executable.

WMMA RDNA ldmatrix indexing changes

Layer / File(s) Summary
ldmatrix_a / ldmatrix_b shared-buffer address update
tilelang/rocm/intrinsics/wmma_macro_generator.py
Compute (row,col) using local_id % local_size, derive k_tile = local_id // local_size, and add k_tile * micro_size_k into shared-buffer address calculations in both transposed and non-transposed branches.

Language test adjustments

Layer / File(s) Summary
Conditional dtype selection in tests
testing/python/language/test_tilelang_language_eager_jit.py
Import determine_target/target_is_cuda and conditionally limit input dtypes for GEMM test; cast reference result with .to(out_dtype).

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • tile-ai/tilelang#2179: Related FP8/ROCm TVM FFI float8-tensor marshalling work and gating for the same path.
  • tile-ai/tilelang#2127: Prior RDNA generation-aware tensor-instruction lookup and gfx11-only validation this PR extends to gfx12.
  • tile-ai/tilelang#1976: Related changes to HIP AllReduce/CumSum implementations and warp-size behavior.

Poem

"I nibble bytes and hop through code,
From gfx11 to gfx12 road,
I patch the waves and stitch the tiles,
Reject the odd, debug with smiles,
A rabbit cheers — the kernels load!"

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title describes fixing RDNA4 CI failures, which aligns with the main objective of supporting gfx1201 Wave32 hardware and fixing related correctness bugs across multiple modules.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

Review ran into problems

🔥 Problems

Git: Failed to clone repository. Please run the @coderabbitai full review command to re-trigger a full review. If the issue persists, set path_filters to include or exclude specific files.

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/tl_templates/hip/reduce.h`:
- Around line 151-152: The static_assert allowing threads==32 must also enforce
that threads is at least the hardware wavefront size to avoid wave64/threads32
mismatches; update the check that currently lists allowed thread counts (the
static_assert containing "threads == 1024 or ... or threads == 32") to
additionally require threads >= __builtin_amdgcn_wavefrontsize(), or replace it
with a combined condition (allowed sizes AND threads >=
__builtin_amdgcn_wavefrontsize()) so that configurations used by run_seg<T, 64>
never run with fewer active lanes than the wavefront.
- Around line 232-239: The template CumSum2D currently allows thread counts
smaller than the hardware wavefront which leads to TILE_H = threads/SEG becoming
zero and broken shuffles; add a compile-time constraint to enforce threads >=
hardware wavefront size by augmenting the existing static_assert in struct
CumSum2D (affecting run_seg and the TILE_H calculation) to require threads >=
the platform wavefront constant (e.g., TL_WAVE_SZ or equivalent), so invalid
instantiations (like threads=32 on wave64) fail to compile instead of producing
TILE_H==0 and corrupted shuffle behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2b7e4095-87e2-4fc9-a871-03f9485e7bcd

📥 Commits

Reviewing files that changed from the base of the PR and between b4bb6b5 and f085364.

📒 Files selected for processing (3)
  • src/tl_templates/hip/reduce.h
  • tilelang/jit/adapter/tvm_ffi.py
  • tilelang/rocm/intrinsics/wmma_macro_generator.py

Comment thread src/tl_templates/hip/reduce.h Outdated
Comment on lines +151 to +152
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64);
template <typename T, int SEG = 64>
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int N) {
if (N <= 0)
return;
threads == 128 or threads == 64 or threads == 32);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Add constraint: threads must be >= wavefront size.

Allowing threads == 32 without enforcing threads >= __builtin_amdgcn_wavefrontsize() causes a critical mismatch on CDNA (wave64). When compiled for CDNA with threads=32, the dispatch at lines 224–228 calls run_seg<T, 64>, but only 32 threads are active. Shuffle operations in run_seg then expect 64 active lanes while only lanes 0–31 are populated, producing incorrect cumsum results.

🔒 Proposed fix: enforce threads >= wavefront size
 template <int threads, bool reverse = false> struct CumSum1D {
   static_assert(threads == 1024 or threads == 512 or threads == 256 or
-                threads == 128 or threads == 64 or threads == 32);
+                threads == 128 or threads == 64 or threads == 32);
+  static_assert(threads >= __builtin_amdgcn_wavefrontsize(),
+                "CumSum1D: threads must be >= hardware wavefront size");
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/hip/reduce.h` around lines 151 - 152, The static_assert
allowing threads==32 must also enforce that threads is at least the hardware
wavefront size to avoid wave64/threads32 mismatches; update the check that
currently lists allowed thread counts (the static_assert containing "threads ==
1024 or ... or threads == 32") to additionally require threads >=
__builtin_amdgcn_wavefrontsize(), or replace it with a combined condition
(allowed sizes AND threads >= __builtin_amdgcn_wavefrontsize()) so that
configurations used by run_seg<T, 64> never run with fewer active lanes than the
wavefront.

Comment thread src/tl_templates/hip/reduce.h
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/tl_templates/hip/reduce.h (1)

51-57: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Enforce the full-wavefront precondition in SharedReduceWarp.

This path now derives kWarpSize from the target, but it still treats Threads / kWarpSize as a valid warp count. On wave64 targets, any SharedReduceWarp<..., 32, ...> instantiation makes num_warps == 0, so the loop at Line 59 never progresses (dest_idx += num_warps). Please either guard this template against sub-wavefront thread counts or confirm the ROCm lowering path never emits such instantiations.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/hip/reduce.h` around lines 51 - 57, SharedReduceWarp
currently derives kWarpSize at runtime and computes num_warps = Threads /
kWarpSize, which can be zero on wave64 hardware for templates instantiated with
Threads < hardware wave size (e.g., SharedReduceWarp<...,32,...>), so add a
guard to prevent num_warps==0: in SharedReduceWarp (and where kWarpSize, tid,
warp_id, lane, num_warps are computed) enforce either a compile-time check
(static_assert) when __builtin_amdgcn_wavefrontsize() is a constant to require
Threads >= kWarpSize and Threads % kWarpSize == 0, or add a runtime fallback
that sets num_warps = max(1, Threads / kWarpSize) or returns/handles the
single-partition case so the dest_idx loop (which increments by num_warps) never
stalls; reference the symbols SharedReduceWarp, kWarpSize, Threads, num_warps,
and dest_idx when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@testing/python/language/test_tilelang_language_eager_jit.py`:
- Around line 77-83: The try/except around determine_target("auto",
return_object=True) currently catches Exception and defaults _is_cuda = True,
which masks unexpected errors and incorrectly enables CUDA-only float32 paths;
update the block to catch only specific expected errors (e.g., ValueError,
RuntimeError, OSError) and use a conservative fallback of _is_cuda = False when
detection fails; locate the code around determine_target and target_is_cuda and
change the except Exception to except (ValueError, RuntimeError, OSError) (or
the specific exceptions your detection can raise) and set _is_cuda = False so
in_dtypes selects the non-CUDA-safe list when detection cannot confirm CUDA.

---

Outside diff comments:
In `@src/tl_templates/hip/reduce.h`:
- Around line 51-57: SharedReduceWarp currently derives kWarpSize at runtime and
computes num_warps = Threads / kWarpSize, which can be zero on wave64 hardware
for templates instantiated with Threads < hardware wave size (e.g.,
SharedReduceWarp<...,32,...>), so add a guard to prevent num_warps==0: in
SharedReduceWarp (and where kWarpSize, tid, warp_id, lane, num_warps are
computed) enforce either a compile-time check (static_assert) when
__builtin_amdgcn_wavefrontsize() is a constant to require Threads >= kWarpSize
and Threads % kWarpSize == 0, or add a runtime fallback that sets num_warps =
max(1, Threads / kWarpSize) or returns/handles the single-partition case so the
dest_idx loop (which increments by num_warps) never stalls; reference the
symbols SharedReduceWarp, kWarpSize, Threads, num_warps, and dest_idx when
making the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: fcb4a65b-119a-4e0d-ad97-3333a6d265c1

📥 Commits

Reviewing files that changed from the base of the PR and between f085364 and c17a2b0.

📒 Files selected for processing (3)
  • src/backend/rocm/op/finalize_reducer.cc
  • src/tl_templates/hip/reduce.h
  • testing/python/language/test_tilelang_language_eager_jit.py

Comment thread testing/python/language/test_tilelang_language_eager_jit.py
@zhangnju zhangnju changed the title [AMD][RDNA4]add gfx1201 support for carver module [AMD][RDNA4]Fix RDNA4 (gfx1201 / Wave32) CI Failures May 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant