Skip to content

[Feature] Support Blackwell FP4(float4_e2m1fn) GEMM for SM100 & SM120#2182

Open
Hale423 wants to merge 15 commits into
tile-ai:mainfrom
Hale423:feat/gemm-fp4
Open

[Feature] Support Blackwell FP4(float4_e2m1fn) GEMM for SM100 & SM120#2182
Hale423 wants to merge 15 commits into
tile-ai:mainfrom
Hale423:feat/gemm-fp4

Conversation

@Hale423
Copy link
Copy Markdown
Contributor

@Hale423 Hale423 commented May 11, 2026

Summary

Extended and seperated from PR 1918, this PR adds and fixes Blackwell FP4 GEMM examples for both SM100/SM110 and SM120.

For SM100/SM110, it implements non-block-scaled FP4 GEMM through the TCGEN05 path using T.float4_e2m1fn, TMA, TMEM accumulation, and CUTLASS-aligned 16U4_ALIGN16B unpacksmem semantics. It also adds SM100 A8W4 and simplified fused MoE examples using FP8 activations and FP4 weights.

For SM120, it keeps the fragment MMA path and fixes FP4 lowering so unpacked uint8 storage is interpreted as float4_e2m1fn compute operands with the correct m16n8k32 MMA shape and FP4 ldmatrix offsets.

Changes

  • Add SM100 FP4 examples under examples/gemm_fp4/:
    • example_gemm_fp4_sm100.py
    • example_gemm_a8w4_sm100.py
    • example_fusedmoe_a8w4_sm100.py
  • Keep SM120 FP4 examples under examples/gemm_fp4/:
    • example_gemm_fp4_sm120.py
    • example_gemm_a8w4_sm120.py
    • example_fusedmoe_a8w4_sm120.py
  • Remove generated .cu files and temporary FP4 diagnostic/probe files.
  • Add SM100 FP4 unpacksmem layout support aligned with 16U4_ALIGN16B.
  • Fix FP4 shared-memory codegen so SM100 unpacksmem uses 8-bit container semantics instead of packed /2 addressing.
  • Fix TMA transaction byte accounting for FP4 descriptor loads.
  • Add mixed A/B dtype support in TCGEN05 instruction descriptor generation for A8W4.
  • Fix SM120 FP4 lowering:
    • map unpacked uint8 A/B storage to float4_e2m1fn compute dtype
    • cap FP4 MMA micro-K at k32
    • restore FP4-specific ldmatrix offsets for unpacked uint8 storage

Test Plan

Verified on SM100/SM110:

TILELANG_DISABLE_CACHE=1 python examples/gemm_fp4/example_gemm_fp4_sm100.py
TILELANG_DISABLE_CACHE=1 python examples/gemm_fp4/example_gemm_a8w4_sm100.py
TILELANG_DISABLE_CACHE=1 python examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py

Verified on SM120:

TILELANG_DISABLE_CACHE=1 python examples/gemm_fp4/example_gemm_fp4_sm120.py
TILELANG_DISABLE_CACHE=1 python examples/gemm_fp4/example_gemm_a8w4_sm120.py
TILELANG_DISABLE_CACHE=1 python examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py

Summary by CodeRabbit

  • New Features

    • Added broad FP4 (4‑bit float) support across CUDA backends, mixed FP8×FP4 MMA/GEMM paths, device helpers, and shared‑memory/layout transforms.
    • Added multiple end‑to‑end example scripts demonstrating fused MoE and GEMM workflows with validation and performance benchmarks.
  • Bug Fixes / Correctness

    • Improved FP4 packing/unpacking and scope‑aware memory handling to ensure correct accesses and layout inference for sub‑byte types.

Review Change Stack

Hale423 and others added 8 commits March 12, 2026 10:55
Add the plumbing required to route float4_e2m1fn through the TCGEN5 MMA
code-generation path so that FP4 GEMM kernels can be emitted on SM100.

Changes:
- ptx.h / ptx.cc: add kFloat4_e2m1fn enum, string tables, DTypeFromString
- common.h: add kFloat4_e2m1fn to device-side DataType enum
- tcgen5_meta.h: add FP4 branch in encode_dtype (format code 2)
- tcgen05mma.h: add kFloat4_e2m1fn specializations for SS/TS/WS_SS
  (delegates to the existing f8f6f4 PTX kind)
- mma_macro_generator.py: add dtype_abbrv mapping for float4_e2m1fn
- docs/GEMM_NV_FP4_FEATURE_STEPS.md: design doc and progress tracker

Addresses tile-ai#1592

Made-with: Cursor
… raw bytes, get_ldmatrix_offset add fp4 special layout, add SM120_FP4_FP4_F32_TN MmaDispatcher specialization + fp4 << 2 bit-shift.
Co-authored-by: Cursor <cursoragent@cursor.com>
… as payload, functionality verified valid for naive gemm-fp4
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 11, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds FP4 (float_e2m1) support end-to-end: new FP4 type wrapper and CUDA aliases, ALIGN16B sub-byte swizzle/layout and ldmatrix loaders, TCGEN5 instruction descriptor split and Python packer, SM100/SM120 MMA traits and dispatch (including mixed FP8×FP4), scope-aware CUDA FP4 codegen and TMA copy fixes, TileLang dtype/layout wiring, and runnable FP4/A8W4 GEMM and fused MoE examples with verification and benchmarks.

Changes

FP4 Infrastructure and Examples

