Skip to content

Commit 28dce8c

Browse files
author
rubik
committed
issue/401 refactor(rope): add scaling factory and TP-safe RoPE cache
- Decouple scaling config instantiation from ModelConfig via factory and registry pattern. - Add thread-local RoPE cache with device-scoped keys to reduce VRAM usage and ensure TP safety. - Centralize rotary dimension calculation into ModelConfig.
1 parent c7e8420 commit 28dce8c

13 files changed

Lines changed: 240 additions & 103 deletions

csrc/config/model_config.cpp

Lines changed: 18 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -25,56 +25,6 @@ ModelConfig::get_quant_scheme() const {
2525
}
2626
}
2727

28-
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
29-
ModelConfig::get_rope_scaling() const {
30-
if (!config_json.contains("rope_scaling") || config_json["rope_scaling"].is_null()) {
31-
return nullptr;
32-
}
33-
34-
const auto &rope_scaling = config_json["rope_scaling"];
35-
if (!rope_scaling.is_object()) {
36-
throw std::runtime_error("rope_scaling must be an object");
37-
}
38-
39-
std::string type_str;
40-
if (rope_scaling.contains("type")) {
41-
type_str = rope_scaling["type"].get<std::string>();
42-
} else if (rope_scaling.contains("rope_type")) {
43-
type_str = rope_scaling["rope_type"].get<std::string>();
44-
} else {
45-
throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field");
46-
}
47-
48-
if (type_str == "longrope") {
49-
// Required fields for LongRopeConfig
50-
if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
51-
throw std::runtime_error(
52-
"LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'");
53-
}
54-
55-
auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
56-
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
57-
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();
58-
59-
float factor = 1.0f;
60-
if (rope_scaling.contains("factor")) {
61-
factor = rope_scaling["factor"].get<float>();
62-
}
63-
64-
return std::make_shared<infinicore::nn::RoPE::LongRopeConfig>(
65-
std::move(short_factor),
66-
std::move(long_factor),
67-
original_max_position_embeddings,
68-
factor);
69-
} else if (type_str == "default" || type_str == "none" || type_str == "dynamic") {
70-
// Default scaling, no scaling applied
71-
// Currently not handling extended sequence lengths for dynamic scaling. Add specific branches when needed.
72-
return nullptr;
73-
} else {
74-
throw std::runtime_error("Unsupported rope_scaling type: " + type_str);
75-
}
76-
}
77-
7828
infinicore::DataType ModelConfig::get_dtype() const {
7929
std::string dtype_str;
8030
if (config_json.contains("dtype")) {
@@ -88,6 +38,24 @@ infinicore::DataType ModelConfig::get_dtype() const {
8838
return parse_dtype(dtype_str);
8939
}
9040

41+
size_t ModelConfig::get_rotary_dim() const {
42+
size_t head_dim = get_head_dim();
43+
double partial_rotary_factor = get_or<double>("partial_rotary_factor", 1.0);
44+
45+
if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0) {
46+
return head_dim;
47+
}
48+
49+
size_t rotary_dim = static_cast<size_t>(std::llround(
50+
static_cast<double>(head_dim) * partial_rotary_factor));
51+
rotary_dim = std::clamp(rotary_dim, static_cast<size_t>(2), head_dim);
52+
53+
if (rotary_dim % 2 != 0) {
54+
rotary_dim -= 1;
55+
}
56+
return std::max(rotary_dim, static_cast<size_t>(2));
57+
}
58+
9159
std::ostream &operator<<(std::ostream &os, const ModelConfig &config) {
9260
os << config.config_json.dump(4);
9361
return os;

csrc/config/model_config.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class ModelConfig {
5858
return get<size_t>("hidden_size") / get<size_t>("num_attention_heads");
5959
}
6060

61+
// Compute the actual rotary dimension based on partial rotation factor
62+
size_t get_rotary_dim() const;
63+
6164
QuantConfig get_quant_config() const {
6265
return quant_config;
6366
}
@@ -68,7 +71,7 @@ class ModelConfig {
6871

6972
infinicore::DataType get_dtype() const;
7073
infinilm::quantization::QuantScheme get_quant_scheme() const;
71-
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;
74+
7275
void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) {
7376
this->quant_config.set_kv_quant_scheme(kv_cache_dtype);
7477
}
@@ -102,8 +105,18 @@ class ModelConfig {
102105
// Stream output operator
103106
friend std::ostream &operator<<(std::ostream &os, const ModelConfig &config);
104107

108+
infinicore::nn::RoPE::Algo get_rope_algo() const {
109+
return rope_algo_;
110+
}
111+
112+
void set_rope_algo(infinicore::nn::RoPE::Algo algo) {
113+
rope_algo_ = algo;
114+
}
115+
105116
private:
106117
nlohmann::json config_json;
107118
QuantConfig quant_config;
119+
120+
infinicore::nn::RoPE::Algo rope_algo_ = infinicore::nn::RoPE::Algo::GPT_NEOX;
108121
};
109122
} // namespace infinilm::config
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include "../../config/model_config.hpp"
2+
#include "infinicore/nn/rope_scaling_configs.hpp"
3+
#include "rotary_embedding_factory.hpp"
4+
#include <vector>
5+
6+
namespace infinilm::layers::rotary_embedding {
7+
namespace {
8+
/**
9+
* @brief Default creator for types that apply no scaling.
10+
* Returns nullptr, which the InfiniCore RoPE layer interprets as a 1.0x pass-through.
11+
*/
12+
std::shared_ptr<infinicore::nn::RopeScalingConfig>
13+
create_default_scaling_config(const std::shared_ptr<config::ModelConfig> &) {
14+
return nullptr;
15+
}
16+
17+
// TODO(rubik) create_dynamic_scaling
18+
19+
/**
20+
* @brief Creator function for LongRoPE scaling configuration.
21+
* Extracts 'short_factor', 'long_factor', etc., from the model config.
22+
*/
23+
std::shared_ptr<infinicore::nn::RopeScalingConfig>
24+
create_longrope_config(const std::shared_ptr<config::ModelConfig> &cfg) {
25+
const auto &rope_scaling = cfg->get_config_json()["rope_scaling"];
26+
27+
// Required fields for LongRopeConfig
28+
if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
29+
throw std::runtime_error(
30+
"LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'");
31+
}
32+
33+
auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
34+
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
35+
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();
36+
37+
float factor = 1.0f;
38+
if (rope_scaling.contains("factor")) {
39+
factor = rope_scaling["factor"].get<float>();
40+
}
41+
42+
return std::make_shared<infinicore::nn::LongRopeScalingConfig>(
43+
std::move(short_factor),
44+
std::move(long_factor),
45+
original_max_position_embeddings,
46+
factor);
47+
}
48+
49+
// Future scaling creators go here (e.g., create_llama3, create_linear)
50+
51+
} // anonymous namespace
52+
53+
// Static self-registration block
54+
// Registers creator functions into the factory registry upon program startup.
55+
static bool _registered = []() {
56+
auto &registry = get_scaling_registry();
57+
registry["default"] = create_default_scaling_config;
58+
registry["none"] = create_default_scaling_config;
59+
registry["dynamic"] = create_default_scaling_config;
60+
registry["longrope"] = create_longrope_config;
61+
// add new scaling
62+
// registry["llama3"] = create_llama3_scaling;
63+
return true;
64+
}();
65+
66+
} // namespace infinilm::layers::rotary_embedding

