Skip to content

Commit 6a6780a

Browse files
authored
vulkan: Support GGML_TYPE_NVFP4 (#21455)
This adds nvfp4 support for get_rows, dequant, and mul_mat(_id). For mul_mat, it does not add support for the dp4/q8_1 path, it's all via fp16/fp32.
1 parent e489a5c commit 6a6780a

8 files changed

Lines changed: 171 additions & 3 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 28 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "generic_unary_head.glsl"
55
#include "dequant_funcs.glsl"
66

7-
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
7+
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4)
88
// 16 invocations needed for init_iq_shmem
99
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
1010
#else

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
450450
}
451451
#endif
452452

453+
#if defined(DATA_A_NVFP4)
454+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
455+
const uint sub = iqs >> 4;
456+
const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]);
457+
const uint j = iqs & 7;
458+
const uint shift = (iqs & 8) >> 1; // 0 or 4
459+
const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]);
460+
const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]);
461+
const uint qs0 = (vui0 >> shift) & 0xF;
462+
const uint qs1 = (vui1 >> shift) & 0xF;
463+
return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5;
464+
}
465+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
466+
const vec2 v0 = dequantize(ib, iqs, a_offset);
467+
const vec2 v1 = dequantize(ib, iqs + 2u, a_offset);
468+
return vec4(v0.x, v0.y, v1.x, v1.y);
469+
}
470+
#endif
471+
453472
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
454473
vec2 get_dm(uint ib, uint a_offset) {
455474
return vec2(0, 0);
@@ -484,6 +503,12 @@ vec2 get_dm(uint ib, uint a_offset) {
484503
}
485504
#endif
486505

506+
#if defined(DATA_A_NVFP4)
507+
vec2 get_dm(uint ib, uint a_offset) {
508+
return vec2(1.0, 0.0);
509+
}
510+
#endif
511+
487512
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
488513
vec2 get_dm(uint ib, uint a_offset) {
489514
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,24 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
697697
}
698698
#endif
699699

700+
#if defined(DATA_A_NVFP4)
701+
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 {
702+
block_nvfp4 block;
703+
};
704+
705+
float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
706+
{
707+
const uint idx = coordInBlock[1];
708+
const uint sub = (idx & 0x30) >> 4;
709+
const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7);
710+
const uint shift = (idx & 0x8) >> 1;
711+
const float d = ue4m3_to_fp32(bl.block.d[sub]);
712+
uint qs = uint(bl.block.qs[iqs]);
713+
qs = (qs >> shift) & 0xF;
714+
return float16_t(kvalues_mxfp4[qs] * d * 0.5);
715+
}
716+
#endif
717+
700718
#if defined(DATA_A_Q1_0)
701719
#define dequantFuncA dequantFuncQ1_0
702720
#elif defined(DATA_A_Q4_0)
@@ -743,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
743761
#define dequantFuncA dequantFuncIQ4_NL
744762
#elif defined(DATA_A_MXFP4)
745763
#define dequantFuncA dequantFuncMXFP4
764+
#elif defined(DATA_A_NVFP4)
765+
#define dequantFuncA dequantFuncNVFP4
746766
#elif defined(DATA_A_F32)
747767
#define dequantFuncA dequantFuncF32
748768
#endif
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#version 450
2+
3+
#include "dequant_head.glsl"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
12+
13+
init_iq_shmem(gl_WorkGroupSize);
14+
15+
const uint tid = gl_LocalInvocationID.x % 64;
16+
const uint sub = tid / 16;
17+
const uint ir = tid % 16;
18+
const uint ib = 16 * i + ir;
19+
if (ib >= p.nel / 64) {
20+
return;
21+
}
22+
23+
const uint q_idx = 8 * sub;
24+
const uint b_idx = 1024 * i + 64 * ir + 16 * sub;
25+
26+
const float d = ue4m3_to_fp32(data_a[ib].d[sub]);
27+
28+
[[unroll]] for (uint l = 0; l < 8; ++l) {
29+
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
30+
data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
31+
}
32+
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
501501
kvalues_mxfp4[vui2 & 0xF] * d);
502502
buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
503503
kvalues_mxfp4[vui2 >> 4] * d);
504+
#elif defined(DATA_A_NVFP4)
505+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
506+
// lo and hi nibbles are 8 elements apart, which doesn't quite line up with
507+
// how the thread mapping and buf_idx calculation works for other types.
508+
const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2;
509+
510+
const uint ib = idx / 16u;
511+
const uint sub = (idx & 0xC) >> 2;
512+
const uint iqs = (idx & 0xF) * 2;
513+
const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5;
514+
const uint vui = uint(data_a[ib].qs[iqs]);
515+
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
516+
517+
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
518+
kvalues_mxfp4[vui2 & 0xF] * d);
519+
buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
520+
kvalues_mxfp4[vui2 >> 4] * d);
504521
#endif
505522
}
506523

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,22 @@ struct block_mxfp4
17131713
#define A_TYPE block_mxfp4
17141714
#endif
17151715