Layer / File(s) Summary
Fused MoE & A8W4 examples (SM100)
examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py, examples/gemm_fp4/example_gemm_a8w4_sm100.py
Add fused MoE A8W4 SM100 example and A8W4 GEMM example with FP4 LUT/unpack, prim_func kernels, zero tests, PyTorch reference validation, and benchmarks.
SM120 examples and fused MoE
examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py, examples/gemm_fp4/example_gemm_a8w4_sm120.py, examples/gemm_fp4/example_gemm_fp4_sm120.py
Add SM120 FP4/A8W4 examples with FP4 LUT/unpack helpers, fragment GEMM prim_funcs, zeros tests, LUT-based numeric validation, and 100-iteration timing/TFLOPS.
SM100 FP4 GEMM example
examples/gemm_fp4/example_gemm_fp4_sm100.py
Add TCGEN05/TMEM async FP4 GEMM example, packed-FP4 handling, unpack-to-float reference, zeros/numeric checks, and benchmark.
TCGEN5 instr desc (C++ & FFI)
src/op/tcgen5_meta.h, src/backend/cuda/op/gemm.cc
Change GetTCGEN5InstrDesc signature to accept separate a_dtype and b_dtype; update FFI binding to forward both dtypes.
TCGEN5 descriptor builder (Python)
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
Replace FFI path with Python-side 64-bit descriptor builder; unwrap BufferRegion for layout checks; fix B-desc bytes computation.
CUDA fp4 codegen scope handling
src/backend/cuda/codegen/codegen_cuda.cc
Gate packed-FP4 index/div arithmetic and packed-load/store to packed scopes only; remove local packed-scalar special-case so local FP4 uses normal allocation.
CUDA TMA bulk-copy (sub-byte) changes
src/backend/cuda/op/copy.cc
Detect ALIGN16B sub-byte layout and adjust instruction_dim, inner_box_dim_ and mbarrier byte sizing to account for FP4 payload bytes vs 16U4 container.
ALIGN16B sub-byte layout
src/layout/gemm_layouts.cc
Add MakeAlign16BSwizzleLayout2D + makeAlign16BSwizzleLayout; integrate into makeGemmABLayoutSm100 and DetectSwizzleMode for element_size<8.
FP4 type & CUDA templates
src/tl_templates/cuda/common.h, src/tl_templates/cuda/cuda_fp4.h
Add tl::float_e2m1_t wrapper and to_cute_type mapping; alias fp4_e2_t to tl::float_e2m1_t and update packing helpers to use raw().
SM120 ldmatrix helpers
src/tl_templates/cuda/ldsm.h
Add ptx_ldmatrix_b4x16_x1/_x2/_x4 helpers emitting PTX ldmatrix variants that unpack 4-bit values into 8-bit containers.
SM100/SM120 MMA dispatch & mma_sync
src/tl_templates/cuda/gemm_mma.h, src/tl_templates/cuda/instruction/mma.h
Include cuda_fp4.h in dispatch regions; add DispatchInstruction/MmaDispatcher specializations for FP4 and mixed FP8×FP4; preprocess FP4 operands (left-shift) in mma_sync.
SM100 MMA traits
src/tl_templates/cuda/gemm_sm100.h
Add MMA_Traits partial specializations and DispatchInstruction entries for FP4×FP4 and mixed FP8×FP4 on SM100 (SS/WS/TS variants) and extend tl_tcgen5mma dispatch constraints.
TileLang ldmatrix layout helpers & utils
tilelang/cuda/intrinsics/layout/mma_layout.py, tilelang/cuda/intrinsics/layout/utils.py
Add ldmatrix_32x16_to_shared_16x32_fp4_layout_a/b; extend get_ldmatrix_offset to support dtype_bits==4 and adjust 8-bit handling to remove packing-factor scaling.
MMA macro K-dim special-case
tilelang/cuda/intrinsics/macro/mma_macro_generator.py
Cap k_dim at 32 for float4_e2m1fn and enforce chunk >= 32.
GemmBase & GEMM TCGEN5 wiring
tilelang/tileop/gemm/gemm_base.py, tilelang/cuda/op/gemm/gemm_mma.py, tilelang/cuda/op/gemm/gemm_tcgen05.py
Remove A==B dtype assertion; add in_dtype_b; infer_shared_layout now accepts dtype, continuity, k_major and handles sub-8-bit case (consults tl.disable_tma_lower); emitter/lowering now pass b_dtype=self.in_dtype_b and allocate B_local using in_dtype_b.

Sequence Diagram(s)

sequenceDiagram
  participant Host as Python Host (examples)
  participant TL as TileLang (prim_func)
  participant TVM as TVM/Ffi (tl.get_tcgen5_instr_desc)
  participant NVCC as CUDA Compiler
  participant GPU as GPU Kernel (TCGEN05/Fragment MMA)

  Host->>TL: construct prim_func (tiles, TCGEN05/GEMM intrinsics)
  TL->>TVM: call Python TCGEN5 descriptor builder (pack a_dtype/b_dtype)
  TL->>NVCC: compile prim_func -> CUDA (pass config)
  NVCC->>GPU: load kernel (including ldmatrix/ldsm helpers)
  Host->>GPU: launch kernel with packed/uint8 FP4 weights
  GPU->>GPU: shared-memory ldmatrix (b4x16 loaders) -> MMA (FP4 / FP8×FP4)
  GPU-->>Host: output tensor
  Host->>Host: unpack LUT and verify / benchmark
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • LJC00118

"🐰 I hop through LUTs and tiled arrays,
nibble-shifted floats in CUDA ballets,
ALIGN16B and TCGEN5 in my hop,
Benchmarks hum — the kernels don't stop!"

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.52% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title '[Feature] Support Blackwell FP4(float4_e2m1fn) GEMM for SM100 & SM120' is clear, specific, and directly describes the main feature being added: FP4 GEMM support for two GPU architectures.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
build_sm120/.cmake/api/v1/reply/target-tvm_runtime-Release-69a4c57b2cf0c7008847.json (1)

1-688: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Build artifacts should not be committed to version control.

The entire build_sm120/ directory contains CMake-generated build artifacts and should be excluded from version control. These CMake File API reply JSON files are auto-generated during the CMake configuration phase and contain environment-specific absolute paths (e.g., /home/wahao/projects/tilelang/, /home/wahao/.cache/uv/).

Why this is a critical issue:

  • Build directories are local to each developer's machine
  • These files will conflict across different environments
  • They bloat the repository with machine-generated metadata
  • They contain user-specific paths that won't work for other developers

Recommended action:

  1. Remove the entire build_sm120/ directory from this PR
  2. Add build*/ or build_*/ to .gitignore to prevent future commits of build artifacts
  3. Document the build process in README/documentation if needed
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@build_sm120/.cmake/api/v1/reply/target-tvm_runtime-Release-69a4c57b2cf0c7008847.json`
around lines 1 - 688, Build artifacts under build_sm120/ (e.g.,
.cmake/api/v1/reply/*.json) were committed; remove them and ignore future
commits by deleting the directory from the PR (use git rm -r --cached
build_sm120/ and commit the removal), add ignore patterns like build*/ and
build_*/ (or explicitly build_sm120/) to .gitignore, and commit the updated
.gitignore; optionally add a short note in README about the project build/output
location.
build_sm120/.cmake/api/v1/reply/target-tvm_ffi_testing-Release-f49f6812d8501c253709.json (1)

1-354: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Build artifacts should not be committed to version control.

All files under build_sm120/ are machine-generated CMake build outputs containing absolute paths specific to the developer's machine (e.g., /home/wahao/projects/tilelang/, /tmp/tmp0gh_by0x/wheel/platlib). Committing these causes repository bloat, merge conflicts, reproducibility issues, and potential security concerns from exposing local file system paths.

🔧 Recommended fix

Add the build directory to .gitignore:

+build_sm120/
+build*/

Then remove these files from the repository:

git rm -r build_sm120/
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@build_sm120/.cmake/api/v1/reply/target-tvm_ffi_testing-Release-f49f6812d8501c253709.json`
around lines 1 - 354, Remove the committed machine-generated CMake outputs under
build_sm120 and prevent them from being re-added: add build_sm120/ to
.gitignore, then run git rm -r --cached build_sm120/ (or git rm -r build_sm120/)
and commit the removal; verify no build artifacts such as the JSON reply file
(target-tvm_ffi_testing-Release-*.json) or lib/libtvm_ffi.so remain tracked
before pushing.
🧹 Nitpick comments (1)
examples/gemm_fp4/example_gemm_fp4_sm120.py (1)

125-126: 💤 Low value

Consider gating the generated-CUDA dump behind a flag.

f.write(jit_kernel.get_kernel_source()) writes gemm_fp4_sm120.cu into the examples directory on every run, which is useful while debugging the SM120 FP4 lowering but is a side effect users running the example won't expect. Gating it behind something like TL_FP4_DUMP_CUDA (or removing it now that the lowering is fixed) would keep the example self-contained.

♻️ Optional: gate behind env var
-with open(os.path.join(os.path.dirname(__file__), "gemm_fp4_sm120.cu"), "w") as f:
-    f.write(jit_kernel.get_kernel_source())
+if os.environ.get("TL_FP4_DUMP_CUDA", "0") != "0":
+    with open(os.path.join(os.path.dirname(__file__), "gemm_fp4_sm120.cu"), "w") as f:
+        f.write(jit_kernel.get_kernel_source())
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 125 - 126, The
example currently always writes the generated CUDA source via
jit_kernel.get_kernel_source() to gemm_fp4_sm120.cu, causing an unexpected
side-effect; change the block that opens and writes the file to run only when an
environment flag is enabled (e.g., TL_FP4_DUMP_CUDA). Concretely, guard the
write with a check like os.getenv("TL_FP4_DUMP_CUDA") (or similar) surrounding
the existing with open(...) f.write(jit_kernel.get_kernel_source()) call so the
file is only created when the env var is truthy; keep the same filename
gemm_fp4_sm120.cu and the same use of jit_kernel.get_kernel_source() so behavior
is identical when the flag is set.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In
`@build_sm120/.cmake/api/v1/reply/target-tilelang_objs-Release-45ef627ca6e2f9aad6ad.json`:
- Around line 438-458: The committed CMake File API reply JSON (entries with
keys "path" and "backtrace" in
target-tilelang_objs-Release-45ef627ca6e2f9aad6ad.json) contains developer-local
absolute paths and must be removed from repo; delete or git rm the generated
JSON, add a rule to .gitignore (e.g., ignore
build_sm120/.cmake/api/v1/reply/*.json) so these artifacts aren't committed, and
adjust CI/build to regenerate this metadata as part of the build instead of
storing it; optionally add a small post-process step in the build (or configure
the CMake invocation) to sanitize or relativize the "path" values in CMake File
API replies if you need to commit similar outputs in the future.

In
`@build_sm120/.cmake/api/v1/reply/target-tvm_libinfo_objs-Release-41b633979ad96c614875.json`:
- Around line 632-679: The checked-in CMake reply JSON (e.g.,
target-tvm_libinfo_objs-Release-41b633979ad96c614875.json) contains absolute
local paths and must be removed from the PR; delete the file(s) from the branch
(git rm --cached if you need to keep them locally), add a rule to .gitignore to
exclude the generated .cmake/api/v1/reply/* artifacts, and commit the removal
and ignore change so these build-generated files are no longer tracked.

In `@build_sm120/.ninja_log`:
- Around line 1-719: Delete the generated file build_sm120/.ninja_log from the
PR and stop tracking it; remove it from the commit (e.g., revert/delete the file
in this PR) and update the repo ignore rules to prevent future commits by adding
ignore patterns such as build*/ and **/.ninja_log (or the specific
build_sm120/.ninja_log) to .gitignore; ensure the file is also removed from the
index if already tracked (git rm --cached) so it won't be reintroduced.

In `@build_sm120/CMakeCache.txt`:
- Around line 1-1157: The PR includes generated build artifacts under
build_sm120/ (e.g., CMakeCache.txt) that must not be committed; update
.gitignore to include patterns like build*/ and build_*/ (and optionally
build_sm120/) and remove the committed build artifacts from version control with
a cached removal (e.g., run git rm --cached -r build_sm120/), then commit the
.gitignore change so future CMake-generated files (like CMakeCache.txt) are
ignored.

In `@build_sm120/CMakeInit.txt`:
- Around line 6-27: This file contains machine-local absolute paths (e.g.,
set(PYTHON_EXECUTABLE ...), set(Python3_EXECUTABLE ...), set(SKBUILD_PLATLIB_DIR
...), set(SKBUILD_DATA_DIR ...), etc.) that must not be committed; remove or
neutralize these forced CACHE assignments and replace them with generic,
non-user-specific defaults or let CMake compute them at configure time (e.g.,
unset/remove the explicit set(... CACHE ... FORCE) lines for PYTHON_EXECUTABLE,
Python3_EXECUTABLE, Python_INCLUDE_DIR, SKBUILD_*_DIR, CMAKE_PREFIX_PATH, etc.),
regenerate the CMake init from a clean environment, or convert them to
configurable placeholders (no absolute /home or /tmp paths) so the tree contains
no host-specific paths.

In `@build_sm120/install_manifest.txt`:
- Around line 1-7: The install_manifest.txt currently contains host-specific
absolute paths (e.g., entries referencing tilelang_cython_wrapper.abi3.so,
libz3.so, libtilelang.so and other /tmp or /home paths) and must be removed from
version control; delete build_sm120/install_manifest.txt from the commit, add an
appropriate ignore rule (e.g., ignore generated install manifests and build
artifacts) to your repo's ignore file, and ensure the build or packaging process
emits/install a manifest only in CI or build artifacts (not committed) so paths
are not leaked; verify any packaging step that writes install_manifest.txt is
updated to write to a temp/build artifact location rather than the source tree.

In `@build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/.libs/libbacktrace.la`:
- Line 1: Remove the generated libtool stub libbacktrace.la from the commit
(it's a build artifact) and restore the repository to exclude build outputs;
delete libbacktrace.la from the index/commit and ensure the build directory
artifacts (e.g., the libbacktrace/.libs output) are added to .gitignore or the
appropriate global ignore so future generated files are not committed.

In
`@build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/install-debuginfo-for-buildid.sh`:
- Around line 50-65: The script is vulnerable to word-splitting/globbing because
variable expansions (buildid, prefix, remainder, dir, dst and src) are unquoted
in command invocations and assignments; update all uses to quote expansions so
commands like $readelf -n "$src", $grep "Build ID" (input via pipe is fine) and
$awk '{print $3}' receive quoted values, and ensure assignments use quoted
substitutions (e.g., prefix="$(echo "$buildid" | $sed ...)" and
remainder="$(echo "$buildid" | $sed ...)" ) and that mkdir_p and objcopy are
called with quoted paths ($mkdir_p "$dir" and $objcopy --only-keep-debug "$src"
"$dst"); apply this pattern consistently for build_id_dir, prefix, remainder,
dir, dst and src to prevent word-splitting/globbing.

In `@examples/amd_mxfp4_mm/asm-dump/src.s`:
- Around line 1-1283: Remove the compiler-generated assembly dump file
examples/amd_mxfp4_mm/asm-dump/src.s (and the whole
examples/amd_mxfp4_mm/asm-dump/ directory) from the PR; this file is a build
artifact (it contains markers like __CLANG_OFFLOAD_BUNDLE____START__, .ident
"AMD clang...", and symbols/kernels such as _Z9gemm_corePKhS0_PvS0_S0_iiiiiii,
_Z9reduce_skPKfP12hip_bfloat16ii and __hip_cuid_8b457700d3d88645) and should not
be committed. Add a .gitignore entry to exclude these dumps (e.g.,
examples/amd_mxfp4_mm/asm-dump*/ *.s or a pattern like *.s under that folder) so
future clang/HIP offload-bundle assembly outputs are ignored; if an assembly
excerpt is required keep a short annotated snippet in a docs/markdown file
instead of the full dump.

In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py`:
- Around line 60-96: The kernel only computes gate+up for expert dimension but
declares output as (num_tokens, d_hidden) and also declares an unused W_down;
fix main by (1) renaming the input parameter (avoid shadowing builtin) e.g.,
input -> x_input, (2) remove the unused W_down parameter from the prim_func
signature, (3) change output's shape to T.Tensor((num_tokens, d_expert),
"float32") so the T.copy(up_local, output[bx * block_token, by * block_expert])
writes into the correct expert-width slice, and (4) update the host code to stop
allocating/passing/zeroing W_down and to pass the corrected output shape; keep
references to symbols main, input/input rename, W_down, output, up_local, and
gate_local to locate edits.

---

Outside diff comments:
In
`@build_sm120/.cmake/api/v1/reply/target-tvm_ffi_testing-Release-f49f6812d8501c253709.json`:
- Around line 1-354: Remove the committed machine-generated CMake outputs under
build_sm120 and prevent them from being re-added: add build_sm120/ to
.gitignore, then run git rm -r --cached build_sm120/ (or git rm -r build_sm120/)
and commit the removal; verify no build artifacts such as the JSON reply file
(target-tvm_ffi_testing-Release-*.json) or lib/libtvm_ffi.so remain tracked
before pushing.

In
`@build_sm120/.cmake/api/v1/reply/target-tvm_runtime-Release-69a4c57b2cf0c7008847.json`:
- Around line 1-688: Build artifacts under build_sm120/ (e.g.,
.cmake/api/v1/reply/*.json) were committed; remove them and ignore future
commits by deleting the directory from the PR (use git rm -r --cached
build_sm120/ and commit the removal), add ignore patterns like build*/ and
build_*/ (or explicitly build_sm120/) to .gitignore, and commit the updated
.gitignore; optionally add a short note in README about the project build/output
location.

---

Nitpick comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 125-126: The example currently always writes the generated CUDA
source via jit_kernel.get_kernel_source() to gemm_fp4_sm120.cu, causing an
unexpected side-effect; change the block that opens and writes the file to run
only when an environment flag is enabled (e.g., TL_FP4_DUMP_CUDA). Concretely,
guard the write with a check like os.getenv("TL_FP4_DUMP_CUDA") (or similar)
surrounding the existing with open(...) f.write(jit_kernel.get_kernel_source())
call so the file is only created when the env var is truthy; keep the same
filename gemm_fp4_sm120.cu and the same use of jit_kernel.get_kernel_source() so
behavior is identical when the flag is set.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6dd3e61d-e76c-438e-9084-32c778049395

📥 Commits

Reviewing files that changed from the base of the PR and between beef5cf and c628d6a.

📒 Files selected for processing (105)
  • build_sm120/.cmake/api/v1/query/cache-v2
  • build_sm120/.cmake/api/v1/query/cmakeFiles-v1
  • build_sm120/.cmake/api/v1/query/codemodel-v2
  • build_sm120/.cmake/api/v1/query/toolchains-v1
  • build_sm120/.cmake/api/v1/reply/cache-v2-52c39eaf42c705276742.json
  • build_sm120/.cmake/api/v1/reply/cmakeFiles-v1-232f7282c5ca4bddbcce.json
  • build_sm120/.cmake/api/v1/reply/codemodel-v2-36a860bc847b2d542a0a.json
  • build_sm120/.cmake/api/v1/reply/directory-.-Release-a36f43283df684ea4229.json
  • build_sm120/.cmake/api/v1/reply/directory-tvm-Release-2640e16d11b10fc202fd.json
  • build_sm120/.cmake/api/v1/reply/directory-tvm.3rdparty.tvm-ffi-Release-fc80e485107cfd9f3711.json
  • build_sm120/.cmake/api/v1/reply/index-2026-05-11T07-09-30-0790.json
  • build_sm120/.cmake/api/v1/reply/target-cuda_stub-Release-8f1e2d53694614942136.json
  • build_sm120/.cmake/api/v1/reply/target-cudart_stub-Release-83108a69725c59f02368.json
  • build_sm120/.cmake/api/v1/reply/target-nvrtc_stub-Release-1f28ef3dff552eb390d9.json
  • build_sm120/.cmake/api/v1/reply/target-project_libbacktrace-Release-c2cdedfffcf35112576c.json
  • build_sm120/.cmake/api/v1/reply/target-runtime-Release-62afd9a385c447b59f4e.json
  • build_sm120/.cmake/api/v1/reply/target-tilelang-Release-1ff4e8816791a3d3fb50.json
  • build_sm120/.cmake/api/v1/reply/target-tilelang_cython_wrapper-Release-71bae53629834f4f17ac.json
  • build_sm120/.cmake/api/v1/reply/target-tilelang_objs-Release-45ef627ca6e2f9aad6ad.json
  • build_sm120/.cmake/api/v1/reply/target-tvm-Release-5a611c7d6f0430fb0290.json
  • build_sm120/.cmake/api/v1/reply/target-tvm_ffi_objs-Release-11a9b20bf3cc02e441a3.json
  • build_sm120/.cmake/api/v1/reply/target-tvm_ffi_shared-Release-1bc48936f9ca4def6d44.json
  • build_sm120/.cmake/api/v1/reply/target-tvm_ffi_static-Release-e8cf9af34d5cfa948ca4.json
  • build_sm120/.cmake/api/v1/reply/target-tvm_ffi_testing-Release-f49f6812d8501c253709.json
  • build_sm120/.cmake/api/v1/reply/target-tvm_libinfo_objs-Release-41b633979ad96c614875.json
  • build_sm120/.cmake/api/v1/reply/target-tvm_objs-Release-357520ed3665fa628873.json
  • build_sm120/.cmake/api/v1/reply/target-tvm_runtime-Release-69a4c57b2cf0c7008847.json
  • build_sm120/.cmake/api/v1/reply/target-tvm_runtime_objs-Release-de5cd576b5a2cad88217.json
  • build_sm120/.cmake/api/v1/reply/toolchains-v1-36d721b4a463902ae11c.json
  • build_sm120/.ninja_deps
  • build_sm120/.ninja_log
  • build_sm120/.skbuild-info.json
  • build_sm120/CMakeCache.txt
  • build_sm120/CMakeInit.txt
  • build_sm120/TVMBuildOptions.txt
  • build_sm120/build.ninja
  • build_sm120/cmake_install.cmake
  • build_sm120/compile_commands.json
  • build_sm120/install_manifest.txt
  • build_sm120/tilelang_cython_wrapper.cpp
  • build_sm120/tvm/3rdparty/tvm-ffi/cmake_install.cmake
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/.libs/libbacktrace.a
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/.libs/libbacktrace.la
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/.libs/libbacktrace.lai
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/Makefile
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/backtrace-supported.h
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/config.h
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/config.status
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/include/backtrace-supported.h
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/include/backtrace.h
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/install-debuginfo-for-buildid.sh
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/lib/libbacktrace.a
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/lib/libbacktrace.la
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/libbacktrace.la
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/libtool
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-build
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-build-Release.cmake
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-checkout
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-configure
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-configure-Release.cmake
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-done
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-download
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-install
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-install-Release.cmake
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-mkdir
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-patch
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-patch-info.txt
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-source_dirinfo.txt
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-update
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/src/project_libbacktrace-stamp/project_libbacktrace-update-info.txt
  • build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/stamp-h1
  • build_sm120/tvm/cmake_install.cmake
  • build_sm120/tvm/temp_config_file.cmake
  • build_sm120/tvm/tvmConfig.cmake
  • examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s
  • examples/amd_mxfp4_mm/asm-dump/src.s
  • examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py
  • examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_a8w4_sm100.py
  • examples/gemm_fp4/example_gemm_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_fp4_sm100.py
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/backend/cuda/codegen/codegen_cuda.cc
  • src/backend/cuda/op/copy.cc
  • src/backend/cuda/op/gemm.cc
  • src/layout/gemm_layouts.cc
  • src/layout/layout.cc
  • src/layout/layout.h
  • src/op/tcgen5_meta.h
  • src/op/utils.cc
  • src/tl_templates/cuda/common.h
  • src/tl_templates/cuda/cuda_fp4.h
  • src/tl_templates/cuda/gemm_mma.h
  • src/tl_templates/cuda/gemm_sm100.h
  • src/tl_templates/cuda/instruction/mma.h
  • src/tl_templates/cuda/ldsm.h
  • tilelang/intrinsics/mma_layout.py
  • tilelang/intrinsics/mma_macro_generator.py
  • tilelang/intrinsics/tcgen05_macro_generator.py
  • tilelang/intrinsics/utils.py
  • tilelang/layout/__init__.py
  • tilelang/layout/swizzle.py
  • tilelang/tileop/gemm/gemm_base.py
  • tilelang/tileop/gemm/gemm_mma.py
  • tilelang/tileop/gemm/gemm_tcgen05.py

Comment thread build_sm120/.ninja_log Outdated
Comment thread build_sm120/CMakeCache.txt Outdated
Comment thread build_sm120/CMakeInit.txt Outdated
Comment thread build_sm120/install_manifest.txt Outdated
Comment thread build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/.libs/libbacktrace.la Outdated
Comment thread build_sm120/tvm/3rdparty/tvm-ffi/libbacktrace/install-debuginfo-for-buildid.sh Outdated
Comment thread examples/amd_mxfp4_mm/asm-dump/src.s Outdated
Comment thread examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py
Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/intrinsics/tcgen05_macro_generator.py (1)

918-973: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Refactor to reuse FFI entry instead of duplicating C++ descriptor encoding logic.

The FFI entry tl.get_tcgen5_instr_desc (in src/backend/cuda/op/gemm.cc) now accepts separate a_dtype and b_dtype parameters, yet this Python method reimplements the entire encoding (encode_dtype and bit-field assembly) locally. This duplicates logic that should be sourced from a single implementation to prevent silent drift if C++ adds new dtype encodings or changes bit-layout. Refactor to route through _ffi_api.get_tcgen5_instr_desc(atom_m, atom_n, atom_k, DataType(self.a_dtype), DataType(self.b_dtype), DataType(self.accum_dtype), ...) and lift the result, mirroring the pattern used by get_tcgen5_blockscaled_instr_desc (line 811).

Also, the docstring claims "64-bit instruction descriptor" but the C++ implementation returns uint32_t (32-bit; max bit position is 28, well within 32 bits).

♻️ Duplicate comments (1)
examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py (1)

60-96: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Output shape and unused W_down still incorrect; docstring also overstates kernel scope.

The kernel still declares output: T.Tensor((num_tokens, d_hidden), "float32") (line 66) yet writes a (block_token, block_expert)-shaped tile at offset (bx*block_token, by*block_expert) (line 96). For d_expert != d_hidden this silently misplaces results or overruns. W_down (line 65) remains unreferenced in the kernel body, and the host code at lines 137/143/144/148/173 still allocates and forwards it. Separately, the file-level docstring (lines 3-7) advertises a 4-step pipeline including a Down GEMM that the kernel never performs — please either implement the down projection or trim the docstring (and the W_down parameter / output shape) to reflect the actual gate+up fusion. Renaming input → e.g. Input also resolves Ruff A002 (line 62).

🔧 Proposed fix — drop unused `W_down`, correct output shape, update host code, rename `input`
     `@T.prim_func`
     def main(
-        input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"),
+        Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"),
         W_gate: T.Tensor((d_expert, d_hidden), "uint8"),
         W_up: T.Tensor((d_expert, d_hidden), "uint8"),
-        W_down: T.Tensor((d_hidden, d_expert), "uint8"),
-        output: T.Tensor((num_tokens, d_hidden), "float32"),
+        Output: T.Tensor((num_tokens, d_expert), "float32"),
     ):
@@
-                T.copy(input[bx * block_token, k * block_hidden], input_shared)
+                T.copy(Input[bx * block_token, k * block_hidden], input_shared)
@@
-            T.copy(up_local, output[bx * block_token, by * block_expert])
+            T.copy(up_local, Output[bx * block_token, by * block_expert])

Host-side:

 jit_kernel = tilelang.compile(
     func,
-    out_idx=[4],
+    out_idx=[3],
     target="cuda",
     ...
 )
@@
-W_down_uint8 = torch.randint(0, 16, (d_hidden, d_expert), device="cuda", dtype=torch.uint8)
@@
-z_down = torch.zeros(d_hidden, d_expert, device="cuda", dtype=torch.uint8)
-c_zero = jit_kernel(z_input, z_gate, z_up, z_down)
+c_zero = jit_kernel(z_input, z_gate, z_up)
@@
-out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8, W_down_uint8)
+out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8)
@@
-    jit_kernel(input_fp8, W_gate_uint8, W_up_uint8, W_down_uint8)
+    jit_kernel(input_fp8, W_gate_uint8, W_up_uint8)

And tighten the docstring so step 4 is not advertised, or document the kernel as gate+up fusion only (mirroring example_fusedmoe_a8w4_sm100.py).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py` around lines 60 - 96, The
kernel declares wrong outputs and an unused parameter: in prim_func main the
parameter names input (rename to Input to satisfy Ruff A002), W_down is never
used, and output is typed as T.Tensor((num_tokens, d_hidden), "float32") while
the kernel writes a (block_token, block_expert) tile at offset (bx*block_token,
by*block_expert) which mismatches when d_expert != d_hidden; fix by removing the
unused W_down parameter and all host-side allocations/forwards of W_down, change
the output tensor shape to T.Tensor((num_tokens, d_expert), "float32") (or
implement the missing down projection using W_down and emit tiles into d_hidden
if you prefer the 4-step pipeline), and rename input → Input in the main
signature and all uses; also update the file docstring to reflect the actual
fused gate+up kernel if you choose to drop the down projection so the
documentation no longer advertises a Down GEMM.
🧹 Nitpick comments (4)
tilelang/tileop/gemm/gemm_base.py (1)

80-108: ⚖️ Poor tradeoff

Implicit uint8 → float4_e2m1fn mapping is fragile; consider an explicit marker.

Using (A.dtype == "uint8" and C.dtype in ("float32","float16")) as the sole signal for FP4 means any future legitimate use of uint8 operands with FP/FP16 accumulators (e.g., uint8 activations + dequant in float, or mixed-precision quantization paths) will be silently routed through the FP4 MMA traits. An explicit attribute on gemm_node (e.g. is_fp4, compute_dtype_a/b) — set by the lowering pass that converts FP4 → uint8 storage — would be more robust and self-documenting. Not blocking for this PR if no other uint8 path exists today, but worth tracking.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/tileop/gemm/gemm_base.py` around lines 80 - 108, The current
implicit mapping from uint8 to "float4_e2m1fn" in in_dtype and in_dtype_b should
be replaced with an explicit flag on the gemm node (e.g., is_fp4 or
compute_dtype_a/compute_dtype_b) so only nodes that were intentionally lowered
from FP4 to uint8 trigger the MMA FP4 trait; update in_dtype and in_dtype_b to
check that explicit attribute (e.g., self.is_fp4 or self.compute_dtype_a/ b)
instead of relying on (dtype == "uint8" and self.C.dtype in (...)), and ensure
the lowering pass that converts FP4→uint8 sets that attribute so mixed uint8
usages aren’t misclassified.
src/layout/gemm_layouts.cc (1)

914-947: 💤 Low value

No fallback for sub-byte layouts that don't satisfy ALIGN16B alignment.

makeGemmABLayoutSm100 unconditionally routes element_size < 8 to MakeAlign16BSwizzleLayout2D, which ICHECK-aborts when continuous % 128 != 0 or stride % 8 != 0. Non-128-aligned FP4 tile shapes (e.g., small-K kernels, exotic block sizes a user may attempt) will trigger a hard compiler abort rather than gracefully falling through to makeLinearLayout. Consider adding the same if (mat_continuous % vector_size == 0) return makeLinearLayout(...) fallback used by the ≥8-bit branch below.

Sketch
   if (element_size < 8) {
-    return MakeAlign16BSwizzleLayout2D(mat_stride, mat_continuous,
-                                       element_size);
+    if (mat_stride % 8 == 0 && mat_continuous % 128 == 0) {
+      return MakeAlign16BSwizzleLayout2D(mat_stride, mat_continuous,
+                                         element_size);
+    }
+    return makeLinearLayout(
+        Array<PrimExpr>{Integer(mat_stride), Integer(mat_continuous)});
   }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/layout/gemm_layouts.cc` around lines 914 - 947, In makeGemmABLayoutSm100,
handle sub-byte element_size (<8) cases that currently unconditionally return
MakeAlign16BSwizzleLayout2D (which ICHECKs on mat_continuous % 128 or mat_stride
% 8); compute the same vector_size = 128 / element_size and add a fallback: if
mat_continuous % vector_size == 0 return
makeLinearLayout(Array<PrimExpr>{Integer(mat_stride), Integer(mat_continuous)}),
otherwise keep returning MakeAlign16BSwizzleLayout2D (or ICHECK with a clear
message). This mirrors the ≥8-bit branch logic and prevents hard aborts for
non-ALIGN16B-aligned FP4 tiles while preserving existing optimized path (refer
to MakeAlign16BSwizzleLayout2D, vector_size, makeLinearLayout, mat_continuous,
mat_stride).
tilelang/tileop/gemm/gemm_tcgen05.py (1)

60-64: 💤 Low value

Narrow the swallowed exception around PassContext lookup.

tvm.transform.PassContext.current() and _pass_ctx.config.get(...) should not raise in normal flows; catching the base Exception will also hide real bugs (e.g., a future config schema change where config is no longer dict-like). Narrowing to the expected exception types makes the intent — "fall back to TMA-enabled layout when the config key isn't plumbed" — explicit, and resolves Ruff BLE001 (line 63).

♻️ Proposed fix
-            try:
-                _pass_ctx = tvm.transform.PassContext.current()
-                disable_tma = _pass_ctx.config.get("tl.disable_tma_lower", False)
-            except Exception:
-                disable_tma = False
+            try:
+                _pass_ctx = tvm.transform.PassContext.current()
+                disable_tma = _pass_ctx.config.get("tl.disable_tma_lower", False)
+            except (AttributeError, KeyError):
+                disable_tma = False
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/tileop/gemm/gemm_tcgen05.py` around lines 60 - 64, The current broad
except around tvm.transform.PassContext.current() and _pass_ctx.config.get(...)
swallows unexpected errors; narrow it to the expected failure modes by catching
only the likely exceptions (e.g., AttributeError, KeyError, TypeError) and then
set disable_tma = False in that handler. Update the block that references
tvm.transform.PassContext.current(), _pass_ctx.config.get, and the disable_tma
assignment so only those specific exceptions are caught instead of Exception.
src/tl_templates/cuda/gemm_sm100.h (1)

344-357: ⚡ Quick win

Remove or translate the Chinese-language commentary block before merging.

Lines 344-357 contain a multi-line Chinese explanation of how the MMA_Traits<SM100_MMA_F8F6F4_TS<...>> partial specialization is selected. It reads like a chat transcript rather than design documentation, and a production header in an OSS project shouldn't carry maintainer-only notes in a non-English language — future reviewers and downstream contributors who don't read Chinese can't verify, update, or even reformat it. Please either delete it (the rest of the file's English doc comments already explain the SS/WS/TS variants) or replace it with a short English-language note focused on why this _TS partial specialization is needed (CUTLASS's to_UMMAFormat<float_e2m1_t> ↔ MXF8F6F4Format::E2M1 mismatch), consistent with lines 207-210 and 342-343.

♻️ Suggested replacement
-// 你的理解接近正确,但还需要澄清一点:
-// 使用时其实不用再对 template 里面的参数“再次特化”,而是——
-// 1. 外层 template <...> 这部分负责声明参数列表,即让这个偏特化对所有可能的
-// <c_type, M, N, ...> 组合都有效。
-// 2. 内层 struct MMA_Traits<SM100_MMA_F8F6F4_TS<...>> 说明:只要模板参数匹配到
-// SM100_MMA_F8F6F4_TS 那一整组,
-//    就会自动选择这个偏特化的实现,不需要“再特化”。
-// 换句话说,你写一个 MMA_Traits<X>,如果 X 恰好能匹配 SM100_MMA_F8F6F4_TS<...>
-// 那一组参数, 这个特化版本就会被用到。比如:
-//     using Traits = MMA_Traits<SM100_MMA_F8F6F4_TS<float_e2m1_t, float_e2m1_t,
-//     float, 64, 32, ...>>;
-// Traits的内容就是匹配到此特化,无需再继续特化template参数。
-// 总结:外层 template 支持泛型匹配,struct MMA_Traits<...>
-// 实现具体特化,实际用时只需传入对应参数即可自动选中。
+// Like the SS / WS_SS specializations above, this overrides MMA_Traits for
+// SM100_MMA_F8F6F4_TS<float_e2m1_t, float_e2m1_t, ...> so that the FP4
+// descriptor is encoded as MXF8F6F4Format::E2M1 (=5) rather than CUTLASS's
+// default MXF4Format::E2M1 (=1) returned by to_UMMAFormat<float_e2m1_t>.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/cuda/gemm_sm100.h` around lines 344 - 357, Remove the
Chinese-language commentary block (lines around the MMA_Traits partial
specialization) and replace it with a brief English note explaining why the
SM100_MMA_F8F6F4_TS partial specialization exists: state that this
specialization of struct MMA_Traits<SM100_MMA_F8F6F4_TS<...>> is required to
resolve CUTLASS format mismatches (e.g., to_UMMAFormat<float_e2m1_t> vs
MXF8F6F4Format::E2M1) and to select the TS implementation for those template
parameter combinations; keep the rest of the file's English documentation intact
and match the style of nearby comments (see existing notes around the SS/WS/TS
variants at lines ~207-210 and ~342-343).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/backend/cuda/op/copy.cc`:
- Around line 1217-1224: The mbarrier transaction byte count undercounts when
16U4_ALIGN16B tiles are split across multiple TMA instructions because the
is_align16b_subbyte_layout branch uses total_elements directly; compute the
number of slices (loop_extent = (*inner_box_dim) / instruction_dim) and multiply
total_elements by loop_extent when calling TMABytesFromElements for the
is_align16b_subbyte_layout path so mbarrier_expect_tx (and mbarrier_expect_tx
caller) accounts for every sliced TMA instead of a single slice.

In `@tilelang/intrinsics/mma_macro_generator.py`:
- Around line 118-126: The FP4 path in _initialize_k_dim can produce an
unsupported m16n8k16 dispatcher when self.chunk < 32; add a validation in
_initialize_k_dim (the method) to detect when a_dtype is "float4_e2m1fn" and
self.chunk < 32 and fail early (raise an exception or log and exit) instead of
silently setting k_dim to 16; keep the existing behavior for valid chunks (use
min(32, self.chunk) to set self.k_dim) so mma_prefix generation (which produces
m16n8k16) cannot be emitted for an unsupported configuration.

---

Duplicate comments:
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py`:
- Around line 60-96: The kernel declares wrong outputs and an unused parameter:
in prim_func main the parameter names input (rename to Input to satisfy Ruff
A002), W_down is never used, and output is typed as T.Tensor((num_tokens,
d_hidden), "float32") while the kernel writes a (block_token, block_expert) tile
at offset (bx*block_token, by*block_expert) which mismatches when d_expert !=
d_hidden; fix by removing the unused W_down parameter and all host-side
allocations/forwards of W_down, change the output tensor shape to
T.Tensor((num_tokens, d_expert), "float32") (or implement the missing down
projection using W_down and emit tiles into d_hidden if you prefer the 4-step
pipeline), and rename input → Input in the main signature and all uses; also
update the file docstring to reflect the actual fused gate+up kernel if you
choose to drop the down projection so the documentation no longer advertises a
Down GEMM.

---

Nitpick comments:
In `@src/layout/gemm_layouts.cc`:
- Around line 914-947: In makeGemmABLayoutSm100, handle sub-byte element_size
(<8) cases that currently unconditionally return MakeAlign16BSwizzleLayout2D
(which ICHECKs on mat_continuous % 128 or mat_stride % 8); compute the same
vector_size = 128 / element_size and add a fallback: if mat_continuous %
vector_size == 0 return makeLinearLayout(Array<PrimExpr>{Integer(mat_stride),
Integer(mat_continuous)}), otherwise keep returning MakeAlign16BSwizzleLayout2D
(or ICHECK with a clear message). This mirrors the ≥8-bit branch logic and
prevents hard aborts for non-ALIGN16B-aligned FP4 tiles while preserving
existing optimized path (refer to MakeAlign16BSwizzleLayout2D, vector_size,
makeLinearLayout, mat_continuous, mat_stride).

In `@src/tl_templates/cuda/gemm_sm100.h`:
- Around line 344-357: Remove the Chinese-language commentary block (lines
around the MMA_Traits partial specialization) and replace it with a brief
English note explaining why the SM100_MMA_F8F6F4_TS partial specialization
exists: state that this specialization of struct
MMA_Traits<SM100_MMA_F8F6F4_TS<...>> is required to resolve CUTLASS format
mismatches (e.g., to_UMMAFormat<float_e2m1_t> vs MXF8F6F4Format::E2M1) and to
select the TS implementation for those template parameter combinations; keep the
rest of the file's English documentation intact and match the style of nearby
comments (see existing notes around the SS/WS/TS variants at lines ~207-210 and
~342-343).

In `@tilelang/tileop/gemm/gemm_base.py`:
- Around line 80-108: The current implicit mapping from uint8 to "float4_e2m1fn"
in in_dtype and in_dtype_b should be replaced with an explicit flag on the gemm
node (e.g., is_fp4 or compute_dtype_a/compute_dtype_b) so only nodes that were
intentionally lowered from FP4 to uint8 trigger the MMA FP4 trait; update
in_dtype and in_dtype_b to check that explicit attribute (e.g., self.is_fp4 or
self.compute_dtype_a/ b) instead of relying on (dtype == "uint8" and
self.C.dtype in (...)), and ensure the lowering pass that converts FP4→uint8
sets that attribute so mixed uint8 usages aren’t misclassified.

In `@tilelang/tileop/gemm/gemm_tcgen05.py`:
- Around line 60-64: The current broad except around
tvm.transform.PassContext.current() and _pass_ctx.config.get(...) swallows
unexpected errors; narrow it to the expected failure modes by catching only the
likely exceptions (e.g., AttributeError, KeyError, TypeError) and then set
disable_tma = False in that handler. Update the block that references
tvm.transform.PassContext.current(), _pass_ctx.config.get, and the disable_tma
assignment so only those specific exceptions are caught instead of Exception.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b61a79c9-1494-4134-be89-1b3631aa471b

📥 Commits

Reviewing files that changed from the base of the PR and between c628d6a and 63c9451.

📒 Files selected for processing (21)
  • examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py
  • examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_a8w4_sm100.py
  • examples/gemm_fp4/example_gemm_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_fp4_sm100.py
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/backend/cuda/codegen/codegen_cuda.cc
  • src/backend/cuda/op/copy.cc
  • src/backend/cuda/op/gemm.cc
  • src/layout/gemm_layouts.cc
  • src/op/tcgen5_meta.h
  • src/tl_templates/cuda/cuda_fp4.h
  • src/tl_templates/cuda/gemm_mma.h
  • src/tl_templates/cuda/gemm_sm100.h
  • src/tl_templates/cuda/instruction/mma.h
  • src/tl_templates/cuda/ldsm.h
  • tilelang/intrinsics/mma_macro_generator.py
  • tilelang/intrinsics/tcgen05_macro_generator.py
  • tilelang/intrinsics/utils.py
  • tilelang/tileop/gemm/gemm_base.py
  • tilelang/tileop/gemm/gemm_tcgen05.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • examples/gemm_fp4/example_gemm_a8w4_sm120.py
  • examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py
  • examples/gemm_fp4/example_gemm_a8w4_sm100.py
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • examples/gemm_fp4/example_gemm_fp4_sm100.py

Comment thread src/backend/cuda/op/copy.cc
Comment thread tilelang/cuda/intrinsics/macro/mma_macro_generator.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 546-596: The TCGEN5 descriptor handling mixes A/B element widths:
update compute_tcgen05_b_desc_params to derive elems_in_bytes and any
byte-offset/swizzle calculations from self.b_dtype (not self.a_dtype), mirroring
how compute_tcgen05_a_desc_params uses self.a_dtype; locate
compute_tcgen05_b_desc_params and replace occurrences where elems_in_bytes (or
equivalent) are initialized from self.a_dtype with logic that calls
DataType(self.b_dtype) (or reuse encode_dtype(self.b_dtype) to determine byte
width), and ensure any downstream shifts/atom byte-offset math uses that
B-derived value so B swizzle atoms and offsets match the descriptor bits set by
encode_dtype(self.b_dtype).
- Around line 162-168: The layout equality checks incorrectly use the original
"buffer" (which may be a BufferRegion) instead of the unwrapped "tir_buffer",
causing valid region-backed layouts to be misclassified; update the comparisons
in the function (where tir_buffer is set) to call make_linear_layout(tir_buffer)
and make_align16b_swizzled_layout(tir_buffer) (i.e., replace uses of
make_*_layout(buffer) with make_*_layout(tir_buffer)) so both dtype and layout
operations consistently operate on the unwrapped buffer; keep the rest of the
SwizzleMode logic and the ValueError unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4b80d8c8-b362-4962-8193-48fe8c96b660

📥 Commits

Reviewing files that changed from the base of the PR and between 63c9451 and 8732c4c.

📒 Files selected for processing (9)
  • src/backend/cuda/op/gemm.cc
  • src/tl_templates/cuda/common.h
  • src/tl_templates/cuda/gemm_mma.h
  • tilelang/cuda/intrinsics/layout/mma_layout.py
  • tilelang/cuda/intrinsics/layout/utils.py
  • tilelang/cuda/intrinsics/macro/mma_macro_generator.py
  • tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
  • tilelang/cuda/op/gemm/gemm_mma.py
  • tilelang/cuda/op/gemm/gemm_tcgen05.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/backend/cuda/op/gemm.cc
  • src/tl_templates/cuda/gemm_mma.h

Comment thread tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
Comment thread tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
@Hale423
Copy link
Copy Markdown
Contributor Author

Hale423 commented May 11, 2026

Hi @LeiWang1999, this is the blackwell fp4(float4_e2m1fn) feature implementation, separated from the old one.
The functionalities were verified valid on both SM100 & SM120, please take a look at your convenience

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant