Skip to content

Commit 5ab5778

Browse files
committed
CUDA: PoC for repacking mxfp4
1 parent b94050e commit 5ab5778

4 files changed

Lines changed: 233 additions & 12 deletions

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#include "ggml-cuda/im2col.cuh"
2828
#include "ggml-cuda/mmf.cuh"
2929
#include "ggml-cuda/mmq.cuh"
30+
#ifdef GGML_CUDA_MXFP4_REPACK
31+
#include "ggml-cuda/mxfp4-repack.cuh"
32+
#endif
3033
#include "ggml-cuda/mmvf.cuh"
3134
#include "ggml-cuda/mmvq.cuh"
3235
#include "ggml-cuda/norm.cuh"
@@ -655,11 +658,51 @@ static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer,
655658
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
656659
}
657660

661+
#ifdef GGML_CUDA_MXFP4_REPACK
662+
// In-place repack of a fully-uploaded MXFP4 tensor into the per-row SoA
663+
// layout expected by load_tiles_mxfp4_fp4_soa. Allocates a transient device
664+
// staging buffer, copies the current AoS bytes into it, launches the repack
665+
// kernel writing back over tensor->data, and frees. Synchronizes the provided
666+
// stream so staging is safe to free on return.
667+
static void ggml_cuda_mxfp4_repack_tensor_inplace(ggml_tensor * tensor, cudaStream_t stream) {
668+
const int64_t ne0 = tensor->ne[0];
669+
const int64_t nrow = ggml_nrows(tensor);
670+
const int B_src = (int) (ne0 / QK_MXFP4);
671+
constexpr int blocks_per_iter = MMQ_ITER_K_MXFP4_FP4 / QK_MXFP4; // 16
672+
const int B_dst = (B_src + blocks_per_iter - 1) / blocks_per_iter * blocks_per_iter;
673+
674+
const size_t src_bytes = (size_t) 17 * B_src * nrow;
675+
676+
void * staging = nullptr;
677+
CUDA_CHECK(cudaMallocAsync(&staging, src_bytes, stream));
678+
CUDA_CHECK(cudaMemcpyAsync(staging, tensor->data, src_bytes,
679+
cudaMemcpyDeviceToDevice, stream));
680+
ggml_cuda_mxfp4_repack_soa_launch(tensor->data, staging,
681+
(int) nrow, B_src, B_dst, stream);
682+
CUDA_CHECK(cudaGetLastError());
683+
CUDA_CHECK(cudaFreeAsync(staging, stream));
684+
}
685+
686+
static inline bool ggml_cuda_mxfp4_should_repack(const ggml_tensor * tensor,
687+
size_t offset, size_t size) {
688+
// Fires exactly once per tensor: when a write ends at ggml_nbytes(tensor)
689+
// the loader has finished uploading this tensor's data.
690+
return tensor->type == GGML_TYPE_MXFP4 &&
691+
ggml_n_dims(tensor) >= 2 &&
692+
offset + size == ggml_nbytes(tensor);
693+
}
694+
#endif // GGML_CUDA_MXFP4_REPACK
695+
658696
static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
659697
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
660698

661699
ggml_cuda_set_device(ctx->device);
662700
CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
701+
#ifdef GGML_CUDA_MXFP4_REPACK
702+
if (ggml_cuda_mxfp4_should_repack(tensor, offset, size)) {
703+
ggml_cuda_mxfp4_repack_tensor_inplace(tensor, cudaStreamPerThread);
704+
}
705+
#endif
663706
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
664707
}
665708

@@ -782,7 +825,17 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
782825
if (ggml_is_quantized(tensor->type)) {
783826
if (ne0 % MATRIX_ROW_PADDING != 0) {
784827
GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor));
785-
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
828+
const size_t pad_bytes_per_row = ggml_row_size(
829+
tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
830+
831+
// MXFP4 weights get repacked per-row into an SoA layout with each
832+
// row padded to MATRIX_ROW_PADDING elements, so we need padding
833+
// space for every row rather than only the tensor tail.
834+
if (tensor->type == GGML_TYPE_MXFP4 && ggml_n_dims(tensor) >= 2) {
835+
size += pad_bytes_per_row * ggml_nrows(tensor);
836+
} else {
837+
size += pad_bytes_per_row;
838+
}
786839
}
787840
}
788841