1716+
#define QUANT_K_NVFP4 64
1717+
#define QUANT_R_NVFP4 1
1718+
1719+
struct block_nvfp4
1720+
{
1721+
uint8_t d[QUANT_K_NVFP4 / 16];
1722+
uint8_t qs[QUANT_K_NVFP4 / 2];
1723+
};
1724+
1725+
#if defined(DATA_A_NVFP4)
1726+
#define QUANT_K QUANT_K_NVFP4
1727+
#define QUANT_R QUANT_R_NVFP4
1728+
#define QUANT_AUXF 1
1729+
#define A_TYPE block_nvfp4
1730+
#endif
1731+
17161732
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
17171733
const int8_t kvalues_iq4nl_const[16] = {
17181734
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
@@ -1732,21 +1748,44 @@ void init_iq_shmem(uvec3 wgsize)
17321748
}
17331749
#endif
17341750

1735-
#if defined(DATA_A_MXFP4)
1751+
#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4)
17361752
const int8_t kvalues_mxfp4_const[16] = {
17371753
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
17381754
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
17391755
};
17401756

17411757
shared int8_t kvalues_mxfp4[16];
17421758

1759+
#if defined(DATA_A_NVFP4)
1760+
// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero.
1761+
shared float ue4m3_fp32_lut[128];
1762+
1763+
float ue4m3_to_fp32_build(uint u) {
1764+
if (u == 0u || u == 127u) {
1765+
return 0.0;
1766+
}
1767+
const uint exp = (u >> 3) & 15u;
1768+
const uint man = u & 7u;
1769+
if (exp == 0u) {
1770+
return float(man) * (1.0 / 512.0);
1771+
}
1772+
const uint bits = (exp + 120u) << 23 | (man << 20);
1773+
return uintBitsToFloat(bits);
1774+
}
1775+
#endif
1776+
17431777
#define NEEDS_INIT_IQ_SHMEM
17441778
void init_iq_shmem(uvec3 wgsize)
17451779
{
17461780
// copy the table into shared memory and sync
17471781
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
17481782
kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
17491783
}
1784+
#if defined(DATA_A_NVFP4)
1785+
for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) {
1786+
ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i);
1787+
}
1788+
#endif
17501789
barrier();
17511790
}
17521791
#endif
@@ -1783,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) {
17831822
return uintBitsToFloat(bits);
17841823
}
17851824

1825+
#if defined(DATA_A_NVFP4)
1826+
float ue4m3_to_fp32(uint8_t x) {
1827+
return ue4m3_fp32_lut[uint(x)];
1828+
}
1829+
#endif
1830+
17861831
#if BDA
17871832

17881833
#extension GL_EXT_buffer_reference : enable

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ const std::vector<std::string> type_names = {
6666
"iq4_xs",
6767
"iq4_nl",
6868
"mxfp4",
69+
"nvfp4",
6970
"bf16",
7071
};
7172

@@ -556,7 +557,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
556557
std::string load_vec_quant = "2";
557558
if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
558559
load_vec_quant = "8";
559-
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
560+
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4"))
560561
load_vec_quant = "4";
561562

562563
if (tname == "bf16") {

0 commit comments

Comments
 (0)