@@ -85,24 +85,76 @@ struct GemmKernel224MXFP4SmallKGroup {
8585 return _mm512_inserti64x4 (_mm512_castsi256_si512 (q0), q1, 1 );
8686 }
8787
88+ struct ActivationBF16 {
89+ __m512bh a;
90+ #if !defined(__AVX512BF16__)
91+ __m512 a_even;
92+ __m512 a_odd;
93+ inline static const __m512i odd_mask = _mm512_set1_epi32(0xFFFF0000 );
94+ #endif
95+
96+ __attribute__ ((always_inline)) ActivationBF16(__m512bh a_) : a(a_) {
97+ #if !defined(__AVX512BF16__)
98+ a_even = _mm512_castsi512_ps (_mm512_slli_epi32 ((__m512i)a_, 16 ));
99+ a_odd = _mm512_castsi512_ps (_mm512_and_si512 ((__m512i)a_, odd_mask));
100+ #endif
101+ }
102+ };
103+
104+ struct DequantizedWeight {
105+ #if defined(__AVX512BF16__)
106+ __m512bh d;
107+ #else
108+ __m512 w_even;
109+ __m512 w_odd;
110+ inline static const __m128i lo_mask = _mm_set1_epi8(0x0F );
111+ inline static const __m512 lut = _mm512_setr_ps(0 .0f , 0 .5f , 1 .0f , 1 .5f , 2 .0f , 3 .0f , 4 .0f , 6 .0f , -0 .0f , -0 .5f , -1 .0f ,
112+ -1 .5f , -2 .0f , -3 .0f , -4 .0f , -6 .0f );
113+ #endif
114+
115+ __attribute__ ((always_inline)) DequantizedWeight(__m128i w) {
116+ #if defined(__AVX512BF16__)
117+ d = (__m512bh)mxfp4_to_bf16_32 (w);
118+ #else
119+ __m128i lo = _mm_and_si128 (w, lo_mask);
120+ __m128i hi = _mm_and_si128 (_mm_srli_epi16 (w, 4 ), lo_mask);
121+
122+ __m512i lo_32 = _mm512_cvtepu8_epi32 (lo);
123+ __m512i hi_32 = _mm512_cvtepu8_epi32 (hi);
124+
125+ w_even = _mm512_permutexvar_ps (lo_32, lut);
126+ w_odd = _mm512_permutexvar_ps (hi_32, lut);
127+ #endif
128+ }
129+ };
130+
131+ __attribute__ ((always_inline)) static inline __m512 mxfp4_dot_bf16 (const DequantizedWeight& w,
132+ const ActivationBF16& act) {
133+ #if defined(__AVX512BF16__)
134+ return _mm512_dpbf16_ps (_mm512_setzero_ps (), act.a , w.d );
135+ #else
136+ __m512 dot = _mm512_mul_ps (act.a_odd , w.w_odd );
137+ return _mm512_fmadd_ps (act.a_even , w.w_even , dot);
138+ #endif
139+ }
140+
88141 // Buffers
89142 using BufferA = BufferABF16Impl<GemmKernel224MXFP4SmallKGroup>; // raw BF16, no quant
90143 using BufferB = BufferBInt4KGroupImpl<GemmKernel224MXFP4SmallKGroup>; // nibble-packed FP4
91144 using BufferC = BufferCReduceImpl<GemmKernel224MXFP4SmallKGroup>; // FP32 reduce
92145
93146 // 4 个 zmm 的 horizontal reduce → 4 个连续 fp32。
94147 // 4 次 reduce_add_ps 之间无依赖,编译器/CPU 可并行调度。
95- __attribute__ ((always_inline)) static inline void
96- reduce4 (__m512 s0, __m512 s1, __m512 s2, __m512 s3, float * dst) {
148+ __attribute__ ((always_inline)) static inline void reduce4 (__m512 s0, __m512 s1, __m512 s2, __m512 s3, float * dst) {
97149 dst[0 ] = _mm512_reduce_add_ps (s0);
98150 dst[1 ] = _mm512_reduce_add_ps (s1);
99151 dst[2 ] = _mm512_reduce_add_ps (s2);
100152 dst[3 ] = _mm512_reduce_add_ps (s3);
101153 }
102154
103155 // mat-vec: M 个独立 token,N 维 4 行一组累加,摊销 horizontal reduce。
104- static void fp4_mat_vec_kgroup (int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc,
105- int ith, int nth) {
156+ static void fp4_mat_vec_kgroup (int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc, int ith,
157+ int nth) {
106158 auto [n_start, n_end] = split_range_n (n, ith, nth);
107159 if (n_start >= n_end) return ;
108160 const int kg_count = k / 32 ;
@@ -129,19 +181,15 @@ struct GemmKernel224MXFP4SmallKGroup {
129181 __m512 acc3 = _mm512_setzero_ps ();
130182
131183 for (int g = 0 ; g < kg_count; g++) {
132- const __m512bh a = a_row[g];
133- const __m512bh d0 = (__m512bh)mxfp4_to_bf16_32 (w0[g]);
134- const __m512bh d1 = (__m512bh)mxfp4_to_bf16_32 (w1[g]);
135- const __m512bh d2 = (__m512bh)mxfp4_to_bf16_32 (w2[g]);
136- const __m512bh d3 = (__m512bh)mxfp4_to_bf16_32 (w3[g]);
137- acc0 = _mm512_fmadd_ps (_mm512_set1_ps (s0[g]),
138- _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d0), acc0);
139- acc1 = _mm512_fmadd_ps (_mm512_set1_ps (s1[g]),
140- _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d1), acc1);
141- acc2 = _mm512_fmadd_ps (_mm512_set1_ps (s2[g]),
142- _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d2), acc2);
143- acc3 = _mm512_fmadd_ps (_mm512_set1_ps (s3[g]),
144- _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d3), acc3);
184+ const ActivationBF16 a (a_row[g]);
185+ const DequantizedWeight d0 (w0[g]);
186+ const DequantizedWeight d1 (w1[g]);
187+ const DequantizedWeight d2 (w2[g]);
188+ const DequantizedWeight d3 (w3[g]);
189+ acc0 = _mm512_fmadd_ps (_mm512_set1_ps (s0[g]), mxfp4_dot_bf16 (d0, a), acc0);
190+ acc1 = _mm512_fmadd_ps (_mm512_set1_ps (s1[g]), mxfp4_dot_bf16 (d1, a), acc1);
191+ acc2 = _mm512_fmadd_ps (_mm512_set1_ps (s2[g]), mxfp4_dot_bf16 (d2, a), acc2);
192+ acc3 = _mm512_fmadd_ps (_mm512_set1_ps (s3[g]), mxfp4_dot_bf16 (d3, a), acc3);
145193 }
146194 reduce4 (acc0, acc1, acc2, acc3, c_row + (n_pos - n_start));
147195 }
@@ -151,10 +199,9 @@ struct GemmKernel224MXFP4SmallKGroup {
151199 const float * s = bb->get_scale (n, n_pos, k, 0 );
152200 __m512 acc = _mm512_setzero_ps ();
153201 for (int g = 0 ; g < kg_count; g++) {
154- const __m512bh a = a_row[g];
155- const __m512bh d = (__m512bh)mxfp4_to_bf16_32 (w[g]);
156- acc = _mm512_fmadd_ps (_mm512_set1_ps (s[g]),
157- _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d), acc);
202+ const ActivationBF16 a (a_row[g]);
203+ const DequantizedWeight d (w[g]);
204+ acc = _mm512_fmadd_ps (_mm512_set1_ps (s[g]), mxfp4_dot_bf16 (d, a), acc);
158205 }
159206 c_row[n_pos - n_start] = _mm512_reduce_add_ps (acc);
160207 }
@@ -164,8 +211,8 @@ struct GemmKernel224MXFP4SmallKGroup {
164211 // mat-mat: 4×4 register tile (M_TILE=4, N_TILE=4 → 16 累加器)。
165212 // 每 K-group 解码 4 行 N 一次, 被 4 个 token 共享 → PSHUFB 解码开销 / 4。
166213 // M / N 尾巴回退到 mat-vec 单 token 内层 (V4 chunked-prefill 16/32/64 整数倍, 极少触发)。
167- static void fp4_mat_mat_kgroup (int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc,
168- int ith, int nth) {
214+ static void fp4_mat_mat_kgroup (int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc, int ith,
215+ int nth) {
169216 auto [n_start, n_end] = split_range_n (n, ith, nth);
170217 if (n_start >= n_end) return ;
171218 const int kg_count = k / 32 ;
@@ -198,27 +245,28 @@ struct GemmKernel224MXFP4SmallKGroup {
198245
199246 for (int g = 0 ; g < kg_count; g++) {
200247 // 4 行权重解码一次, MB 个 token 共享
201- const __m512bh d0 = (__m512bh)mxfp4_to_bf16_32 (w0[g]);
202- const __m512bh d1 = (__m512bh)mxfp4_to_bf16_32 (w1[g]);
203- const __m512bh d2 = (__m512bh)mxfp4_to_bf16_32 (w2[g]);
204- const __m512bh d3 = (__m512bh)mxfp4_to_bf16_32 (w3[g]);
205- const __m512 sv0 = _mm512_set1_ps (s0[g]);
206- const __m512 sv1 = _mm512_set1_ps (s1[g]);
207- const __m512 sv2 = _mm512_set1_ps (s2[g]);
208- const __m512 sv3 = _mm512_set1_ps (s3[g]);
209-
210- #define V_FMA_ROW (M_I ) do { \
211- const __m512bh a = a_rows[M_I][g]; \
212- acc[M_I][0 ] = _mm512_fmadd_ps (sv0, _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d0), acc[M_I][0 ]); \
213- acc[M_I][1 ] = _mm512_fmadd_ps (sv1, _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d1), acc[M_I][1 ]); \
214- acc[M_I][2 ] = _mm512_fmadd_ps (sv2, _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d2), acc[M_I][2 ]); \
215- acc[M_I][3 ] = _mm512_fmadd_ps (sv3, _mm512_dpbf16_ps (_mm512_setzero_ps (), a, d3), acc[M_I][3 ]); \
216- } while (0 )
248+ const DequantizedWeight d0 (w0[g]);
249+ const DequantizedWeight d1 (w1[g]);
250+ const DequantizedWeight d2 (w2[g]);
251+ const DequantizedWeight d3 (w3[g]);
252+ const __m512 sv0 = _mm512_set1_ps (s0[g]);
253+ const __m512 sv1 = _mm512_set1_ps (s1[g]);
254+ const __m512 sv2 = _mm512_set1_ps (s2[g]);
255+ const __m512 sv3 = _mm512_set1_ps (s3[g]);
256+
257+ #define V_FMA_ROW (M_I ) \
258+ do { \
259+ const ActivationBF16 a (a_rows[M_I][g]); \
260+ acc[M_I][0 ] = _mm512_fmadd_ps (sv0, mxfp4_dot_bf16 (d0, a), acc[M_I][0 ]); \
261+ acc[M_I][1 ] = _mm512_fmadd_ps (sv1, mxfp4_dot_bf16 (d1, a), acc[M_I][1 ]); \
262+ acc[M_I][2 ] = _mm512_fmadd_ps (sv2, mxfp4_dot_bf16 (d2, a), acc[M_I][2 ]); \
263+ acc[M_I][3 ] = _mm512_fmadd_ps (sv3, mxfp4_dot_bf16 (d3, a), acc[M_I][3 ]); \
264+ } while (0 )
217265 V_FMA_ROW (0 );
218266 V_FMA_ROW (1 );
219267 V_FMA_ROW (2 );
220268 V_FMA_ROW (3 );
221- #undef V_FMA_ROW
269+ #undef V_FMA_ROW
222270 }
223271 for (int i = 0 ; i < MB; i++) {
224272 float * c_row = bc->get_submat (m, n, m_pos + i, n_start);
@@ -233,11 +281,9 @@ struct GemmKernel224MXFP4SmallKGroup {
233281 float * c_row = bc->get_submat (m, n, m_pos + i, n_start);
234282 __m512 acc = _mm512_setzero_ps ();
235283 for (int g = 0 ; g < kg_count; g++) {
236- acc = _mm512_fmadd_ps (_mm512_set1_ps (s[g]),
237- _mm512_dpbf16_ps (_mm512_setzero_ps (),
238- a_rows[i][g],
239- (__m512bh)mxfp4_to_bf16_32 (w[g])),
240- acc);
284+ const ActivationBF16 a (a_rows[i][g]);
285+ const DequantizedWeight d (w[g]);
286+ acc = _mm512_fmadd_ps (_mm512_set1_ps (s[g]), mxfp4_dot_bf16 (d, a), acc);
241287 }
242288 c_row[n_pos - n_start] = _mm512_reduce_add_ps (acc);
243289 }
@@ -257,18 +303,17 @@ struct GemmKernel224MXFP4SmallKGroup {
257303 const float * s1 = bb->get_scale (n, n_pos + 1 , k, 0 );
258304 const float * s2 = bb->get_scale (n, n_pos + 2 , k, 0 );
259305 const float * s3 = bb->get_scale (n, n_pos + 3 , k, 0 );
260- __m512 a0 = _mm512_setzero_ps (), a1 = _mm512_setzero_ps (),
261- a2 = _mm512_setzero_ps (), a3 = _mm512_setzero_ps ();
306+ __m512 a0 = _mm512_setzero_ps (), a1 = _mm512_setzero_ps (), a2 = _mm512_setzero_ps (), a3 = _mm512_setzero_ps ();
262307 for (int g = 0 ; g < kg_count; g++) {
263- const __m512bh a = a_row[g];
264- a0 = _mm512_fmadd_ps ( _mm512_set1_ps (s0 [g]),
265- _mm512_dpbf16_ps ( _mm512_setzero_ps (), a, (__m512bh) mxfp4_to_bf16_32 (w0 [g])), a0 );
266- a1 = _mm512_fmadd_ps ( _mm512_set1_ps (s1 [g]),
267- _mm512_dpbf16_ps ( _mm512_setzero_ps (), a, (__m512bh) mxfp4_to_bf16_32 (w1 [g])), a1 );
268- a2 = _mm512_fmadd_ps (_mm512_set1_ps (s2 [g]),
269- _mm512_dpbf16_ps ( _mm512_setzero_ps ( ), a, (__m512bh) mxfp4_to_bf16_32 (w2[g])), a2 );
270- a3 = _mm512_fmadd_ps (_mm512_set1_ps (s3 [g]),
271- _mm512_dpbf16_ps ( _mm512_setzero_ps ( ), a, (__m512bh) mxfp4_to_bf16_32 (w3[g]) ), a3);
308+ const ActivationBF16 a ( a_row[g]) ;
309+ const DequantizedWeight d0 (w0 [g]);
310+ const DequantizedWeight d1 (w1 [g]);
311+ const DequantizedWeight d2 (w2 [g]);
312+ const DequantizedWeight d3 (w3 [g]);
313+ a0 = _mm512_fmadd_ps (_mm512_set1_ps (s0 [g]), mxfp4_dot_bf16 (d0, a), a0);
314+ a1 = _mm512_fmadd_ps ( _mm512_set1_ps (s1[g] ), mxfp4_dot_bf16 (d1, a), a1 );
315+ a2 = _mm512_fmadd_ps (_mm512_set1_ps (s2 [g]), mxfp4_dot_bf16 (d2, a), a2);
316+ a3 = _mm512_fmadd_ps ( _mm512_set1_ps (s3[g] ), mxfp4_dot_bf16 (d3, a ), a3);
272317 }
273318 reduce4 (a0, a1, a2, a3, c_row + (n_pos - n_start));
274319 }
@@ -277,11 +322,9 @@ struct GemmKernel224MXFP4SmallKGroup {
277322 const float * s = bb->get_scale (n, n_pos, k, 0 );
278323 __m512 acc = _mm512_setzero_ps ();
279324 for (int g = 0 ; g < kg_count; g++) {
280- acc = _mm512_fmadd_ps (_mm512_set1_ps (s[g]),
281- _mm512_dpbf16_ps (_mm512_setzero_ps (),
282- a_row[g],
283- (__m512bh)mxfp4_to_bf16_32 (w[g])),
284- acc);
325+ const ActivationBF16 a (a_row[g]);
326+ const DequantizedWeight d (w[g]);
327+ acc = _mm512_fmadd_ps (_mm512_set1_ps (s[g]), mxfp4_dot_bf16 (d, a), acc);
285328 }
286329 c_row[n_pos - n_start] = _mm512_reduce_add_ps (acc);
287330 }
0 commit comments