33#include " ggml-cpu.h"
44#include " ggml-impl.h"
55#include " binary-ops.h"
6+ #include " simd-gemm.h"
67#include " ggml.h"
78#include " unary-ops.h"
89#include " vec.h"
@@ -8326,10 +8327,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
83268327 GGML_ASSERT (k->type == v->type );
83278328 const ggml_type kv_type = k->type ;
83288329
8329- const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu (kv_type);
8330- const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float ;
8331- const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot ;
8332- const size_t kv_type_size = ggml_type_size (kv_type);
83338330
83348331 // broadcast factors
83358332 const int64_t rk2 = neq2/nek2;
@@ -8360,8 +8357,9 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
83608357
83618358 static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
83628359 static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV ;
8363-
8364- GGML_ASSERT (nek1 % KV_TILE_SZ == 0 && " KV sequence length must be divisible by KV_TILE_SZ" );
8360+ #ifdef GGML_SIMD
8361+ GGML_ASSERT (DV % GGML_F32_EPR == 0 );
8362+ #endif
83658363
83668364 int ir = ir0;
83678365 while (ir < ir1) {
@@ -8389,18 +8387,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
83898387 }
83908388
83918389 // Per-thread scratch layout:
8392- // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
8390+ // Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar )
83938391 // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
83948392 // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
83958393 // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
8396- // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
8397- float * base = (float *) params->wdata + ith*(Q_TILE_SZ *DK + 2 *Q_TILE_SZ *KV_TILE_SZ + Q_TILE_SZ *DV + KV_TILE_SZ *DV + CACHE_LINE_SIZE_F32 );
8394+ // V32: KV_TILE_SZ * DV (F32 buffer for V tile)
8395+ // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
8396+ float * base = (float *) params->wdata + ith*(Q_TILE_SZ *DK + 2 *Q_TILE_SZ *KV_TILE_SZ + Q_TILE_SZ *DV + KV_TILE_SZ *DV + KV_TILE_SZ *DK + CACHE_LINE_SIZE_F32 );
83988397
83998398 void * Q_q = base;
84008399 float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof (float ));
84018400 float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ ;
84028401 float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ ;
8403- float * V32 = VKQ32 + Q_TILE_SZ * DV ; // F32 buffer for V tile
8402+ float * V32 = VKQ32 + Q_TILE_SZ * DV ;
8403+ float * K_f32 = V32 + KV_TILE_SZ * DV ;
84048404
84058405 memset (VKQ32 , 0 , Q_TILE_SZ * DV * sizeof (float ));
84068406 memset (mask32, 0 , Q_TILE_SZ * KV_TILE_SZ * sizeof (float ));
@@ -8413,42 +8413,69 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
84138413 const int iv3 = iq3 / rv3;
84148414 const int iv2 = iq2 / rv2;
84158415
8416- for (int tq = 0 ; tq < tile_rows; tq++) {
8417- const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8418- kv_from_float (pq, (char *)Q_q + tq * DK * kv_type_size, DK );
8419- }
8420- // Zero-pad remaining rows
8421- for (int tq = tile_rows; tq < Q_TILE_SZ ; tq++) {
8422- memset ((char *)Q_q + tq * DK * kv_type_size, 0 , DK * kv_type_size);
8416+ {
8417+ float * Q_f32 = (float *)Q_q;
8418+ for (int tq = 0 ; tq < tile_rows; tq++) {
8419+ const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8420+ memcpy (Q_f32 + tq * DK , pq, DK * sizeof (float ));
8421+ }
8422+ for (int tq = tile_rows; tq < Q_TILE_SZ ; tq++) {
8423+ memset (Q_f32 + tq * DK , 0 , DK * sizeof (float ));
8424+ }
84238425 }
84248426
84258427 for (int64_t ic = 0 ; ic < nek1; ic += KV_TILE_SZ ) {
8428+ const int kv_tile = (int )std::min ((int64_t )KV_TILE_SZ , nek1 - ic);
84268429
84278430 // skip the tile entirely if all the masks are -inf
84288431 if (mask) {
84298432 bool can_skip = true ;
84308433 for (int tq = 0 ; tq < tile_rows; tq++) {
84318434 const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb [1 ] + (iq2%mask->ne [2 ])*mask->nb [2 ] + (iq3%mask->ne [3 ])*mask->nb [3 ]);
8432- for (int tk = 0 ; tk < KV_TILE_SZ ; tk++) {
8435+ for (int tk = 0 ; tk < kv_tile ; tk++) {
84338436 mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32 (mp_row[ic + tk]);
84348437 if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY ) {
84358438 can_skip = false ;
84368439 }
84378440 }
8441+ // Pad remaining mask entries with -inf
8442+ for (int tk = kv_tile; tk < KV_TILE_SZ ; tk++) {
8443+ mask32[tq * KV_TILE_SZ + tk] = -INFINITY ;
8444+ }
84388445 }
84398446
84408447 if (can_skip) {
84418448 continue ;
84428449 }
84438450 }
84448451
8445- for (int tq = 0 ; tq < Q_TILE_SZ ; tq++) {
8446- const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
8447- for (int tk = 0 ; tk < KV_TILE_SZ ; tk++) {
8448- const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
8449- float s;
8450- kv_vec_dot (DK , &s, 0 , k_row, 0 , q_row, 0 , 1 );
8451- KQ [tq * KV_TILE_SZ + tk] = s * scale;
8452+ // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
8453+ // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
8454+ memset (K_f32, 0 , DK * KV_TILE_SZ * sizeof (float ));
8455+ for (int tk = 0 ; tk < kv_tile; tk++) {
8456+ const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
8457+ if (kv_type == GGML_TYPE_F16 ) {
8458+ const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
8459+ for (int64_t dk = 0 ; dk < DK ; dk++) {
8460+ K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32 (k_f16[dk]);
8461+ }
8462+ } else {
8463+ const float * k_f32_src = (const float *)k_data;
8464+ for (int64_t dk = 0 ; dk < DK ; dk++) {
8465+ K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
8466+ }
8467+ }
8468+ }
8469+ memset (KQ , 0 , Q_TILE_SZ * KV_TILE_SZ * sizeof (float ));
8470+ simd_gemm (KQ , (const float *)Q_q, K_f32, Q_TILE_SZ , DK , KV_TILE_SZ );
8471+ ggml_vec_scale_f32 (Q_TILE_SZ * KV_TILE_SZ , KQ , scale);
8472+
8473+ // Set padded KQ entries to -inf so softmax gives them zero weight
8474+ if (kv_tile < KV_TILE_SZ ) {
8475+ for (int tq = 0 ; tq < Q_TILE_SZ ; tq++) {
8476+ for (int tk = kv_tile; tk < KV_TILE_SZ ; tk++) {
8477+ KQ [tq * KV_TILE_SZ + tk] = -INFINITY ;
8478+ }
84528479 }
84538480 }
84548481
@@ -8488,33 +8515,23 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
84888515 S[tq] += ggml_vec_soft_max_f32 (KV_TILE_SZ , kq_row, kq_row, Mnew);
84898516 }
84908517
8491- // Convert V tile to F32 first (if F16), then do MAD
8492- // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
8493- // TODO: on ARM, native f16 should be faster
8494- if (kv_type == GGML_TYPE_F16 ) {
8495- for (int tk = 0 ; tk < KV_TILE_SZ ; tk++) {
8496- const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
8497- ggml_fp16_to_fp32_row (v_row, V32 + tk * DV , DV );
8498- }
8499- for (int tq = 0 ; tq < Q_TILE_SZ ; tq++) {
8500- if (skip[tq]) continue ;
8501- float * vkq_row = VKQ32 + tq * DV ;
8502- for (int tk = 0 ; tk < KV_TILE_SZ ; tk++) {
8503- const float p = KQ [tq * KV_TILE_SZ + tk];
8504- ggml_vec_mad_f32 (DV , vkq_row, V32 + tk * DV , p);
8505- }
8518+ // V accumulation: VKQ32 += softmax(KQ) * V
8519+ // Pack V tile to contiguous F32, zero-padded
8520+ memset (V32 , 0 , KV_TILE_SZ * DV * sizeof (float ));
8521+ for (int tk = 0 ; tk < kv_tile; tk++) {
8522+ const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
8523+ if (kv_type == GGML_TYPE_F16 ) {
8524+ ggml_fp16_to_fp32_row ((const ggml_fp16_t *)v_data, V32 + tk * DV , DV );
8525+ } else {
8526+ memcpy (V32 + tk * DV , v_data, DV * sizeof (float ));
85068527 }
8507- } else {
8508- for (int tq = 0 ; tq < Q_TILE_SZ ; tq++) {
8509- if (skip[tq]) continue ;
8510- float * vkq_row = VKQ32 + tq * DV ;
8511- for (int tk = 0 ; tk < KV_TILE_SZ ; tk++) {
8512- const float p = KQ [tq * KV_TILE_SZ + tk];
8513- const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
8514- ggml_vec_mad_f32 (DV , vkq_row, v_row, p);
8515- }
8528+ }
8529+ for (int tq = 0 ; tq < Q_TILE_SZ ; tq++) {
8530+ if (skip[tq]) {
8531+ memset (KQ + tq * KV_TILE_SZ , 0 , KV_TILE_SZ * sizeof (float ));
85168532 }
85178533 }
8534+ simd_gemm (VKQ32 , KQ , V32 , Q_TILE_SZ , KV_TILE_SZ , DV );
85188535 }
85198536
85208537 // sinks (apply only to valid rows in the tile)
@@ -8731,13 +8748,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
87318748
87328749 const int64_t dr = (nr + nchunk - 1 ) / nchunk;
87338750
8734- static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV ;
87358751 static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
87368752 const bool use_tiled = !use_ref &&
87378753 (q->type == GGML_TYPE_F32 &&
87388754 kv_is_f32_or_f16 &&
87398755 k->type == v->type &&
8740- nek1 % KV_TILE_SZ == 0 &&
87418756 neq1 >= Q_TILE_SZ );
87428757
87438758 int current_chunk = ith;
0 commit comments