Skip to content

Commit f077244

Browse files
authored
[perf]: native path for MXFP4 MoE on AVX512F (#2006)
* [perf]: native path for MXFP4 MoE on AVX512F * [perf]: move inline static constants outside structs
1 parent 95e20f9 commit f077244

1 file changed

Lines changed: 104 additions & 61 deletions

File tree

kt-kernel/operators/amx/fp4-moe.hpp

Lines changed: 104 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -85,24 +85,76 @@ struct GemmKernel224MXFP4SmallKGroup {
8585
return _mm512_inserti64x4(_mm512_castsi256_si512(q0), q1, 1);
8686
}
8787

88+
struct ActivationBF16 {
89+
__m512bh a;
90+
#if !defined(__AVX512BF16__)
91+
__m512 a_even;
92+
__m512 a_odd;
93+
inline static const __m512i odd_mask = _mm512_set1_epi32(0xFFFF0000);
94+
#endif
95+
96+
__attribute__((always_inline)) ActivationBF16(__m512bh a_) : a(a_) {
97+
#if !defined(__AVX512BF16__)
98+
a_even = _mm512_castsi512_ps(_mm512_slli_epi32((__m512i)a_, 16));
99+
a_odd = _mm512_castsi512_ps(_mm512_and_si512((__m512i)a_, odd_mask));
100+
#endif
101+
}
102+
};
103+
104+
struct DequantizedWeight {
105+
#if defined(__AVX512BF16__)
106+
__m512bh d;
107+
#else
108+
__m512 w_even;
109+
__m512 w_odd;
110+
inline static const __m128i lo_mask = _mm_set1_epi8(0x0F);
111+
inline static const __m512 lut = _mm512_setr_ps(0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, -0.0f, -0.5f, -1.0f,
112+
-1.5f, -2.0f, -3.0f, -4.0f, -6.0f);
113+
#endif
114+
115+
__attribute__((always_inline)) DequantizedWeight(__m128i w) {
116+
#if defined(__AVX512BF16__)
117+
d = (__m512bh)mxfp4_to_bf16_32(w);
118+
#else
119+
__m128i lo = _mm_and_si128(w, lo_mask);
120+
__m128i hi = _mm_and_si128(_mm_srli_epi16(w, 4), lo_mask);
121+
122+
__m512i lo_32 = _mm512_cvtepu8_epi32(lo);
123+
__m512i hi_32 = _mm512_cvtepu8_epi32(hi);
124+
125+
w_even = _mm512_permutexvar_ps(lo_32, lut);
126+
w_odd = _mm512_permutexvar_ps(hi_32, lut);
127+
#endif
128+
}
129+
};
130+
131+
__attribute__((always_inline)) static inline __m512 mxfp4_dot_bf16(const DequantizedWeight& w,
132+
const ActivationBF16& act) {
133+
#if defined(__AVX512BF16__)
134+
return _mm512_dpbf16_ps(_mm512_setzero_ps(), act.a, w.d);
135+
#else
136+
__m512 dot = _mm512_mul_ps(act.a_odd, w.w_odd);
137+
return _mm512_fmadd_ps(act.a_even, w.w_even, dot);
138+
#endif
139+
}
140+
88141
// Buffers
89142
using BufferA = BufferABF16Impl<GemmKernel224MXFP4SmallKGroup>; // raw BF16, no quant
90143
using BufferB = BufferBInt4KGroupImpl<GemmKernel224MXFP4SmallKGroup>; // nibble-packed FP4
91144
using BufferC = BufferCReduceImpl<GemmKernel224MXFP4SmallKGroup>; // FP32 reduce
92145

93146
// 4 个 zmm 的 horizontal reduce → 4 个连续 fp32。
94147
// 4 次 reduce_add_ps 之间无依赖,编译器/CPU 可并行调度。
95-
__attribute__((always_inline)) static inline void
96-
reduce4(__m512 s0, __m512 s1, __m512 s2, __m512 s3, float* dst) {
148+
__attribute__((always_inline)) static inline void reduce4(__m512 s0, __m512 s1, __m512 s2, __m512 s3, float* dst) {
97149
dst[0] = _mm512_reduce_add_ps(s0);
98150
dst[1] = _mm512_reduce_add_ps(s1);
99151
dst[2] = _mm512_reduce_add_ps(s2);
100152
dst[3] = _mm512_reduce_add_ps(s3);
101153
}
102154

103155
// mat-vec: M 个独立 token,N 维 4 行一组累加,摊销 horizontal reduce。
104-
static void fp4_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc,
105-
int ith, int nth) {
156+
static void fp4_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc, int ith,
157+
int nth) {
106158
auto [n_start, n_end] = split_range_n(n, ith, nth);
107159
if (n_start >= n_end) return;
108160
const int kg_count = k / 32;
@@ -129,19 +181,15 @@ struct GemmKernel224MXFP4SmallKGroup {
129181
__m512 acc3 = _mm512_setzero_ps();
130182

131183
for (int g = 0; g < kg_count; g++) {
132-
const __m512bh a = a_row[g];
133-
const __m512bh d0 = (__m512bh)mxfp4_to_bf16_32(w0[g]);
134-
const __m512bh d1 = (__m512bh)mxfp4_to_bf16_32(w1[g]);
135-
const __m512bh d2 = (__m512bh)mxfp4_to_bf16_32(w2[g]);
136-
const __m512bh d3 = (__m512bh)mxfp4_to_bf16_32(w3[g]);
137-
acc0 = _mm512_fmadd_ps(_mm512_set1_ps(s0[g]),
138-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, d0), acc0);
139-
acc1 = _mm512_fmadd_ps(_mm512_set1_ps(s1[g]),
140-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, d1), acc1);
141-
acc2 = _mm512_fmadd_ps(_mm512_set1_ps(s2[g]),
142-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, d2), acc2);
143-
acc3 = _mm512_fmadd_ps(_mm512_set1_ps(s3[g]),
144-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, d3), acc3);
184+
const ActivationBF16 a(a_row[g]);
185+
const DequantizedWeight d0(w0[g]);
186+
const DequantizedWeight d1(w1[g]);
187+
const DequantizedWeight d2(w2[g]);
188+
const DequantizedWeight d3(w3[g]);
189+
acc0 = _mm512_fmadd_ps(_mm512_set1_ps(s0[g]), mxfp4_dot_bf16(d0, a), acc0);
190+
acc1 = _mm512_fmadd_ps(_mm512_set1_ps(s1[g]), mxfp4_dot_bf16(d1, a), acc1);
191+
acc2 = _mm512_fmadd_ps(_mm512_set1_ps(s2[g]), mxfp4_dot_bf16(d2, a), acc2);
192+
acc3 = _mm512_fmadd_ps(_mm512_set1_ps(s3[g]), mxfp4_dot_bf16(d3, a), acc3);
145193
}
146194
reduce4(acc0, acc1, acc2, acc3, c_row + (n_pos - n_start));
147195
}
@@ -151,10 +199,9 @@ struct GemmKernel224MXFP4SmallKGroup {
151199
const float* s = bb->get_scale(n, n_pos, k, 0);
152200
__m512 acc = _mm512_setzero_ps();
153201
for (int g = 0; g < kg_count; g++) {
154-
const __m512bh a = a_row[g];
155-
const __m512bh d = (__m512bh)mxfp4_to_bf16_32(w[g]);
156-
acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]),
157-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, d), acc);
202+
const ActivationBF16 a(a_row[g]);
203+
const DequantizedWeight d(w[g]);
204+
acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]), mxfp4_dot_bf16(d, a), acc);
158205
}
159206
c_row[n_pos - n_start] = _mm512_reduce_add_ps(acc);
160207
}
@@ -164,8 +211,8 @@ struct GemmKernel224MXFP4SmallKGroup {
164211
// mat-mat: 4×4 register tile (M_TILE=4, N_TILE=4 → 16 累加器)。
165212
// 每 K-group 解码 4 行 N 一次, 被 4 个 token 共享 → PSHUFB 解码开销 / 4。
166213
// M / N 尾巴回退到 mat-vec 单 token 内层 (V4 chunked-prefill 16/32/64 整数倍, 极少触发)。
167-
static void fp4_mat_mat_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc,
168-
int ith, int nth) {
214+
static void fp4_mat_mat_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc, int ith,
215+
int nth) {
169216
auto [n_start, n_end] = split_range_n(n, ith, nth);
170217
if (n_start >= n_end) return;
171218
const int kg_count = k / 32;
@@ -198,27 +245,28 @@ struct GemmKernel224MXFP4SmallKGroup {
198245

199246
for (int g = 0; g < kg_count; g++) {
200247
// 4 行权重解码一次, MB 个 token 共享
201-
const __m512bh d0 = (__m512bh)mxfp4_to_bf16_32(w0[g]);
202-
const __m512bh d1 = (__m512bh)mxfp4_to_bf16_32(w1[g]);
203-
const __m512bh d2 = (__m512bh)mxfp4_to_bf16_32(w2[g]);
204-
const __m512bh d3 = (__m512bh)mxfp4_to_bf16_32(w3[g]);
205-
const __m512 sv0 = _mm512_set1_ps(s0[g]);
206-
const __m512 sv1 = _mm512_set1_ps(s1[g]);
207-
const __m512 sv2 = _mm512_set1_ps(s2[g]);
208-
const __m512 sv3 = _mm512_set1_ps(s3[g]);
209-
210-
#define V_FMA_ROW(M_I) do { \
211-
const __m512bh a = a_rows[M_I][g]; \
212-
acc[M_I][0] = _mm512_fmadd_ps(sv0, _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d0), acc[M_I][0]); \
213-
acc[M_I][1] = _mm512_fmadd_ps(sv1, _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d1), acc[M_I][1]); \
214-
acc[M_I][2] = _mm512_fmadd_ps(sv2, _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d2), acc[M_I][2]); \
215-
acc[M_I][3] = _mm512_fmadd_ps(sv3, _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d3), acc[M_I][3]); \
216-
} while (0)
248+
const DequantizedWeight d0(w0[g]);
249+
const DequantizedWeight d1(w1[g]);
250+
const DequantizedWeight d2(w2[g]);
251+
const DequantizedWeight d3(w3[g]);
252+
const __m512 sv0 = _mm512_set1_ps(s0[g]);
253+
const __m512 sv1 = _mm512_set1_ps(s1[g]);
254+
const __m512 sv2 = _mm512_set1_ps(s2[g]);
255+
const __m512 sv3 = _mm512_set1_ps(s3[g]);
256+
257+
#define V_FMA_ROW(M_I) \
258+
do { \
259+
const ActivationBF16 a(a_rows[M_I][g]); \
260+
acc[M_I][0] = _mm512_fmadd_ps(sv0, mxfp4_dot_bf16(d0, a), acc[M_I][0]); \
261+
acc[M_I][1] = _mm512_fmadd_ps(sv1, mxfp4_dot_bf16(d1, a), acc[M_I][1]); \
262+
acc[M_I][2] = _mm512_fmadd_ps(sv2, mxfp4_dot_bf16(d2, a), acc[M_I][2]); \
263+
acc[M_I][3] = _mm512_fmadd_ps(sv3, mxfp4_dot_bf16(d3, a), acc[M_I][3]); \
264+
} while (0)
217265
V_FMA_ROW(0);
218266
V_FMA_ROW(1);
219267
V_FMA_ROW(2);
220268
V_FMA_ROW(3);
221-
#undef V_FMA_ROW
269+
#undef V_FMA_ROW
222270
}
223271
for (int i = 0; i < MB; i++) {
224272
float* c_row = bc->get_submat(m, n, m_pos + i, n_start);
@@ -233,11 +281,9 @@ struct GemmKernel224MXFP4SmallKGroup {
233281
float* c_row = bc->get_submat(m, n, m_pos + i, n_start);
234282
__m512 acc = _mm512_setzero_ps();
235283
for (int g = 0; g < kg_count; g++) {
236-
acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]),
237-
_mm512_dpbf16_ps(_mm512_setzero_ps(),
238-
a_rows[i][g],
239-
(__m512bh)mxfp4_to_bf16_32(w[g])),
240-
acc);
284+
const ActivationBF16 a(a_rows[i][g]);
285+
const DequantizedWeight d(w[g]);
286+
acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]), mxfp4_dot_bf16(d, a), acc);
241287
}
242288
c_row[n_pos - n_start] = _mm512_reduce_add_ps(acc);
243289
}
@@ -257,18 +303,17 @@ struct GemmKernel224MXFP4SmallKGroup {
257303
const float* s1 = bb->get_scale(n, n_pos + 1, k, 0);
258304
const float* s2 = bb->get_scale(n, n_pos + 2, k, 0);
259305
const float* s3 = bb->get_scale(n, n_pos + 3, k, 0);
260-
__m512 a0 = _mm512_setzero_ps(), a1 = _mm512_setzero_ps(),
261-
a2 = _mm512_setzero_ps(), a3 = _mm512_setzero_ps();
306+
__m512 a0 = _mm512_setzero_ps(), a1 = _mm512_setzero_ps(), a2 = _mm512_setzero_ps(), a3 = _mm512_setzero_ps();
262307
for (int g = 0; g < kg_count; g++) {
263-
const __m512bh a = a_row[g];
264-
a0 = _mm512_fmadd_ps(_mm512_set1_ps(s0[g]),
265-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, (__m512bh)mxfp4_to_bf16_32(w0[g])), a0);
266-
a1 = _mm512_fmadd_ps(_mm512_set1_ps(s1[g]),
267-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, (__m512bh)mxfp4_to_bf16_32(w1[g])), a1);
268-
a2 = _mm512_fmadd_ps(_mm512_set1_ps(s2[g]),
269-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, (__m512bh)mxfp4_to_bf16_32(w2[g])), a2);
270-
a3 = _mm512_fmadd_ps(_mm512_set1_ps(s3[g]),
271-
_mm512_dpbf16_ps(_mm512_setzero_ps(), a, (__m512bh)mxfp4_to_bf16_32(w3[g])), a3);
308+
const ActivationBF16 a(a_row[g]);
309+
const DequantizedWeight d0(w0[g]);
310+
const DequantizedWeight d1(w1[g]);
311+
const DequantizedWeight d2(w2[g]);
312+
const DequantizedWeight d3(w3[g]);
313+
a0 = _mm512_fmadd_ps(_mm512_set1_ps(s0[g]), mxfp4_dot_bf16(d0, a), a0);
314+
a1 = _mm512_fmadd_ps(_mm512_set1_ps(s1[g]), mxfp4_dot_bf16(d1, a), a1);
315+
a2 = _mm512_fmadd_ps(_mm512_set1_ps(s2[g]), mxfp4_dot_bf16(d2, a), a2);
316+
a3 = _mm512_fmadd_ps(_mm512_set1_ps(s3[g]), mxfp4_dot_bf16(d3, a), a3);
272317
}
273318
reduce4(a0, a1, a2, a3, c_row + (n_pos - n_start));
274319
}
@@ -277,11 +322,9 @@ struct GemmKernel224MXFP4SmallKGroup {
277322
const float* s = bb->get_scale(n, n_pos, k, 0);
278323
__m512 acc = _mm512_setzero_ps();
279324
for (int g = 0; g < kg_count; g++) {
280-
acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]),
281-
_mm512_dpbf16_ps(_mm512_setzero_ps(),
282-
a_row[g],
283-
(__m512bh)mxfp4_to_bf16_32(w[g])),
284-
acc);
325+
const ActivationBF16 a(a_row[g]);
326+
const DequantizedWeight d(w[g]);
327+
acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]), mxfp4_dot_bf16(d, a), acc);
285328
}
286329
c_row[n_pos - n_start] = _mm512_reduce_add_ps(acc);
287330
}

0 commit comments

Comments
 (0)