@@ -2959,6 +3012,11 @@ static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tens
29593012
GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
29603013

29613014
CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
3015+
#ifdef GGML_CUDA_MXFP4_REPACK
3016+
if (ggml_cuda_mxfp4_should_repack(tensor, offset, size)) {
3017+
ggml_cuda_mxfp4_repack_tensor_inplace(tensor, cuda_ctx->stream());
3018+
}
3019+
#endif
29623020
}
29633021

29643022
static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,73 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
934934
}
935935
}
936936

937+
// SoA variant of load_tiles_mxfp4_fp4. Source tensor must already be repacked
938+
// per-row into [qs_0..qs_{B_dst-1} | e_0..e_{B_dst-1}] with
939+
// B_dst = GGML_PAD(B_src, iter_k/QK_MXFP4). qs region is 16B aligned so the
940+
// per-thread 16B qs load is a single coalesced 128-bit transaction.
941+
//
942+
// kbx0 is the flat AoS block index at which this tile starts (it already
943+
// includes sample*stride_sample_x + channel*stride_channel_x + tile_row*stride).
944+
// stride is the source-side block count per row (B_src). We decompose kbx0
945+
// into (flat_row_base, kb0_in_row) and advance by local tile row i.
946+
template <int mmq_y, bool need_check>
947+
static __device__ __forceinline__ void load_tiles_mxfp4_fp4_soa(const char * __restrict__ x,
948+
int * __restrict__ x_tile,
949+
const int kbx0,
950+
const int i_max,
951+
const int stride) {
952+
constexpr int nwarps = mmq_get_nwarps_device();
953+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
954+
955+
int * x_qs = (int *) x_tile;
956+
uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
957+
958+
const int txi = threadIdx.x;
959+
960+
constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
961+
962+
constexpr int threads_per_row = iter_k / QK_MXFP4; // 16 on Blackwell
963+
constexpr int rows_per_warp = warp_size / threads_per_row;
964+
const int kbx = txi % threads_per_row;
965+
const int row_in_warp = txi / threads_per_row;
966+
967+
// Derive padded blocks-per-row. threads_per_row equals blocks_per_iter for
968+
// MXFP4_FP4, so rounding stride to this multiple matches how the repack
969+
// kernel pads.
970+
constexpr int blocks_per_iter = threads_per_row;
971+
const int B_src = stride;
972+
const int B_dst = (B_src + blocks_per_iter - 1) / blocks_per_iter * blocks_per_iter;
973+
const int row_bytes = 17 * B_dst;
974+
975+
const int flat_row_base = kbx0 / B_src;
976+
const int kb0_in_row = kbx0 - flat_row_base * B_src;
977+
const int kbx_in_row = kb0_in_row + kbx;
978+
979+
#pragma unroll
980+
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
981+
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
982+
983+
if constexpr (need_check) {
984+
i = min(i, i_max);
985+
}
986+
987+
const uint8_t * row_base = reinterpret_cast<const uint8_t *>(x)
988+
+ (size_t) (flat_row_base + i) * row_bytes;
989+
const uint8_t * qs_base = row_base;
990+
const uint8_t * sc_base = row_base + 16 * B_dst;
991+
992+
const int k0 = kbx * 4;
993+
const uint4 q = reinterpret_cast<const uint4 *>(qs_base)[kbx_in_row];
994+
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, &q, 16);
995+
996+
if (kbx % 2 == 0) {
997+
uint32_t e = sc_base[kbx_in_row];
998+
e |= ((uint32_t) sc_base[kbx_in_row + 1]) << 8;
999+
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
1000+
}
1001+
}
1002+
}
1003+
9371004

