@@ -46,6 +46,33 @@ create_longrope_config(const std::shared_ptr<config::ModelConfig> &cfg) {
4646 factor);
4747}
4848
49+ /* *
50+ * @brief Creator function for Llama3 RoPE scaling configuration.
51+ * Extracts 'factor', 'low_freq_factor', 'high_freq_factor', and
52+ * 'original_max_position_embeddings' from the model config.
53+ */
54+ std::shared_ptr<infinicore::nn::RopeScalingConfig>
55+ create_llama3_scaling_config (const std::shared_ptr<config::ModelConfig> &cfg) {
56+ const auto &rope_scaling = cfg->get_config_json ()[" rope_scaling" ];
57+
58+ // Validate required fields for Llama3 scaling
59+ if (!rope_scaling.contains (" factor" ) || !rope_scaling.contains (" low_freq_factor" ) || !rope_scaling.contains (" high_freq_factor" ) || !rope_scaling.contains (" original_max_position_embeddings" )) {
60+ throw std::runtime_error (
61+ " Llama3RopeScalingConfig requires 'factor', 'low_freq_factor', 'high_freq_factor', and 'original_max_position_embeddings'" );
62+ }
63+
64+ float factor = rope_scaling[" factor" ].get <float >();
65+ float low_freq_factor = rope_scaling[" low_freq_factor" ].get <float >();
66+ float high_freq_factor = rope_scaling[" high_freq_factor" ].get <float >();
67+ size_t original_max_position_embeddings = rope_scaling[" original_max_position_embeddings" ].get <size_t >();
68+
69+ return std::make_shared<infinicore::nn::Llama3RopeScalingConfig>(
70+ factor,
71+ low_freq_factor,
72+ high_freq_factor,
73+ original_max_position_embeddings);
74+ }
75+
4976// Future scaling creators go here (e.g., create_llama3, create_linear)
5077
5178} // anonymous namespace
@@ -58,8 +85,7 @@ static bool _registered = []() {
5885 registry[" none" ] = create_default_scaling_config;
5986 registry[" dynamic" ] = create_default_scaling_config;
6087 registry[" longrope" ] = create_longrope_config;
61- // add new scaling
62- // registry["llama3"] = create_llama3_scaling;
88+ registry[" llama3" ] = create_llama3_scaling_config;
6389 return true ;
6490}();
6591
0 commit comments