Skip to content

Commit 4b015bf

Browse files
committed
[Metal] Add Metal GEMM support with simdgroup_matrix MMA
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
1 parent 936ae92 commit 4b015bf

29 files changed

Lines changed: 1682 additions & 28 deletions

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ list(APPEND TILE_LANG_SRCS
190190
src/runtime/error_helpers.cc
191191
)
192192

193+
# Metal codegen is pure C++ (no Apple frameworks) and can generate Metal shader
194+
# source on any platform. Always compile it so that "target.build.tilelang_metal"
195+
# is available for cross-compilation on Linux/Windows.
196+
list(APPEND TILE_LANG_SRCS
197+
src/target/codegen_metal.cc
198+
)
199+
193200
set(TILELANG_OUTPUT_TARGETS tilelang tvm)
194201

195202
# Track if the user explicitly selected a backend via cache options.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import argparse
2+
import logging
3+
import time
4+
5+
import torch
6+
7+
import tilelang
8+
import tilelang.language as T
9+
10+
logging.getLogger("tilelang").setLevel(logging.WARNING)
11+
12+
BLOCK_CONFIGS = [
13+
(16, 16, 16),
14+
(32, 32, 16),
15+
(32, 32, 32),
16+
(64, 64, 32),
17+
]
18+
19+
20+
@tilelang.jit
21+
def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32):
22+
23+
@T.prim_func
24+
def gemm_kernel(
25+
A: T.Tensor((M, K), dtype),
26+
B: T.Tensor((K, N), dtype),
27+
C: T.Tensor((M, N), accum_dtype),
28+
):
29+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
30+
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
31+
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
32+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
33+
T.clear(C_local)
34+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
35+
T.copy(A[by * block_M, ko * block_K], A_shared)
36+
T.copy(B[ko * block_K, bx * block_N], B_shared)
37+
T.gemm(A_shared, B_shared, C_local)
38+
T.copy(C_local, C[by * block_M, bx * block_N])
39+
40+
return gemm_kernel
41+
42+
43+
def _tflops(M, N, K, seconds):
44+
return 2.0 * M * N * K / seconds / 1e12
45+
46+
47+
def _bench(fn, warmup, repeats):
48+
for _ in range(warmup):
49+
fn()
50+
torch.mps.synchronize()
51+
t0 = time.perf_counter()
52+
for _ in range(repeats):
53+
fn()
54+
torch.mps.synchronize()
55+
return (time.perf_counter() - t0) / repeats
56+
57+
58+
def bench_torch_mps(M, N, K, warmup, repeats):
59+
a = torch.randn(M, K, dtype=torch.float16, device="mps")
60+
b = torch.randn(K, N, dtype=torch.float16, device="mps")
61+
avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats)
62+
return _tflops(M, N, K, avg_s)
63+
64+
65+
def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats):
66+
kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K)
67+
a = torch.randn(M, K, dtype=torch.float16, device="mps")
68+
b = torch.randn(K, N, dtype=torch.float16, device="mps")
69+
c = torch.zeros(M, N, dtype=torch.float32, device="mps")
70+
avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats)
71+
return _tflops(M, N, K, avg_s)
72+
73+
74+
if __name__ == "__main__":
75+
parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)")
76+
parser.add_argument("--m", type=int, default=4096)
77+
parser.add_argument("--n", type=int, default=4096)
78+
parser.add_argument("--k", type=int, default=4096)
79+
parser.add_argument("--warmup", type=int, default=10)
80+
parser.add_argument("--repeats", type=int, default=100)
81+
parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)")
82+
args = parser.parse_args()
83+
84+
M, N, K = args.m, args.n, args.k
85+
86+
print(f"torch: {torch.__version__}")
87+
print(f"tilelang: {tilelang.__version__}")
88+
print(f"MPS: {torch.backends.mps.is_available()}")
89+
print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}")
90+
print()
91+
92+
ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats)
93+
print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS")
94+
print()
95+
96+
configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)]
97+
98+
print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}")
99+
print("-" * 44)
100+
101+
best_tflops = 0.0
102+
best_config = configs[0]
103+
for bM, bN, bK in configs:
104+
try:
105+
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
106+
ratio = tl / ref_tflops * 100
107+
tag = ""
108+
if tl > best_tflops:
109+
best_tflops = tl
110+
best_config = (bM, bN, bK)
111+
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
112+
except Exception as e:
113+
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")
114+
115+
if args.sweep:
116+
print()
117+
print(f"Best config: {best_config}")
118+
print(f"Best TFlops: {best_tflops:.1f}")
119+
print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
# requirement as wide as possible to be compatible with other libraries
3232
# pip will try to use latest version whenever possible.
3333
"apache-tvm-ffi~=0.1.0,>=0.1.2",
34+
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
3435
# torch-c-dlpack-ext provides prebuilt torch extensions.
3536
# Without it, TVM FFI may require JIT compilation on first import.
3637
"torch-c-dlpack-ext; python_version < '3.14'",

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Requirements to run local build with `--no-build-isolation` or other developments
22

