Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 10 additions & 50 deletions include/infinicore/nn/rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
#include "../context/context.hpp"
#include "../tensor.hpp"
#include "module.hpp"
#include <memory>
#include "rope_scaling_configs.hpp"
#include <cmath>
#include <memory>

namespace infinicore::nn {

Expand All @@ -18,47 +19,6 @@ class RoPE : public Module {
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
};

enum class ScalingType {
DEFAULT = 0, // Default RoPE
LONGROPE = 1 // Long-RoPE
};

class ScalingConfig {
public:
virtual ~ScalingConfig() = default;
ScalingType type() const { return type_; }

protected:
ScalingType type_ = ScalingType::DEFAULT;
ScalingConfig(ScalingType type) : type_(type) {}
};

// longrope scaling
class LongRopeConfig : public ScalingConfig {
protected:
std::vector<float> short_factor_;
std::vector<float> long_factor_;
size_t original_max_position_embeddings_;
float factor_;

public:
LongRopeConfig(
std::vector<float> short_factor,
std::vector<float> long_factor,
size_t original_max_position_embeddings,
float factor = 1.0f)
: ScalingConfig(ScalingType::LONGROPE),
short_factor_(short_factor),
long_factor_(long_factor),
original_max_position_embeddings_(original_max_position_embeddings),
factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {}
~LongRopeConfig() override = default;
size_t original_max_position_embeddings() const { return original_max_position_embeddings_; }
const std::vector<float> &short_factor() const { return short_factor_; }
const std::vector<float> &long_factor() const { return long_factor_; }
float factor() const { return factor_; }
};

/**
* @brief Construct a RoPE layer
*
Expand All @@ -68,15 +28,15 @@ class RoPE : public Module {
* @param algo RoPE algorithm type (default: Algo::GPT_J)
* @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param device Device to create the cache on
* @param scaling RoPE scaling type (default: nullptr)
* @param scaling RoPE scaling configuration (default: nullptr)
*/
RoPE(size_t head_dim,
size_t max_seq_len,
double theta = 10000.0,
Algo algo = Algo::GPT_J,
const DataType &dtype = DataType::F32,
const Device &device = Device(),
std::shared_ptr<ScalingConfig> scaling = nullptr);
std::shared_ptr<RopeScalingConfig> scaling = nullptr);

/**
* @brief Forward pass: apply RoPE to a tensor
Expand Down Expand Up @@ -132,12 +92,12 @@ class RoPE : public Module {
private:
void initialize_cache();

size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
std::shared_ptr<ScalingConfig> scaling_; // RoPE scaling type
size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
std::shared_ptr<RopeScalingConfig> scaling_; // RoPE scaling configuration
};

} // namespace infinicore::nn
89 changes: 89 additions & 0 deletions include/infinicore/nn/rope_scaling_configs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#pragma once
#include <memory>
#include <vector>

namespace infinicore::nn {

/**
* @brief Abstract base class for RoPE scaling strategies.
* Uses polymorphism to eliminate type checking (if-else) in the core RoPE loop.
*/
class RopeScalingConfig {
public:
virtual ~RopeScalingConfig() = default;

/**
* @brief Calculate the frequency scaling factor for a specific position and dimension.
*
* @param pos Current sequence position
* @param j Current dimension index (0 to head_dim/2 - 1)
* @param head_dim Total dimension of the attention head
* @param theta Base frequency (usually 10000.0)
* @return Frequency scaling factor (default 1.0)
*/
virtual float get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const {
return 1.0f;
}

/**
* @brief Calculate the magnitude scaling factor (table_factor) for a specific position.
*
* @param pos Current sequence position
* @return Magnitude scaling factor (default 1.0)
*/
virtual float get_magnitude_scale(size_t pos) const {
return 1.0f;
}
};

/**
* @brief LongRoPE scaling configuration.
*/
class LongRopeConfig : public RopeScalingConfig {
public:
LongRopeConfig(
std::vector<float> short_factor,
std::vector<float> long_factor,
size_t original_max_position_embeddings,
float factor = 1.0f);

float get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const override;
float get_magnitude_scale(size_t pos) const override;

size_t original_max_position_embeddings() const { return original_max_position_embeddings_; }
const std::vector<float> &short_factor() const { return short_factor_; }
const std::vector<float> &long_factor() const { return long_factor_; }
float factor() const { return factor_; }

private:
std::vector<float> short_factor_;
std::vector<float> long_factor_;
size_t original_max_position_embeddings_;
float factor_;
};

