Skip to content

Commit 6b949d1

Browse files
authored
sycl : support nvfp4 type in mul_mat (ggml-org#21227)
1 parent 84f82e8 commit 6b949d1

File tree

6 files changed

+232
-1
lines changed

6 files changed

+232
-1
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "ggml-impl.h"
2424
#include "ggml-sycl.h"
2525
#include "presets.hpp"
26+
#include "type.hpp"
2627
#include "sycl_hw.hpp"
2728

2829
namespace syclexp = sycl::ext::oneapi::experimental;
@@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) {
965966
return val;
966967
}
967968

969+
static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) {
970+
const uint32_t bits = x * (x != 0x7F && x != 0xFF);
971+
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
972+
return static_cast<float>(xf) / 2;
973+
}
974+
968975
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/convert.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
482482
});
483483
}
484484

485+
template <typename dst_t>
486+
static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
487+
GGML_ASSERT(k % QK_NVFP4 == 0);
488+
const int nb = k / QK_NVFP4;
489+
stream->parallel_for(
490+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
491+
[=](sycl::nd_item<3> item_ct1) {
492+
dequantize_block_nvfp4(vx, y, k);
493+
});
494+
}
495+
496+
485497
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
486498
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
487499
const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -641,13 +653,16 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
641653
return dequantize_row_iq4_nl_sycl;
642654
case GGML_TYPE_MXFP4:
643655
return dequantize_row_mxfp4_sycl;
656+
case GGML_TYPE_NVFP4:
657+
return dequantize_row_nvfp4_sycl;
644658
case GGML_TYPE_F32:
645659
return convert_unary_sycl<float>;
646660
#ifdef GGML_SYCL_HAS_BF16
647661
case GGML_TYPE_BF16:
648662
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
649663
#endif
650664
default:
665+
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
651666
return nullptr;
652667
}
653668
}
@@ -708,13 +723,16 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
708723
return dequantize_row_iq4_nl_sycl;
709724
case GGML_TYPE_MXFP4:
710725
return dequantize_row_mxfp4_sycl;
726+
case GGML_TYPE_NVFP4:
727+
return dequantize_row_nvfp4_sycl;
711728
case GGML_TYPE_F16:
712729
return convert_unary_sycl<sycl::half>;
713730
#ifdef GGML_SYCL_HAS_BF16
714731
case GGML_TYPE_BF16:
715732
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
716733
#endif
717734
default:
735+
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
718736
return nullptr;
719737
}
720738
}

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr
838838
}
839839
}
840840

841+
842+
template <typename dst_t>
843+
static void dequantize_block_nvfp4(
844+
const void * __restrict__ vx,
845+
dst_t * __restrict__ yy,
846+
const int64_t ne) {
847+
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
848+
const int64_t i = item_ct1.get_group(2);
849+
const int tid = item_ct1.get_local_id(2);
850+
851+
const int64_t base = i * QK_NVFP4;
852+
if (base >= ne) {
853+
return;
854+
}
855+
856+
const block_nvfp4 * x = (const block_nvfp4 *) vx;
857+
const block_nvfp4 & xb = x[i];
858+
859+
const int sub = tid / (QK_NVFP4_SUB / 2);
860+
const int j = tid % (QK_NVFP4_SUB / 2);
861+
862+
const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
863+
const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
864+
865+
const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
866+
const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
867+
868+
yy[y0] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
869+
yy[y1] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
870+
}
871+
872+
841873
#endif // GGML_SYCL_DEQUANTIZE_HPP

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
613613
}
614614
}
615615

616+
static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
617+
dpct::queue_ptr stream) {
618+
GGML_ASSERT(ncols % QK_NVFP4 == 0);
619+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
620+
const sycl::range<3> block_nums(1, 1, block_num_y);
621+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
622+
623+
{
624+
stream->submit([&](sycl::handler & cgh) {
625+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
626+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
627+
mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
628+
vx, vy, dst, ncols, nrows, item_ct1);
629+
});
630+
});
631+
}
632+
}
616633

617634
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
618635
float *dst, const int ncols,
@@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
11451162
case GGML_TYPE_MXFP4:
11461163
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
11471164
break;
1165+
case GGML_TYPE_NVFP4:
1166+
mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1167+
break;
11481168
default:
1149-
GGML_ABORT("fatal error");
1169+
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
11501170
}
11511171
}
11521172
GGML_UNUSED(src1);

