Skip to content

[AMD][CDNA4] Add MXFP4 (FP4 E2M1) support for gfx950#2132

Open
zhangnju wants to merge 5 commits into
tile-ai:mainfrom
zhangnju:gfx950_mxfp4
Open

[AMD][CDNA4] Add MXFP4 (FP4 E2M1) support for gfx950#2132
zhangnju wants to merge 5 commits into
tile-ai:mainfrom
zhangnju:gfx950_mxfp4

Conversation

@zhangnju
Copy link
Copy Markdown
Collaborator

@zhangnju zhangnju commented Apr 30, 2026

This PR adds end-to-end MXFP4 (FP4 E2M1) support for AMD gfx950 (CDNA4 /MI350) in the HIP backend.

HIP codegen (codegen_hip.cc/.h)

  • Added GetFP4Type() to map float4_e2m1fn DataType to HIP C types (fp4_e2_t, fp4_e2_2_t, ..., fp4_e2_32_t).
  • Implemented vectorized FP4 ↔ float16 / float32 / bfloat16 / float64 cast codegen, processing two lanes at a time via pairwise helpers (__tl_cvt_fp4x2_to_half2, __tl_cvt_fp4x2_to_bfloat162, etc.).
  • Added enable_fp4_ flag: when a FP4 type is encountered, Finish() automatically includes hip_fp4.h.

New header src/tl_templates/hip/hip_fp4.h

  • Defines FP4 scalar and vector types for gfx950.
  • Provides conversion intrinsic wrappers, guarded by #if defined(gfx950).

Dequantization kernels (tilelang/quantize/mxfp.py)

  • decode_f4_to_bf16_twiddling_hip: ports the CUDA PTX bit-twiddling algorithm to portable HIP C++ (no inline PTX), numerically equivalent to the CUDA reference.
  • decode_f4_to_bf16_simple_hip: a static-LUT fallback path for non-twiddling dequantization.

MFMA layout (tilelang/intrinsics/mfma_layout.py)

  • Extended MFMA matrix layout to support FP4 input operands on gfx950.

Tests & examples

  • testing/python/amd/test_tilelang_mxfp4_gfx950.py: covers FP4 copy, vectorized casts, and MXFP4 dequant-GEMM (both twiddling and simple paths); guarded by @requires_gfx950, silently skipped on other targets.
  • examples/: added an end-to-end BF16-output MXFP4 GEMM example (example_dequant_gemm_bf16_mxfp4_cdna4.py).

CI Test:

  • MI300 CI test: Pass
  • MI350 CI test: Pass

Summary by CodeRabbit

  • New Features

    • MXFP4 (FP4→BF16) support for AMD gfx950 with a fast hardware-backed dequant path.
  • Examples

    • New gfx950-targeted MXFP4 dequantize-GEMM example with autotuning, reference implementations, correctness checks, and benchmarks.
  • Tests

    • New tests for FP4 packed round-trip, both dequantize-GEMM modes, and source-generation validation.
  • Bug Fixes

    • Corrected MFMA 32×32 shared/register layout mapping.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 30, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b9119e2a-efcf-4b8c-a66d-7c0681238533

📥 Commits

Reviewing files that changed from the base of the PR and between e1c3005 and dd4e993.

📒 Files selected for processing (3)
  • src/backend/rocm/codegen/codegen_hip.cc
  • src/tl_templates/hip/hip_fp4.h
  • tilelang/quantize/mxfp.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • src/backend/rocm/codegen/codegen_hip.cc
  • src/tl_templates/hip/hip_fp4.h
  • tilelang/quantize/mxfp.py

📝 Walkthrough

Walkthrough

Adds gfx950 HIP FP4 types and device conversions, packed scalar FP4 storage/load/store, two gfx950 FP4→BF16 decoder intrinsics with target dispatch, a dequantize-GEMM example/tests, and fixes MFMA 32x32 shared→local layout mappings.

Changes

HIP codegen FP4 support for gfx950

