@@ -642,6 +642,7 @@ METAL_FUNC void fp_qmm_t_impl(
642642 const constant int & K,
643643 const constant int & N,
644644 const constant int & M,
645+ const constant int & K_eff,
645646 uint3 tid [[threadgroup_position_in_grid]],
646647 uint lid [[thread_index_in_threadgroup]],
647648 uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -695,7 +696,7 @@ METAL_FUNC void fp_qmm_t_impl(
695696
696697 if (num_els < BM) {
697698 if (!aligned_N && num_outs < BN) {
698- for (int k = 0 ; k < K ; k += BK) {
699+ for (int k = 0 ; k < K_eff ; k += BK) {
699700 threadgroup_barrier (mem_flags::mem_threadgroup);
700701 loader_x.load_safe (short2 (BK, num_els));
701702 loader_w.load_safe (short2 (BK, num_outs));
@@ -705,7 +706,7 @@ METAL_FUNC void fp_qmm_t_impl(
705706 loader_w.next ();
706707 }
707708 } else {
708- for (int k = 0 ; k < K ; k += BK) {
709+ for (int k = 0 ; k < K_eff ; k += BK) {
709710 threadgroup_barrier (mem_flags::mem_threadgroup);
710711 loader_x.load_safe (short2 (BK, num_els));
711712 loader_w.load_unsafe ();
@@ -717,7 +718,7 @@ METAL_FUNC void fp_qmm_t_impl(
717718 }
718719 } else {
719720 if (!aligned_N && num_outs < BN) {
720- for (int k = 0 ; k < K ; k += BK) {
721+ for (int k = 0 ; k < K_eff ; k += BK) {
721722 threadgroup_barrier (mem_flags::mem_threadgroup);
722723 loader_x.load_unsafe ();
723724 loader_w.load_safe (short2 (BK, num_outs));
@@ -727,7 +728,7 @@ METAL_FUNC void fp_qmm_t_impl(
727728 loader_w.next ();
728729 }
729730 } else {
730- for (int k = 0 ; k < K ; k += BK) {
731+ for (int k = 0 ; k < K_eff ; k += BK) {
731732 threadgroup_barrier (mem_flags::mem_threadgroup);
732733 loader_x.load_unsafe ();
733734 loader_w.load_unsafe ();
@@ -1219,7 +1220,7 @@ template <
12191220 tid);
12201221 }
12211222 fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1222- w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1223+ w, scales, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid);
12231224}
12241225
12251226template <
@@ -1486,7 +1487,61 @@ template <
14861487 s_strides,
14871488 tid);
14881489 fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1489- w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1490+ w, scales, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid);
1491+ }
1492+
1493+ template <
1494+ typename T,
1495+ const int group_size,
1496+ const int bits,
1497+ const bool aligned_N,
1498+ const int BM = 32 ,
1499+ const int BK = 32 ,
1500+ const int BN = 32 >
1501+ [[kernel]] void fp_qmm_t_splitk (
1502+ const device uint32_t * w [[buffer(0 )]],
1503+ const device uint8_t* scales [[buffer(1 )]],
1504+ const device T* x [[buffer(2 )]],
1505+ device T* y [[buffer(3 )]],
1506+ const constant int& K [[buffer(4 )]],
1507+ const constant int& N [[buffer(5 )]],
1508+ const constant int& M [[buffer(6 )]],
1509+ const constant int& k_partition_size [[buffer(7 )]],
1510+ const constant int& split_k_partition_stride [[buffer(8 )]],
1511+ uint3 tid [[threadgroup_position_in_grid]],
1512+ uint lid [[thread_index_in_threadgroup]],
1513+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1514+ uint simd_lid [[thread_index_in_simdgroup]]) {
1515+ (void )lid;
1516+
1517+ constexpr int BK_padded = (BK + 16 / sizeof (T));
1518+ constexpr int pack_factor = get_pack_factor<8 , bits>();
1519+ constexpr int bytes_per_pack = get_bytes_per_pack ();
1520+ threadgroup T Xs[BM * BK_padded];
1521+ threadgroup T Ws[BN * BK_padded];
1522+ const int k_start = tid.z * k_partition_size;
1523+ x += k_start;
1524+
1525+ auto wl = (const device uint8_t *)w;
1526+ wl += k_start * bytes_per_pack / pack_factor;
1527+ scales += k_start / group_size;
1528+ y += tid.z * static_cast <int64_t >(split_k_partition_stride);
1529+
1530+ fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1531+ (const device uint32_t *)wl,
1532+ scales,
1533+ x,
1534+ y,
1535+ Xs,
1536+ Ws,
1537+ K,
1538+ N,
1539+ M,
1540+ k_partition_size,
1541+ tid,
1542+ lid,
1543+ simd_gid,
1544+ simd_lid);
14901545}
14911546
14921547template <
0 commit comments