ggml/src/ggml-sycl/type.hpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#pragma once
2+
3+
#include <sycl/sycl.hpp>
4+
#include <cstdint>
5+
#include <limits>
6+
7+
inline uint8_t float_to_e4m3(float f)
8+
{
9+
if (sycl::isnan(f)) {
10+
return 0x7F; // Canonical NaN (positive)
11+
}
12+
13+
uint32_t bits = sycl::bit_cast<uint32_t>(f);
14+
uint32_t sign = (bits >> 31) & 0x1u;
15+
uint32_t exp = (bits >> 23) & 0xFFu;
16+
uint32_t mant = bits & 0x7FFFFFu;
17+
18+
// Zero
19+
if (exp == 0 && mant == 0) {
20+
return static_cast<uint8_t>(sign << 7);
21+
}
22+
23+
// Extract biased exponent and mantissa for FP8
24+
int e = static_cast<int>(exp) - 127; // true exponent (IEEE bias 127)
25+
uint32_t m = mant;
26+
27+
// Handle very large values → NaN (NVIDIA behavior for E4M3)
28+
if (e > 7) { // max exponent for E4M3 is 7 (biased 14)
29+
return static_cast<uint8_t>((sign << 7) | 0x7F);
30+
}
31+
32+
// Handle subnormals and normal numbers
33+
if (e < -6) { // smallest normal exponent is -6
34+
// Subnormal in FP8: shift mantissa right
35+
int shift = -6 - e;
36+
m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position
37+
if (shift > 23) m = 0;
38+
} else {
39+
// Normal number: adjust exponent bias from 127 to 7
40+
int new_exp = e + 7;
41+
m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1)
42+
m |= (static_cast<uint32_t>(new_exp) << 3);
43+
}
44+
45+
// Round-to-nearest-even (simple guard + round bit)
46+
// For better accuracy you can add sticky bit, but this is sufficient for most use cases
47+
uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits
48+
if (round_bit) {
49+
m += 1;
50+
// Carry into exponent if mantissa overflows
51+
if ((m & 0x8u) != 0) {
52+
m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling
53+
// If exponent overflows after carry → NaN
54+
if ((m >> 3) > 14) {
55+
return static_cast<uint8_t>((sign << 7) | 0x7F);
56+
}
57+
}
58+
}
59+
60+
uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F));
61+
return result;
62+
}
63+
64+
inline float e4m3_to_float(uint8_t x)
65+
{
66+
if (x == 0) return 0.0f;
67+
68+
uint8_t sign = (x >> 7) & 0x1u;
69+
uint8_t exp = (x >> 3) & 0xFu;
70+
uint8_t mant = x & 0x7u;
71+
72+
// NaN (NVIDIA uses 0x7F / 0xFF as NaN)
73+
if (exp == 0xF && mant != 0) {
74+
return std::numeric_limits<float>::quiet_NaN();
75+
}
76+
if (exp == 0xF) { // 0x7F or 0xFF treated as NaN
77+
return std::numeric_limits<float>::quiet_NaN();
78+
}
79+
80+
float val;
81+
82+
if (exp == 0) {
83+
// Subnormal
84+
val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f);
85+
} else {
86+
// Normal: implicit leading 1 + bias 7
87+
val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f);
88+
}
89+
90+
return sign ? -val : val;
91+
}
92+
93+
// The actual type definition
94+
struct __nv_fp8_e4m3 {
95+
uint8_t raw;
96+
97+
__nv_fp8_e4m3() = default;
98+
99+
explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {}
100+
explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {}
101+
102+
operator float() const { return e4m3_to_float(raw); }
103+
operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); }
104+
105+
// Allow direct access for vector loads/stores
106+
operator uint8_t&() { return raw; }
107+
operator uint8_t() const { return raw; }
108+
};
109+
110+
using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>;
111+
using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>;
112+

ggml/src/ggml-sycl/vecdotq.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "dpct/helper.hpp"
1717
#include "ggml.h"
18+
#include "type.hpp"
1819
#include "quants.hpp"
1920

2021
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
@@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
3132
return x32;
3233
}
3334

35+
static __dpct_inline__ int get_int_b2(const void * x, const int & i32) {
36+
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
37+
38+
int x32 = x16[2*i32 + 0] << 0;
39+
x32 |= x16[2*i32 + 1] << 16;
40+
41+
return x32;
42+
}
43+
44+
static __dpct_inline__ int get_int_b4(const void * x, const int & i32) {
45+
return ((const int *) x)[i32]; // assume at least 4 byte alignment
46+
}
3447

3548
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
3649
const uint16_t* x16 =
@@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
755768
return d * sumi;
756769
}
757770

771+
#define VDR_NVFP4_Q8_1_MMVQ 4
772+
#define VDR_NVFP4_Q8_1_MMQ 8
773+
774+
static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq,
775+
const block_q8_1 * __restrict__ bq8_1,
776+
const int32_t & iqs) {
777+
const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq;
778+
float sum = 0.0f;
779+
#pragma unroll
780+
for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
781+
const int32_t iqs0 = iqs + 2*i;
782+
const int32_t iqs1 = iqs0 + 1;
783+
const int32_t is = iqs0 >> 1;
784+
const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
785+
const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
786+
const block_q8_1 * bq8 = bq8_1 + (is >> 1);
787+
const int32_t i8 = ((is & 1) << 2);
788+
789+
int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0);
790+
sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi);
791+
sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi);
792+
sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi);
793+
794+
const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0];
795+
sum += d * float(sumi);
796+
}
797+
798+
return sum;
799+
}
758800

759801
static __dpct_inline__ float
760802
vec_dot_q5_0_q8_1(const void *__restrict__ vbq,

0 commit comments

Comments
 (0)