Layer / File(s) Summary
HIP FP4 codegen state and declarations
src/backend/rocm/codegen/codegen_hip.h
BufferStoreNode visitor override and FP4 tracking state (enable_fp4_, fp4_packed_buffers_ map).
FP4 type printing and header inclusion
src/backend/rocm/codegen/codegen_hip.cc
GetFP4Type helper, PrintType FP4 handling, and conditional inclusion of hip_fp4.h in Finish().
FP4 vectorized type conversions
src/backend/rocm/codegen/codegen_hip.cc
VisitExpr_(CastNode) emits pairwise conversions between FP4 and float/half/bfloat16/double using hip_fp4 intrinsics and packed-byte intermediates.
FP4 packed-buffer allocation, load and store
src/backend/rocm/codegen/codegen_hip.cc
Local scalar FP4 allocations become packed fp4_e2_2_t arrays; scalar loads use tl_fp4_packed_load; scalar stores emit tl_fp4_packed_store.
hip_fp4 device helpers
src/tl_templates/hip/hip_fp4.h
gfx950-only header defining fp4_e2_t/packed group types, float↔FP4 conversion, FP4↔half/float/double/bfloat16 helpers, and tl_fp4_packed_load/tl_fp4_packed_store.

MXFP4 dequant intrinsics and dispatch

Layer / File(s) Summary
gfx950 twiddling and simple decoders
tilelang/quantize/mxfp.py
Adds decode_f4_to_bf16_twiddling_hip (portable C++ twiddling) and decode_f4_to_bf16_simple_hip (LUT) c_source strings for gfx950 HIP.
Target-aware intrin dispatch
tilelang/quantize/mxfp.py
Extend get_mxfp_intrin_group(..., target=None) to detect gfx950 and return HIP c_source for FP4→BF16 (enforcing 4-bit→bfloat16), fallback to existing CUDA/PTX mapping otherwise.

MFMA 32x32 C layout mapping fix

Layer / File(s) Summary
32x32 C layout mapping and accessors
tilelang/rocm/intrinsics/mfma_layout.py
Fix shared_32x32_to_local_64x16_layout_C to compute tid_high/thread_id/local_id correctly and update thread_id_shared_access_* index formulas for n_m and m_n orderings.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • LJC00118

Poem

🐰 In tiny nibbles four bits play,
hip_fp4 hums at break of day,
packed loads, stores, and casts align,
MFMA threads now march in line,
dequant blooms fast — a rabbit's rhyme.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.31% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[AMD][CDNA4] Add MXFP4 (FP4 E2M1) support for gfx950' accurately and concisely summarizes the main objective of the changeset: adding MXFP4 support for AMD gfx950.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/target/codegen_hip.cc (1)

1727-1757: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Scalar FP4 packing is only applied to local, but the accessors assume it everywhere.

GetBufferRef() (Line 1079 onward) and the new BufferStoreNode path (Line 1789 onward) always use tl_fp4_packed_load/store for scalar FP4 buffers, but this allocator only packs scope == "local". shared and local.var allocations still have fp4_e2_t layout, so their logical indexing no longer matches the physical storage and neighboring elements can alias/corrupt each other. Either pack every scalar-FP4 allocation that goes through the nibble helpers, or gate those helpers to buffers recorded in fp4_packed_buffers_.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_hip.cc` around lines 1727 - 1757, The bug: GetBufferRef
and the BufferStoreNode path assume scalar-FP4 buffers are packed but the
allocator only packs when is_fp4_scalar_local (scope == "local"), causing
mis-indexing for shared/local.var buffers; to fix, either (A) consistently emit
packed storage for every scalar FP4 allocation that uses the nibble helpers
(change the allocation path that creates fp4_e2_2_t in the code that sets
fp4_packed_buffers_), or (B) safer minimal change—gate all uses of
tl_fp4_packed_load/tl_fp4_packed_store to only buffers recorded in
fp4_packed_buffers_: modify GetBufferRef(...) and the BufferStoreNode handling
to check fp4_packed_buffers_.count(op->buffer_var.get()) (or equivalent) before
emitting packed helpers, and fall back to fp4_e2_t accessors for others; update
any code that constructs vid_packed (where fp4_e2_2_t is emitted) to ensure
fp4_packed_buffers_ is the single source-of-truth.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/target/codegen_hip.cc`:
- Around line 806-808: The fp4 handling currently sets fp4_pair_cast based on
fp4_lanes (variable fp4_lanes) but only allows 2/4/8, which lets x16/x32 fall
through to generic vector casting and causes PrintVecElemLoad/Store to emit
invalid fp4_e2_* accesses; either extend the pairwise lowering to support 16 and
32 lanes (update fp4_pair_cast to include 16 and 32 and implement corresponding
pairwise lowering logic) or explicitly reject unsupported fp4 lane widths by
adding an ICHECK (or CHECK) where fp4_pair_cast is computed to assert fp4_lanes
is one of the supported values; update the same area in
src/target/codegen_hip.cc so casts for FP4x16/FP4x32 never fall through to
PrintVecElemLoad/Store.
- Around line 818-827: The code currently takes the address of a prvalue from
__tl_cvt_fp4x2_to_half2, which is invalid in C++; instead, materialize the
conversion into a local half2 temporary (e.g. declare a variable like half2
tmp_half2 = __tl_cvt_fp4x2_to_half2(((uint8_t*)&(src))[i/2]);) and then build
v0/v1 to reference its components (prefer member access like tmp_half2.x and
tmp_half2.y) before calling PrintVecElemStore(sret, target_ty, i, ...) and
PrintVecElemStore(sret, target_ty, i+1, ...). Ensure the temporary name is
unique per loop iteration to avoid collisions.