csrc/layers/rotary_embedding/rotary_embedding.cpp

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,44 @@
11
#include "rotary_embedding.hpp"
2-
#include <algorithm> // std::clamp
3-
#include <cmath> // std::llround
2+
#include "../../config/model_config.hpp"
3+
#include "rotary_embedding_factory.hpp"
4+
#include <memory>
45
#include <string>
5-
#include <unordered_map>
66

77
namespace infinilm::layers::rotary_embedding {
8-
namespace {
9-
thread_local std::unordered_map<std::string, std::shared_ptr<infinicore::nn::RoPE>> _ROPE_DICT;
10-
} // namespace
118

12-
size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor) {
13-
if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0) {
14-
return head_dim;
15-
}
9+
// Cache dictionary to avoid redundant allocations of RoPE instances.
10+
// thread_local ensures it is only visible within this compilation unit.
11+
thread_local std::unordered_map<std::string, std::shared_ptr<infinicore::nn::RoPE>> _ROPE_DICT;
1612

17-
size_t rotary_dim = static_cast<size_t>(std::llround(
18-
static_cast<double>(head_dim) * partial_rotary_factor));
19-
rotary_dim = std::clamp(rotary_dim, static_cast<size_t>(2), head_dim);
13+
std::shared_ptr<infinicore::nn::RoPE>
14+
get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
15+
const infinicore::Device &device) {
2016

21-
// RoPE operates on complex pairs, so the rotary dimension must be even
22-
if (rotary_dim % 2 != 0) {
23-
rotary_dim -= 1;
24-
}
25-
return std::max(rotary_dim, static_cast<size_t>(2));
26-
}
27-
28-
std::shared_ptr<infinicore::nn::RoPE> get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
29-
const infinicore::Device &device,
30-
infinicore::nn::RoPE::Algo algo) {
31-
// 1. Get head dimension
17+
// 1. Compute the actual rotary dimension
18+
size_t rotary_dim = model_config->get_rotary_dim();
3219
size_t head_dim = model_config->get_head_dim();
3320

34-
// 2. Safely get partial_rotary_factor, defaulting to 1.0 (full rotation)
35-
double partial_rotary_factor = model_config->get_or<double>("partial_rotary_factor", 1.0);
36-
37-
// 3. Compute the actual rotary dimension
38-
size_t rotary_dim = get_rotary_dim(head_dim, partial_rotary_factor);
21+
// 2. Resolve scaling config via the internal factory
22+
auto scaling = make_scaling_config(model_config);
23+
24+
// 3. Cache key must include rotary_dim AND the actual scaling type
25+
// to avoid reusing the same RoPE instance across models with different settings
26+
// (Enhancement: dynamically determine scaling_type instead of hardcoding "default")
27+
std::string scaling_type_str = "default";
28+
if (scaling) {
29+
// Assuming we can get the type string from the JSON for cache key generation,
30+
// or ideally, ScalingConfig should have a virtual std::string type_name() const method.
31+
// Here we read it from JSON for the cache key purpose only, keeping it decoupled from InfiniCore.
32+
const auto &rope_scaling_json = model_config->get_config_json()["rope_scaling"];
33+
if (rope_scaling_json.contains("type")) {
34+
scaling_type_str = rope_scaling_json["type"].get<std::string>();
35+
} else if (rope_scaling_json.contains("rope_type")) {
36+
scaling_type_str = rope_scaling_json["rope_type"].get<std::string>();
37+
}
38+
}
3939

40-
// 4. Cache key must include rotary_dim to avoid reusing the same RoPE instance
41-
// across models with different partial_rotary_factor values
42-
const std::string scaling_type = "default";
43-
std::string cache_key = scaling_type + "_rope_dim_" + std::to_string(rotary_dim);
40+
std::string cache_key = scaling_type_str + "_rope_dim_" + std::to_string(rotary_dim)
41+
+ "_dev_" + device.toString();
4442
auto it = _ROPE_DICT.find(cache_key);
4543
if (it != _ROPE_DICT.end()) {
4644
return it->second;
@@ -49,9 +47,10 @@ std::shared_ptr<infinicore::nn::RoPE> get_rope(const std::shared_ptr<infinilm::c
4947
const auto &dtype = model_config->get_dtype();
5048
size_t max_position_embeddings = model_config->get<size_t>("max_position_embeddings");
5149
double rope_theta = model_config->get<double>("rope_theta");
52-
auto rope = std::make_shared<infinicore::nn::RoPE>(rotary_dim, max_position_embeddings, rope_theta,
53-
algo, dtype, device,
54-
model_config->get_rope_scaling());
50+
51+
infinicore::nn::RoPE::Algo algo = model_config->get_rope_algo();
52+
auto rope = std::make_shared<infinicore::nn::RoPE>(head_dim, rotary_dim, max_position_embeddings, rope_theta,
53+
algo, dtype, device, scaling);
5554

5655
_ROPE_DICT.emplace(cache_key, rope);
5756
return rope;
Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
#pragma once
22

3-
#include "../../config/model_config.hpp"
43
#include "infinicore/nn/rope.hpp"
54
#include <memory>
65

7-
namespace infinilm::layers::rotary_embedding {
6+
namespace infinilm::config {
7+
class ModelConfig; // Forward declaration
8+
}
89

9-
// Compute the actual number of dimensions involved in rotary position embedding.
10-
// For partial rotation, the dimension is clamped to [2, head_dim] and must be even.
11-
size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor);
10+
namespace infinilm::layers::rotary_embedding {
1211

13-
std::shared_ptr<infinicore::nn::RoPE> get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
14-
const infinicore::Device &device,
15-
infinicore::nn::RoPE::Algo algo = infinicore::nn::RoPE::Algo::GPT_NEOX);
12+
/**
13+
* @brief Public API to assemble and construct a complete RoPE module.
14+
*
15+
* @param model_config Model configuration.
16+
* @param device Device to create the cache on.
17+
* @param algo RoPE algorithm type (default: Algo::GPT_NEOX).
18+
*/
19+
std::shared_ptr<infinicore::nn::RoPE>
20+
get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
21+
const infinicore::Device &device);
1622

1723
} // namespace infinilm::layers::rotary_embedding
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "rotary_embedding_factory.hpp"
2+
#include "../../config/model_config.hpp"
3+
#include <stdexcept>
4+
5+
namespace infinilm::layers::rotary_embedding {
6+
7+
std::unordered_map<std::string, ScalingCreator> &get_scaling_registry() {
8+
static std::unordered_map<std::string, ScalingCreator> registry;
9+
return registry;
10+
}
11+
12+
std::shared_ptr<infinicore::nn::RopeScalingConfig>
13+
make_scaling_config(const std::shared_ptr<config::ModelConfig> &model_config) {
14+
if (!model_config || !model_config->get_config_json().contains("rope_scaling") || model_config->get_config_json()["rope_scaling"].is_null()) {
15+
return nullptr;
16+
}
17+
18+
const auto &rope_scaling = model_config->get_config_json()["rope_scaling"];
19+
if (!rope_scaling.is_object()) {
20+
throw std::runtime_error("rope_scaling must be an object");
21+
}
22+
23+
std::string scaling_type;
24+
if (rope_scaling.contains("type")) {
25+
scaling_type = rope_scaling["type"].get<std::string>();
26+
} else if (rope_scaling.contains("rope_type")) {
27+
scaling_type = rope_scaling["rope_type"].get<std::string>();
28+
} else {
29+
throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field");
30+
}
31+
32+
// Registry routing: delegate construction to the specific creator
33+
auto &registry = get_scaling_registry();
34+
auto it = registry.find(scaling_type);
35+
if (it != registry.end()) {
36+
return it->second(model_config);
37+
}
38+
39+
throw std::runtime_error("Unsupported rope_scaling_type: " + scaling_type);
40+
}
41+
42+
} // namespace infinilm::layers::rotary_embedding
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include "infinicore/nn/rope.hpp"
4+
#include "infinicore/nn/rope_scaling_configs.hpp"
5+
#include <functional>
6+
#include <memory>
7+
#include <string>
8+
#include <unordered_map>
9+
10+
namespace infinilm::config {
11+
class ModelConfig; // Forward declaration
12+
}
13+
14+
namespace infinilm::layers::rotary_embedding {
15+
16+
/**
17+
* @brief Function pointer type for creating specific RopeScalingConfig instances.
18+
* Implementations should extract parameters from ModelConfig and construct the corresponding Config object.
19+
*/
20+
using ScalingCreator = std::function<std::shared_ptr<infinicore::nn::RopeScalingConfig>(
21+
const std::shared_ptr<infinilm::config::ModelConfig> &)>;
22+
23+
/**
24+
* @brief Get the singleton registry mapping scaling type strings to their creator functions.
25+
*/
26+
std::unordered_map<std::string, ScalingCreator> &get_scaling_registry();
27+
28+
/**
29+
* @brief Factory method to create a RopeScalingConfig based on the ModelConfig.
30+
* Routes the "rope_scaling_type" string to the corresponding registered creator.
31+
*/
32+
std::shared_ptr<infinicore::nn::RopeScalingConfig>
33+
make_scaling_config(const std::shared_ptr<infinilm::config::ModelConfig> &model_config);
34+
35+
} // namespace infinilm::layers::rotary_embedding

csrc/models/chatglm/chatglm_for_causal_lm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ std::shared_ptr<infinilm::config::ModelConfig> create_chatglm_model_config(
3737
config_json["rope_theta"] = 10000.0;
3838
}
3939

40+
// Use GPT-J style RoPE (interleaved dimensions) for chatglm, chatglm/GLM4 share same attention layer
41+
model_config->set_rope_algo(infinicore::nn::RoPE::Algo::GPT_J);
42+
4043
return model_config;
4144
}
4245

0 commit comments

Comments
 (0)