Skip to content

Commit 5a66c43

Browse files
committed
[Metal] Add Metal GEMM support with simdgroup_matrix MMA
Add T.gemm support for Apple Metal using simdgroup_matrix 8×8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for shared×shared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment→device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
1 parent f309d81 commit 5a66c43

178 files changed

Lines changed: 10716 additions & 2684 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
---
2+
name: tilelang-build
3+
description: Repository-specific build, rebuild, install, and test instructions for tilelang. Use when working in the tilelang repository and the correct commands are needed for building from source, reinstalling after changes, or running project tests.
4+
---
5+
16
# Build & Install
27

38
## Installing / Rebuilding tilelang

.github/workflows/ci.yml

Lines changed: 128 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,8 @@ jobs:
309309
uv run --no-project -m --
310310
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
311311
)
312-
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
312+
"${PYTEST[@]}" --maxfail=3 --numprocesses=8 \
313+
--ignore=../examples/grouped_gemm/test_example_grouped_gemm.py \
313314
../examples
314315
315316
# NVIDIA CUDA tests
@@ -322,7 +323,7 @@ jobs:
322323
uv run --no-project -m --
323324
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
324325
)
325-
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
326+
"${PYTEST[@]}" --maxfail=3 --numprocesses=8 \
326327
./python
327328
328329
# AMD ROCm tests
@@ -336,7 +337,7 @@ jobs:
336337
uv run --no-project -m --
337338
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
338339
)
339-
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
340+
"${PYTEST[@]}" --maxfail=3 --numprocesses=8 \
340341
--ignore=./python/runtime --ignore=./python/transform \
341342
./python
342343
@@ -350,14 +351,132 @@ jobs:
350351
uv run --no-project -m --
351352
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
352353
)
353-
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
354+
"${PYTEST[@]}" --maxfail=3 --numprocesses=8 \
354355
-k metal \
355356
./python
356357
357-
# CuTeDSL backend: run examples with TILELANG_TARGET=cutedsl
358-
# Placed after core test steps so a CuTeDSL failure doesn't skip them.
359-
- name: Run CuTeDSL examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
360-
if: ${{ !cancelled() && contains(matrix.runner.toolkit, 'CUDA') }}
358+
- name: List generated files
359+
if: ${{ !cancelled() }}
360+
run: |
361+
find . -type f -name '*.py[co]' -delete
362+
find . -depth -type d -name "__pycache__" -exec rm -r "{}" +
363+
if git status --ignored --porcelain | grep -qvE '/$'; then
364+
ls -alh $(git status --ignored --porcelain | grep -vE '/$' | grep -oE '\S+$')
365+
fi
366+
367+
cutedsl:
368+
name: CuTeDSL Examples for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
369+
if: |
370+
github.repository_owner == 'tile-ai' &&
371+
(github.event_name != 'pull_request' || !github.event.pull_request.draft)
372+
needs: [tests]
373+
runs-on: [self-hosted, nvidia]
374+
timeout-minutes: 120
375+
376+
steps:
377+
- name: Checkout repository
378+
uses: actions/checkout@v6
379+
with:
380+
fetch-depth: 0
381+
submodules: recursive
382+
383+
- name: Set environment (self-hosted runners)
384+
run: |
385+
# Hide sensitive data in logs for self-hosted runners
386+
if [[ -n "${{ secrets.SECRET_PATH_PREFIXES }}" ]]; then
387+
echo "::add-mask::${{ secrets.SECRET_PATH_PREFIXES }}"
388+
# Colon separated list of secrets to mask
389+
for secret in $(echo "${{ secrets.SECRET_PATH_PREFIXES }}" | tr ':' '\n'); do
390+
echo "::add-mask::${secret}"
391+
done
392+
fi
393+
394+
# Use runner tool_cache as cache root for self-hosted runners to avoid internet connection
395+
# issues and to share cache between jobs.
396+
export XDG_CACHE_HOME="${{ runner.tool_cache }}/.ci-cache-${{ github.workflow }}"
397+
echo "XDG_CACHE_HOME=${XDG_CACHE_HOME}" | tee -a "${GITHUB_ENV}"
398+
echo "PIP_CACHE_DIR=${XDG_CACHE_HOME}/pip" | tee -a "${GITHUB_ENV}"
399+
echo "UV_CACHE_DIR=${XDG_CACHE_HOME}/uv" | tee -a "${GITHUB_ENV}"
400+
echo "PRE_COMMIT_HOME=${XDG_CACHE_HOME}/pip/.pre-commit" | tee -a "${GITHUB_ENV}"
401+
402+
- name: Set environment (CUDA)
403+
run: |
404+
TOOLKIT="CUDA-12.8"
405+
CUDA_VERSION="${TOOLKIT##*-}"
406+
CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)"
407+
CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}"
408+
export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}"
409+
export UV_INDEX="${PIP_EXTRA_INDEX_URL}"
410+
411+
echo "USE_CUDA=ON" | tee -a "${GITHUB_ENV}"
412+
echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}"
413+
echo "CUDA_VERSION_MAJMIN=${CUDA_VERSION_MAJMIN}" | tee -a "${GITHUB_ENV}"
414+
echo "CUDA_VERSION_MAJMIN_NODOT=${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}"
415+
echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}"
416+
echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}"
417+
418+
if [[ ! -x "$(command -v nvcc)" ]]; then
419+
export PATH="/usr/local/cuda/bin:${PATH}"
420+
export LD_LIBRARY_PATH="/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}"
421+
echo "PATH=${PATH}" | tee -a "${GITHUB_ENV}"
422+
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "${GITHUB_ENV}"
423+
fi
424+
if [[ -x "$(command -v nvcc)" ]]; then
425+
echo "\$ $(command -v nvcc) --version" && nvcc --version
426+
else
427+
echo "::warning::nvcc not found in PATH!"
428+
fi
429+
430+
- name: Setup Python and uv with caching
431+
id: setup-uv
432+
uses: astral-sh/setup-uv@v7
433+
with:
434+
python-version: "3.12"
435+
activate-environment: true
436+
enable-cache: false
437+
prune-cache: false
438+
cache-local-path: ${{ env.UV_CACHE_DIR }}
439+
ignore-nothing-to-cache: true
440+
cache-suffix: uv-${{ runner.os }}-${{ runner.arch }}-3.12-self-hosted-nvidia-CUDA-12.8
441+
cache-dependency-glob: |
442+
pyproject.toml
443+
requirements*.txt
444+
.pre-commit-config.yaml
445+
446+
- name: Setup venv
447+
id: setup-venv
448+
run: |
449+
set -o pipefail
450+
451+
uv pip install --upgrade pip setuptools wheel
452+
uv pip install -v -r requirements-test.txt
453+
echo "import torch; print(f'torch: {torch.__version__}')" | uv run --no-project --script -
454+
uv pip install --no-build-isolation-package=flash-attn -v -r requirements-test-cuda.txt
455+
echo "import flash_attn; print(f'flash_attn: {flash_attn.__version__}')" | uv run --no-project --script -
456+
echo "::group::torch.utils.collect_env"
457+
uv run --no-project -m -- torch.utils.collect_env
458+
echo "::endgroup::"
459+
460+
- name: Clear uv cache for self-hosted runners (if setup failed)
461+
if: >-
462+
${{
463+
failure() &&
464+
(steps.setup-uv.conclusion == 'failure' || steps.setup-venv.conclusion == 'failure')
465+
}}
466+
run: |
467+
echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure."
468+
uv cache clean
469+
470+
- name: Install project (wheel form)
471+
run: |
472+
uv pip install -v .
473+
474+
- name: Clean up stale /tmp files (self-hosted runners)
475+
run: |
476+
rm -f /tmp/tmp*.so /tmp/tmp*.cu /tmp/tmp*.cubin /tmp/tmp*.cpp
477+
rm -rf /tmp/tvm-debug-mode-tempdirs /tmp/tilelang_cutedsl_*
478+
479+
- name: Run CuTeDSL examples with Python 3.12 (CUDA-12.8)
361480
env:
362481
TILELANG_TARGET: cutedsl
363482
run: |
@@ -366,7 +485,7 @@ jobs:
366485
uv run --no-project -m --
367486
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
368487
)
369-
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
488+
"${PYTEST[@]}" --maxfail=3 --numprocesses=8 \
370489
../examples
371490
372491
- name: List generated files

