From 6dcab377822534eee2cb193469938842a43fade1 Mon Sep 17 00:00:00 2001 From: rubik Date: Thu, 28 May 2026 16:51:04 +0800 Subject: [PATCH] issue/1180 refactor(nn): decouple RoPE scaling logic with polymorphic interfaces - Extract `RopeScalingConfig` and `LongRopeConfig` to dedicated `rope_scaling_configs.hpp/.cc` files. - Introduce `get_freq_scale` and `get_magnitude_scale` virtual methods to eliminate type-checking branches in the core `initialize_cache` loop. - Add `Llama3Config` skeleton for future Llama 3/3.1 support. --- include/infinicore/nn/rope.hpp | 60 +++---------- .../infinicore/nn/rope_scaling_configs.hpp | 89 +++++++++++++++++++ src/infinicore/nn/rope.cc | 40 +++------ src/infinicore/nn/rope_scaling_configs.cc | 46 ++++++++++ 4 files changed, 158 insertions(+), 77 deletions(-) create mode 100644 include/infinicore/nn/rope_scaling_configs.hpp create mode 100644 src/infinicore/nn/rope_scaling_configs.cc diff --git a/include/infinicore/nn/rope.hpp b/include/infinicore/nn/rope.hpp index 80568a0d1..acf5a82cc 100644 --- a/include/infinicore/nn/rope.hpp +++ b/include/infinicore/nn/rope.hpp @@ -3,8 +3,9 @@ #include "../context/context.hpp" #include "../tensor.hpp" #include "module.hpp" -#include +#include "rope_scaling_configs.hpp" #include +#include namespace infinicore::nn { @@ -18,47 +19,6 @@ class RoPE : public Module { GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos) }; - enum class ScalingType { - DEFAULT = 0, // Default RoPE - LONGROPE = 1 // Long-RoPE - }; - - class ScalingConfig { - public: - virtual ~ScalingConfig() = default; - ScalingType type() const { return type_; } - - protected: - ScalingType type_ = ScalingType::DEFAULT; - ScalingConfig(ScalingType type) : type_(type) {} - }; - - // longrope scaling - class LongRopeConfig : public ScalingConfig { - protected: - std::vector short_factor_; - std::vector long_factor_; - size_t original_max_position_embeddings_; - float factor_; - - public: - LongRopeConfig( - std::vector short_factor, - std::vector long_factor, - size_t original_max_position_embeddings, - float factor = 1.0f) - : ScalingConfig(ScalingType::LONGROPE), - short_factor_(short_factor), - long_factor_(long_factor), - original_max_position_embeddings_(original_max_position_embeddings), - factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {} - ~LongRopeConfig() override = default; - size_t original_max_position_embeddings() const { return original_max_position_embeddings_; } - const std::vector &short_factor() const { return short_factor_; } - const std::vector &long_factor() const { return long_factor_; } - float factor() const { return factor_; } - }; - /** * @brief Construct a RoPE layer * @@ -68,7 +28,7 @@ class RoPE : public Module { * @param algo RoPE algorithm type (default: Algo::GPT_J) * @param dtype Data type for sin/cos cache (default: DataType::F32) * @param device Device to create the cache on - * @param scaling RoPE scaling type (default: nullptr) + * @param scaling RoPE scaling configuration (default: nullptr) */ RoPE(size_t head_dim, size_t max_seq_len, @@ -76,7 +36,7 @@ class RoPE : public Module { Algo algo = Algo::GPT_J, const DataType &dtype = DataType::F32, const Device &device = Device(), - std::shared_ptr scaling = nullptr); + std::shared_ptr scaling = nullptr); /** * @brief Forward pass: apply RoPE to a tensor @@ -132,12 +92,12 @@ class RoPE : public Module { private: void initialize_cache(); - size_t head_dim_; // Dimension of each attention head - size_t max_seq_len_; // Maximum sequence length - double theta_; // Base frequency for rotary embeddings - Algo algo_; // RoPE algorithm type - DataType dtype_; // Data type for cache tables - std::shared_ptr scaling_; // RoPE scaling type + size_t head_dim_; // Dimension of each attention head + size_t max_seq_len_; // Maximum sequence length + double theta_; // Base frequency for rotary embeddings + Algo algo_; // RoPE algorithm type + DataType dtype_; // Data type for cache tables + std::shared_ptr scaling_; // RoPE scaling configuration }; } // namespace infinicore::nn diff --git a/include/infinicore/nn/rope_scaling_configs.hpp b/include/infinicore/nn/rope_scaling_configs.hpp new file mode 100644 index 000000000..b4bb649cc --- /dev/null +++ b/include/infinicore/nn/rope_scaling_configs.hpp @@ -0,0 +1,89 @@ +#pragma once +#include +#include + +namespace infinicore::nn { + +/** + * @brief Abstract base class for RoPE scaling strategies. + * Uses polymorphism to eliminate type checking (if-else) in the core RoPE loop. + */ +class RopeScalingConfig { +public: + virtual ~RopeScalingConfig() = default; + + /** + * @brief Calculate the frequency scaling factor for a specific position and dimension. + * + * @param pos Current sequence position + * @param j Current dimension index (0 to head_dim/2 - 1) + * @param head_dim Total dimension of the attention head + * @param theta Base frequency (usually 10000.0) + * @return Frequency scaling factor (default 1.0) + */ + virtual float get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const { + return 1.0f; + } + + /** + * @brief Calculate the magnitude scaling factor (table_factor) for a specific position. + * + * @param pos Current sequence position + * @return Magnitude scaling factor (default 1.0) + */ + virtual float get_magnitude_scale(size_t pos) const { + return 1.0f; + } +}; + +/** + * @brief LongRoPE scaling configuration. + */ +class LongRopeConfig : public RopeScalingConfig { +public: + LongRopeConfig( + std::vector short_factor, + std::vector long_factor, + size_t original_max_position_embeddings, + float factor = 1.0f); + + float get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const override; + float get_magnitude_scale(size_t pos) const override; + + size_t original_max_position_embeddings() const { return original_max_position_embeddings_; } + const std::vector &short_factor() const { return short_factor_; } + const std::vector &long_factor() const { return long_factor_; } + float factor() const { return factor_; } + +private: + std::vector short_factor_; + std::vector long_factor_; + size_t original_max_position_embeddings_; + float factor_; +}; + +// TODO(rubik) implement in cpp +/** + * @brief Llama3 frequency-aware RoPE scaling configuration. + * Native support for Llama 3.1 RoPE scaling (smooth interpolation based on wavelength). + */ +class Llama3Config : public RopeScalingConfig { +public: + Llama3Config( + float factor, + float low_freq_factor, + float high_freq_factor, + size_t original_max_position_embeddings); + + float get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const override; + + // Llama3 does not use magnitude scaling, so it inherits the default get_magnitude_scale() returning 1.0f + +private: + float factor_; + float low_freq_factor_; + float high_freq_factor_; + size_t original_max_position_embeddings_; +}; + +} // namespace infinicore::nn diff --git a/src/infinicore/nn/rope.cc b/src/infinicore/nn/rope.cc index 26403a31c..a35bb08b2 100644 --- a/src/infinicore/nn/rope.cc +++ b/src/infinicore/nn/rope.cc @@ -17,7 +17,7 @@ RoPE::RoPE(size_t head_dim, Algo algo, const DataType &dtype, const Device &device, - std::shared_ptr scaling) + std::shared_ptr scaling) : head_dim_(head_dim), max_seq_len_(max_seq_len), theta_(theta), @@ -54,32 +54,18 @@ void RoPE::initialize_cache() { for (size_t pos = 0; pos < max_seq_len_; pos++) { for (size_t j = 0; j < cache_dim; j++) { - // GPT-J style inverse frequency: theta^(-2j/head_dim) - // Compute directly in float to avoid double->float casting - float inv_freq; - float table_factor = 1.0f; - if (scaling_ == nullptr) { - inv_freq = 1.0f / std::pow(static_cast(theta_), 2.0f * static_cast(j) / static_cast(head_dim_)); - } else if (scaling_->type() == ScalingType::LONGROPE) { - std::shared_ptr lr = std::dynamic_pointer_cast(scaling_); - table_factor = lr->factor(); - float _ext; - if (pos < lr->original_max_position_embeddings()) { - _ext = lr->short_factor()[j]; - } else { - _ext = lr->long_factor()[j]; - } - inv_freq = 1.0f / (_ext * std::pow(static_cast(theta_), 2.0f * static_cast(j) / static_cast(head_dim_))); - } else { - inv_freq = 1.0f / std::pow(static_cast(theta_), 2.0f * static_cast(j) / static_cast(head_dim_)); - } - - // Compute angle: position * inverse_frequency - float angle = static_cast(pos) * inv_freq; - - // Compute sin and cos directly on float - sin_data[pos * cache_dim + j] = std::sin(angle) * table_factor; - cos_data[pos * cache_dim + j] = std::cos(angle) * table_factor; + // 1. Base inverse frequency (shared across all RoPE types) + float base_inv_freq = 1.0f / std::pow(static_cast(theta_), 2.0f * static_cast(j) / static_cast(head_dim_)); + + // 2. Polymorphic scaling resolution + float freq_scale = scaling_ ? scaling_->get_freq_scale(pos, j, head_dim_, static_cast(theta_)) : 1.0f; + float mag_scale = scaling_ ? scaling_->get_magnitude_scale(pos) : 1.0f; + + // 3. Compute final angle and sin/cos values + float angle = static_cast(pos) * base_inv_freq * freq_scale; + + sin_data[pos * cache_dim + j] = std::sin(angle) * mag_scale; + cos_data[pos * cache_dim + j] = std::cos(angle) * mag_scale; } } diff --git a/src/infinicore/nn/rope_scaling_configs.cc b/src/infinicore/nn/rope_scaling_configs.cc new file mode 100644 index 000000000..7a127acf5 --- /dev/null +++ b/src/infinicore/nn/rope_scaling_configs.cc @@ -0,0 +1,46 @@ +#include "infinicore/nn/rope_scaling_configs.hpp" +#include +#include + +namespace infinicore::nn { + +// LongRopeConfig Implementation +LongRopeConfig::LongRopeConfig( + std::vector short_factor, + std::vector long_factor, + size_t original_max_position_embeddings, + float factor) + : short_factor_(std::move(short_factor)), + long_factor_(std::move(long_factor)), + original_max_position_embeddings_(original_max_position_embeddings), + factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {} + +float LongRopeConfig::get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const { + float _ext = (pos < original_max_position_embeddings_) ? short_factor_[j] : long_factor_[j]; + // The base inv_freq is multiplied by this scale. + // Original: inv_freq = 1.0f / (_ext * pow(theta, 2j/head_dim)) + // New: inv_freq = base_inv_freq * (1.0f / _ext) + return 1.0f / _ext; +} + +float LongRopeConfig::get_magnitude_scale(size_t pos) const { + return factor_; +} + +// TODO(rubik) llama3 implement here +// Llama3Config Implementation +Llama3Config::Llama3Config( + float factor, + float low_freq_factor, + float high_freq_factor, + size_t original_max_position_embeddings) + : factor_(factor), + low_freq_factor_(low_freq_factor), + high_freq_factor_(high_freq_factor), + original_max_position_embeddings_(original_max_position_embeddings) {} + +float Llama3Config::get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const { + return 1.0f; +} + +} // namespace infinicore::nn