Skip to content

Commit 2974f4b

Browse files
committed
ggml-ve: vectorize the YaRN NeoX rope kernel (prompt-eval 1.48x)
ftrace on a 330-token prompt showed the rope was 32% of prompt-eval and fully SCALAR (V.OP 0%): the per-element ve_rope_yarn() function call blocked vectorization, and the theta *= theta_scale recurrence serialized the inner loop (~30M scalar calls). Restructure ve_rope_neox_hbm_omp_nocache: precompute theta_scale^i once (breaks the recurrence so theta[i] is independent), fold YaRN into a branchless form (ramp_mix == 0 when ext_factor == 0, magnitude scale precomputed), and drop the function call. The inner loop is then call-free and branch-free, so NCC vectorizes cos/sin via libsysve. Result: the rope goes V.OP 0% -> 98.4%, from ~32% of prompt-eval to 0.1%; interpreter prompt-eval 8.54 -> 12.67 tok/s (1.48x) on Bonsai-8B-VEBP, output still correct (YaRN math unchanged, just reorganized). Helps every prompt and will carry into the compiled N>1 path. Not pushed.
1 parent 0fb4f29 commit 2974f4b

1 file changed

Lines changed: 24 additions & 9 deletions

File tree

ggml/src/ggml-ve/kernels-veda/ve_sgemv_wrapper.c

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11647,6 +11647,18 @@ uint64_t ve_rope_neox_hbm_omp_nocache(VEDAdeviceptr y_hbm,
1164711647
float corr_lo, corr_hi;
1164811648
ve_rope_corr_dims(nd, (int)n_ctx_orig, freq_base, beta_fast, beta_slow, &corr_lo, &corr_hi);
1164911649

11650+
/* Vectorisation prep. The old per-element ve_rope_yarn() call + the
11651+
* theta *= theta_scale recurrence forced the inner loop fully SCALAR
11652+
* (V.OP 0%, ~30M calls = 32% of prompt-eval). Precompute theta_scale^i
11653+
* once (breaks the recurrence -> theta[i] independent) and fold YaRN into
11654+
* a branchless form (ramp_mix == 0 when ext_factor == 0), so the inner
11655+
* loop is call-free + branch-free and NCC vectorises cos/sin (libsysve). */
11656+
float ts_pow[512];
11657+
{ float p = 1.0f; for (int i = 0; i < half_nd && i < 512; i++) { ts_pow[i] = p; p *= theta_scale; } }
11658+
const float ms_eff = (ext_factor != 0.0f) ? mscale * (1.0f + 0.1f * logf(1.0f / freq_scale)) : mscale;
11659+
float denom = corr_hi - corr_lo; if (denom < 0.001f) denom = 0.001f;
11660+
const float inv_denom = 1.0f / denom;
11661+
1165011662
int total_rows = batch * ctx * heads;
1165111663

1165211664
#pragma omp parallel for
@@ -11656,28 +11668,31 @@ uint64_t ve_rope_neox_hbm_omp_nocache(VEDAdeviceptr y_hbm,
1165611668
int i2 = rem / heads;
1165711669
int i1 = rem % heads;
1165811670

11659-
float theta_extrap = (float)pos[i2]; /* NOT freq-scaled — YaRN needs extrap */
11671+
const float posf = (float)pos[i2];
1166011672

1166111673
size_t src_offset = i3 * nb3 + i2 * nb2 + i1 * nb1;
1166211674
size_t dst_offset = src_offset;
1166311675

1166411676
const float* src = (const float*)((const char*)x + src_offset);
1166511677
float* dst = (float*)((char*)y + dst_offset);
1166611678

11667-
/* NeoX style: rotate first half with second half */
11679+
/* NeoX style: rotate first half with second half. Branchless YaRN. */
1166811680
#pragma _NEC ivdep
1166911681
for (int i0 = 0; i0 < half_nd; i0++) {
11670-
float cos_val, sin_val;
11671-
ve_rope_yarn(theta_extrap, freq_scale, corr_lo, corr_hi, i0,
11672-
ext_factor, mscale, &cos_val, &sin_val);
11682+
float theta_extrap = posf * ts_pow[i0];
11683+
float theta_interp = freq_scale * theta_extrap;
11684+
float yv = ((float) i0 - corr_lo) * inv_denom;
11685+
yv = yv < 0.0f ? 0.0f : (yv > 1.0f ? 1.0f : yv);
11686+
float rmix = (1.0f - yv) * ext_factor; /* 0 when ext_factor==0 */
11687+
float theta = theta_interp * (1.0f - rmix) + theta_extrap * rmix;
11688+
float c = cosf(theta) * ms_eff;
11689+
float s = sinf(theta) * ms_eff;
1167311690

1167411691
float x0 = src[i0];
1167511692
float x1 = src[i0 + half_nd];
1167611693

11677-
dst[i0] = x0 * cos_val - x1 * sin_val;
11678-
dst[i0 + half_nd] = x0 * sin_val + x1 * cos_val;
11679-
11680-
theta_extrap *= theta_scale;
11694+
dst[i0] = x0 * c - x1 * s;
11695+
dst[i0 + half_nd] = x0 * s + x1 * c;
1168111696
}
1168211697

1168311698
/* Copy remaining elements unchanged */

0 commit comments

Comments
 (0)