3rdparty/tvm

Submodule tvm updated from 882a774 to 0e15b27

CMakeLists.txt

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ file(GLOB TILE_LANG_SRCS
179179
src/op/*.cc
180180
src/target/utils.cc
181181
src/target/codegen_c_host.cc
182-
src/target/codegen_cpp.cc
183-
src/target/rt_mod_cpp.cc
182+
src/target/codegen_c.cc
183+
src/target/rt_mod_c.cc
184184
# intrin_rule doesn't have system dependency
185185
src/target/intrin_rule*.cc
186186
)
@@ -190,6 +190,16 @@ list(APPEND TILE_LANG_SRCS
190190
src/runtime/error_helpers.cc
191191
)
192192

193+
# Metal codegen is pure C++ (no Apple frameworks) and can generate Metal shader
194+
# source on any platform. Always compile it so that "target.build.tilelang_metal"
195+
# is available for cross-compilation on Linux/Windows.
196+
# The Metal *runtime* (execution on GPU) still requires macOS and is handled by
197+
# TVM's Metal.cmake which links the real runtime on Apple or a source-only
198+
# fallback (build_metal_off.cc) elsewhere.
199+
list(APPEND TILE_LANG_SRCS
200+
src/target/codegen_metal.cc
201+
)
202+
193203
# Track if the user explicitly selected a backend via cache options.
194204
set(TILELANG_BACKEND_USER_SELECTED OFF)
195205
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
@@ -229,10 +239,6 @@ if(USE_METAL)
229239
message(STATUS "Metal backend on non-Apple: enabling codegen-only mode (no Metal runtime)")
230240
set(USE_METAL OFF)
231241
endif()
232-
file(GLOB TILE_LANG_METAL_SRCS
233-
src/target/rt_mod_metal.cc
234-
)
235-
list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS})
236242
# FIXME: CIBW failed with backtrace, why???
237243
set(TVM_FFI_USE_LIBBACKTRACE OFF)
238244
elseif(USE_ROCM)
@@ -426,9 +432,30 @@ if(USE_Z3 AND USE_PYPI_Z3)
426432
find_package(Z3 REQUIRED)
427433
endif()
428434

435+
# Enable custom logging so we control the output format (e.g. strip build paths
436+
# from __FILE__ so wheel users don't see CI machine paths in warnings).
437+
set(USE_CUSTOM_LOGGING ON CACHE BOOL "Use custom logging implementation" FORCE)
438+
439+
# Detect release (wheel) builds: in CI (cibuildwheel) or scikit-build-core wheel builds,
440+
# we strip source paths from LOG(WARNING) etc. for a cleaner user experience.
441+
# Local dev builds keep full paths for debugging.
442+
if(DEFINED ENV{CIBUILDWHEEL} OR "$ENV{SKBUILD_STATE}" STREQUAL "wheel")
443+
set(TILELANG_RELEASE_BUILD_DEFAULT ON)
444+
else()
445+
set(TILELANG_RELEASE_BUILD_DEFAULT OFF)
446+
endif()
447+
option(TILELANG_RELEASE_BUILD "Strip source paths from log messages (for wheel releases)" ${TILELANG_RELEASE_BUILD_DEFAULT})
448+
429449
# Include tvm after configs have been populated
430450
add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL)
431451

452+
# Provide the custom LogMessageImpl / LogFatalImpl implementation to TVM,
453+
# since TVM_LOG_CUSTOMIZE=1 requires them to be supplied by the user.
454+
target_sources(tvm_objs PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/logging.cc")
455+
if(TILELANG_RELEASE_BUILD)
456+
target_compile_definitions(tvm_objs PRIVATE TILELANG_RELEASE_BUILD=1)
457+
endif()
458+
432459
# Resolve compile warnings in tvm
433460
add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
434461

@@ -442,6 +469,10 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
442469
endif()
443470

444471
target_include_directories(tilelang_objs PRIVATE ${TILE_LANG_INCLUDES})
472+
target_compile_definitions(tilelang_objs PRIVATE TVM_LOG_CUSTOMIZE=1)
473+
if(TILELANG_RELEASE_BUILD)
474+
target_compile_definitions(tilelang_objs PRIVATE TILELANG_RELEASE_BUILD=1)
475+
endif()
445476

446477
add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
447478
target_link_libraries(tilelang PUBLIC tvm)

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.1.8
1+
0.1.9
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import argparse
2+
import logging
3+
import time
4+
5+
import torch
6+
7+
import tilelang
8+
import tilelang.language as T
9+
10+
logging.getLogger("tilelang").setLevel(logging.WARNING)
11+
12+
BLOCK_CONFIGS = [
13+
(16, 16, 16),
14+
(32, 32, 16),
15+
(32, 32, 32),
16+
(64, 64, 32),
17+
]
18+
19+
20+
@tilelang.jit
21+
def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32):
22+
23+
@T.prim_func
24+
def gemm_kernel(
25+
A: T.Tensor((M, K), dtype),
26+
B: T.Tensor((K, N), dtype),
27+
C: T.Tensor((M, N), accum_dtype),
28+
):
29+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
30+
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
31+
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
32+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
33+
T.clear(C_local)
34+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
35+
T.copy(A[by * block_M, ko * block_K], A_shared)
36+
T.copy(B[ko * block_K, bx * block_N], B_shared)
37+
T.gemm(A_shared, B_shared, C_local)
38+
T.copy(C_local, C[by * block_M, bx * block_N])
39+
40+
return gemm_kernel
41+
42+
43+
def _tflops(M, N, K, seconds):
44+
return 2.0 * M * N * K / seconds / 1e12
45+
46+
47+
def _bench(fn, warmup, repeats):
48+
for _ in range(warmup):
49+
fn()
50+
torch.mps.synchronize()
51+
t0 = time.perf_counter()
52+
for _ in range(repeats):
53+
fn()
54+
torch.mps.synchronize()
55+
return (time.perf_counter() - t0) / repeats
56+
57+
58+
def bench_torch_mps(M, N, K, warmup, repeats):
59+
a = torch.randn(M, K, dtype=torch.float16, device="mps")
60+
b = torch.randn(K, N, dtype=torch.float16, device="mps")
61+
avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats)
62+
return _tflops(M, N, K, avg_s)
63+
64+
65+
def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats):
66+
kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K)
67+
a = torch.randn(M, K, dtype=torch.float16, device="mps")
68+
b = torch.randn(K, N, dtype=torch.float16, device="mps")
69+
c = torch.zeros(M, N, dtype=torch.float32, device="mps")
70+
avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats)
71+
return _tflops(M, N, K, avg_s)
72+
73+
74+
if __name__ == "__main__":
75+
parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)")
76+
parser.add_argument("--m", type=int, default=4096)
77+
parser.add_argument("--n", type=int, default=4096)
78+
parser.add_argument("--k", type=int, default=4096)
79+
parser.add_argument("--warmup", type=int, default=10)
80+
parser.add_argument("--repeats", type=int, default=100)
81+
parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)")
82+
args = parser.parse_args()
83+
84+
M, N, K = args.m, args.n, args.k
85+
86+
print(f"torch: {torch.__version__}")
87+
print(f"tilelang: {tilelang.__version__}")
88+
print(f"MPS: {torch.backends.mps.is_available()}")
89+
print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}")
90+
print()
91+
92+
ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats)
93+
print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS")
94+
print()
95+
96+
configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)]
97+
98+
print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}")
99+
print("-" * 44)
100+
101+
best_tflops = 0.0
102+
best_config = configs[0]
103+
for bM, bN, bK in configs:
104+
try:
105+
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
106+
ratio = tl / ref_tflops * 100
107+
tag = ""
108+
if tl > best_tflops:
109+
best_tflops = tl
110+
best_config = (bM, bN, bK)
111+
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
112+
except Exception as e:
113+
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")
114+
115+
if args.sweep:
116+
print()
117+
print(f"Best config: {best_config}")
118+
print(f"Best TFlops: {best_tflops:.1f}")
119+
print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")

0 commit comments

Comments
 (0)