Skip to content

Commit 5957f77

Browse files
cavusmustafawine99
authored andcommitted
Imrope support (ggml-org#126)
1 parent 77bd354 commit 5957f77

4 files changed

Lines changed: 101 additions & 41 deletions

File tree

ggml/src/ggml-openvino/ggml-openvino.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
884884
const int32_t * op_params = op->op_params;
885885
const int n_dims = op_params[1];
886886
const int mode = op_params[2];
887-
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
887+
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_IMROPE) {
888888
// GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode);
889889
return true;
890890
}
@@ -906,6 +906,12 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
906906
return true;
907907
}
908908
}
909+
if (mode == GGML_ROPE_TYPE_IMROPE &&
910+
(op->src[2] != 0 || ((const float *) op_params)[6] != 1 || ((const float *) op_params)[7] != 0 ||
911+
((const float *) op_params)[8] != 1)) {
912+
// GGML_LOG_WARN("OpenVINO backend does not support IMROPE with freq_factors, freq_scale, ext_factor, and attn_factor\n");
913+
return true;
914+
}
909915
break;
910916
}
911917
default:

ggml/src/ggml-openvino/openvino/op/rope.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@
99
#include <openvino/op/add.hpp>
1010
#include <openvino/op/concat.hpp>
1111
#include <openvino/op/constant.hpp>
12+
#include <openvino/op/convert.hpp>
13+
#include <openvino/op/cos.hpp>
14+
#include <openvino/op/gather.hpp>
1215
#include <openvino/op/multiply.hpp>
1316
#include <openvino/op/reshape.hpp>
1417
#include <openvino/op/shape_of.hpp>
18+
#include <openvino/op/sin.hpp>
1519
#include <openvino/op/slice.hpp>
1620
#include <openvino/op/split.hpp>
1721
#include <openvino/op/subtract.hpp>
22+
#include <openvino/op/transpose.hpp>
1823
#include <openvino/op/unsqueeze.hpp>
1924
#include <vector>
2025

@@ -33,6 +38,11 @@ OutputVector translate_rope(const NodeContext & context) {
3338
auto data_node = context.get_input(0).get_node_shared_ptr();
3439
auto output_shape = context.get_output_shape().to_shape();
3540
int32_t * op_params = context.get_output_op_params();
41+
const int mode = op_params[2];
42+
43+
constexpr int ROPE_TYPE_NORMAL = 0;
44+
constexpr int ROPE_TYPE_NEOX = 2;
45+
constexpr int ROPE_TYPE_IMROPE = 40;
3646

3747
Output<Node> cos_theta_node;
3848
Output<Node> sin_theta_node;
@@ -45,7 +55,7 @@ OutputVector translate_rope(const NodeContext & context) {
4555
if (context.get_input_size() == 3) {
4656
rope_freqs_weight = context.get_input(2).get_node_shared_ptr();
4757
}
48-
auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight);
58+
auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight, mode == ROPE_TYPE_IMROPE);
4959
sin_theta_node = sin_cos.first;
5060
cos_theta_node = sin_cos.second;
5161
}
@@ -65,10 +75,6 @@ OutputVector translate_rope(const NodeContext & context) {
6575
}
6676
}
6777

68-
const int mode = op_params[2];
69-
constexpr int ROPE_TYPE_NORMAL = 0;
70-
constexpr int ROPE_TYPE_NEOX = 2;
71-
7278
if (mode == ROPE_TYPE_NORMAL) {
7379
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
7480
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
@@ -112,6 +118,25 @@ OutputVector translate_rope(const NodeContext & context) {
112118
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
113119

114120
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1);
121+
} else if (mode == ROPE_TYPE_IMROPE) {
122+
int64_t n_dims = data_node->get_shape()[3];
123+
auto cos_sin_shape = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{1,-1,1,(n_dims >> 1)});
124+
auto cos_reshaped = std::make_shared<ov::op::v1::Reshape>(cos_theta_node, cos_sin_shape, true);
125+
auto sin_reshaped = std::make_shared<ov::op::v1::Reshape>(sin_theta_node, cos_sin_shape, true);
126+
127+
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3});
128+
auto split_a = std::make_shared<ov::op::v1::Split>(data_node, split_axis, 2);
129+
auto x0 = split_a->output(0);
130+
auto x1 = split_a->output(1);
131+
auto mul_a = std::make_shared<ov::op::v1::Multiply>(x0, cos_reshaped);
132+
auto mul_b = std::make_shared<ov::op::v1::Multiply>(x1, sin_reshaped);
133+
auto sub = std::make_shared<ov::op::v1::Subtract>(mul_a, mul_b);
134+
135+
auto mul_c = std::make_shared<ov::op::v1::Multiply>(x0, sin_reshaped);
136+
auto mul_d = std::make_shared<ov::op::v1::Multiply>(x1, cos_reshaped);
137+
auto add = std::make_shared<ov::op::v1::Add>(mul_c, mul_d);
138+
139+
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{sub, add}, 3);
115140
}
116141

117142
return rename_outputs_with_suffix({res}, context.get_name());

ggml/src/ggml-openvino/openvino/utils.cpp

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "ggml-impl.h"
44

5+
#include <cmath>
56
#include <cstddef>
67
#include <ctime>
78
#include <memory>
@@ -13,6 +14,7 @@
1314
#include <openvino/op/gather.hpp>
1415
#include <openvino/op/maximum.hpp>
1516
#include <openvino/op/multiply.hpp>
17+
#include <openvino/op/reshape.hpp>
1618
#include <openvino/op/shape_of.hpp>
1719
#include <openvino/op/sin.hpp>
1820
#include <openvino/op/squeeze.hpp>
@@ -118,13 +120,21 @@ void ggml_rope_yarn_corr_dims(int n_dims,
118120
std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params,
119121
std::shared_ptr<ov::Node> inp_pos,
120122
std::shared_ptr<ov::Node> rope_freqs_weight,
123+
bool imrope,
121124
bool stateful) {
122125
if (stateful) {
123126
inp_pos = std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
124127
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
125128
auto pos_perm =
126129
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{2, 1, 0});
127130
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);
131+
} else if (imrope) {
132+
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
133+
auto pos_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 0, 0, 4, -1});
134+
inp_pos = std::make_shared<ov::op::v1::Reshape>(inp_pos, pos_shape, true);
135+
auto pos_transpose_shape =
136+
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{5}, std::vector<int64_t>{0, 1, 2, 4, 3});
137+
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_transpose_shape);
128138
} else {
129139
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
130140
auto pos_perm =
@@ -139,6 +149,7 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
139149
float beta_fast;
140150
float beta_slow;
141151
const int n_dims = rope_params[1];
152+
const size_t n_dims_half = n_dims >> 1;
142153
const int n_ctx_orig = rope_params[4];
143154
memcpy(&freq_base, rope_params + 5, sizeof(float));
144155
memcpy(&freq_scale, rope_params + 6, sizeof(float));
@@ -149,57 +160,74 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
149160

150161
const float theta_scale = powf(freq_base, -2.0f / n_dims);
151162

152-
float corr_dims[2];
153-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
154-
155-
std::vector<float> factor(n_dims / 2);
156-
factor[0] = 1.0f;
157-
for (size_t i = 1; i < factor.size(); i++) {
158-
factor[i] = theta_scale * factor[i - 1];
159-
}
163+
std::vector<float> factor(n_dims_half);
160164

161165
Output<Node> freq_factors;
162-
if (stateful) {
163-
freq_factors =
164-
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor);
165-
} else {
166-
freq_factors =
167-
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);
168-
}
169-
if (rope_freqs_weight) {
170-
freq_factors = std::make_shared<ov::op::v1::Divide>(freq_factors, rope_freqs_weight);
171-
}
172-
173-
auto theta_extrap = std::make_shared<ov::op::v1::Multiply>(freq_factors, inp_pos);
174-
auto theta_interp = std::make_shared<ov::op::v1::Multiply>(
175-
theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale}));
176166

177167
Output<Node> theta;
178168
float mscale = attn_factor;
179-
if (ext_factor == 0.0f) {
180-
theta = theta_interp;
169+
if (imrope) {
170+
std::vector<int64_t> gather_indices(n_dims_half);
171+
for (size_t j = 0; j < n_dims_half; j++) {
172+
gather_indices[j] = j % 3;
173+
factor[j] = std::pow(theta_scale, j);
174+
}
175+
auto gather_indices_const =
176+
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{n_dims_half}, gather_indices);
177+
auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {4});
178+
inp_pos = std::make_shared<ov::op::v8::Gather>(inp_pos, gather_indices_const, gather_axis);
179+
auto factor_const = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{n_dims_half}, factor);
180+
theta = std::make_shared<ov::op::v1::Multiply>(inp_pos, factor_const);
181181
} else {
182-
auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor);
183-
Output<Node> one;
182+
float corr_dims[2];
183+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
184+
factor[0] = 1.0f;
185+
for (size_t i = 1; i < factor.size(); i++) {
186+
factor[i] = theta_scale * factor[i - 1];
187+
}
184188
if (stateful) {
185-
one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f});
189+
freq_factors =
190+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor);
186191
} else {
187-
one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});
192+
freq_factors =
193+
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);
194+
}
195+
if (rope_freqs_weight) {
196+
freq_factors = std::make_shared<ov::op::v1::Divide>(freq_factors, rope_freqs_weight);
188197
}
189-
auto one_minus_ramp = std::make_shared<ov::op::v1::Subtract>(one, ramp_mix);
190198

191-
theta = std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(theta_interp, one_minus_ramp),
192-
std::make_shared<ov::op::v1::Multiply>(theta_extrap, ramp_mix));
193-
mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale));
199+
auto theta_extrap = std::make_shared<ov::op::v1::Multiply>(freq_factors, inp_pos);
200+
auto theta_interp = std::make_shared<ov::op::v1::Multiply>(
201+
theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale}));
202+
203+
if (ext_factor == 0.0f) {
204+
theta = theta_interp;
205+
} else {
206+
auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor);
207+
Output<Node> one;
208+
if (stateful) {
209+
one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f});
210+
} else {
211+
one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});
212+
}
213+
auto one_minus_ramp = std::make_shared<ov::op::v1::Subtract>(one, ramp_mix);
214+
215+
theta = std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(theta_interp, one_minus_ramp),
216+
std::make_shared<ov::op::v1::Multiply>(theta_extrap, ramp_mix));
217+
mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale));
218+
}
194219
}
195220

196221
Output<Node> cos_theta = std::make_shared<ov::op::v0::Cos>(theta);
197222
Output<Node> sin_theta = std::make_shared<ov::op::v0::Sin>(theta);
198223

199-
auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale});
224+
if (!imrope) {
225+
auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale});
226+
227+
cos_theta = std::make_shared<ov::op::v1::Multiply>(cos_theta, mscale_node);
228+
sin_theta = std::make_shared<ov::op::v1::Multiply>(sin_theta, mscale_node);
229+
}
200230

201-
cos_theta = std::make_shared<ov::op::v1::Multiply>(cos_theta, mscale_node);
202-
sin_theta = std::make_shared<ov::op::v1::Multiply>(sin_theta, mscale_node);
203231
return std::make_pair(sin_theta, cos_theta);
204232
}
205233

ggml/src/ggml-openvino/openvino/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::
6767
std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t* rope_params,
6868
std::shared_ptr<ov::Node> inp_pos,
6969
std::shared_ptr<ov::Node> rope_freqs_weight = nullptr,
70+
bool imrope = false,
7071
bool stateful = false);
7172

7273
ov::Output<ov::Node> process_view_input(const NodeContext& context, int input_index, int slice_len = 0);

0 commit comments

Comments
 (0)