Skip to content

Commit 8a1b3bb

Browse files
committed
feat: implement TileLang Path C topk selector with shape-specialized kernels
1 parent 8fb3e32 commit 8a1b3bb

16 files changed

Lines changed: 2049 additions & 111 deletions

bench/tilelang_ports/sparse_mla_blockscaled.json

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
"shape": {
66
"q_shape": [
77
1,
8-
64,
98
4,
9+
2,
1010
64
1111
],
1212
"kv_shape": [
1313
1,
14-
64,
14+
4,
1515
1,
1616
64
1717
],
1818
"indices_shape": [
1919
1,
20-
64,
20+
4,
2121
1,
2222
16
2323
],
@@ -31,52 +31,89 @@
3131
"codegen_blocker_reason": "sparse_mla_blockscaled direct-MSL kernel built via mx.fast.metal_kernel is available; block-scaled MXFP8 dequant happens inline inside MSL on uint8 e4m3 storage with E8M0 block-scales. Apple MSL 4.0 has no native float8 simdgroup matrix, so the matmuls run as plain register fma.",
3232
"block_size": 32
3333
},
34+
"path_c_tilelang_e8m0_qk_status": {
35+
"available": false,
36+
"reason": "TileLang Path C E8M0 Sparse-MLA QK is not safe to dispatch: no simdgroup_multiply_accumulate; scale operands disappeared from emitted MSL; E8M0 exp2(byte - 127) decode markers missing; scale operands are not indexed by K/32; scalar fallback markers present; Sparse-MLA M=1/topk tile violates current Metal FP8 simdgroup tile constraints",
37+
"target": "metal",
38+
"m": 1,
39+
"n": 16,
40+
"k": 64,
41+
"transpose_B": true,
42+
"scale_block_size": 32,
43+
"scale_layout": "logical_unswizzled_k_axis_blocks",
44+
"features": {
45+
"kernel_void": 1,
46+
"simdgroup_multiply_accumulate": 0,
47+
"simdgroup_load": 0,
48+
"simdgroup_store": 0,
49+
"fp8_e4m3_decode_helper": 3,
50+
"A_scale_refs": 0,
51+
"B_scale_refs": 0,
52+
"signature_has_A_scale": false,
53+
"signature_has_B_scale": false,
54+
"e8m0_exp2": 0,
55+
"e8m0_bias_subtract_127": 0,
56+
"e8m0_sentinel_255": 0,
57+
"e8m0_zero_sentinel": 1,
58+
"k_block_shift_5": 0,
59+
"k_block_div_32": 0,
60+
"A_scale_collapsed_zero": 0,
61+
"B_scale_collapsed_zero": 0,
62+
"float_a_val": true,
63+
"float_b_val": true,
64+
"threadgroup_half": false,
65+
"scale_format": "e8m0_block_k32",
66+
"scale_block_size": 32,
67+
"scale_axis": "contracted_k",
68+
"scale_layout": "logical_unswizzled_k_axis_blocks"
69+
}
70+
},
3471
"parity": {
3572
"blockscaled_vs_bf16": {
36-
"max_abs_err": 0.011328823864459991,
37-
"max_rel_err": 0.11758861583070061
73+
"max_abs_err": 0.004521891474723816,
74+
"max_rel_err": 0.06304103596741027
3875
},
3976
"quantized_matmul_vs_bf16": {
4077
"max_abs_err": 0.0,
4178
"max_rel_err": 0.0
4279
},
4380
"msl_blockscaled_vs_bf16": {
44-
"max_abs_err": 0.011317778378725052,
45-
"max_rel_err": 0.1174739681502098
81+
"max_abs_err": 0.0045216078869998455,
82+
"max_rel_err": 0.06303708238647923
4683
},
4784
"msl_blockscaled_vs_bs_ref": {
48-
"max_abs_err": 3.0465424060821533e-05,
49-
"max_rel_err": 0.0003377691831805473
85+
"max_abs_err": 2.81408429145813e-05,
86+
"max_rel_err": 0.00039434889228588904
5087
}
5188
},
5289
"bench": {
5390
"bf16_reference": {
5491
"label": "bf16_reference",
55-
"median_ms": 0.9395829401910305,
56-
"min_ms": 0.665042083710432,
57-
"max_ms": 0.9992080740630627,
58-
"iters": 8
92+
"median_ms": 0.608749920502305,
93+
"min_ms": 0.5949169863015413,
94+
"max_ms": 0.6114158313721418,
95+
"iters": 3
5996
},
6097
"blockscaled_reference": {
6198
"label": "blockscaled_reference",
62-
"median_ms": 0.645665917545557,
63-
"min_ms": 0.3741669934242964,
64-
"max_ms": 1.2330419849604368,
65-
"iters": 8
99+
"median_ms": 0.5933749489486217,
100+
"min_ms": 0.5850829184055328,
101+
"max_ms": 0.6149171385914087,
102+
"iters": 3
66103
},
67104
"quantized_matmul_reference": {
68105
"label": "quantized_matmul_reference",
69-
"median_ms": 0.3338339738547802,
70-
"min_ms": 0.2703331410884857,
71-
"max_ms": 0.5257499869912863,
72-
"iters": 8
106+
"median_ms": 0.5981249269098043,
107+
"min_ms": 0.5954578518867493,
108+
"max_ms": 0.7106249686330557,
109+
"iters": 3
73110
},
74111
"path_b_msl_blockscaled_fwd": {
75112
"label": "path_b_msl_blockscaled_fwd",
76-
"median_ms": 0.9964159689843655,
77-
"min_ms": 0.5499999970197678,
78-
"max_ms": 1.009250059723854,
79-
"iters": 8
113+
"median_ms": 0.18091709353029728,
114+
"min_ms": 0.1712499652057886,
115+
"max_ms": 0.20291702821850777,
116+
"iters": 3
80117
}
81118
}
82119
}

