@@ -57,6 +57,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
5757
5858static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout (const ggml_type type_x) {
5959 switch (type_x) {
60+ case GGML_TYPE_Q1_0:
61+ return MMQ_Q8_1_DS_LAYOUT_D4;
6062 case GGML_TYPE_Q4_0:
6163 case GGML_TYPE_Q4_1:
6264 return MMQ_Q8_1_DS_LAYOUT_DS4;
@@ -185,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() {
185187
186188static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
187189 switch (type) {
190+ case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
188191 case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
189192 case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
190193 case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
@@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
229232
230233static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
231234 switch (type) {
235+ case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
232236 case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
233237 case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
234238 case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
@@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() {
302306
303307// ------------------------------------------------------------
304308
309+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0 (
310+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
311+ constexpr int nwarps = mmq_get_nwarps_device ();
312+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
313+
314+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
315+ int * x_qs = (int *) x_tile;
316+ float * x_df = (float *) (x_qs + 2 *MMQ_TILE_NE_K);
317+ #else
318+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q8_0, mmq_y);
319+ int * x_qs = (int *) x_tile;
320+ float * x_df = (float *) (x_qs + txs.qs );
321+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
322+
323+ constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
324+ constexpr int threads_per_row = blocks_per_iter * QI1_0;
325+ constexpr int nrows = warp_size / threads_per_row;
326+ constexpr int scale_entries_per_block = QK1_0 / QK8_1;
327+ constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
328+
329+ const int txi = threadIdx .x % threads_per_row;
330+ const int kbx = txi / QI1_0;
331+ const int kqsx = txi % QI1_0;
332+
333+ #pragma unroll
334+ for (int i0 = 0 ; i0 < mmq_y; i0 += nrows*nwarps) {
335+ int i = i0 + threadIdx .y *nrows + threadIdx .x /threads_per_row;
336+
337+ if (need_check) {
338+ i = min (i, i_max);
339+ }
340+
341+ const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx;
342+ const int qs_offset = 4 *kqsx;
343+ const int qs0 = bxi->qs [qs_offset + 0 ] | (bxi->qs [qs_offset + 1 ] << 8 ) |
344+ (bxi->qs [qs_offset + 2 ] << 16 ) | (bxi->qs [qs_offset + 3 ] << 24 );
345+
346+ int unpacked_bytes[8 ];
347+ #pragma unroll
348+ for (int j = 0 ; j < 8 ; ++j) {
349+ const int shift = j * 4 ;
350+ const int bits4 = (qs0 >> shift) & 0x0F ;
351+ const int b0 = (bits4 & 0x01 ) ? 1 : -1 ;
352+ const int b1 = (bits4 & 0x02 ) ? 1 : -1 ;
353+ const int b2 = (bits4 & 0x04 ) ? 1 : -1 ;
354+ const int b3 = (bits4 & 0x08 ) ? 1 : -1 ;
355+ unpacked_bytes[j] = (b0 & 0xFF ) | ((b1 & 0xFF ) << 8 ) | ((b2 & 0xFF ) << 16 ) | ((b3 & 0xFF ) << 24 );
356+ }
357+
358+ const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
359+ #pragma unroll
360+ for (int j = 0 ; j < 8 ; ++j) {
361+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
362+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
363+ #else
364+ x_qs[i*(2 *MMQ_TILE_NE_K + 1 ) + dst_offset + j] = unpacked_bytes[j];
365+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
366+ }
367+ }
368+
369+ const int ksx = threadIdx .x % scale_entries_per_row;
370+ const int scale_block = ksx / scale_entries_per_block;
371+
372+ #pragma unroll
373+ for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps) {
374+ int i = i0 + threadIdx .y ;
375+
376+ if (need_check) {
377+ i = min (i, i_max);
378+ }
379+
380+ const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
381+
382+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
383+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d ;
384+ #else
385+ x_df[i*(2 *MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2 ) + ksx] = bxi->d ;
386+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
387+ }
388+ }
389+
305390template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0 (
306391 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
307392 constexpr int nwarps = mmq_get_nwarps_device ();
@@ -3290,6 +3375,14 @@ static __device__ __forceinline__ void mmq_write_back_mma(
32903375template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
32913376struct mmq_type_traits ;
32923377
3378+ template <int mmq_x, int mmq_y, bool need_check>
3379+ struct mmq_type_traits <mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
3380+ static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
3381+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
3382+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3383+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3384+ };
3385+
32933386template <int mmq_x, int mmq_y, bool need_check>
32943387struct mmq_type_traits <mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
32953388 static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
0 commit comments