9381005
template <int mmq_y, bool need_check>
9391006
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
@@ -3427,7 +3494,11 @@ template <int mmq_x, int mmq_y, bool need_check>
34273494
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
34283495
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
34293496
#ifdef BLACKWELL_MMA_AVAILABLE
3497+
#ifdef GGML_CUDA_MXFP4_REPACK
3498+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4_soa<mmq_y, need_check>;
3499+
#else
34303500
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
3501+
#endif // GGML_CUDA_MXFP4_REPACK
34313502
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
34323503
#else
34333504
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@
55

66
#include <cstdint>
77

8+
#ifdef GGML_CUDA_MXFP4_REPACK
9+
// Device-side mirror of init_fastdiv_values. Runs once per kernel (with a
10+
// uniform divisor across the thread block), so the while-loop and 64-bit
11+
// divide are cheap and hoisted out of hot code by the compiler.
12+
static __device__ __forceinline__ uint3 init_fastdiv_values_device(uint32_t d) {
13+
uint32_t L = 0;
14+
while (L < 32 && (uint32_t{ 1 } << L) < d) {
15+
L++;
16+
}
17+
const uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
18+
return make_uint3(mp, L, d);
19+
}
20+
#endif
21+
822
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
923

1024
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
@@ -413,6 +427,14 @@ static __global__ void mul_mat_vec_q(
413427
const int blocks_per_row_x = ncols_x / qk;
414428
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
415429

430+
#ifdef GGML_CUDA_MXFP4_REPACK
431+
// MXFP4 SoA: fastdiv values for (kbx / B_src, kbx % B_src) computed
432+
// once per thread (uniform across the block) and held in registers.
433+
const uint3 mxfp4_bsrc_fd = (type == GGML_TYPE_MXFP4)
434+
? init_fastdiv_values_device((uint32_t) blocks_per_row_x)
435+
: make_uint3(0, 0, 0);
436+
#endif
437+
416438
const uint32_t channel_dst = blockIdx.y;
417439

418440
uint32_t channel_x;
@@ -490,12 +512,27 @@ static __global__ void mul_mat_vec_q(
490512
for (int j = 0; j < ncols_dst; ++j) {
491513
#pragma unroll
492514
for (int i = 0; i < rows_per_cuda_block; ++i) {
493-
tmp[j][i] += vec_dot_q_cuda(
494-
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
495-
if constexpr (has_fusion) {
496-
if (use_gate) {
497-
tmp_gate[j][i] += vec_dot_q_cuda(
498-
vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
515+
const int kbx_arg = kbx_offset + i*stride_row_x + kbx;
516+
#ifdef GGML_CUDA_MXFP4_REPACK
517+
if constexpr (type == GGML_TYPE_MXFP4) {
518+
tmp[j][i] += vec_dot_mxfp4_q8_1_soa(
519+
vx, &y[j*stride_col_y + kby], kbx_arg, kqs, mxfp4_bsrc_fd);
520+
if constexpr (has_fusion) {
521+
if (use_gate) {
522+
tmp_gate[j][i] += vec_dot_mxfp4_q8_1_soa(
523+
vgate, &y[j*stride_col_y + kby], kbx_arg, kqs, mxfp4_bsrc_fd);
524+
}
525+
}
526+
} else
527+
#endif
528+
{
529+
tmp[j][i] += vec_dot_q_cuda(
530+
vx, &y[j*stride_col_y + kby], kbx_arg, kqs);
531+
if constexpr (has_fusion) {
532+
if (use_gate) {
533+
tmp_gate[j][i] += vec_dot_q_cuda(
534+
vgate, &y[j*stride_col_y + kby], kbx_arg, kqs);
535+
}
499536
}
500537
}
501538
}
@@ -631,13 +668,27 @@ static __global__ void mul_mat_vec_q_moe(
631668
// partial sum for each thread
632669
float tmp[c_rows_per_block] = {0.0f};
633670

671+
#ifdef GGML_CUDA_MXFP4_REPACK
672+
const uint3 mxfp4_bsrc_fd = (type == GGML_TYPE_MXFP4)
673+
? init_fastdiv_values_device((uint32_t) blocks_per_row_x)
674+
: make_uint3(0, 0, 0);
675+
#endif
676+
634677
for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
635678
const int kby = kbx * (qk/QK8_1);
636679
const int kqs = vdr * (threadIdx.x % (qi/vdr));
637680

638681
#pragma unroll
639682
for (int i = 0; i < c_rows_per_block; ++i) {
640-
tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs);
683+
const int kbx_arg = kbx_offset + i*stride_row_x + kbx;
684+
#ifdef GGML_CUDA_MXFP4_REPACK
685+
if constexpr (type == GGML_TYPE_MXFP4) {
686+
tmp[i] += vec_dot_mxfp4_q8_1_soa(vx, &y[kby], kbx_arg, kqs, mxfp4_bsrc_fd);
687+
} else
688+
#endif
689+
{
690+
tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_arg, kqs);
691+
}
641692
}
642693
}
643694

@@ -924,12 +975,12 @@ static void mul_mat_vec_q_switch_type(
924975
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
925976
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
926977
break;
927-
case GGML_TYPE_MXFP4:
978+
case GGML_TYPE_MXFP4: {
928979
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
929980
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
930981
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
931982
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
932-
break;
983+
} break;
933984
case GGML_TYPE_NVFP4:
934985
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_NVFP4>
935986
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,23 +308,64 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
308308
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
309309

310310
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
311+
const uint8_t * qs_base = bq4->qs;
312+
const uint8_t e_byte = bq4->e;
311313

312314
const int * q8 = (const int *) bq8_1->qs + iqs;
313315

314316
int sumi = 0;
315317
#pragma unroll
316318
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
317-
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
319+
const int aux_q4 = get_int_b1(qs_base, iqs + l);
318320
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
319321

320322
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
321323
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
322324
}
323325

324-
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
326+
const float d = ggml_cuda_e8m0_to_fp32(e_byte) * 0.5f * __low2float(bq8_1->ds);
325327
return d * sumi;
326328
}
327329

330+
#ifdef GGML_CUDA_MXFP4_REPACK
331+
// SoA variant: tensor is repacked per-row into
332+
// [qs_0..qs_{B_dst-1} | e_0..e_{B_dst-1}]
333+
// with B_dst = GGML_PAD(B_src, 16). bsrc_fd carries fastdiv values for
334+
// B_src = ncols_x / QK_MXFP4 so the caller can supply (row_idx, kbx_in_row)
335+
// via mulhi + shift instead of a hardware divide. Caller computes bsrc_fd
336+
// once per kernel and passes it in.
337+
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1_soa(
338+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
339+
const int kbx, const int iqs, const uint3 bsrc_fd) {
340+
341+
const int B_src = (int) bsrc_fd.z;
342+
const int B_dst = (B_src + 15) & ~15;
343+
const uint2 dm = fast_div_modulo((uint32_t) kbx, bsrc_fd);
344+
const int row_idx = (int) dm.x;
345+
const int kbx_in_row = (int) dm.y;
346+
347+
const uint8_t * row_base = (const uint8_t *) vbq + (size_t) row_idx * 17 * B_dst;
348+
const uint8_t * qs_base = row_base + (size_t) kbx_in_row * 16;
349+
const uint8_t * sc_base = row_base + 16 * B_dst;
350+
const uint8_t e_byte = sc_base[kbx_in_row];
351+
352+
const int * q8 = (const int *) bq8_1->qs + iqs;
353+
354+
int sumi = 0;
355+
#pragma unroll
356+
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
357+
const int aux_q4 = get_int_b1(qs_base, iqs + l);
358+
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
359+
360+
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
361+
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
362+
}
363+
364+
const float d = ggml_cuda_e8m0_to_fp32(e_byte) * 0.5f * __low2float(bq8_1->ds);
365+
return d * sumi;
366+
}
367+
#endif // GGML_CUDA_MXFP4_REPACK
368+
328369
#define VDR_NVFP4_Q8_1_MMVQ 4
329370
#define VDR_NVFP4_Q8_1_MMQ 8
330371

0 commit comments

Comments
 (0)