bench/tilelang_ports/sparse_mla_fp8.json

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
"shape": {
66
"q_shape": [
77
1,
8-
64,
98
4,
9+
2,
1010
64
1111
],
1212
"kv_shape": [
1313
1,
14-
64,
14+
4,
1515
1,
1616
64
1717
],
1818
"indices_shape": [
1919
1,
20-
64,
20+
4,
2121
1,
2222
16
2323
],
@@ -31,52 +31,115 @@
3131
"codegen_blocker_reason": "sparse_mla_fp8 direct-MSL kernel built via mx.fast.metal_kernel is available; FP8 e4m3 dequant happens inline inside MSL on uint8 storage. Apple MSL 4.0 has no native float8 simdgroup matrix, so the matmuls run as plain register fma ops (still ~2x faster than the BF16-fallback reference because the dequant fuses with the QK loop).",
3232
"fp8_dtype": "float8_e4m3"
3333
},
34+
"path_c_tilelang_qk_status": {
35+
"available": false,
36+
"reason": "TileLang Path C FP8 Sparse-MLA QK is not safe to dispatch: no simdgroup_multiply_accumulate; scale operands disappeared from emitted MSL; scalar fallback markers present; Sparse-MLA M=1/topk tile violates current Metal FP8 simdgroup tile constraints",
37+
"target": "metal",
38+
"m": 1,
39+
"n": 16,
40+
"k": 64,
41+
"transpose_B": true,
42+
"features": {
43+
"kernel_void": 1,
44+
"simdgroup_multiply_accumulate": 0,
45+
"simdgroup_load": 0,
46+
"simdgroup_store": 0,
47+
"fp8_e4m3_decode_helper": 3,
48+
"A_scale_refs": 0,
49+
"B_scale_refs": 0,
50+
"signature_has_A_scale": false,
51+
"signature_has_B_scale": false,
52+
"float_a_val": true,
53+
"float_b_val": true,
54+
"threadgroup_half": false
55+
}
56+
},
57+
"path_c_tilelang_qk_reduce_status": {
58+
"available": true,
59+
"reason": "TileLang Path C FP8 Sparse-MLA real QK reducer is dispatchable for M=1/topk with per-row B scales",
60+
"target": "metal",
61+
"n": 16,
62+
"k": 64,
63+
"outputs_per_block": 4,
64+
"reduce_threads": 32,
65+
"vec": 4,
66+
"features": {
67+
"kernel_void": 1,
68+
"fp8_e4m3_decode_helper": 3,
69+
"scalar_fp8_byte_decode": 3,
70+
"scalar_fp8_byte_decode_calls": 2,
71+
"tvm_thread_allreduce": 0,
72+
"simd_sum": 0,
73+
"simd_shuffle_down": 5,
74+
"A_scale_refs": 1,
75+
"B_scale_refs": 1,
76+
"signature_has_A_scale": true,
77+
"signature_has_B_scale": true,
78+
"per_row_B_scale": true,
79+
"reinterpret_cast": 0,
80+
"device_const_uint": 0,
81+
"uchar4": 0,
82+
"threadgroup_half": false,
83+
"qk_shape": "m1_n_topk_k"
84+
}
85+
},
3486
"parity": {
3587
"fp8_vs_bf16": {
36-
"max_abs_err": 0.0030034519731998444,
37-
"max_rel_err": 0.031174618342377305
88+
"max_abs_err": 0.0034084729850292206,
89+
"max_rel_err": 0.047518537152928295
3890
},
3991
"quantized_matmul_vs_bf16": {
4092
"max_abs_err": 0.0,
4193
"max_rel_err": 0.0
4294
},
4395
"msl_fp8_vs_bf16": {
44-
"max_abs_err": 0.0030084550380706787,
45-
"max_rel_err": 0.0312265481349234
96+
"max_abs_err": 0.003423169255256653,
97+
"max_rel_err": 0.04772342223368997
4698
},
4799
"msl_fp8_vs_fp8_ref": {
48-
"max_abs_err": 3.0413269996643066e-05,
49-
"max_rel_err": 0.0003079714788127215
100+
"max_abs_err": 3.0234456062316895e-05,
101+
"max_rel_err": 0.00042000606965991613
102+
},
103+
"path_c_qk_reduce_vs_oracle": {
104+
"max_abs_err": 0.0,
105+
"max_rel_err": 0.0
50106
}
51107
},
52108
"bench": {
53109
"bf16_reference": {
54110
"label": "bf16_reference",
55-
"median_ms": 0.9395829401910305,
56-
"min_ms": 0.665042083710432,
57-
"max_ms": 0.9992080740630627,
58-
"iters": 8
111+
"median_ms": 0.608749920502305,
112+
"min_ms": 0.5949169863015413,
113+
"max_ms": 0.6114158313721418,
114+
"iters": 3
59115
},
60116
"fp8_reference": {
61117
"label": "fp8_reference",
62-
"median_ms": 0.4516250919550657,
63-
"min_ms": 0.33437483943998814,
64-
"max_ms": 0.5283341743052006,
65-
"iters": 8
118+
"median_ms": 0.6220829673111439,
119+
"min_ms": 0.6120421458035707,
120+
"max_ms": 0.6377911195158958,
121+
"iters": 3
66122
},
67123
"quantized_matmul_reference": {
68124
"label": "quantized_matmul_reference",
69-
"median_ms": 0.3338339738547802,
70-
"min_ms": 0.2703331410884857,
71-
"max_ms": 0.5257499869912863,
72-
"iters": 8
125+
"median_ms": 0.5981249269098043,
126+
"min_ms": 0.5954578518867493,
127+
"max_ms": 0.7106249686330557,
128+
"iters": 3
73129
},
74130
"path_b_msl_fp8_fwd": {
75131
"label": "path_b_msl_fp8_fwd",
76-
"median_ms": 0.9899579454213381,
77-
"min_ms": 0.21674996241927147,
78-
"max_ms": 2.00541689991951,
79-
"iters": 8
132+
"median_ms": 0.5564999300986528,
133+
"min_ms": 0.22545899264514446,
134+
"max_ms": 0.6050418596714735,
135+
"iters": 3
136+
},
137+
"path_c_tilelang_fp8_qk_reduce": {
138+
"label": "path_c_tilelang_fp8_qk_reduce",
139+
"median_ms": 0.20920787937939167,
140+
"min_ms": 0.1621670089662075,
141+
"max_ms": 0.5389999132603407,
142+
"iters": 3
80143
}
81144
}
82145
}

