@@ -67,6 +67,44 @@ inline U qdot_4bit(
6767 return scale * accum + sum * bias;
6868 }
6969
70+ // 4-bit load_vector_safe: same as load_vector_4bit but handles partial reads.
71+ template <typename T, typename U, int values_per_thread>
72+ inline U load_vector_safe_4bit(constant T* x, thread U* x_thread, int N) {
73+ U sum = 0;
74+ for (int i = 0; i < N; i += 4) {
75+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
76+ x_thread[i] = x[i];
77+ x_thread[i + 1] = x[i + 1] / 16.0f;
78+ x_thread[i + 2] = x[i + 2] / 256.0f;
79+ x_thread[i + 3] = x[i + 3] / 4096.0f;
80+ }
81+ for (int i = N; i < values_per_thread; i++) {
82+ x_thread[i] = 0;
83+ }
84+ return sum;
85+ }
86+
87+ // 4-bit qdot_safe: handles partial K dimension.
88+ template <typename U, int values_per_thread>
89+ inline U qdot_safe_4bit(
90+ constant uint8_t* w,
91+ const thread U* x_thread,
92+ U scale,
93+ U bias,
94+ U sum,
95+ int N) {
96+ U accum = 0;
97+ constant uint16_t* ws = (constant uint16_t*)w;
98+ for (int i = 0; i < (N / 4); i++) {
99+ accum +=
100+ (x_thread[4 * i] * (ws[i] & 0x000f) +
101+ x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
102+ x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
103+ x_thread[4 * i + 3] * (ws[i] & 0xf000));
104+ }
105+ return scale * accum + sum * bias;
106+ }
107+
70108 // gather_qmv_fast: per-expert quantized GEMV for MoE.
71109 //
72110 // Same as qmv_fast but offsets w/scales/biases by expert_indices[tid.x]
@@ -179,6 +217,155 @@ inline U qdot_4bit(
179217 INSTANTIATE_GATHER_QMV_FAST(bfloat, 64);
180218 INSTANTIATE_GATHER_QMV_FAST(bfloat, 128);
181219
220+ // gather_qmv_impl: generic-K fallback (handles any K, any N).
221+ // Same as qmv_impl in op_linear_4bit.mm but with expert index offset.
222+ template <typename T, int group_size>
223+ [[kernel]] void gather_qmv_impl(
224+ constant T* x [[buffer(0)]],
225+ constant uchar* w [[buffer(1)]],
226+ constant T* scales [[buffer(2)]],
227+ constant T* biases [[buffer(3)]],
228+ device T* y [[buffer(4)]],
229+ constant uint3 &sizes [[buffer(5)]],
230+ constant uint32_t* expert_indices [[buffer(6)]],
231+ constant uint3 &expert_strides [[buffer(7)]],
232+ uint3 tid [[threadgroup_position_in_grid]],
233+ uint simd_gid [[simdgroup_index_in_threadgroup]],
234+ uint simd_lid [[thread_index_in_simdgroup]]) {
235+ const int in_vec_size = static_cast<int>(sizes.y); // K
236+ const int out_vec_size = static_cast<int>(sizes.z); // N
237+
238+ constexpr int bits = 4;
239+ constexpr int packs_per_thread = 2;
240+ constexpr int num_simdgroups = 2;
241+ constexpr int results_per_simdgroup = 4;
242+ constexpr int pack_factor = 32 / bits; // 8
243+ constexpr int bytes_per_pack = 4;
244+ constexpr int values_per_thread = pack_factor * packs_per_thread; // 16
245+ constexpr int block_size = values_per_thread * SIMD_SIZE;
246+ constexpr int scale_step_per_thread = group_size / values_per_thread;
247+
248+ // Offset to this expert's weights
249+ uint expert_idx = expert_indices[tid.x];
250+ constant uint8_t* ws = (constant uint8_t*)w + expert_idx * expert_strides.x;
251+ constant T* sc = scales + expert_idx * expert_strides.y;
252+ constant T* bi = biases + expert_idx * expert_strides.z;
253+
254+ typedef float U;
255+
256+ thread U x_thread[values_per_thread];
257+ thread U result[results_per_simdgroup] = {0};
258+
259+ const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
260+ const int in_vec_size_g = (in_vec_size + group_size - 1) / group_size;
261+ const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
262+ simd_gid * results_per_simdgroup;
263+ const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
264+
265+ if (out_row >= out_vec_size) {
266+ return;
267+ }
268+
269+ // Small N path: fewer than 1 tile of output rows
270+ if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
271+ ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
272+ sc += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
273+ bi += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
274+ x += tid.x * in_vec_size + simd_lid * values_per_thread;
275+ y += tid.x * out_vec_size + out_row;
276+
277+ int k = 0;
278+ for (; k < in_vec_size - block_size; k += block_size) {
279+ U sum = load_vector_4bit<T, U, values_per_thread>(x, x_thread);
280+ for (int row = 0; out_row + row < out_vec_size; row++) {
281+ auto wl = (constant uint8_t*)(ws + row * in_vec_size_w);
282+ constant T* sl = sc + row * in_vec_size_g;
283+ constant T* bl = bi + row * in_vec_size_g;
284+ result[row] += qdot_4bit<U, values_per_thread>(wl, x_thread, sl[0], bl[0], sum);
285+ }
286+ ws += block_size * bytes_per_pack / pack_factor;
287+ sc += block_size / group_size;
288+ bi += block_size / group_size;
289+ x += block_size;
290+ }
291+ const int remaining = clamp(
292+ static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
293+ if (remaining > 0) {
294+ U sum = load_vector_safe_4bit<T, U, values_per_thread>(x, x_thread, remaining);
295+ for (int row = 0; out_row + row < out_vec_size; row++) {
296+ auto wl = (constant uint8_t*)(ws + row * in_vec_size_w);
297+ constant T* sl = sc + row * in_vec_size_g;
298+ constant T* bl = bi + row * in_vec_size_g;
299+ result[row] += qdot_safe_4bit<U, values_per_thread>(wl, x_thread, sl[0], bl[0], sum, remaining);
300+ }
301+ }
302+ for (int row = 0; out_row + row < out_vec_size; row++) {
303+ result[row] = simd_sum(result[row]);
304+ if (simd_lid == 0) { y[row] = static_cast<T>(result[row]); }
305+ }
306+ }
307+ // Normal path: last tile may overlap with previous
308+ else {
309+ ws += used_out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
310+ sc += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
311+ bi += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
312+ x += tid.x * in_vec_size + simd_lid * values_per_thread;
313+ y += tid.x * out_vec_size + used_out_row;
314+
315+ int k = 0;
316+ for (; k < in_vec_size - block_size; k += block_size) {
317+ U sum = load_vector_4bit<T, U, values_per_thread>(x, x_thread);
318+ for (int row = 0; row < results_per_simdgroup; row++) {
319+ auto wl = (constant uint8_t*)(ws + row * in_vec_size_w);
320+ constant T* sl = sc + row * in_vec_size_g;
321+ constant T* bl = bi + row * in_vec_size_g;
322+ result[row] += qdot_4bit<U, values_per_thread>(wl, x_thread, sl[0], bl[0], sum);
323+ }
324+ ws += block_size * bytes_per_pack / pack_factor;
325+ sc += block_size / group_size;
326+ bi += block_size / group_size;
327+ x += block_size;
328+ }
329+ const int remaining = clamp(
330+ static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
331+ if (remaining > 0) {
332+ U sum = load_vector_safe_4bit<T, U, values_per_thread>(x, x_thread, remaining);
333+ for (int row = 0; row < results_per_simdgroup; row++) {
334+ auto wl = (constant uint8_t*)(ws + row * in_vec_size_w);
335+ constant T* sl = sc + row * in_vec_size_g;
336+ constant T* bl = bi + row * in_vec_size_g;
337+ result[row] += qdot_safe_4bit<U, values_per_thread>(wl, x_thread, sl[0], bl[0], sum, remaining);
338+ }
339+ }
340+ for (int row = 0; row < results_per_simdgroup; row++) {
341+ result[row] = simd_sum(result[row]);
342+ if (simd_lid == 0) { y[row] = static_cast<T>(result[row]); }
343+ }
344+ }
345+ }
346+
347+ #define INSTANTIATE_GATHER_QMV_IMPL(DTYPE, GSIZE) \
348+ template [[host_name("gather_qmv_impl_4bit_" #GSIZE "_" #DTYPE)]] kernel void \
349+ gather_qmv_impl<DTYPE, GSIZE>( \
350+ constant DTYPE * x [[buffer(0)]], \
351+ constant uchar * w [[buffer(1)]], \
352+ constant DTYPE * scales [[buffer(2)]], \
353+ constant DTYPE * biases [[buffer(3)]], \
354+ device DTYPE * y [[buffer(4)]], \
355+ constant uint3 & sizes [[buffer(5)]], \
356+ constant uint32_t * expert_indices [[buffer(6)]], \
357+ constant uint3 & expert_strides [[buffer(7)]], \
358+ uint3 tid [[threadgroup_position_in_grid]], \
359+ uint simd_gid [[simdgroup_index_in_threadgroup]], \
360+ uint simd_lid [[thread_index_in_simdgroup]])
361+
362+ INSTANTIATE_GATHER_QMV_IMPL(float, 32);
363+ INSTANTIATE_GATHER_QMV_IMPL(float, 64);
364+ INSTANTIATE_GATHER_QMV_IMPL(float, 128);
365+ INSTANTIATE_GATHER_QMV_IMPL(bfloat, 32);
366+ INSTANTIATE_GATHER_QMV_IMPL(bfloat, 64);
367+ INSTANTIATE_GATHER_QMV_IMPL(bfloat, 128);
368+
182369 )" ;
183370}
184371
@@ -280,8 +467,11 @@ AOTITorchError aoti_torch_mps_gather_qmv(
280467 return Error::Internal;
281468 }
282469
283- // Select kernel (M=1 GEMV path)
284- std::string kernel_name = " gather_qmv_fast_4bit_" + std::to_string (group_size) + " _" + type_str;
470+ // Select kernel: fast path for aligned K, impl path for generic K
471+ bool use_fast = (N % 8 == 0 && K % 512 == 0 );
472+ std::string kernel_name = use_fast
473+ ? " gather_qmv_fast_4bit_" + std::to_string (group_size) + " _" + type_str
474+ : " gather_qmv_impl_4bit_" + std::to_string (group_size) + " _" + type_str;
285475 ET_LOG (Debug, " aoti_torch_mps_gather_qmv: Using kernel: %s" , kernel_name.c_str ());
286476
287477 auto kernel_func = library->getKernelFunction (kernel_name);
0 commit comments