// TODO(rubik) implement in cpp
/**
* @brief Llama3 frequency-aware RoPE scaling configuration.
* Native support for Llama 3.1 RoPE scaling (smooth interpolation based on wavelength).
*/
class Llama3Config : public RopeScalingConfig {
public:
Llama3Config(
float factor,
float low_freq_factor,
float high_freq_factor,
size_t original_max_position_embeddings);

float get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const override;

// Llama3 does not use magnitude scaling, so it inherits the default get_magnitude_scale() returning 1.0f

private:
float factor_;
float low_freq_factor_;
float high_freq_factor_;
size_t original_max_position_embeddings_;
};

} // namespace infinicore::nn
40 changes: 13 additions & 27 deletions src/infinicore/nn/rope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ RoPE::RoPE(size_t head_dim,
Algo algo,
const DataType &dtype,
const Device &device,
std::shared_ptr<ScalingConfig> scaling)
std::shared_ptr<RopeScalingConfig> scaling)
: head_dim_(head_dim),
max_seq_len_(max_seq_len),
theta_(theta),
Expand Down Expand Up @@ -54,32 +54,18 @@ void RoPE::initialize_cache() {

for (size_t pos = 0; pos < max_seq_len_; pos++) {
for (size_t j = 0; j < cache_dim; j++) {
// GPT-J style inverse frequency: theta^(-2j/head_dim)
// Compute directly in float to avoid double->float casting
float inv_freq;
float table_factor = 1.0f;
if (scaling_ == nullptr) {
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
} else if (scaling_->type() == ScalingType::LONGROPE) {
std::shared_ptr<LongRopeConfig> lr = std::dynamic_pointer_cast<LongRopeConfig>(scaling_);
table_factor = lr->factor();
float _ext;
if (pos < lr->original_max_position_embeddings()) {
_ext = lr->short_factor()[j];
} else {
_ext = lr->long_factor()[j];
}
inv_freq = 1.0f / (_ext * std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_)));
} else {
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
}

// Compute angle: position * inverse_frequency
float angle = static_cast<float>(pos) * inv_freq;

// Compute sin and cos directly on float
sin_data[pos * cache_dim + j] = std::sin(angle) * table_factor;
cos_data[pos * cache_dim + j] = std::cos(angle) * table_factor;
// 1. Base inverse frequency (shared across all RoPE types)
float base_inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));

// 2. Polymorphic scaling resolution
float freq_scale = scaling_ ? scaling_->get_freq_scale(pos, j, head_dim_, static_cast<float>(theta_)) : 1.0f;
float mag_scale = scaling_ ? scaling_->get_magnitude_scale(pos) : 1.0f;

// 3. Compute final angle and sin/cos values
float angle = static_cast<float>(pos) * base_inv_freq * freq_scale;

sin_data[pos * cache_dim + j] = std::sin(angle) * mag_scale;
cos_data[pos * cache_dim + j] = std::cos(angle) * mag_scale;
}
}

Expand Down
46 changes: 46 additions & 0 deletions src/infinicore/nn/rope_scaling_configs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "infinicore/nn/rope_scaling_configs.hpp"
#include <cmath>
#include <stdexcept>

namespace infinicore::nn {

// LongRopeConfig Implementation
LongRopeConfig::LongRopeConfig(
std::vector<float> short_factor,
std::vector<float> long_factor,
size_t original_max_position_embeddings,
float factor)
: short_factor_(std::move(short_factor)),
long_factor_(std::move(long_factor)),
original_max_position_embeddings_(original_max_position_embeddings),
factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {}

float LongRopeConfig::get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const {
float _ext = (pos < original_max_position_embeddings_) ? short_factor_[j] : long_factor_[j];
// The base inv_freq is multiplied by this scale.
// Original: inv_freq = 1.0f / (_ext * pow(theta, 2j/head_dim))
// New: inv_freq = base_inv_freq * (1.0f / _ext)
return 1.0f / _ext;
}

float LongRopeConfig::get_magnitude_scale(size_t pos) const {
return factor_;
}

// TODO(rubik) llama3 implement here
// Llama3Config Implementation
Llama3Config::Llama3Config(
float factor,
float low_freq_factor,
float high_freq_factor,
size_t original_max_position_embeddings)
: factor_(factor),
low_freq_factor_(low_freq_factor),
high_freq_factor_(high_freq_factor),
original_max_position_embeddings_(original_max_position_embeddings) {}

float Llama3Config::get_freq_scale(size_t pos, size_t j, size_t head_dim, float theta) const {
return 1.0f;
}

} // namespace infinicore::nn