In `@src/tl_templates/hip/hip_fp4.h`:
- Around line 35-38: The FP4 decoder in hip_fp4.h currently maps exp==0 and
mant==1 to 0.25f, which conflicts with __tl_float_to_fp4() and the new LUTs that
treat nibble 1 as 0.5; update the denormal handling in the decoder (the exp==0
branch that sets result based on mant) to return 0.5f for mant==1 (instead of
0.25f) so decoding matches the encoder and the mxfp LUTs (or alternatively make
the encoder/LUTs consistent with the decoder, but prefer updating the decoder to
match __tl_float_to_fp4() and tilelang/quantize/mxfp.py).

In `@tilelang/quantize/mxfp.py`:
- Around line 201-207: The current try/except around importing and calling
target_is_gfx950 silently swallows all exceptions causing a HIP target to fall
back to the CUDA/PTX path; update the block that references target_is_gfx950 and
_is_gfx950 to only catch expected import/availability errors (e.g.,
ImportError/ModuleNotFoundError/AttributeError) or, if any other exception is
raised by target_is_gfx950, re-raise it with additional context (e.g., include
the provided target) instead of pass so target-detection failures do not
silently change backend selection.

---

Outside diff comments:
In `@src/target/codegen_hip.cc`:
- Around line 1727-1757: The bug: GetBufferRef and the BufferStoreNode path
assume scalar-FP4 buffers are packed but the allocator only packs when
is_fp4_scalar_local (scope == "local"), causing mis-indexing for
shared/local.var buffers; to fix, either (A) consistently emit packed storage
for every scalar FP4 allocation that uses the nibble helpers (change the
allocation path that creates fp4_e2_2_t in the code that sets
fp4_packed_buffers_), or (B) safer minimal change—gate all uses of
tl_fp4_packed_load/tl_fp4_packed_store to only buffers recorded in
fp4_packed_buffers_: modify GetBufferRef(...) and the BufferStoreNode handling
to check fp4_packed_buffers_.count(op->buffer_var.get()) (or equivalent) before
emitting packed helpers, and fall back to fp4_e2_t accessors for others; update
any code that constructs vid_packed (where fp4_e2_2_t is emitted) to ensure
fp4_packed_buffers_ is the single source-of-truth.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 529c506e-c76e-4ad7-82a3-d5c03629a45b

📥 Commits

Reviewing files that changed from the base of the PR and between 936ae92 and fb1d0c3.

📒 Files selected for processing (7)
  • examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_cdna4.py
  • src/target/codegen_hip.cc
  • src/target/codegen_hip.h
  • src/tl_templates/hip/hip_fp4.h
  • testing/python/amd/test_tilelang_mxfp4_gfx950.py
  • tilelang/intrinsics/mfma_layout.py
  • tilelang/quantize/mxfp.py