33
apache-tvm-ffi~=0.1.0,>=0.1.2
4+
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
45
build
56
cmake>=3.26
67
cython>=3.1.0

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Runtime requirements
22

33
apache-tvm-ffi~=0.1.0,>=0.1.2
4+
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
45
torch-c-dlpack-ext; python_version < '3.14'
56
cloudpickle
67
ml-dtypes

src/op/copy.cc

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,10 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
517517
return result_map;
518518
}
519519

520+
if (copy_inst == CopyInst::kMetalSIMDGroup) {
521+
return {};
522+
}
523+
520524
// for LDSM/STSM, the layout was deduced from register layout
521525
// so we can directly apply the layout of normal copy
522526
// Use parallel op to infer the layout
@@ -792,11 +796,16 @@ bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map,
792796
if (!CheckCPAsyncCopyPreconditions()) {
793797
return false;
794798
}
795-
// Skip vectorize size check here because, during the Infer Layout stage,
796-
// the layout is not stable and the vectorized size cannot be determined.
797799
return true;
798800
}
799801

802+
bool CopyNode::CheckSIMDGroupCopy(Target target) const {
803+
if (TargetIsMetal(target) && IsSIMDGroupBuffer(src)) {
804+
return IsSharedBuffer(dst) || IsGlobalBuffer(dst);
805+
}
806+
return false;
807+
}
808+
800809
// Selects the most specific copy instruction for the given target and buffers.
801810
// Priority: BulkLoad1D, BulkStore1D, BulkLoad, BulkStore, LDSM, STSM,
802811
// TMemLoad, TMemStore, CPAsync, Normal.
@@ -864,6 +873,8 @@ CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map,
864873
return CopyInst::kTMemLoad;
865874
} else if (CheckTMemStore(target)) {
866875
return CopyInst::kTMemStore;
876+
} else if (CheckSIMDGroupCopy(target)) {
877+
return CopyInst::kMetalSIMDGroup;
867878
} else {
868879
return CopyInst::kNormal;
869880
}
@@ -897,6 +908,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
897908
auto cp_async_copy = LowerCPAsyncCopy(T, analyzer);
898909
ICHECK(cp_async_copy.defined()) << "Failed to lower cp.async copy";
899910
return cp_async_copy;
911+
} else if (copy_inst == CopyInst::kMetalSIMDGroup) {
912+
return LowerSIMDGroupCopy(T, analyzer);
900913
} else if (copy_inst == CopyInst::kNormal) {
901914
return LowerNormalCopy(T, analyzer);
902915
} else {
@@ -982,7 +995,88 @@ Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T,
982995
return cp_async_loop;
983996
}
984997

