@@ -1179,72 +1179,148 @@ static float fused_dot_iq2_xxs(const void* row, const float* x, int n) {
11791179}
11801180
11811181#if TQ_HAS_NEON
1182+
1183+ /* Vectorized sign application helper: given 8 grid bytes and an 8-bit sign mask,
1184+ * produce signed int8x8 where negative signs are applied.
1185+ * Uses NEON bit test: broadcast sign byte, AND with bit masks, compare to produce
1186+ * negation mask, then apply via (grid ^ neg) - neg (conditional negate). */
1187+ static const uint8_t iq2_sign_bit_masks [8 ] = {1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 };
1188+
11821189/* NEON-optimized fused IQ2_XXS dot product.
1183- * Processes 8 grid values at a time using vectorized sign application. */
1190+ * Optimizations over baseline:
1191+ * 1. Vectorized sign expansion via NEON bit-test (replaces 8 scalar shifts)
1192+ * 2. Apply signs in int8 domain before float conversion (fewer instructions)
1193+ * 3. Fully unrolled inner loop (4 groups per ib32)
1194+ * 4. Prefetch next block's weight data
1195+ * 5. Two accumulator strategy to reduce FMA dependency chains */
11841196static float fused_dot_iq2_xxs_neon (const void * row , const float * x , int n ) {
11851197 const int nb = n / 256 ;
11861198 const uint8_t * base = (const uint8_t * )row ;
1187- float32x4_t vtotal = vdupq_n_f32 (0.0f );
1199+ float32x4_t vtotal0 = vdupq_n_f32 (0.0f );
1200+
1201+ /* Preload sign bit masks into a NEON register */
1202+ const uint8x8_t vbit_masks = vld1_u8 (iq2_sign_bit_masks );
11881203
11891204 for (int b = 0 ; b < nb ; b ++ ) {
11901205 const uint8_t * blk = base + b * 66 ;
1206+
1207+ /* Prefetch next block */
1208+ if (b + 1 < nb ) {
1209+ __builtin_prefetch (blk + 66 , 0 , 3 );
1210+ __builtin_prefetch (blk + 66 + 32 , 0 , 3 );
1211+ }
1212+
11911213 uint16_t d_raw ;
11921214 memcpy (& d_raw , blk , 2 );
11931215 const float d = fp16_to_fp32 (d_raw );
1194- const uint16_t * qs = ( const uint16_t * )( blk + 2 ) ;
1216+ const uint8_t * qs_bytes = blk + 2 ;
11951217 const float * xbase = x + b * 256 ;
11961218
11971219 for (int ib32 = 0 ; ib32 < 8 ; ib32 ++ ) {
11981220 uint32_t aux32 [2 ];
1199- memcpy (aux32 , qs + 4 * ib32 , 8 );
1221+ memcpy (aux32 , qs_bytes + 8 * ib32 , 8 );
12001222 const uint8_t * aux8 = (const uint8_t * )aux32 ;
12011223 const float db = d * (0.5f + (float )(aux32 [1 ] >> 28 )) * 0.25f ;
12021224 const float * xb = xbase + ib32 * 32 ;
12031225
1204- float32x4_t vgroup = vdupq_n_f32 (0.0f );
1226+ /* Accumulate across all 4 sub-groups before scaling by db.
1227+ * Use two accumulators to break FMA dependency chains. */
1228+ float32x4_t vacc0 = vdupq_n_f32 (0.0f );
1229+ float32x4_t vacc1 = vdupq_n_f32 (0.0f );
12051230
1206- for (int l = 0 ; l < 4 ; l ++ ) {
1207- const uint8_t * grid = (const uint8_t * )(iq2xxs_grid + aux8 [l ]);
1208- const uint8_t signs = ksigns_iq2xs [(aux32 [1 ] >> (7 * l )) & 127 ];
1209- const float * xp = xb + l * 8 ;
1231+ /* --- Group 0 --- */
1232+ {
1233+ const uint8_t * grid = (const uint8_t * )(iq2xxs_grid + aux8 [0 ]);
1234+ const uint8_t signs = ksigns_iq2xs [aux32 [1 ] & 127 ];
1235+
1236+ uint8x8_t vgrid = vld1_u8 (grid );
1237+ /* Vectorized sign expansion:
1238+ * Broadcast sign byte to all lanes, AND with bit masks,
1239+ * compare != 0 produces 0xFF for negative lanes.
1240+ * Then: signed = (grid ^ neg_mask) - neg_mask
1241+ * which is grid when neg_mask=0, -grid when neg_mask=0xFF */
1242+ uint8x8_t vsign_bcast = vdup_n_u8 (signs );
1243+ uint8x8_t vsign_bits = vtst_u8 (vsign_bcast , vbit_masks );
1244+ /* vsign_bits is 0xFF where negative, 0x00 where positive */
1245+ int8x8_t vgrid_s = vreinterpret_s8_u8 (vgrid );
1246+ int8x8_t vneg_mask = vreinterpret_s8_u8 (vsign_bits );
1247+ int8x8_t vsigned = vsub_s8 (veor_s8 (vgrid_s , vneg_mask ), vneg_mask );
1248+
1249+ /* Widen to int16, then int32, then float */
1250+ int16x8_t vs16 = vmovl_s8 (vsigned );
1251+ float32x4_t vf_lo = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vs16 )));
1252+ float32x4_t vf_hi = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vs16 )));
1253+
1254+ vacc0 = vfmaq_f32 (vacc0 , vf_lo , vld1q_f32 (xb ));
1255+ vacc1 = vfmaq_f32 (vacc1 , vf_hi , vld1q_f32 (xb + 4 ));
1256+ }
1257+
1258+ /* --- Group 1 --- */
1259+ {
1260+ const uint8_t * grid = (const uint8_t * )(iq2xxs_grid + aux8 [1 ]);
1261+ const uint8_t signs = ksigns_iq2xs [(aux32 [1 ] >> 7 ) & 127 ];
1262+
1263+ uint8x8_t vgrid = vld1_u8 (grid );
1264+ uint8x8_t vsign_bcast = vdup_n_u8 (signs );
1265+ uint8x8_t vsign_bits = vtst_u8 (vsign_bcast , vbit_masks );
1266+ int8x8_t vgrid_s = vreinterpret_s8_u8 (vgrid );
1267+ int8x8_t vneg_mask = vreinterpret_s8_u8 (vsign_bits );
1268+ int8x8_t vsigned = vsub_s8 (veor_s8 (vgrid_s , vneg_mask ), vneg_mask );
1269+
1270+ int16x8_t vs16 = vmovl_s8 (vsigned );
1271+ float32x4_t vf_lo = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vs16 )));
1272+ float32x4_t vf_hi = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vs16 )));
1273+
1274+ vacc0 = vfmaq_f32 (vacc0 , vf_lo , vld1q_f32 (xb + 8 ));
1275+ vacc1 = vfmaq_f32 (vacc1 , vf_hi , vld1q_f32 (xb + 12 ));
1276+ }
1277+
1278+ /* --- Group 2 --- */
1279+ {
1280+ const uint8_t * grid = (const uint8_t * )(iq2xxs_grid + aux8 [2 ]);
1281+ const uint8_t signs = ksigns_iq2xs [(aux32 [1 ] >> 14 ) & 127 ];
1282+
1283+ uint8x8_t vgrid = vld1_u8 (grid );
1284+ uint8x8_t vsign_bcast = vdup_n_u8 (signs );
1285+ uint8x8_t vsign_bits = vtst_u8 (vsign_bcast , vbit_masks );
1286+ int8x8_t vgrid_s = vreinterpret_s8_u8 (vgrid );
1287+ int8x8_t vneg_mask = vreinterpret_s8_u8 (vsign_bits );
1288+ int8x8_t vsigned = vsub_s8 (veor_s8 (vgrid_s , vneg_mask ), vneg_mask );
1289+
1290+ int16x8_t vs16 = vmovl_s8 (vsigned );
1291+ float32x4_t vf_lo = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vs16 )));
1292+ float32x4_t vf_hi = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vs16 )));
1293+
1294+ vacc0 = vfmaq_f32 (vacc0 , vf_lo , vld1q_f32 (xb + 16 ));
1295+ vacc1 = vfmaq_f32 (vacc1 , vf_hi , vld1q_f32 (xb + 20 ));
1296+ }
1297+
1298+ /* --- Group 3 --- */
1299+ {
1300+ const uint8_t * grid = (const uint8_t * )(iq2xxs_grid + aux8 [3 ]);
1301+ const uint8_t signs = ksigns_iq2xs [(aux32 [1 ] >> 21 ) & 127 ];
12101302
1211- /* Load 8 grid bytes, expand to float */
12121303 uint8x8_t vgrid = vld1_u8 (grid );
1213- int16x8_t vgrid16 = vreinterpretq_s16_u16 (vmovl_u8 (vgrid ));
1214- int32x4_t vg_lo = vmovl_s16 (vget_low_s16 (vgrid16 ));
1215- int32x4_t vg_hi = vmovl_s16 (vget_high_s16 (vgrid16 ));
1216- float32x4_t vf_lo = vcvtq_f32_s32 (vg_lo );
1217- float32x4_t vf_hi = vcvtq_f32_s32 (vg_hi );
1218-
1219- /* Apply signs via float bit XOR: set sign bit where signs bit is 1.
1220- * Expand each sign bit to a 32-bit mask with only the float sign bit set. */
1221- uint32x4_t sign_lo = {
1222- (uint32_t )((signs >> 0 ) & 1 ) << 31 ,
1223- (uint32_t )((signs >> 1 ) & 1 ) << 31 ,
1224- (uint32_t )((signs >> 2 ) & 1 ) << 31 ,
1225- (uint32_t )((signs >> 3 ) & 1 ) << 31
1226- };
1227- uint32x4_t sign_hi = {
1228- (uint32_t )((signs >> 4 ) & 1 ) << 31 ,
1229- (uint32_t )((signs >> 5 ) & 1 ) << 31 ,
1230- (uint32_t )((signs >> 6 ) & 1 ) << 31 ,
1231- (uint32_t )((signs >> 7 ) & 1 ) << 31
1232- };
1233- vf_lo = vreinterpretq_f32_u32 (veorq_u32 (vreinterpretq_u32_f32 (vf_lo ), sign_lo ));
1234- vf_hi = vreinterpretq_f32_u32 (veorq_u32 (vreinterpretq_u32_f32 (vf_hi ), sign_hi ));
1235-
1236- /* Dot with input */
1237- float32x4_t vx_lo = vld1q_f32 (xp );
1238- float32x4_t vx_hi = vld1q_f32 (xp + 4 );
1239-
1240- vgroup = vfmaq_f32 (vgroup , vf_lo , vx_lo );
1241- vgroup = vfmaq_f32 (vgroup , vf_hi , vx_hi );
1304+ uint8x8_t vsign_bcast = vdup_n_u8 (signs );
1305+ uint8x8_t vsign_bits = vtst_u8 (vsign_bcast , vbit_masks );
1306+ int8x8_t vgrid_s = vreinterpret_s8_u8 (vgrid );
1307+ int8x8_t vneg_mask = vreinterpret_s8_u8 (vsign_bits );
1308+ int8x8_t vsigned = vsub_s8 (veor_s8 (vgrid_s , vneg_mask ), vneg_mask );
1309+
1310+ int16x8_t vs16 = vmovl_s8 (vsigned );
1311+ float32x4_t vf_lo = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vs16 )));
1312+ float32x4_t vf_hi = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vs16 )));
1313+
1314+ vacc0 = vfmaq_f32 (vacc0 , vf_lo , vld1q_f32 (xb + 24 ));
1315+ vacc1 = vfmaq_f32 (vacc1 , vf_hi , vld1q_f32 (xb + 28 ));
12421316 }
1243- /* Scale by db and accumulate */
1244- vtotal = vfmaq_n_f32 (vtotal , vgroup , db );
1317+
1318+ /* Combine accumulators, scale by db, accumulate to total */
1319+ float32x4_t vgroup = vaddq_f32 (vacc0 , vacc1 );
1320+ vtotal0 = vfmaq_n_f32 (vtotal0 , vgroup , db );
12451321 }
12461322 }
1247- return vaddvq_f32 (vtotal );
1323+ return vaddvq_f32 (vtotal0 );
12481324}
12491325#endif /* TQ_HAS_NEON */
12501326
0 commit comments