Comment thread src/backend/rocm/codegen/codegen_hip.cc
Comment thread src/backend/rocm/codegen/codegen_hip.cc Outdated
Comment thread src/tl_templates/hip/hip_fp4.h
Comment thread tilelang/quantize/mxfp.py
zhangnju and others added 2 commits May 15, 2026 09:51
Four issues from CodeRabbit review fixed:

1. hip_fp4.h: Fix FP4 E2M1 denormal decoder round-trip bug.
   The decoder returned 0.25f for mant==1 (exp==0) but the encoder
   maps nibble 1 to 0.5f, breaking float→fp4→float round-trips.
   Fix: return 0.5f to match the encoder and LUT tables.

2. codegen_hip.cc: Fix UB address-of-prvalue in FP4→float16 cast.
   The generated code took &(__tl_cvt_fp4x2_to_half2(...)), which
   takes the address of a temporary (C++ UB). Fix: materialize the
   uint1 return value into a local variable first, matching the
   pattern already used in the float32/bfloat16 cast paths.

3. codegen_hip.cc: Extend fp4_pair_cast to cover FP4x16 and FP4x32.
   Previously only 2/4/8-lane FP4 casts were handled; x16/x32 fell
   through silently to generic vector code that cannot index fp4_e2_*
   aggregates. The pairwise uint8_t byte logic is correct for any
   even lane count, so extend to include 16 and 32 lanes.

4. mxfp.py: Narrow bare except to expected import-time errors.
   The catch-all except Exception silently swallowed all errors from
   target_is_gfx950(), causing HIP targets to fall back to CUDA/PTX.
   Fix: catch only ImportError/ModuleNotFoundError/AttributeError.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
tilelang/rocm/intrinsics/mfma_layout.py (1)

131-137: 🏗️ Heavy lift

Remove unused 32x32 layout helper functions or consolidate duplicates.

The functions thread_id_shared_access_64x16_to_32x32_layout_C_n_m and shared_32x32_to_local_64x16_layout_C are never imported or called anywhere in the codebase. Only thread_id_shared_access_64x16_to_32x32_layout_C_m_n is actively used. Additionally, _n_m and _m_n contain identical logic (only variable names differ), creating dead code and API confusion. Consider removing the unused functions or consolidating them into a single, well-named implementation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/rocm/intrinsics/mfma_layout.py` around lines 131 - 137, The two
helper functions thread_id_shared_access_64x16_to_32x32_layout_C_n_m and
shared_32x32_to_local_64x16_layout_C are dead/duplicated with
thread_id_shared_access_64x16_to_32x32_layout_C_m_n; remove the unused ones (or
consolidate them into a single canonical function) so only the active API
thread_id_shared_access_64x16_to_32x32_layout_C_m_n remains; if you consolidate,
merge identical logic, keep one clear name, update any internal references to
that single function, and delete the other definitions to avoid API confusion.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/backend/rocm/codegen/codegen_hip.cc`:
- Around line 243-245: The code emits FP4-specific types (e.g., fp4_e2_2_t,
tl_fp4_packed_*) directly in several places (including the block guarded by
enable_fp4_ and the ranges around lines 1079-1094, 1753-1760, 1799-1806) but
only conditionally includes "tl_templates/hip/hip_fp4.h" when enable_fp4_ is
true; update those emission sites so they either call PrintType() (which will
trigger Finish() to include the header) or explicitly check enable_fp4_ before
emitting FP4 names; specifically, modify the blocks that produce fp4_e2_2_t /
tl_fp4_packed_* to guard on enable_fp4_ or route through PrintType() so Finish()
will not skip adding hip_fp4.h when FP4 symbols are emitted.
- Around line 1079-1094: The FP4 scalar local path must be preserved: when
handling t.is_float4() && t.is_scalar(), detect if the target buffer is a scalar
local allocation emitted by AllocateNode (scope == "local.var") and avoid
treating its vid as a packed-buffer pointer; do not unconditionally call
tl_fp4_packed_load((fp4_e2_2_t*)vid, ...). Instead, when buffer_var/vid refers
to a local.var scalar (instead of an entry in fp4_packed_buffers_), emit the
scalar access path using vid directly (or the existing scalar load/store code)
and only use fp4_packed_buffers_ / tl_fp4_packed_load for actual packed buffers;
update the branch around fp4_packed_buffers_, vid, buffer_var and
t.is_float4()/t.is_scalar() to guard this case.
- Around line 1783-1786: The override CodeGenTileLangHIP::VisitStmt_(const
BufferStoreNode *op) currently rejects predicated or non-flat stores
unconditionally; move those ICHECKs into the FP4-specific handling path so they
only apply when emitting the FP4 path, and for all other cases delegate to the
base C implementation (call CodeGenC::VisitStmt_(op)) so generic BufferStoreNode
cases (predicated or non-flat) are not hard-failed by the HIP override.

