33
44#include <stddef.h>
55#include <stdint.h>
6+ #include <stdlib.h>
7+ #include <math.h>
68
79#define Q4K_BLOCK_SIZE 256
810#define Q4K_SUB_BLOCK_SIZE 32
@@ -60,21 +62,48 @@ static inline void skainet_q4k_decode_scales(
6062 }
6163}
6264
65+ /*
66+ * Quantize one 256-float input block to symmetric int8 (Q8) with a single
67+ * per-block scale d_in = maxabs/127, q8[i] = round(in[i]/d_in). Returns d_in
68+ * (0 if the block is all-zero, with q8 zeroed). Mirrors ggml's block_q8_K
69+ * activation quantization — the source of the (small, well-understood) error
70+ * vs the exact float kernel, and what unlocks the int8 dot-product fast path.
71+ */
72+ static inline float skainet_q8_quantize_block (const float * SKAINET_RESTRICT in , int8_t * SKAINET_RESTRICT q8 ) {
73+ float maxabs = 0.0f ;
74+ for (int i = 0 ; i < Q4K_BLOCK_SIZE ; ++ i ) {
75+ const float a = fabsf (in [i ]);
76+ if (a > maxabs ) maxabs = a ;
77+ }
78+ if (maxabs == 0.0f ) {
79+ for (int i = 0 ; i < Q4K_BLOCK_SIZE ; ++ i ) q8 [i ] = 0 ;
80+ return 0.0f ;
81+ }
82+ const float d_in = maxabs / 127.0f ;
83+ const float inv = 127.0f / maxabs ;
84+ for (int i = 0 ; i < Q4K_BLOCK_SIZE ; ++ i ) {
85+ int v = (int ) lrintf (in [i ] * inv );
86+ if (v > 127 ) v = 127 ; else if (v < -127 ) v = -127 ;
87+ q8 [i ] = (int8_t ) v ;
88+ }
89+ return d_in ;
90+ }
91+
6392/*
6493 * Native Q4_K matrix-vector multiply matching the
65- * sk.ainet.backend.api.kernel.Q4KMatmulKernel SPI contract. Single
66- * input row times an `outputDim x inputDim` Q4_K-packed weight tensor
67- * laid out (blockIdx * outputDim + o) * 144 bytes.
68- *
69- * Lazy-dmin pattern: per sub-block accumulate
70- * codeSum[s] = sum_i input[i] * code[i]
71- * inputSum[s] = sum_i input[i]
72- * and combine once via
73- * acc += d * scaleIdx[s] * codeSum[s] - dMin * minIdx[s] * inputSum[s]
94+ * sk.ainet.backend.api.kernel.Q4KMatmulKernel SPI contract. Single input row
95+ * times an `outputDim x inputDim` Q4_K-packed weight laid out
96+ * (blockIdx * outputDim + o) * 144 bytes.
7497 *
75- * Scalar single-threaded for PR 2; the tight inner loop is
76- * straight-line FP arithmetic so -O3 auto-vectorizes the
77- * codeSum/inputSum accumulators on AVX2/NEON.
98+ * Fused int8 dot path (ggml-style): the input row is quantized to Q8 ONCE per
99+ * 256-block (reused across all output rows), then each weight sub-block is an
100+ * int8 dot-product against the Q8 activation:
101+ * acc += d_in[b] * ( d * Σ_s scaleIdx[s]*intDot[s] - dMin * Σ_s minIdx[s]*intSum[s] )
102+ * where intDot[s] = Σ q8[i]*code[i] and intSum[s] = Σ q8[i] over the sub-block.
103+ * On AArch64 with dotprod (asimddp) the inner dot uses vdotq_s32 (16 int8 MACs
104+ * per instruction); otherwise a scalar integer fallback (auto-vectorized).
105+ * The index mapping (groups, lo/hi sub-blocks, input alignment) is identical to
106+ * the previous float kernel, which was parity-checked against Panama.
78107 */
79108SKAINET_API void skainet_q4k_matmul (
80109 const float * SKAINET_RESTRICT input ,
@@ -92,86 +121,102 @@ SKAINET_API void skainet_q4k_matmul(
92121 const float * in_base = input + input_offset ;
93122 float * out_base = output + output_offset ;
94123
124+ /* Pre-quantize the whole input row to Q8 once (reused across all o). */
125+ int8_t * q8 = (int8_t * ) malloc ((size_t ) input_dim * sizeof (int8_t ));
126+ float * d_in = (float * ) malloc ((size_t ) blocks_per_input_dim * sizeof (float ));
127+ if (q8 == NULL || d_in == NULL ) { free (q8 ); free (d_in ); return ; }
128+ for (int32_t b = 0 ; b < blocks_per_input_dim ; ++ b ) {
129+ d_in [b ] = skainet_q8_quantize_block (in_base + (size_t ) b * Q4K_BLOCK_SIZE ,
130+ q8 + (size_t ) b * Q4K_BLOCK_SIZE );
131+ }
132+
95133 int scale_idx [Q4K_SUB_BLOCKS ];
96134 int min_idx [Q4K_SUB_BLOCKS ];
97135
98- for (int32_t o = 0 ; o < output_dim ; ++ o ) {
99- float acc = 0.0f ;
100-
101- for (int32_t block_idx = 0 ; block_idx < blocks_per_input_dim ; ++ block_idx ) {
102- const uint8_t * block = weight + weight_byte_offset
103- + (size_t )(block_idx * output_dim + o ) * Q4K_BYTES_PER_BLOCK ;
104-
105- /* d, dMin (FP16 LE -> FP32). */
136+ /*
137+ * Loop order: block OUTER, output row INNER. The weight is packed
138+ * block-major — (blockIdx * output_dim + o) * 144 — so for a fixed block,
139+ * consecutive `o` are exactly 144 bytes apart: the weight bytes are read
140+ * strictly sequentially (prefetch- and cache-line-friendly). The reverse
141+ * order (o outer) strides output_dim*144 bytes per step (~295 KB on the
142+ * down-proj), which on an in-order A55 with small caches makes every weight
143+ * read a cold miss and dominates runtime regardless of inner-loop compute.
144+ * out_base[o] is accumulated across blocks (output_dim*4 bytes stays hot in
145+ * cache); the accumulation order over blocks is unchanged, so this is
146+ * numerically identical to the o-outer form.
147+ */
148+ for (int32_t o = 0 ; o < output_dim ; ++ o ) out_base [o ] = 0.0f ;
149+
150+ for (int32_t block_idx = 0 ; block_idx < blocks_per_input_dim ; ++ block_idx ) {
151+ const int8_t * q8_block = q8 + (size_t ) block_idx * Q4K_BLOCK_SIZE ;
152+ const float di = d_in [block_idx ];
153+ const uint8_t * block = weight + weight_byte_offset
154+ + (size_t )(block_idx * output_dim ) * Q4K_BYTES_PER_BLOCK ;
155+
156+ for (int32_t o = 0 ; o < output_dim ; ++ o , block += Q4K_BYTES_PER_BLOCK ) {
106157 const uint16_t d_bits = (uint16_t ) block [0 ] | ((uint16_t ) block [1 ] << 8 );
107158 const uint16_t d_min_bits = (uint16_t ) block [2 ] | ((uint16_t ) block [3 ] << 8 );
108159 const float d = skainet_half_to_float (d_bits );
109160 const float d_min = skainet_half_to_float (d_min_bits );
110161
111- /* 12 bytes of packed (scaleIdx, minIdx) -> 8 ints each. */
112162 skainet_q4k_decode_scales (block + 4 , scale_idx , min_idx );
113163
114164 const uint8_t * qs = block + 16 ;
115- const float * in_block = in_base + (size_t ) block_idx * Q4K_BLOCK_SIZE ;
116165
117- /* 4 strided qs groups; group j carries sub-blocks 2j (lo) and 2j+1 (hi). */
166+ int64_t block_scale_dot = 0 ;
167+ int64_t block_min_sum = 0 ;
168+
118169 for (int group_j = 0 ; group_j < 4 ; ++ group_j ) {
119- const uint8_t * qs_group = qs + group_j * Q4K_SUB_BLOCK_SIZE ;
170+ const uint8_t * qs_group = qs + group_j * Q4K_SUB_BLOCK_SIZE ;
120171 const int sb_lo = 2 * group_j ;
121172 const int sb_hi = sb_lo + 1 ;
122- const float * in_lo = in_block + sb_lo * Q4K_SUB_BLOCK_SIZE ;
123- const float * in_hi = in_block + sb_hi * Q4K_SUB_BLOCK_SIZE ;
173+ const int8_t * q8_lo = q8_block + sb_lo * Q4K_SUB_BLOCK_SIZE ;
174+ const int8_t * q8_hi = q8_block + sb_hi * Q4K_SUB_BLOCK_SIZE ;
124175
125- float code_sum_lo = 0.0f , input_sum_lo = 0.0f ;
126- float code_sum_hi = 0.0f , input_sum_hi = 0.0f ;
176+ int32_t dot_lo = 0 , sum_lo = 0 , dot_hi = 0 , sum_hi = 0 ;
127177
128- #ifdef SKAINET_HAVE_NEON
129- float32x4_t cacc_lo = vdupq_n_f32 ( 0.0f ), iacc_lo = vdupq_n_f32 ( 0.0f );
130- float32x4_t cacc_hi = vdupq_n_f32 ( 0.0f ), iacc_hi = vdupq_n_f32 ( 0.0f ) ;
178+ #ifdef SKAINET_HAVE_DOTPROD
179+ int32x4_t acc_dot_lo = vdupq_n_s32 ( 0 ), acc_dot_hi = vdupq_n_s32 ( 0 );
180+ int32_t acc_sum_lo = 0 , acc_sum_hi = 0 ;
131181 for (int off = 0 ; off < Q4K_SUB_BLOCK_SIZE ; off += 16 ) {
132182 const uint8x16_t packed = vld1q_u8 (qs_group + off );
133- const uint8x16_t lo_nib = vandq_u8 (packed , vdupq_n_u8 (0x0F ));
134- const uint8x16_t hi_nib = vshrq_n_u8 (packed , 4 );
135- float32x4_t cl [4 ], ch [4 ];
136- skainet_neon_u8x16_to_f32x4x4 (lo_nib , cl );
137- skainet_neon_u8x16_to_f32x4x4 (hi_nib , ch );
138- for (int q = 0 ; q < 4 ; ++ q ) {
139- const float32x4_t v_lo = vld1q_f32 (in_lo + off + q * 4 );
140- const float32x4_t v_hi = vld1q_f32 (in_hi + off + q * 4 );
141- cacc_lo = vfmaq_f32 (cacc_lo , v_lo , cl [q ]);
142- iacc_lo = vaddq_f32 (iacc_lo , v_lo );
143- cacc_hi = vfmaq_f32 (cacc_hi , v_hi , ch [q ]);
144- iacc_hi = vaddq_f32 (iacc_hi , v_hi );
145- }
183+ const int8x16_t code_lo = vreinterpretq_s8_u8 (vandq_u8 (packed , vdupq_n_u8 (0x0F )));
184+ const int8x16_t code_hi = vreinterpretq_s8_u8 (vshrq_n_u8 (packed , 4 ));
185+ const int8x16_t a_lo = vld1q_s8 (q8_lo + off );
186+ const int8x16_t a_hi = vld1q_s8 (q8_hi + off );
187+ acc_dot_lo = vdotq_s32 (acc_dot_lo , code_lo , a_lo );
188+ acc_dot_hi = vdotq_s32 (acc_dot_hi , code_hi , a_hi );
189+ acc_sum_lo += vaddlvq_s8 (a_lo );
190+ acc_sum_hi += vaddlvq_s8 (a_hi );
146191 }
147- code_sum_lo = skainet_neon_hadd_f32 ( cacc_lo );
148- input_sum_lo = skainet_neon_hadd_f32 ( iacc_lo );
149- code_sum_hi = skainet_neon_hadd_f32 ( cacc_hi ) ;
150- input_sum_hi = skainet_neon_hadd_f32 ( iacc_hi ) ;
192+ dot_lo = vaddvq_s32 ( acc_dot_lo );
193+ dot_hi = vaddvq_s32 ( acc_dot_hi );
194+ sum_lo = acc_sum_lo ;
195+ sum_hi = acc_sum_hi ;
151196#else
152- /* 32 iterations — auto-vectorizes cleanly under -O3. */
153197 for (int i = 0 ; i < Q4K_SUB_BLOCK_SIZE ; ++ i ) {
154- const uint8_t b = qs_group [i ];
155- const float code_lo = (float )( b & 0x0F );
156- const float code_hi = (float )( b >> 4 );
157- const float v_lo = in_lo [i ];
158- const float v_hi = in_hi [i ];
159- code_sum_lo += v_lo * code_lo ;
160- input_sum_lo += v_lo ;
161- code_sum_hi += v_hi * code_hi ;
162- input_sum_hi += v_hi ;
198+ const uint8_t pb = qs_group [i ];
199+ const int code_lo = (int )( pb & 0x0F );
200+ const int code_hi = (int )( pb >> 4 );
201+ const int a_lo = ( int ) q8_lo [i ];
202+ const int a_hi = ( int ) q8_hi [i ];
203+ dot_lo += a_lo * code_lo ;
204+ sum_lo += a_lo ;
205+ dot_hi += a_hi * code_hi ;
206+ sum_hi += a_hi ;
163207 }
164208#endif
165209
166- const float scale_lo = d * (float ) scale_idx [sb_lo ];
167- const float offset_lo = d_min * (float ) min_idx [sb_lo ];
168- const float scale_hi = d * (float ) scale_idx [sb_hi ];
169- const float offset_hi = d_min * (float ) min_idx [sb_hi ];
170- acc += code_sum_lo * scale_lo - input_sum_lo * offset_lo ;
171- acc += code_sum_hi * scale_hi - input_sum_hi * offset_hi ;
210+ block_scale_dot += (int64_t ) scale_idx [sb_lo ] * dot_lo
211+ + (int64_t ) scale_idx [sb_hi ] * dot_hi ;
212+ block_min_sum += (int64_t ) min_idx [sb_lo ] * sum_lo
213+ + (int64_t ) min_idx [sb_hi ] * sum_hi ;
172214 }
173- }
174215
175- out_base [o ] = acc ;
216+ out_base [o ] += di * (d * (float ) block_scale_dot - d_min * (float ) block_min_sum );
217+ }
176218 }
219+
220+ free (q8 );
221+ free (d_in );
177222}
0 commit comments