@@ -17,6 +17,47 @@ class RoPE : public Module {
1717 GPT_NEOX = 1 , // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
1818 };
1919
20+ enum class ScalingType {
21+ DEFAULT = 0 , // Default RoPE
22+ LONGROPE = 1 // Long-RoPE
23+ };
24+
25+ class ScalingConfig {
26+ public:
27+ virtual ~ScalingConfig () = default ;
28+ ScalingType type () const { return type_; }
29+
30+ protected:
31+ ScalingType type_ = ScalingType::DEFAULT;
32+ ScalingConfig (ScalingType type) : type_(type) {}
33+ };
34+
35+ // longrope scaling
36+ class LongRopeConfig : public ScalingConfig {
37+ protected:
38+ std::vector<float > short_factor_;
39+ std::vector<float > long_factor_;
40+ size_t original_max_position_embeddings_;
41+ float factor_;
42+
43+ public:
44+ LongRopeConfig (
45+ std::vector<float > short_factor,
46+ std::vector<float > long_factor,
47+ size_t original_max_position_embeddings,
48+ float factor = 1 .0f )
49+ : ScalingConfig(ScalingType::LONGROPE),
50+ short_factor_ (short_factor),
51+ long_factor_(long_factor),
52+ original_max_position_embeddings_(original_max_position_embeddings),
53+ factor_(factor == 1 .0f ? 1 .0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {}
54+ ~LongRopeConfig () override = default ;
55+ size_t original_max_position_embeddings () const { return original_max_position_embeddings_; }
56+ const std::vector<float > &short_factor () const { return short_factor_; }
57+ const std::vector<float > &long_factor () const { return long_factor_; }
58+ float factor () const { return factor_; }
59+ };
60+
2061 /* *
2162 * @brief Construct a RoPE layer
2263 *
@@ -26,13 +67,15 @@ class RoPE : public Module {
2667 * @param algo RoPE algorithm type (default: Algo::GPT_J)
2768 * @param dtype Data type for sin/cos cache (default: DataType::F32)
2869 * @param device Device to create the cache on
70+ * @param scaling RoPE scaling type (default: nullptr)
2971 */
3072 RoPE (size_t head_dim,
3173 size_t max_seq_len,
3274 double theta = 10000.0 ,
3375 Algo algo = Algo::GPT_J,
3476 const DataType &dtype = DataType::F32,
35- const Device &device = Device());
77+ const Device &device = Device(),
78+ std::shared_ptr<ScalingConfig> scaling = nullptr);
3679
3780 /* *
3881 * @brief Forward pass: apply RoPE to a tensor
@@ -88,11 +131,12 @@ class RoPE : public Module {
88131private:
89132 void initialize_cache ();
90133
91- size_t head_dim_; // Dimension of each attention head
92- size_t max_seq_len_; // Maximum sequence length
93- double theta_; // Base frequency for rotary embeddings
94- Algo algo_; // RoPE algorithm type
95- DataType dtype_; // Data type for cache tables
134+ size_t head_dim_; // Dimension of each attention head
135+ size_t max_seq_len_; // Maximum sequence length
136+ double theta_; // Base frequency for rotary embeddings
137+ Algo algo_; // RoPE algorithm type
138+ DataType dtype_; // Data type for cache tables
139+ std::shared_ptr<ScalingConfig> scaling_; // RoPE scaling type
96140};
97141
98142} // namespace infinicore::nn
0 commit comments