In `@tilelang/rocm/intrinsics/mfma_layout.py`:
- Around line 140-159: The function
thread_id_shared_access_64x16_to_32x32_layout_C_n_m currently computes i and j
identically to thread_id_shared_access_64x16_to_32x32_layout_C_m_n but returns
(i, j), violating the _n_m convention; change its return to swap the order
(return j, i or equivalently return n, m) so it returns (n, m) (column, row)
consistent with the established transpose convention and keep the same index
computations.

---

Nitpick comments:
In `@tilelang/rocm/intrinsics/mfma_layout.py`:
- Around line 131-137: The two helper functions
thread_id_shared_access_64x16_to_32x32_layout_C_n_m and
shared_32x32_to_local_64x16_layout_C are dead/duplicated with
thread_id_shared_access_64x16_to_32x32_layout_C_m_n; remove the unused ones (or
consolidate them into a single canonical function) so only the active API
thread_id_shared_access_64x16_to_32x32_layout_C_m_n remains; if you consolidate,
merge identical logic, keep one clear name, update any internal references to
that single function, and delete the other definitions to avoid API confusion.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ca0cb00c-f409-4b0c-8168-d85bfa077830

📥 Commits

Reviewing files that changed from the base of the PR and between fb1d0c3 and b20d2e4.

📒 Files selected for processing (3)
  • src/backend/rocm/codegen/codegen_hip.cc
  • src/backend/rocm/codegen/codegen_hip.h
  • tilelang/rocm/intrinsics/mfma_layout.py

