Skip to content

Commit 7c97894

Browse files
authored
Merge pull request #921 from InfiniTensor/issue/920
issue/920 RoPE supports longrope
2 parents 97da993 + 06dcc06 commit 7c97894

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

include/infinicore/nn/rope.hpp

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,47 @@ class RoPE : public Module {
1717
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
1818
};
1919

20+
enum class ScalingType {
21+
DEFAULT = 0, // Default RoPE
22+
LONGROPE = 1 // Long-RoPE
23+
};
24+
25+
class ScalingConfig {
26+
public:
27+
virtual ~ScalingConfig() = default;
28+
ScalingType type() const { return type_; }
29+
30+
protected:
31+
ScalingType type_ = ScalingType::DEFAULT;
32+
ScalingConfig(ScalingType type) : type_(type) {}
33+
};
34+
35+
// longrope scaling
36+
class LongRopeConfig : public ScalingConfig {
37+
protected:
38+
std::vector<float> short_factor_;
39+
std::vector<float> long_factor_;
40+
size_t original_max_position_embeddings_;
41+
float factor_;
42+
43+
public:
44+
LongRopeConfig(
45+
std::vector<float> short_factor,
46+
std::vector<float> long_factor,
47+
size_t original_max_position_embeddings,
48+
float factor = 1.0f)
49+
: ScalingConfig(ScalingType::LONGROPE),
50+
short_factor_(short_factor),
51+
long_factor_(long_factor),
52+
original_max_position_embeddings_(original_max_position_embeddings),
53+
factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {}
54+
~LongRopeConfig() override = default;
55+
size_t original_max_position_embeddings() const { return original_max_position_embeddings_; }
56+
const std::vector<float> &short_factor() const { return short_factor_; }
57+
const std::vector<float> &long_factor() const { return long_factor_; }
58+
float factor() const { return factor_; }
59+
};
60+
2061
/**
2162
* @brief Construct a RoPE layer
2263
*
@@ -26,13 +67,15 @@ class RoPE : public Module {
2667
* @param algo RoPE algorithm type (default: Algo::GPT_J)
2768
* @param dtype Data type for sin/cos cache (default: DataType::F32)
2869
* @param device Device to create the cache on
70+
* @param scaling RoPE scaling type (default: nullptr)
2971
*/
3072
RoPE(size_t head_dim,
3173
size_t max_seq_len,
3274
double theta = 10000.0,
3375
Algo algo = Algo::GPT_J,
3476
const DataType &dtype = DataType::F32,
35-
const Device &device = Device());
77+
const Device &device = Device(),
78+
std::shared_ptr<ScalingConfig> scaling = nullptr);
3679

3780
/**
3881
* @brief Forward pass: apply RoPE to a tensor
@@ -88,11 +131,12 @@ class RoPE : public Module {
88131
private:
89132
void initialize_cache();
90133

91-
size_t head_dim_; // Dimension of each attention head
92-
size_t max_seq_len_; // Maximum sequence length
93-
double theta_; // Base frequency for rotary embeddings
94-
Algo algo_; // RoPE algorithm type
95-
DataType dtype_; // Data type for cache tables
134+
size_t head_dim_; // Dimension of each attention head
135+
size_t max_seq_len_; // Maximum sequence length
136+
double theta_; // Base frequency for rotary embeddings
137+
Algo algo_; // RoPE algorithm type
138+
DataType dtype_; // Data type for cache tables
139+
std::shared_ptr<ScalingConfig> scaling_; // RoPE scaling type
96140
};
97141

98142
} // namespace infinicore::nn

src/infinicore/nn/rope.cc

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ RoPE::RoPE(size_t head_dim,
1616
double theta,
1717
Algo algo,
1818
const DataType &dtype,
19-
const Device &device)
19+
const Device &device,
20+
std::shared_ptr<ScalingConfig> scaling)
2021
: head_dim_(head_dim),
2122
max_seq_len_(max_seq_len),
2223
theta_(theta),
2324
algo_(algo),
24-
dtype_(dtype) {
25+
dtype_(dtype),
26+
scaling_(scaling) {
2527
if (head_dim % 2 != 0) {
2628
throw std::invalid_argument("head_dim must be even for RoPE, got " + std::to_string(head_dim));
2729
}
@@ -54,14 +56,30 @@ void RoPE::initialize_cache() {
5456
for (size_t j = 0; j < cache_dim; j++) {
5557
// GPT-J style inverse frequency: theta^(-2j/head_dim)
5658
// Compute directly in float to avoid double->float casting
57-
float inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
59+
float inv_freq;
60+
float table_factor = 1.0f;
61+
if (scaling_ == nullptr) {
62+
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
63+
} else if (scaling_->type() == ScalingType::LONGROPE) {
64+
std::shared_ptr<LongRopeConfig> lr = std::dynamic_pointer_cast<LongRopeConfig>(scaling_);
65+
table_factor = lr->factor();
66+
float _ext;
67+
if (pos < lr->original_max_position_embeddings()) {
68+
_ext = lr->short_factor()[j];
69+
} else {
70+
_ext = lr->long_factor()[j];
71+
}
72+
inv_freq = 1.0f / (_ext * std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_)));
73+
} else {
74+
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
75+
}
5876

5977
// Compute angle: position * inverse_frequency
6078
float angle = static_cast<float>(pos) * inv_freq;
6179

6280
// Compute sin and cos directly on float
63-
sin_data[pos * cache_dim + j] = std::sin(angle);
64-
cos_data[pos * cache_dim + j] = std::cos(angle);
81+
sin_data[pos * cache_dim + j] = std::sin(angle) * table_factor;
82+
cos_data[pos * cache_dim + j] = std::cos(angle) * table_factor;
6583
}
6684
}
6785

0 commit comments

Comments
 (0)