Skip to content

Commit fc2b005

Browse files
authored
ggml-cuda: Repost of 21896: Blackwell native NVFP4 support (#22196)
1 parent 7b8443a commit fc2b005

8 files changed

Lines changed: 321 additions & 133 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,18 @@ static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
830830
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
831831
}
832832

833+
static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) {
834+
#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only
835+
if (!(x > 0.0f)) {
836+
return 0;
837+
}
838+
const __nv_fp8_e4m3 xf(x);
839+
return xf.__x;
840+
#else
841+
NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell
842+
#endif // defined(BLACKWELL_MMA_AVAILABLE)
843+
}
844+
833845
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
834846
const uint8_t sign_bit = (x < 0.0f) << 3;
835847
float ax = fabsf(x) * e;

ggml/src/ggml-cuda/mma.cuh

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma {
10151015
#endif // AMD_MFMA_AVAILABLE
10161016
}
10171017

1018-
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
1019-
const tile<16, 8, int> & A,
1020-
const tile<8, 8, int> & B,
1021-
uint32_t a_scale,
1022-
uint32_t b_scale) {
1018+
template <ggml_type type>
1019+
static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D,
1020+
const tile<16, 8, int> & A,
1021+
const tile<8, 8, int> & B,
1022+
uint32_t a_scale,
1023+
uint32_t b_scale) {
10231024
#ifdef BLACKWELL_MMA_AVAILABLE
10241025
const int * Axi = (const int *) A.x;
10251026
const int * Bxi = (const int *) B.x;
10261027
float * Dxi = (float *) D.x;
10271028

1028-
asm volatile(
1029-
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
1030-
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1031-
"%10, {0, 0}, %11, {0, 0};"
1032-
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
1033-
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
1029+
if constexpr (type == GGML_TYPE_MXFP4) {
1030+
asm volatile(
1031+
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
1032+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1033+
"%10, {0, 0}, %11, {0, 0};"
1034+
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
1035+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
1036+
} else {
1037+
asm volatile(
1038+
"mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
1039+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1040+
"%10, {0, 0}, %11, {0, 0};"
1041+
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
1042+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
1043+
}
10341044
#else
10351045
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
1036-
#endif // BLACKWELL_MMA_AVAILABLE
1046+
#endif // BLACKWELL_MMA_AVAILABLE
10371047
}
10381048

10391049
static __device__ __forceinline__ void mma(

ggml/src/ggml-cuda/mmq.cu

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ void ggml_cuda_mul_mat_q(
122122
|| GGML_CUDA_CC_IS_CDNA(cc);
123123

124124
// TODO: tighter pool buffer size vs q8 path
125-
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
125+
const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4);
126126

127127
if (!ids) {
128128
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -133,9 +133,9 @@ void ggml_cuda_mul_mat_q(
133133
const int64_t s11 = src1->nb[1] / ts_src1;
134134
const int64_t s12 = src1->nb[2] / ts_src1;
135135
const int64_t s13 = src1->nb[3] / ts_src1;
136-
if (use_native_mxfp4) {
136+
if (use_native_fp4) {
137137
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
138-
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
138+
quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
139139
ne11, ne12, ne13, stream);
140140

141141
} else {
@@ -146,10 +146,8 @@ void ggml_cuda_mul_mat_q(
146146
}
147147

148148
// Stride depends on quantization format
149-
const int64_t s12 = use_native_mxfp4 ?
150-
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
151-
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
152-
:
149+
const int64_t s12 = use_native_fp4 ?
150+
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values
153151
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
154152
const int64_t s13 = ne12*s12;
155153

@@ -198,8 +196,8 @@ void ggml_cuda_mul_mat_q(
198196
const int64_t s12 = src1->nb[2] / ts_src1;
199197
const int64_t s13 = src1->nb[3] / ts_src1;
200198

201-
if (use_native_mxfp4) {
202-
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
199+
if (use_native_fp4) {
200+
quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
203201
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
204202
} else {
205203
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
@@ -208,8 +206,9 @@ void ggml_cuda_mul_mat_q(
208206
CUDA_CHECK(cudaGetLastError());
209207
}
210208

211-
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
212-
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
209+
static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4");
210+
const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) :
211+
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
213212
const int64_t s13 = ne12*s12;
214213

215214
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.

0 commit comments

Comments
 (0)