Skip to content

Commit 4c3e266

Browse files
authored
Merge pull request #413 from rubik-hua/llama3_rope
issue/392 [Feature](rope): wire up Llama3 RoPE scaling config creator
2 parents ebe20ae + 05aa611 commit 4c3e266

1 file changed

Lines changed: 28 additions & 2 deletions

File tree

csrc/layers/rotary_embedding/rope_scaling_creators.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,33 @@ create_longrope_config(const std::shared_ptr<config::ModelConfig> &cfg) {
4646
factor);
4747
}
4848

49+
/**
50+
* @brief Creator function for Llama3 RoPE scaling configuration.
51+
* Extracts 'factor', 'low_freq_factor', 'high_freq_factor', and
52+
* 'original_max_position_embeddings' from the model config.
53+
*/
54+
std::shared_ptr<infinicore::nn::RopeScalingConfig>
55+
create_llama3_scaling_config(const std::shared_ptr<config::ModelConfig> &cfg) {
56+
const auto &rope_scaling = cfg->get_config_json()["rope_scaling"];
57+
58+
// Validate required fields for Llama3 scaling
59+
if (!rope_scaling.contains("factor") || !rope_scaling.contains("low_freq_factor") || !rope_scaling.contains("high_freq_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
60+
throw std::runtime_error(
61+
"Llama3RopeScalingConfig requires 'factor', 'low_freq_factor', 'high_freq_factor', and 'original_max_position_embeddings'");
62+
}
63+
64+
float factor = rope_scaling["factor"].get<float>();
65+
float low_freq_factor = rope_scaling["low_freq_factor"].get<float>();
66+
float high_freq_factor = rope_scaling["high_freq_factor"].get<float>();
67+
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();
68+
69+
return std::make_shared<infinicore::nn::Llama3RopeScalingConfig>(
70+
factor,
71+
low_freq_factor,
72+
high_freq_factor,
73+
original_max_position_embeddings);
74+
}
75+
4976
// Future scaling creators go here (e.g., create_llama3, create_linear)
5077

5178
} // anonymous namespace
@@ -58,8 +85,7 @@ static bool _registered = []() {
5885
registry["none"] = create_default_scaling_config;
5986
registry["dynamic"] = create_default_scaling_config;
6087
registry["longrope"] = create_longrope_config;
61-
// add new scaling
62-
// registry["llama3"] = create_llama3_scaling;
88+
registry["llama3"] = create_llama3_scaling_config;
6389
return true;
6490
}();
6591

0 commit comments

Comments
 (0)