cppmega_mlx/nn/_tilelang/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@
104104
sparse_mla_blockscaled_metal_status,
105105
sparse_mla_blockscaled_reference,
106106
)
107+
from cppmega_mlx.nn._tilelang.sparse_mla_blockscaled_path_c import (
108+
E8M0_BLOCK_SIZE,
109+
E8M0_LAYOUT,
110+
E8M0_SCALE_FORMAT,
111+
SparseMLABlockScaledPathCStatus,
112+
blockscaled_sparse_mla_qk_msl_features,
113+
blockscaled_sparse_mla_qk_path_c_status,
114+
lower_blockscaled_sparse_mla_qk_msl,
115+
make_blockscaled_sparse_mla_qk_kernel,
116+
)
107117
from cppmega_mlx.nn._tilelang.sparse_mla_fp8 import (
108118
SparseMLAFp8MetalStatus,
109119
sparse_mla_fp8_apply,
@@ -123,11 +133,15 @@
123133
__all__ = [
124134
"FP8MSLKernelStatus",
125135
"FP8VecmatPathCStatus",
136+
"E8M0_BLOCK_SIZE",
137+
"E8M0_LAYOUT",
138+
"E8M0_SCALE_FORMAT",
126139
"M2RNNMetalStatus",
127140
"Mamba3MetalStatus",
128141
"MXFP8_BLOCK_SIZE",
129142
"PathBStatus",
130143
"SparseMLABlockScaledMetalStatus",
144+
"SparseMLABlockScaledPathCStatus",
131145
"SparseMLAFp8MetalStatus",
132146
"SparseMLAMetalStatus",
133147
"SparseMLAPathCStatus",
@@ -139,6 +153,8 @@
139153
"build_mlx_body",
140154
"bwd_dadt_fused",
141155
"bwd_dtrap_ddt",
156+
"blockscaled_sparse_mla_qk_msl_features",
157+
"blockscaled_sparse_mla_qk_path_c_status",
142158
"compute_dacs_segsum",
143159
"fp8_msl_kernels",
144160
"fp8_msl_status",
@@ -150,6 +166,8 @@
150166
"fp8_vecmat_path_c_status",
151167
"half_to_fp8",
152168
"lower_fp8_vecmat_msl",
169+
"lower_blockscaled_sparse_mla_qk_msl",
170+
"make_blockscaled_sparse_mla_qk_kernel",
153171
"make_fp8_vecmat_reduce_kernel",
154172
"m2rnn",
155173
"m2rnn_apply",

0 commit comments

Comments
 (0)