Comment on lines +243 to +245
if (enable_fp4_) {
decl_stream << "#include <tl_templates/hip/hip_fp4.h>\n";
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Set enable_fp4_ anywhere FP4 names are emitted directly.

These branches emit fp4_e2_2_t / tl_fp4_packed_* without going through PrintType(), so Finish() can still skip hip_fp4.h for FP4-only packed-local code paths.

Also applies to: 1079-1094, 1753-1760, 1799-1806

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/backend/rocm/codegen/codegen_hip.cc` around lines 243 - 245, The code
emits FP4-specific types (e.g., fp4_e2_2_t, tl_fp4_packed_*) directly in several
places (including the block guarded by enable_fp4_ and the ranges around lines
1079-1094, 1753-1760, 1799-1806) but only conditionally includes
"tl_templates/hip/hip_fp4.h" when enable_fp4_ is true; update those emission
sites so they either call PrintType() (which will trigger Finish() to include
the header) or explicitly check enable_fp4_ before emitting FP4 names;
specifically, modify the blocks that produce fp4_e2_2_t / tl_fp4_packed_* to
guard on enable_fp4_ or route through PrintType() so Finish() will not skip
adding hip_fp4.h when FP4 symbols are emitted.

Comment on lines +1079 to +1094
// FP4 scalar access on gfx950: redirect to tl_fp4_packed_load helper.
// Non-scalar FP4 accesses fall through to the normal path (the vector
// types fp4_e2_4_t etc. are directly addressable as structs).
if (t.is_float4() && t.is_scalar()) {
std::string idx_str = PrintExpr(index);
auto packed_it = fp4_packed_buffers_.find(buffer_var);
if (packed_it != fp4_packed_buffers_.end()) {
// Packed local buffer: use the pre-allocated fp4_e2_2_t array.
os << "tl_fp4_packed_load(" << packed_it->second << ", " << idx_str
<< ")";
} else {
// Non-packed (e.g. shared) buffer: reinterpret as fp4_e2_2_t*.
os << "tl_fp4_packed_load((fp4_e2_2_t*)" << vid << ", " << idx_str << ")";
}
return os.str();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve the existing local.var scalar path for FP4.

AllocateNode still emits fp4_e2_t <vid> for scope == "local.var", but both the load and store fast paths unconditionally reinterpret that scalar as fp4_e2_2_t*. That turns a scalar object into a fake packed buffer pointer and generates invalid C++ for FP4 scalar locals.

Suggested fix
   if (t.is_float4() && t.is_scalar()) {
+    if (scope == "local.var") {
+      os << vid;
+      return os.str();
+    }
     std::string idx_str = PrintExpr(index);
     auto packed_it = fp4_packed_buffers_.find(buffer_var);
     if (packed_it != fp4_packed_buffers_.end()) {
       os << "tl_fp4_packed_load(" << packed_it->second << ", " << idx_str
          << ")";
@@
   if (element_dtype.is_float4() && element_dtype.is_scalar() &&
       value_dtype.is_scalar()) {
+    auto it = alloc_storage_scope_.find(buffer_var.get());
+    if (it != alloc_storage_scope_.end() && it->second == "local.var") {
+      CodeGenC::VisitStmt_(op);
+      return;
+    }
     std::string idx_str = PrintExpr(op->indices[0]);
     std::string value = this->PrintExpr(op->value);

Also applies to: 1753-1761, 1792-1808

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/backend/rocm/codegen/codegen_hip.cc` around lines 1079 - 1094, The FP4
scalar local path must be preserved: when handling t.is_float4() &&
t.is_scalar(), detect if the target buffer is a scalar local allocation emitted
by AllocateNode (scope == "local.var") and avoid treating its vid as a
packed-buffer pointer; do not unconditionally call
tl_fp4_packed_load((fp4_e2_2_t*)vid, ...). Instead, when buffer_var/vid refers
to a local.var scalar (instead of an entry in fp4_packed_buffers_), emit the
scalar access path using vid directly (or the existing scalar load/store code)
and only use fp4_packed_buffers_ / tl_fp4_packed_load for actual packed buffers;
update the branch around fp4_packed_buffers_, vid, buffer_var and
t.is_float4()/t.is_scalar() to guard this case.

Comment on lines +1783 to +1786
void CodeGenTileLangHIP::VisitStmt_(const BufferStoreNode *op) {
ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer store is not supported.";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't reject generic BufferStoreNodes in the override.

These ICHECKs run before the non-FP4 fallback, so predicated or non-flat stores now hard-fail on HIP even when they never touch the FP4 path. Keep the constraints inside the FP4-specific branch and delegate everything else directly to CodeGenC.

Suggested fix
 void CodeGenTileLangHIP::VisitStmt_(const BufferStoreNode *op) {
-  ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
-  ICHECK(!op->predicate.defined())
-      << "Predicated buffer store is not supported.";
-
   DataType value_dtype = op->value.dtype();
   DataType element_dtype = op->buffer->dtype;
   Var buffer_var = op->buffer->data;
 
   // FP4 scalar store: use tl_fp4_packed_store to correctly handle nibble-level
   // writes without corrupting the neighbouring nibble.
   if (element_dtype.is_float4() && element_dtype.is_scalar() &&
       value_dtype.is_scalar()) {
+    ICHECK_EQ(op->indices.size(), 1)
+        << "Store to non-flat FP4 memory not supported.";
+    ICHECK(!op->predicate.defined())
+        << "Predicated FP4 buffer store is not supported.";
     std::string idx_str = PrintExpr(op->indices[0]);
     std::string value = this->PrintExpr(op->value);
     this->PrintIndent();
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/backend/rocm/codegen/codegen_hip.cc` around lines 1783 - 1786, The
override CodeGenTileLangHIP::VisitStmt_(const BufferStoreNode *op) currently
rejects predicated or non-flat stores unconditionally; move those ICHECKs into
the FP4-specific handling path so they only apply when emitting the FP4 path,
and for all other cases delegate to the base C implementation (call
CodeGenC::VisitStmt_(op)) so generic BufferStoreNode cases (predicated or
non-flat) are not hard-failed by the HIP override.

Comment on lines 140 to 159
def thread_id_shared_access_64x16_to_32x32_layout_C_n_m(thread_id, local_id):
i = (thread_id // 32) * 4 + local_id % 4 + (local_id // 4) * 8
j = thread_id % 32
# Returns (row=M, col=N) for v_mfma_i32_32x32x32_i8 output layout.
# tid%32 = M_row, (tid//32)*4 + lid%4 + (lid//4)*8 = N_col.
i = thread_id % 32
j = (thread_id // 32) * 4 + local_id % 4 + (local_id // 4) * 8
return i, j


def thread_id_shared_access_64x16_to_32x32_layout_C_m_n(thread_id, local_id):
"""Return (m, n) = (row, col) for the 32x32 MFMA output register layout.

For v_mfma_i32_32x32x32_i8 (gfx950), each wave-64 lane holds 16 output
i32 values. The column (N-dimension) is indexed by ``thread_id % 32``
and the row (M-dimension) is given by the interleaved formula below.
i32 values. The row (M-dimension) is indexed by ``thread_id % 32``
and the column (N-dimension) is given by the interleaved formula below.
This function returns ``(m_idx, n_idx)`` matching the ``(row, col)``
convention expected by ``stmatrix``.
"""
m = (thread_id // 32) * 4 + local_id % 4 + (local_id // 4) * 8
n = thread_id % 32
m = thread_id % 32
n = (thread_id // 32) * 4 + local_id % 4 + (local_id // 4) * 8
return m, n
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
rg -nP -C5 '\bthread_id_shared_access_64x16_to_32x32_layout_C_(n_m|m_n)\b'
# Compare against the established 16x16 usage to confirm the (n, m) vs (m, n) convention.
rg -nP -C5 '\bthread_id_shared_access_64x4_to_16x16_layout_C_(n_m|m_n)\b'

Repository: tile-ai/tilelang

Length of output: 6592


🏁 Script executed:

# Check if the 32x32 _n_m function is used anywhere at all
rg 'thread_id_shared_access_64x16_to_32x32_layout_C_n_m' -A 2 -B 2

# Also check if there's a __all__ export list in mfma_layout.py
rg '__all__' tilelang/rocm/intrinsics/mfma_layout.py -A 10

Repository: tile-ai/tilelang

Length of output: 475


thread_id_shared_access_64x16_to_32x32_layout_C_n_m violates the naming convention and will produce incorrect results if called.

The 32×32 _n_m function is defined but unused. However, it violates the established transpose convention seen in the 16×16 case (lines 71–80), where:

  • _m_n returns (m, n) with values in row-major order
  • _n_m returns (n, m) with values swapped for column-major / transposed access

The 32×32 _n_m currently returns the same tuple as _m_n instead of swapping. If this function is ever called—especially if later code adopts the 16×16 pattern of using _n_m for transposed shared-memory layouts—it will silently return incorrect indices.

Swap the return values to restore the convention:

Proposed fix
 def thread_id_shared_access_64x16_to_32x32_layout_C_n_m(thread_id, local_id):
-    # Returns (row=M, col=N) for v_mfma_i32_32x32x32_i8 output layout.
-    # tid%32 = M_row, (tid//32)*4 + lid%4 + (lid//4)*8 = N_col.
-    i = thread_id % 32
-    j = (thread_id // 32) * 4 + local_id % 4 + (local_id // 4) * 8
+    # Transposed (N, M) view of the v_mfma_i32_32x32x32_i8 output layout.
+    # tid%32 = M, (tid//32)*4 + lid%4 + (lid//4)*8 = N. First returned index is N.
+    i = (thread_id // 32) * 4 + local_id % 4 + (local_id // 4) * 8
+    j = thread_id % 32
     return i, j
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/rocm/intrinsics/mfma_layout.py` around lines 140 - 159, The function
thread_id_shared_access_64x16_to_32x32_layout_C_n_m currently computes i and j
identically to thread_id_shared_access_64x16_to_32x32_layout_C_m_n but returns
(i, j), violating the _n_m convention; change its return to swap the order
(return j, i or equivalently return n, m) so it returns (n, m) (column, row)
consistent with the established transpose convention and keep the same index
computations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant