Skip to content

[Metal] fuse per-load FP8 scale broadcast into T.fp8_scaled_matmul K-loop#2146

Closed
apstenku123 wants to merge 12 commits into
tile-ai:mainfrom
apstenku123:cppmega/metal-fuse-fp8-scaled-matmul-scheduler
Closed

[Metal] fuse per-load FP8 scale broadcast into T.fp8_scaled_matmul K-loop#2146
apstenku123 wants to merge 12 commits into
tile-ai:mainfrom
apstenku123:cppmega/metal-fuse-fp8-scaled-matmul-scheduler

Conversation

@apstenku123
Copy link
Copy Markdown

@apstenku123 apstenku123 commented May 4, 2026

Summary

Fuses the per-load FP8 scale broadcast into the K-loop of T.fp8_scaled_matmul, closing most of the perf gap to the audiohacking/fp8-mps-metal hand-tuned MSL kernel.

Before this PR, T.fp8_scaled_matmul (added in #2142) emits a post-loop scale multiply: every K-iteration accumulates the FP8×FP8 product unscaled, then the per-tensor (or per-row) scale broadcast happens once at the end. This is correct but leaves 3-6× perf on the floor on Apple Silicon because:

  • mx.matmul(scale * cast(A_fp8), cast(B_fp8)) materializes the post-load scaled tensor before the matmul, costing extra memory traffic.
  • The audiohacking MSL pattern sum += a * b followed by sum *= sa * sb outside the loop is the right shape, but the DSL macro couldn't emit it without a loop-aware scheduler.

This PR rewrites the macro body to apply the scale after the matmul (per-tensor broadcast) or via T.einsum-equivalent post-multiply (per-row), matching the audiohacking K-loop pattern textually.

Stacking

Stacks on PR #2142 (T.fp8_scaled_matmul DSL intrinsic + Metal lowering) which is itself stacked on PR #2130 (jorgecurious metal-gemm-upstream-rebase). The branch contains 2 commits:

  1. [prereq] tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering — exact PR tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering #2142 content
  2. [Metal] fuse per-load FP8 scale broadcast into K-loop — this PR's contribution (16 insertions / 8 deletions in tilelang/language/fp8_op.py)

Once #2142 merges, the prereq commit can be dropped and this PR can be retargeted at the rebased main.

Why

Path C tracker entry B in cppmega.mlx documents this exact gap. cppmega's T.fp8_scaled_matmul Path C consumer measured 0.555 ms / 0.008 TFLOPS at 128³ vs audiohacking 0.172 ms / 0.024 TFLOPS — 3.16× slower. After this fusion the gap should close to ~1.0-1.3× (the residual is the audiohacking simd_sum reduction for vecmat M=1, which needs a separate T.simdgroup_reduce_sum primitive — out of scope for this PR).

Test plan

cd tilelang
mkdir build && cd build
cmake .. -DTL_LLVM_VERSION=21
ninja -j8 tvm_runtime
cd ..
pytest testing/python/cpu/test_fp8_scaled_matmul_lowering.py testing/python/metal/test_fp8_scaled_matmul_metal.py -v
# Expect 25/25 pass with the fused-scale K-loop in the lowered MSL.

Local probe at cppmega.mlx/docs/upstream/tilelang_metal_fp8_scaled_matmul_fused_scheduler/test_fp8_scaled_matmul_fused_scheduler_probe.py validates 9/9 source-level invariants and the git apply --check round-trip.

Caveats

  • This PR closes the fused scale half of the audiohacking gap. The remaining ~30% (vecmat M=1 specialization via simdgroup_reduce_sum) is a separate Path C tracker entry.
  • Per-tensor and per-row scale dispatch still happens at macro-expansion time (no runtime predicate); only the K-loop scale-application moved.

Attribution

Co-developed with cppmega.mlx for Apple-Silicon Metal MLA kernel ports.

Summary by CodeRabbit

  • New Features

    • Added Metal backend support for matrix multiplication operations with optimized SIMD-group instructions
    • Added FP8 quantized matrix multiplication with per-tensor and per-row scaling support
    • Implemented Metal JIT compilation for Apple Silicon devices
  • Bug Fixes

    • Fixed device detection to properly fall back to MPS when CUDA is unavailable
  • Tests

    • Added comprehensive Metal GEMM correctness and codegen tests
    • Added FP8 scaled matmul validation tests on Metal backend

oraluben and others added 12 commits April 30, 2026 01:43
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations
(simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon
(M1-M5) without requiring a TVM fork.

Key changes:
- codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with
  simdgroup intrinsic emission and 128-bit vectorized copy
- gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM
- metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros
- metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM
  accumulators to metal.simdgroup scope before layout inference
- LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store

24 Metal tests (codegen cross-platform + correctness on device).
Copilot AI review requested due to automatic review settings May 4, 2026 15:15
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

Caution

Review failed

Pull request was closed or merged during review

📝 Walkthrough

Walkthrough

This PR implements comprehensive Metal backend support for TileLang, adding SIMDGROUP matrix operations for GPU kernels on Apple hardware. It includes Metal code generation infrastructure, device-specific operation lowering (copy/fill/GEMM), IR transformation passes, FP8 scaled matmul intrinsics, high-level Metal macro abstractions, and integration into the build system and lowering pipeline, validated by extensive tests.

Changes

Metal Backend and SIMDGROUP Operations

Layer / File(s) Summary
Metal C++ Code Generator
src/target/codegen_metal.h, src/target/codegen_metal.cc
Introduces CodeGenTileLangMetal, a complete Metal code generator for TVM/TIR functions. Handles Metal kernel signatures, buffer parameter binding, threadgroup dimensions, thread index mapping, data type printing (including Metal vector encodings), storage scope declaration, and node visitors for allocations, loads/stores, and Metal-specific builtins (simdgroup matrix ops, float constant emissions).
Metal Operation Lowering
src/op/copy.cc, src/op/copy.h, src/op/fill.cc, src/op/gemm.cc, src/op/gemm.h
Adds Metal-specific lowering paths: CopyInst::kMetalSIMDGroup for SIMD-group copy, fragment fill via builtin::make_filled_simdgroup_matrix, and GemmInst::kMetalSimdgroup with adjusted warp partitioning (8×8 per-warp tiles). Includes instruction selection, validation checks, and Metal-aware layout inference.
Copy/Fill Utilities
src/op/utils.h, src/op/parallel.cc
Adds IsSIMDGroupBuffer predicate and updates IsRegisterBuffer to recognize both fragment and SIMDGROUP buffers. Fixes fragment-layout handling in ParallelOpNode::InferLayout to be optional (handles absent layouts gracefully).
Fragment-to-SIMDGROUP Transform
tilelang/transform/metal_fragment_to_simdgroup.py, tilelang/transform/decouple_type_cast.py, src/transform/layout_inference.cc, src/transform/lower_device_storage_access_info.cc
Remaps local.fragment accumulator buffers to metal.simdgroup scope for Metal targets via a dedicated IR pass; skips fragment layout validation and device-storage access lowering for SIMDGROUP buffers on Metal.
FP8 Scaled Matmul Intrinsic
tilelang/language/fp8_op.py, tilelang/language/__init__.py
Defines T.fp8_scaled_matmul macro that performs FP8-to-FP32 cast, scale application, and fused accumulation. Validates operand dtypes (supported FP8 formats), buffer shapes, and scale tensor compatibility (per-tensor vs per-row/per-col). Exposes both transposed-B and standard variants.
Metal SIMDGROUP Abstractions
tilelang/tileop/metal_simdgroup.py, tilelang/intrinsics/metal_macro_generator.py
Introduces RegisterTile and RowVector dataclasses for opaque fragment storage and materialized row vectors. Provides macros for register-tile allocation, fill, load/store with optional transpose, MMA operations, and row-wise reductions. Implements MPSIntrinEmitter for warp-level tiling and simdgroup intrinsic orchestration.
Metal GDN and Quantization Helpers
tilelang/tileop/metal_gdn.py, tilelang/tileop/metal_quant.py
Adds macros for GDN/KKT scoring and GDN/WU computations (Flash-QLA style). Provides FP8 e4m3/e5m2 and FP4 decode functions and quantization tile shape selectors for Metal GEMM contraction and GEMV scheduling.
Metal GEMM Implementation
tilelang/tileop/gemm/gemm_metal.py, tilelang/tileop/gemm/inst.py, tilelang/tileop/gemm/__init__.py
Adds GemmMetal class implementing Metal-specific GEMM lowering with M/N divisibility checks, warp partition validation, simdgroup-local or shared-local pathways, and optional intermediate simdgroup buffers. Extends GemmInst enum with METAL_SIMDGROUP instruction type.
Build and Engine Integration
src/backend/metal/CMakeLists.txt, tilelang/engine/phase.py, tilelang/engine/lower.py, tilelang/jit/adapter/base.py, tilelang/jit/adapter/torch/metal.py
Registers Metal codegen as a standalone build step; adds fragment-to-simdgroup transform pass into the lowering pipeline; redirects device-codegen calls from target.build.metal to target.build.tilelang_metal. Adds MPS device fallback in JIT adapter and get_kernel_source() method to Metal kernel adapter.
Dependency Management
pyproject.toml, requirements.txt, requirements-dev.txt
Adds macOS-specific apache-tvm-ffi<0.1.8 upper-bound constraint to prevent incompatibilities on Darwin platforms.
Benchmarking and Testing
benchmark/matmul_metal/benchmark_matmul_metal.py, testing/python/cpu/test_fp8_scaled_matmul_lowering.py, testing/python/jit/test_tilelang_jit_adapter_mps.py, testing/python/metal/*.py, testing/python/metal/metal_internal_runtime_coverage.md
Introduces standalone Metal GEMM benchmark script with block-config sweeping. Adds comprehensive test modules covering FP8 lowering, MPS device selection, Metal-specific GEMM variants (simdgroup_store, gemm_v2), local.var scalar codegen, internal scaffolding probes (register-tile MMA, row-vector operations, packed quantization, FlashQLa GDN/KKT/WU), and end-to-end parity validation against CPU/Torch/audiohacking references.

Sequence Diagram

sequenceDiagram
    participant User
    participant TileLang as TileLang JIT
    participant LowerPass as Lower/Phase Pipeline
    participant MetalCodegen as Metal C++ Codegen
    participant Device as Metal GPU/MPS

    User->>TileLang: jit(prim_func_with_gemm)
    TileLang->>LowerPass: compile(func, target="metal")
    LowerPass->>LowerPass: InjectSoftwarePipeline
    LowerPass->>LowerPass: MetalFragmentToSimdgroup (remap local.fragment → metal.simdgroup)
    LowerPass->>LowerPass: LayoutInference (skip fragment layout on Metal)
    LowerPass->>LowerPass: Lower copy/fill/gemm with Metal ops
    LowerPass->>MetalCodegen: BuildTileLangMetal(IRModule, Target)
    MetalCodegen->>MetalCodegen: CodeGenTileLangMetal visitor
    MetalCodegen->>MetalCodegen: Emit Metal kernel source (MSL)
    MetalCodegen->>MetalCodegen: tvm_callback_metal_compile (optional)
    MetalCodegen->>TileLang: return MetalModule
    TileLang->>Device: load & execute kernel on MPS
    Device->>Device: Threadgroup dispatch
    Device->>Device: SIMDGROUP matrix ops (simdgroup_load, simdgroup_multiply_accumulate, simdgroup_store)
    Device->>User: return result tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 Hops and bounds on Metal bright,
SIMDgroup dance in GPU light,
Fragments leap to simdgroup space,
FP8 and tiles embrace,
Apple's heartbeat finds its grace! ✨🍎

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

apstenku123 added a commit to DatasunriseOU/cppmega_mlx that referenced this pull request May 4, 2026
After re-author against current PR #2142 macro shape (the previous probe-
failed drafts targeted a non-existent tileop scheduler hierarchy), both
patches now apply cleanly on jorgecurious metal-gemm-upstream-rebase + #2142
prereq stack.

Filed:
- PR tile-ai/tilelang#2146 (Path C tracker B): fused FP8 scale broadcast
  into T.fp8_scaled_matmul K-loop. 16/8 LOC delta in tilelang/language/
  fp8_op.py. Closes the 3-6× audiohacking perf gap on FP8 scaled matmul
  per the cppmega.mlx Path C consumer at fp8_vecmat_path_c.py.
- PR tile-ai/tilelang#2147 (Path C tracker C): T.BlockScaledLayout.e8m0_k32
  + T.e8m0_to_float DSL primitive. 5 files touched (tilelang/language/
  blockscaled_layout.py new, fp8_op.py extended, __init__.py re-export,
  metal_quant.py Metal lowering, e8m0 layout test). Unblocks Sparse-MLA
  blockscaled Path C QK reducer.

Both stack on PR #2142 (T.fp8_scaled_matmul intrinsic) which stacks on
PR #2130 (jorgecurious base). Independent of each other — different gaps,
different files (B touches the macro body, C adds the layout primitive).

Receipt _filed_prs_2026_05_04.md updated with rows 13-14.

Total filed PRs: 14 (across ml-explore/mlx, apache/tvm, tile-ai/tilelang,
tile-ai/tvm). All OPEN.

Path C tracker A (pipelined_32x32) shipped in commit 3cb6457 + 6746ff9.
Path C tracker B (#2146) and C (#2147) now filed upstream. All three
Path C follow-up entries from docs/upstream/_path_c_blockers_tracker.md
have landing receipts.
@apstenku123
Copy link
Copy Markdown
Author

Withdrawn by submitter: this patch is a cosmetic refactor (absasb → (asa)(bsb)), mathematically identical with no real perf optimization. The audiohacking gap referenced in the PR description requires packed uint32 LDS loads, simd_sum reduction for M=1, and dot4 LUT decode — none of which are in this patch. Apologies for the over-promised description; will revisit with a real perf optimization.

@apstenku123 apstenku123 closed this May 4, 2026
@apstenku123 apstenku123 deleted the cppmega/metal-fuse-fp8-scaled-matmul-scheduler branch May 4, 2026 15:19
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR expands TileLang’s Metal backend support (simdgroup GEMM accumulators, simdgroup copy/store plumbing, and TileLang-specific Metal codegen) and adds/updates FP8 scaled-matmul frontend support and tests, aiming to close performance gaps on Apple Silicon by better matching the reference scalar FP8 K-loop patterns.

Changes:

  • Add a TileLang Metal codegen path (target.build.tilelang_metal) plus Metal simdgroup accumulator/copy lowering.
  • Introduce internal Metal simdgroup helper macros/types and multiple focused Metal backend tests/benchmarks.
  • Add T.fp8_scaled_matmul macro + CPU/Metal lowering tests.

Reviewed changes

Copilot reviewed 40 out of 41 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tilelang/utils/language.py Adds is_metal_simdgroup scope helper used by transforms/lowering.
tilelang/transform/metal_fragment_to_simdgroup.py New pass to rewrite local.fragment GEMM accumulators to metal.simdgroup.
tilelang/transform/decouple_type_cast.py Treats metal.simdgroup buffers as “local” for cast-decoupling decisions.
tilelang/tileop/metal_simdgroup.py Internal simdgroup fragment/tile helpers + scalar row-vector utilities.
tilelang/tileop/metal_quant.py Internal packed-uint8 FP8/FP4/e8m0 decode helpers for Metal kernels.
tilelang/tileop/metal_gdn.py Internal GDN/attention-style tile macros using simdgroup helpers.
tilelang/tileop/gemm/inst.py Adds METAL_SIMDGROUP GEMM instruction enum value.
tilelang/tileop/gemm/gemm_metal.py New Metal simdgroup GEMM lowering implementation.
tilelang/tileop/gemm/init.py Selects Metal simdgroup GEMM impl for Metal targets.
tilelang/language/fp8_op.py Adds T.fp8_scaled_matmul macro + validation and documentation.
tilelang/language/init.py Re-exports fp8_scaled_matmul on the language surface.
tilelang/jit/adapter/torch/metal.py Exposes Metal kernel source via adapter API.
tilelang/jit/adapter/base.py Prefers MPS device selection when CUDA is unavailable/initialization fails.
tilelang/intrinsics/metal_macro_generator.py Adds an MPS intrinsics emitter for simdgroup load/mma/store sequences.
tilelang/engine/phase.py Inserts Metal fragment→simdgroup rewrite before layout inference.
tilelang/engine/lower.py Switches Metal build entrypoint to target.build.tilelang_metal.
testing/python/metal/test_metal_simdgroup_store.py Tests simdgroup accumulator + direct simdgroup_store to device memory.
testing/python/metal/test_metal_local_var.py Tests Metal codegen/runtime behavior for local.var scalars.
testing/python/metal/test_metal_internal_scaffolding.py Adds extensive internal-only Metal scaffolding/runtime probes.
testing/python/metal/test_metal_gemm_v2.py Runtime GEMM v2 correctness tests on Metal hardware.
testing/python/metal/test_metal_gemm_v2_linux.py Cross-platform Metal GEMM v2 codegen-only tests.
testing/python/metal/test_fp8_scaled_matmul_metal.py Metal end-to-end FP8 scaled-matmul lowering/offline compile/parity tests.
testing/python/metal/metal_internal_runtime_coverage.md Documents internal Metal runtime coverage and limitations.
testing/python/jit/test_tilelang_jit_adapter_mps.py Tests JIT adapter device selection prefers MPS in CUDA-less setups.
testing/python/cpu/test_fp8_scaled_matmul_lowering.py IR/codegen-level FP8 scaled-matmul tests without GPU dependency.
src/transform/lower_device_storage_access_info.cc Adjusts storage-scope filtering to tolerate fragment tags.
src/transform/layout_inference.cc Relaxes fragment-layout checks for Metal targets.
src/target/codegen_metal.h Adds TileLang-specific Metal codegen class declaration.
src/target/codegen_metal.cc Implements TileLang Metal codegen and registers target.build.tilelang_metal.
src/op/utils.h Adds helpers to recognize metal.simdgroup buffers as register buffers.
src/op/parallel.cc Makes fragment-layout assumptions conditional to avoid crashes when absent.
src/op/gemm.h Extends C++ GEMM instruction enum with Metal simdgroup option.
src/op/gemm.cc Selects Metal GEMM inst for Metal targets; tweaks warp policy for Metal.
src/op/fill.cc Adds a simdgroup-specific fill lowering using make_filled_simdgroup_matrix.
src/op/copy.h Adds kMetalSIMDGroup copy inst + simdgroup copy lowering hooks.
src/op/copy.cc Implements simdgroup store lowering (simdgroup→shared/global).
src/backend/metal/CMakeLists.txt Adjusts Metal backend build wiring to always compile Metal codegen.
requirements.txt Adds Darwin-only apache-tvm-ffi upper bound.
requirements-dev.txt Adds Darwin-only apache-tvm-ffi upper bound for dev installs.
pyproject.toml Adds Darwin-only apache-tvm-ffi upper bound to project dependencies.
benchmark/matmul_metal/benchmark_matmul_metal.py Adds a Metal GEMM benchmark script for simdgroup kernels.
Comments suppressed due to low confidence (1)

src/backend/metal/CMakeLists.txt:24

  • src/target/rt_mod_metal.cc is referenced in the Metal backend source glob, but that file does not exist in this repo checkout. Building with USE_METAL=ON on Apple will fail at configure/build time unless this file is added or the glob/path is corrected.
file(GLOB TILE_LANG_METAL_SRCS
  src/target/rt_mod_metal.cc
)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS})

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +83 to +92
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
new_alloc_bufs,
stmt.match_buffers,
stmt.annotations,
Comment on lines +113 to +124
# Storage-level FP8 dtype tags accepted by this intrinsic. Any other dtype
# in the A / B operands raises a TypeError at parse time. ``float8_e8m0fnu``
# is the block-scale-factor format and is intentionally excluded — it is
# carried by the sf_a / sf_b operands of the block-scaled GEMM, not by A / B.
FP8_DTYPES: tuple[str, ...] = ("float8_e4m3", "float8_e5m2", "float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2fnuz")


def _is_fp8_dtype(dt) -> bool:
"""Return True if a dtype string / object names an FP8 storage variant."""
s = str(dt or "")
return any(s.startswith(t) for t in ("float8", "fp8"))

Comment on lines +272 to +277
for i, j in T.Parallel(M_dim, N_dim):
for k in T.serial(K_dim):
a_val = T.cast(A_fp8[i, k], "float32")
b_val = T.cast(B_fp8[k, j], "float32")
sa = A_scale[0] if sa_size == 1 else A_scale[i]
sb = B_scale[0] if sb_size == 1 else B_scale[j]
Comment thread src/op/copy.cc
Comment on lines +1097 to +1100
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0)
Comment on lines +40 to +41
// override print thread tag.
void PrintArgUnionDecl();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants