11#include " rotary_embedding.hpp"
2- #include < algorithm> // std::clamp
3- #include < cmath> // std::llround
2+ #include " ../../config/model_config.hpp"
3+ #include " rotary_embedding_factory.hpp"
4+ #include < memory>
45#include < string>
5- #include < unordered_map>
66
77namespace infinilm ::layers::rotary_embedding {
8- namespace {
9- thread_local std::unordered_map<std::string, std::shared_ptr<infinicore::nn::RoPE>> _ROPE_DICT;
10- } // namespace
118
12- size_t get_rotary_dim (size_t head_dim, double partial_rotary_factor) {
13- if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0 ) {
14- return head_dim;
15- }
9+ // Cache dictionary to avoid redundant allocations of RoPE instances.
10+ // thread_local ensures it is only visible within this compilation unit.
11+ thread_local std::unordered_map<std::string, std::shared_ptr<infinicore::nn::RoPE>> _ROPE_DICT;
1612
17- size_t rotary_dim = static_cast < size_t >( std::llround (
18- static_cast < double >(head_dim) * partial_rotary_factor));
19- rotary_dim = std::clamp (rotary_dim, static_cast < size_t >( 2 ), head_dim);
13+ std::shared_ptr<infinicore::nn::RoPE>
14+ get_rope ( const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
15+ const infinicore::Device &device) {
2016
21- // RoPE operates on complex pairs, so the rotary dimension must be even
22- if (rotary_dim % 2 != 0 ) {
23- rotary_dim -= 1 ;
24- }
25- return std::max (rotary_dim, static_cast <size_t >(2 ));
26- }
27-
28- std::shared_ptr<infinicore::nn::RoPE> get_rope (const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
29- const infinicore::Device &device,
30- infinicore::nn::RoPE::Algo algo) {
31- // 1. Get head dimension
17+ // 1. Compute the actual rotary dimension
18+ size_t rotary_dim = model_config->get_rotary_dim ();
3219 size_t head_dim = model_config->get_head_dim ();
3320
34- // 2. Safely get partial_rotary_factor, defaulting to 1.0 (full rotation)
35- double partial_rotary_factor = model_config->get_or <double >(" partial_rotary_factor" , 1.0 );
36-
37- // 3. Compute the actual rotary dimension
38- size_t rotary_dim = get_rotary_dim (head_dim, partial_rotary_factor);
21+ // 2. Resolve scaling config via the internal factory
22+ auto scaling = make_scaling_config (model_config);
23+
24+ // 3. Cache key must include rotary_dim AND the actual scaling type
25+ // to avoid reusing the same RoPE instance across models with different settings
26+ // (Enhancement: dynamically determine scaling_type instead of hardcoding "default")
27+ std::string scaling_type_str = " default" ;
28+ if (scaling) {
29+ // Assuming we can get the type string from the JSON for cache key generation,
30+ // or ideally, ScalingConfig should have a virtual std::string type_name() const method.
31+ // Here we read it from JSON for the cache key purpose only, keeping it decoupled from InfiniCore.
32+ const auto &rope_scaling_json = model_config->get_config_json ()[" rope_scaling" ];
33+ if (rope_scaling_json.contains (" type" )) {
34+ scaling_type_str = rope_scaling_json[" type" ].get <std::string>();
35+ } else if (rope_scaling_json.contains (" rope_type" )) {
36+ scaling_type_str = rope_scaling_json[" rope_type" ].get <std::string>();
37+ }
38+ }
3939
40- // 4. Cache key must include rotary_dim to avoid reusing the same RoPE instance
41- // across models with different partial_rotary_factor values
42- const std::string scaling_type = " default" ;
43- std::string cache_key = scaling_type + " _rope_dim_" + std::to_string (rotary_dim);
40+ std::string cache_key = scaling_type_str + " _rope_dim_" + std::to_string (rotary_dim)
41+ + " _dev_" + device.toString ();
4442 auto it = _ROPE_DICT.find (cache_key);
4543 if (it != _ROPE_DICT.end ()) {
4644 return it->second ;
@@ -49,9 +47,10 @@ std::shared_ptr<infinicore::nn::RoPE> get_rope(const std::shared_ptr<infinilm::c
4947 const auto &dtype = model_config->get_dtype ();
5048 size_t max_position_embeddings = model_config->get <size_t >(" max_position_embeddings" );
5149 double rope_theta = model_config->get <double >(" rope_theta" );
52- auto rope = std::make_shared<infinicore::nn::RoPE>(rotary_dim, max_position_embeddings, rope_theta,
53- algo, dtype, device,
54- model_config->get_rope_scaling ());
50+
51+ infinicore::nn::RoPE::Algo algo = model_config->get_rope_algo ();
52+ auto rope = std::make_shared<infinicore::nn::RoPE>(head_dim, rotary_dim, max_position_embeddings, rope_theta,
53+ algo, dtype, device, scaling);
5554
5655 _ROPE_DICT.emplace (cache_key, rope);
5756 return rope;
0 commit comments