Skip to content

Commit bca0c0b

Browse files
committed
attempt to support dp4a
1 parent 84ab75f commit bca0c0b

File tree

3 files changed

+16
-21
lines changed

3 files changed

+16
-21
lines changed

ggml/src/ggml-cuda/dequantize.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const in
44
const block_q1_0 * x = (const block_q1_0 *) vx;
55

66
const float d = x[ib].d;
7-
const float neg_d = -d;
87

98
const int bit_index_0 = iqs;
109
const int bit_index_1 = iqs + 1;

ggml/src/ggml-cuda/mmq.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,6 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
305305
return false;
306306
}
307307

308-
// Q1_0 requires MMA — no DP4A fallback path
309-
if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc) && !amd_mfma_available(cc) && !amd_wmma_available(cc)) {
310-
return false;
311-
}
312-
313308
if (turing_mma_available(cc)) {
314309
return true;
315310
}

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() {
187187

188188
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
189189
switch (type) {
190+
case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
190191
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
191192
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
192193
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
@@ -307,15 +308,17 @@ static constexpr __device__ int mmq_get_nwarps_device() {
307308

308309
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
309310
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
310-
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
311-
GGML_UNUSED_VARS(x, x_tile, kbx0, i_max, stride, mmq_y, need_check);
312-
NO_DEVICE_CODE;
313-
#else
314311
constexpr int nwarps = mmq_get_nwarps_device();
315312
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
316313

314+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
317315
int * x_qs = (int *) x_tile;
318316
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
317+
#else
318+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
319+
int * x_qs = (int *) x_tile;
320+
float * x_df = (float *) (x_qs + txs.qs);
321+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
319322

320323
constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
321324
constexpr int threads_per_row = blocks_per_iter * QI1_0;
@@ -355,7 +358,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
355358
const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
356359
#pragma unroll
357360
for (int j = 0; j < 8; ++j) {
361+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
358362
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
363+
#else
364+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
365+
#endif
359366
}
360367
}
361368

@@ -372,18 +379,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
372379

373380
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
374381

382+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
375383
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
384+
#else
385+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
386+
#endif
376387
}
377-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
378-
}
379-
380-
template <int mmq_x, int mmq_y>
381-
static __device__ __forceinline__ void vec_dot_q1_mmq_dp4a_disabled(
382-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
383-
// Q1_0 intentionally targets the MMA path only.
384-
// If DP4A support is needed later for older GPUs, it should be reintroduced and validated separately.
385-
GGML_UNUSED_VARS(x, y, sum, k00, mmq_x, mmq_y);
386-
NO_DEVICE_CODE;
387388
}
388389

389390
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
@@ -3363,7 +3364,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
33633364
static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
33643365
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
33653366
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3366-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q1_mmq_dp4a_disabled<mmq_x, mmq_y>;
3367+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
33673368
};
33683369

33693370
template <int mmq_x, int mmq_y, bool need_check>

0 commit comments

Comments
 (0)