985-
// Lowers the copy using standard load/store with loop transformations.
998+
Stmt CopyNode::LowerSIMDGroupCopy(const LowerArgs &T,
999+
arith::Analyzer *analyzer) const {
1000+
ICHECK(IsSIMDGroupBuffer(src));
1001+
int total_elements = 1;
1002+
for (auto s : src->shape) {
1003+
auto imm = s.as<IntImmNode>();
1004+
ICHECK(imm) << "simdgroup buffer must have constant shape";
1005+
total_elements *= imm->value;
1006+
}
1007+
ICHECK(total_elements % 64 == 0)
1008+
<< "simdgroup buffer size must be multiple of 64 (8x8), got "
1009+
<< total_elements;
1010+
1011+
ICHECK(dst_range.size() == 2)
1012+
<< "Expected 2D destination for simdgroup store";
1013+
PrimExpr dst_row_base = dst_range[0]->min;
1014+
PrimExpr dst_col_base = dst_range[1]->min;
1015+
PrimExpr dst_stride = dst->shape[dst->shape.size() - 1];
1016+
1017+
int warp_size = TargetGetWarpSize(T.target);
1018+
int block_size = T.thread_bounds->extent.as<IntImmNode>()->value;
1019+
int num_warps = block_size / warp_size;
1020+
PrimExpr warp_id = FloorDiv(T.thread_var, warp_size);
1021+
1022+
int M = src_range[0]->extent.as<IntImmNode>()->value;
1023+
int N = src_range[1]->extent.as<IntImmNode>()->value;
1024+
1025+
int kMPerWarp = 8;
1026+
int kNPerWarp = 8;
1027+
int m_warp = 1, n_warp = num_warps;
1028+
int max_m = M / kMPerWarp;
1029+
int max_n = N / kNPerWarp;
1030+
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
1031+
float best_score = std::numeric_limits<float>::max();
1032+
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
1033+
if (num_warps % m != 0)
1034+
continue;
1035+
int n = num_warps / m;
1036+
if (n > max_n)
1037+
continue;
1038+
float m_per = static_cast<float>(M) / (m * kMPerWarp);
1039+
float n_per = static_cast<float>(N) / (n * kNPerWarp);
1040+
float score = std::abs(m_per / n_per - ideal);
1041+
if (score < best_score) {
1042+
best_score = score;
1043+
m_warp = m;
1044+
n_warp = n;
1045+
}
1046+
}
1047+
1048+
ICHECK(M >= m_warp * 8 && N >= n_warp * 8)
1049+
<< "Cannot partition " << M << "x" << N << " matrix across " << m_warp
1050+
<< "x" << n_warp << " warps with 8x8 simdgroup tiles";
1051+
int warp_row_tiles = M / m_warp / 8;
1052+
int warp_col_tiles = N / n_warp / 8;
1053+
ICHECK(warp_row_tiles > 0 && warp_col_tiles > 0);
1054+
ICHECK(warp_row_tiles * warp_col_tiles * 64 <= total_elements)
1055+
<< "Warp partition produces more tiles than buffer capacity";
1056+
1057+
PrimExpr warp_m = FloorMod(warp_id, m_warp);
1058+
PrimExpr warp_n = FloorDiv(warp_id, m_warp);
1059+
1060+
Array<Stmt> stmts;
1061+
for (int i = 0; i < warp_row_tiles; i++) {
1062+
for (int j = 0; j < warp_col_tiles; j++) {
1063+
int tile_idx = i * warp_col_tiles + j;
1064+
PrimExpr row = dst_row_base + warp_m * (warp_row_tiles * 8) + i * 8;
1065+
PrimExpr col = dst_col_base + warp_n * (warp_col_tiles * 8) + j * 8;
1066+
PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(),
1067+
{BufferLoad(dst, {row, col})});
1068+
stmts.push_back(Evaluate(
1069+
Call(DataType::Handle(), builtin::simdgroup_store(),
1070+
{src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride,
1071+
IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8),
1072+
Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))})));
1073+
}
1074+
}
1075+
if (stmts.size() == 1)
1076+
return stmts[0];
1077+
return SeqStmt(stmts);
1078+
}
1079+
9861080
Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
9871081
arith::Analyzer *analyzer) const {
9881082
bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU;

src/op/copy.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ enum class CopyInst : uint8_t {
2424
kCPAsync = 5, // cp.async global->shared copy
2525
// we should separate the bulk load and store for 1d and multi-dim
2626
// as they have different memory access patterns
27-
kBulkLoad1D = 6, // utilize tma load 1d
28-
kBulkStore1D = 7, // utilize tma store 1d
29-
kTMemLoad = 8, // tcgen05.ld (tensor memory -> register)
30-
kTMemStore = 9, // tcgen05.st (register -> tensor memory)
27+
kBulkLoad1D = 6, // utilize tma load 1d
28+
kBulkStore1D = 7, // utilize tma store 1d
29+
kTMemLoad = 8, // tcgen05.ld (tensor memory -> register)
30+
kTMemStore = 9, // tcgen05.st (register -> tensor memory)
31+
kMetalSIMDGroup = 10, // Metal simdgroup load/store
3132
};
3233

3334
/// Convert CopyInst enum to string for debugging
@@ -53,6 +54,8 @@ inline const char *CopyInstToString(CopyInst inst) {
5354
return "TMemLoad";
5455
case CopyInst::kTMemStore:
5556
return "TMemStore";
57+
case CopyInst::kMetalSIMDGroup:
58+
return "MetalSIMDGroup";
5659
default:
5760
return "Unknown";
5861
}
@@ -290,6 +293,11 @@ class CopyNode : public TileOperatorNode {
290293
arith::Analyzer *analyzer) const;
291294

292295
protected:
296+
/*!
297+
* \brief Check if copy from Metal simdgroup to shared/global is supported.
298+
*/
299+
bool CheckSIMDGroupCopy(Target target) const;
300+
293301
/*!
294302
* \brief Get the copy instruction type.
295303
*/
@@ -331,6 +339,11 @@ class CopyNode : public TileOperatorNode {
331339
*/
332340
Stmt LowerCPAsyncCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
333341

342+
/*!
343+
* \brief Generate lowering for simdgroup store.
344+
*/
345+
Stmt LowerSIMDGroupCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
346+
334347
/*!
335348
* \brief Generate SIMT (thread-level) loop for copying.
336349
*/

src/op/fill.cc

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,30 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
156156
* @return Stmt The lowered TIR statement implementing the fill.
157157
*/
158158
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
159-
if (IsFragmentBuffer(dst)) {
159+
if (IsSIMDGroupBuffer(dst)) {
160+
int region_elements = 1;
161+
for (auto r : region) {
162+
auto imm = r->extent.as<IntImmNode>();
163+
ICHECK(imm) << "simdgroup fill region must have constant extents";
164+
region_elements *= imm->value;
165+
}
166+
int total_elements = region_elements;
167+
ICHECK(total_elements % 64 == 0)
168+
<< "simdgroup buffer size must be multiple of 64 (8x8), got "
169+
<< total_elements;
170+
int num_matrices = total_elements / 64;
171+
PrimExpr fill_value = Cast(dst->dtype, value);
172+
Array<Stmt> stmts;
173+
for (int i = 0; i < num_matrices; i++) {
174+
stmts.push_back(Evaluate(
175+
Call(DataType::Handle(), builtin::make_filled_simdgroup_matrix(),
176+
{dst->data, IntImm(DataType::Int(32), i), fill_value,
177+
IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8)})));
178+
}
179+
if (stmts.size() == 1)
180+
return stmts[0];
181+
return SeqStmt(stmts);
182+
} else if (IsFragmentBuffer(dst)) {
160183
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
161184
par_op->InferLayout({T.target,
162185
T.thread_bounds,

src/op/gemm.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ GemmInst GemmNode::getGemmInst(int block_size, Target target) const {
183183
return GemmInst::kMMA;
184184
} else if (TargetIsCPU(target)) {
185185
return GemmInst::kScalar;
186+
} else if (TargetIsMetal(target)) {
187+
return GemmInst::kMetalSimdgroup;
186188
} else {
187189
ICHECK(0) << "Unsupported target for gemm: " << target->str();
188190
return GemmInst::kMMA;
@@ -199,8 +201,11 @@ std::pair<int, int> GemmWarpPolicyNode::computeWarpPartition(
199201
}
200202

201203
int m_warp = 1, n_warp = 1;
202-
constexpr int kMPerWarp = 16; // Rows processed by a single warp
203-
int kNPerWarp = 8; // Columns processed by a single warp
204+
int kMPerWarp = 16; // Rows processed by a single warp
205+
if (TargetIsMetal(target)) {
206+
kMPerWarp = 8;
207+
}
208+
int kNPerWarp = 8; // Columns processed by a single warp
204209
if (TargetIsVolta(target)) {
205210
kNPerWarp = 16;
206211
} else if (TargetIsCDNA(target)) {

0 commit comments

Comments
 (0)