1818#include "htp-ops.h"
1919#include "htp-ops.h"
2020
21- // Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h
21+ // Redefined the rope type constants as we can't include ggml.h
2222#define HTP_ROPE_TYPE_NORMAL 0
2323#define HTP_ROPE_TYPE_NEOX 2
24+ #define HTP_ROPE_TYPE_MROPE 8
25+ #define HTP_ROPE_TYPE_IMROPE 40
2426
2527#define HTP_ROPE_SPAD_NROWS 16
2628#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
@@ -82,6 +84,29 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
8284 return (1 - MIN (1 , MAX (0 , y )));
8385}
8486
87+ // Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling.
88+ static inline void rope_yarn_one (float theta , float freq_scale , float * corr_dims ,
89+ uint32_t i0 , float ext_factor , float mscale ,
90+ float * cache ) {
91+ float theta_extrap = theta ;
92+
93+ // Get n-d rotational scaling corrected for extrapolation
94+ float theta_interp = freq_scale * theta_extrap ;
95+ float theta_final = theta_interp ;
96+ float mscale_final = mscale ;
97+
98+ if (ext_factor != 0.0f ) {
99+ float ramp_mix = rope_yarn_ramp (corr_dims [0 ], corr_dims [1 ], i0 ) * ext_factor ;
100+ theta_final = theta_interp * (1 - ramp_mix ) + theta_extrap * ramp_mix ;
101+
102+ // Get n-d magnitude scaling corrected for interpolation
103+ mscale_final *= 1.0f + 0.1f * logf (1.0f / freq_scale );
104+ }
105+
106+ cache [i0 + 0 ] = cosf (theta_final ) * mscale_final ;
107+ cache [i0 + 1 ] = sinf (theta_final ) * mscale_final ;
108+ }
109+
85110static void rope_cache_init (const float theta_base ,
86111 const float freq_scale ,
87112 const float * freq_factors ,
@@ -96,26 +121,62 @@ static void rope_cache_init(const float theta_base,
96121
97122 for (uint32_t i0 = 0 ; i0 < ne0 ; i0 += 2 ) {
98123 const float ff = freq_factors ? freq_factors [i0 / 2 ] : 1.0f ;
124+ rope_yarn_one (theta / ff , freq_scale , corr_dims , i0 , ext_factor , mscale , cache );
99125
100- float theta_extrap = theta / ff ;
101-
102- // Get n-d rotational scaling corrected for extrapolation
103- float theta_interp = freq_scale * theta_extrap ;
104- float theta_final = theta_interp ;
105- float mscale_final = mscale ;
126+ theta *= theta_scale ;
127+ }
128+ }
106129
107- if (ext_factor != 0.0f ) {
108- float ramp_mix = rope_yarn_ramp (corr_dims [0 ], corr_dims [1 ], i0 ) * ext_factor ;
109- theta_final = theta_interp * (1 - ramp_mix ) + theta_extrap * ramp_mix ;
130+ // pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra).
131+ // sections[4]: number of head dims assigned to each position component.
132+ static void mrope_cache_init (const float pos_t ,
133+ const float pos_h ,
134+ const float pos_w ,
135+ const float pos_e ,
136+ const int32_t sections [4 ],
137+ const bool is_imrope ,
138+ const float freq_scale ,
139+ const float * freq_factors ,
140+ float * corr_dims ,
141+ const uint32_t ne0 ,
142+ const float ext_factor ,
143+ const float mscale ,
144+ float * cache ,
145+ const float theta_scale ) {
146+ const int sect_dims = sections [0 ] + sections [1 ] + sections [2 ] + sections [3 ];
147+ const int sec_w = sections [0 ] + sections [1 ];
148+ const int sec_e = sec_w + sections [2 ];
149+
150+ float theta_t = pos_t ;
151+ float theta_h = pos_h ;
152+ float theta_w = pos_w ;
153+ float theta_e = pos_e ;
110154
111- // Get n-d magnitude scaling corrected for interpolation
112- mscale_final *= 1.0f + 0.1f * logf (1.0f / freq_scale );
155+ for (uint32_t i0 = 0 ; i0 < ne0 ; i0 += 2 ) {
156+ const float ff = freq_factors ? freq_factors [i0 / 2 ] : 1.0f ;
157+ const int sector = (i0 / 2 ) % sect_dims ;
158+
159+ float theta ;
160+ if (is_imrope ) {
161+ // Interleaved: sector mod 3 selects component
162+ if (sector % 3 == 0 && sector < 3 * sections [0 ]) { theta = theta_t ; }
163+ else if (sector % 3 == 1 && sector < 3 * sections [1 ]) { theta = theta_h ; }
164+ else if (sector % 3 == 2 && sector < 3 * sections [2 ]) { theta = theta_w ; }
165+ else { theta = theta_e ; }
166+ } else {
167+ // Contiguous sections
168+ if (sector < sections [0 ]) { theta = theta_t ; }
169+ else if (sector < sec_w ) { theta = theta_h ; }
170+ else if (sector < sec_e ) { theta = theta_w ; }
171+ else { theta = theta_e ; }
113172 }
114173
115- cache [i0 + 0 ] = cosf (theta_final ) * mscale_final ;
116- cache [i0 + 1 ] = sinf (theta_final ) * mscale_final ;
174+ rope_yarn_one (theta / ff , freq_scale , corr_dims , i0 , ext_factor , mscale , cache );
117175
118- theta *= theta_scale ;
176+ theta_t *= theta_scale ;
177+ theta_h *= theta_scale ;
178+ theta_w *= theta_scale ;
179+ theta_e *= theta_scale ;
119180 }
120181}
121182
@@ -274,7 +335,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
274335 uint64_t tt = HAP_perf_get_qtimer_count ();
275336
276337 const int32_t mode = rctx -> mode ;
277- const bool is_neox = mode & HTP_ROPE_TYPE_NEOX ;
338+ // MROPE and IMROPE use NEOX-style pairing for the rotation
339+ const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX ) || (mode & HTP_ROPE_TYPE_MROPE );
278340
279341 // VTCM setup
280342 uint8_t * src0_spad_base = octx -> src0_spad .data + (ith * octx -> src0_spad .size_per_thread );
@@ -326,8 +388,25 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
326388 if (i2 != prev_i2 ) {
327389 prev_i2 = i2 ;
328390
329- const int32_t p = pos [i2 ];
330- rope_cache_init (p , rctx -> freq_scale , freq_factors , rctx -> corr_dims , ne0 , rctx -> ext_factor , rctx -> attn_factor , theta_cache , rctx -> theta_scale );
391+ const bool is_mrope = (rctx -> mode & HTP_ROPE_TYPE_MROPE ) != 0 ;
392+ if (is_mrope ) {
393+ // src1 holds four position arrays stacked along ne0:
394+ // pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3]
395+ const bool is_imrope = (rctx -> mode == HTP_ROPE_TYPE_IMROPE );
396+ mrope_cache_init (
397+ (float ) pos [i2 ],
398+ (float ) pos [i2 + ne2 ],
399+ (float ) pos [i2 + ne2 * 2 ],
400+ (float ) pos [i2 + ne2 * 3 ],
401+ rctx -> sections , is_imrope ,
402+ rctx -> freq_scale , freq_factors , rctx -> corr_dims ,
403+ ne0 , rctx -> ext_factor , rctx -> attn_factor ,
404+ theta_cache , rctx -> theta_scale );
405+ } else {
406+ rope_cache_init (pos [i2 ], rctx -> freq_scale , freq_factors , rctx -> corr_dims ,
407+ ne0 , rctx -> ext_factor , rctx -> attn_factor ,
408+ theta_cache , rctx -> theta_scale );
409+ }
331410
332411 // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
333412 // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
0 commit comments