Skip to content

Commit 17d22a3

Browse files
authored
hexagon: add MROPE and IMROPE support in HTP rope op (ggml-org#23317)
1 parent 67ace02 commit 17d22a3

2 files changed

Lines changed: 98 additions & 19 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2661,7 +2661,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
26612661

26622662
int mode = op_params[2];
26632663

2664-
if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
2664+
if (mode == GGML_ROPE_TYPE_VISION) {
26652665
return false;
26662666
}
26672667
if (mode & 1) {

ggml/src/ggml-hexagon/htp/rope-ops.c

Lines changed: 97 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
#include "htp-ops.h"
1919
#include "htp-ops.h"
2020

21-
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h
21+
// Redefined the rope type constants as we can't include ggml.h
2222
#define HTP_ROPE_TYPE_NORMAL 0
2323
#define HTP_ROPE_TYPE_NEOX 2
24+
#define HTP_ROPE_TYPE_MROPE 8
25+
#define HTP_ROPE_TYPE_IMROPE 40
2426

2527
#define HTP_ROPE_SPAD_NROWS 16
2628
#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
@@ -82,6 +84,29 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
8284
return (1 - MIN(1, MAX(0, y)));
8385
}
8486

87+
// Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling.
88+
static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dims,
89+
uint32_t i0, float ext_factor, float mscale,
90+
float * cache) {
91+
float theta_extrap = theta;
92+
93+
// Get n-d rotational scaling corrected for extrapolation
94+
float theta_interp = freq_scale * theta_extrap;
95+
float theta_final = theta_interp;
96+
float mscale_final = mscale;
97+
98+
if (ext_factor != 0.0f) {
99+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
100+
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
101+
102+
// Get n-d magnitude scaling corrected for interpolation
103+
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
104+
}
105+
106+
cache[i0 + 0] = cosf(theta_final) * mscale_final;
107+
cache[i0 + 1] = sinf(theta_final) * mscale_final;
108+
}
109+
85110
static void rope_cache_init(const float theta_base,
86111
const float freq_scale,
87112
const float * freq_factors,
@@ -96,26 +121,62 @@ static void rope_cache_init(const float theta_base,
96121

97122
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
98123
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
124+
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
99125

100-
float theta_extrap = theta / ff;
101-
102-
// Get n-d rotational scaling corrected for extrapolation
103-
float theta_interp = freq_scale * theta_extrap;
104-
float theta_final = theta_interp;
105-
float mscale_final = mscale;
126+
theta *= theta_scale;
127+
}
128+
}
106129

107-
if (ext_factor != 0.0f) {
108-
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
109-
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
130+
// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra).
131+
// sections[4]: number of head dims assigned to each position component.
132+
static void mrope_cache_init(const float pos_t,
133+
const float pos_h,
134+
const float pos_w,
135+
const float pos_e,
136+
const int32_t sections[4],
137+
const bool is_imrope,
138+
const float freq_scale,
139+
const float * freq_factors,
140+
float * corr_dims,
141+
const uint32_t ne0,
142+
const float ext_factor,
143+
const float mscale,
144+
float * cache,
145+
const float theta_scale) {
146+
const int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
147+
const int sec_w = sections[0] + sections[1];
148+
const int sec_e = sec_w + sections[2];
149+
150+
float theta_t = pos_t;
151+
float theta_h = pos_h;
152+
float theta_w = pos_w;
153+
float theta_e = pos_e;
110154

111-
// Get n-d magnitude scaling corrected for interpolation
112-
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
155+
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
156+
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
157+
const int sector = (i0 / 2) % sect_dims;
158+
159+
float theta;
160+
if (is_imrope) {
161+
// Interleaved: sector mod 3 selects component
162+
if (sector % 3 == 0 && sector < 3 * sections[0]) { theta = theta_t; }
163+
else if (sector % 3 == 1 && sector < 3 * sections[1]) { theta = theta_h; }
164+
else if (sector % 3 == 2 && sector < 3 * sections[2]) { theta = theta_w; }
165+
else { theta = theta_e; }
166+
} else {
167+
// Contiguous sections
168+
if (sector < sections[0]) { theta = theta_t; }
169+
else if (sector < sec_w) { theta = theta_h; }
170+
else if (sector < sec_e) { theta = theta_w; }
171+
else { theta = theta_e; }
113172
}
114173

115-
cache[i0 + 0] = cosf(theta_final) * mscale_final;
116-
cache[i0 + 1] = sinf(theta_final) * mscale_final;
174+
rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache);
117175

118-
theta *= theta_scale;
176+
theta_t *= theta_scale;
177+
theta_h *= theta_scale;
178+
theta_w *= theta_scale;
179+
theta_e *= theta_scale;
119180
}
120181
}
121182

@@ -274,7 +335,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
274335
uint64_t tt = HAP_perf_get_qtimer_count();
275336

276337
const int32_t mode = rctx->mode;
277-
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
338+
// MROPE and IMROPE use NEOX-style pairing for the rotation
339+
const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX) || (mode & HTP_ROPE_TYPE_MROPE);
278340

279341
// VTCM setup
280342
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
@@ -326,8 +388,25 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
326388
if (i2 != prev_i2) {
327389
prev_i2 = i2;
328390

329-
const int32_t p = pos[i2];
330-
rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
391+
const bool is_mrope = (rctx->mode & HTP_ROPE_TYPE_MROPE) != 0;
392+
if (is_mrope) {
393+
// src1 holds four position arrays stacked along ne0:
394+
// pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3]
395+
const bool is_imrope = (rctx->mode == HTP_ROPE_TYPE_IMROPE);
396+
mrope_cache_init(
397+
(float) pos[i2],
398+
(float) pos[i2 + ne2],
399+
(float) pos[i2 + ne2 * 2],
400+
(float) pos[i2 + ne2 * 3],
401+
rctx->sections, is_imrope,
402+
rctx->freq_scale, freq_factors, rctx->corr_dims,
403+
ne0, rctx->ext_factor, rctx->attn_factor,
404+
theta_cache, rctx->theta_scale);
405+
} else {
406+
rope_cache_init(pos[i2], rctx->freq_scale, freq_factors, rctx->corr_dims,
407+
ne0, rctx->ext_factor, rctx->attn_factor,
408+
theta_cache, rctx->theta_scale);
409+
}
331410

332411
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
333412
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));

0 commit comments

Comments
 (0)