Skip to content

Commit 8d83173

Browse files
committed
refractor rope module && refractor swiglu module as torch-like pattern
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 2d0fc77 commit 8d83173

12 files changed

Lines changed: 143 additions & 76 deletions

File tree

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
#include "../tensor.hpp"
4+
5+
namespace infinicore::nn {
6+
7+
/**
8+
* @brief Functional operations namespace
9+
*
10+
* Similar to torch.nn.functional, this namespace provides stateless functional operations
11+
* that don't require module instantiation. These are pure functions that operate on tensors.
12+
*
13+
* Use functional operations when:
14+
* - The operation has no learnable parameters
15+
* - The operation has no internal state or buffers
16+
* - You want lightweight, stateless operations
17+
*
18+
* For operations with parameters or state, use the corresponding Module classes.
19+
*/
20+
namespace functional {
21+
22+
/**
23+
* @brief SwiGLU activation function
24+
*
25+
* Applies SwiGLU (Swish-Gated Linear Unit) activation: output = up * gate * sigmoid(gate)
26+
*
27+
* This is the functional interface for SwiGLU. The module version (nn::SwiGLU) wraps
28+
* this function. Since SwiGLU has no parameters or buffers, you can use either:
29+
* - Functional: output = functional::swiglu(up, gate);
30+
* - Module: SwiGLU swiglu; output = swiglu.forward(up, gate);
31+
*
32+
* @param up The "up" projection tensor
33+
* @param gate The "gate" projection tensor
34+
* @return Result tensor: up * gate * sigmoid(gate)
35+
*
36+
* Both input tensors must have the same shape and dtype.
37+
* Common usage:
38+
* - Input: up from linear_up layer, gate from linear_gate layer
39+
* - Shapes: typically [batch, seq_len, hidden_size] or [batch, hidden_size]
40+
*/
41+
Tensor swiglu(const Tensor &up, const Tensor &gate);
42+
43+
} // namespace functional
44+
45+
} // namespace infinicore::nn

include/infinicore/nn/module.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class Module {
2323
protected:
2424
Tensor register_parameter(const std::string &name, Parameter param);
2525

26+
Tensor register_buffer(const std::string &name, Parameter buffer);
27+
2628
// Add an existing submodule to this module's hierarchy
2729
// Template parameter M must be a type derived from Module
2830
// Returns the submodule for convenience (allows method chaining)
@@ -72,6 +74,7 @@ class Module {
7274
protected:
7375
Device device_;
7476
std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
77+
std::unordered_map<std::string, Parameter> buffers_;
7578
std::unordered_map<std::string, Parameter> parameters_;
7679

7780
private:
@@ -134,4 +137,15 @@ class Module {
134137
name##_ = infinicore::nn::Parameter args; \
135138
this->register_parameter(#name, name##_)
136139

140+
// Declare a buffer member variable
141+
#define INFINICORE_NN_BUFFER(name) \
142+
infinicore::nn::Parameter name##_
143+
144+
// Initialize a buffer in constructor
145+
// Usage: INFINICORE_NN_BUFFER_INIT(name, (shape, dtype, device))
146+
// Example: INFINICORE_NN_BUFFER_INIT(cache, ({max_seq_len, head_dim}, DataType::F32, device))
147+
#define INFINICORE_NN_BUFFER_INIT(name, args) \
148+
name##_ = infinicore::nn::Parameter args; \
149+
this->register_buffer(#name, name##_)
150+
137151
} // namespace infinicore::nn

include/infinicore/nn/rope.hpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
#include "module.hpp"
44
#include "../context/context.hpp"
55
#include "../tensor.hpp"
6-
#include <infiniop.h>
76
#include <memory>
87

98
namespace infinicore::nn {
109

10+
/**
11+
* @brief RoPE algorithm type
12+
*/
13+
enum class RoPEAlgo {
14+
GPT_J = 0, // GPT-J style RoPE algorithm (Interleave even and odd dimensions)
15+
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
16+
};
1117

1218
class RoPE : public Module {
1319
public:
@@ -17,14 +23,14 @@ class RoPE : public Module {
1723
* @param head_dim Dimension of each attention head (must be even)
1824
* @param max_seq_len Maximum sequence length for pre-computed cache
1925
* @param theta Base frequency for rotary embeddings (default: 10000.0)
20-
* @param algo RoPE algorithm type (default: INFINIOP_ROPE_ALGO_GPT_J)
26+
* @param algo RoPE algorithm type (default: RoPEAlgo::GPT_J)
2127
* @param dtype Data type for sin/cos cache (default: DataType::F32)
2228
* @param device Device to create the cache on
2329
*/
2430
RoPE(size_t head_dim,
2531
size_t max_seq_len,
2632
double theta = 10000.0,
27-
infiniopRoPEAlgo_t algo = INFINIOP_ROPE_ALGO_GPT_J,
33+
RoPEAlgo algo = RoPEAlgo::GPT_J,
2834
const DataType &dtype = DataType::F32,
2935
const Device &device = Device());
3036

@@ -49,28 +55,24 @@ class RoPE : public Module {
4955
size_t head_dim() const { return head_dim_; }
5056
size_t max_seq_len() const { return max_seq_len_; }
5157
double theta() const { return theta_; }
52-
infiniopRoPEAlgo_t algo() const { return algo_; }
58+
RoPEAlgo algo() const { return algo_; }
5359
DataType dtype() const { return dtype_; }
5460

5561
// String representation
5662
std::string extra_repr() const;
5763

58-
// Accessors for parameters
59-
Tensor sin_cache() const { return sin_cache_; }
60-
Tensor cos_cache() const { return cos_cache_; }
61-
6264
protected:
63-
// Parameters (sin and cos cache tables)
64-
INFINICORE_NN_PARAMETER(sin_cache);
65-
INFINICORE_NN_PARAMETER(cos_cache);
65+
// Buffers (sin and cos cache tables) - not exposed in state_dict
66+
INFINICORE_NN_BUFFER(sin_cache);
67+
INFINICORE_NN_BUFFER(cos_cache);
6668

6769
private:
6870
void initialize_cache();
6971

7072
size_t head_dim_; // Dimension of each attention head
7173
size_t max_seq_len_; // Maximum sequence length
7274
double theta_; // Base frequency for rotary embeddings
73-
infiniopRoPEAlgo_t algo_; // RoPE algorithm type
75+
RoPEAlgo algo_; // RoPE algorithm type
7476
DataType dtype_; // Data type for cache tables
7577
};
7678

include/infinicore/nn/swiglu.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3+
#include "functional.hpp"
34
#include "module.hpp"
4-
#include "../ops.hpp"
55

66
namespace infinicore::nn {
77

include/infinicore/ops/rope.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22

33
#include "../device.hpp"
44
#include "../tensor.hpp"
5+
#include "../nn/rope.hpp"
56
#include "common/op.hpp"
6-
#include <infiniop.h>
77

88
namespace infinicore::op {
99
class RoPE {
1010
public:
11-
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infiniopRoPEAlgo_t);
12-
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infiniopRoPEAlgo_t algo);
11+
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPEAlgo);
12+
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPEAlgo algo);
1313
static common::OpDispatcher<schema> &dispatcher();
1414
};
1515

16-
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infiniopRoPEAlgo_t algo);
17-
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infiniopRoPEAlgo_t algo);
16+
// Internal function
17+
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPEAlgo algo);
18+
19+
// Public API that uses infinicore::nn::RoPEAlgo
20+
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPEAlgo algo);
1821
} // namespace infinicore::op

src/infinicore-test/test_nn_module.cc

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -997,24 +997,6 @@ TestResult NNModuleTest::testModuleRoPE() {
997997
infinicore::nn::RoPE rope1(128, 2048);
998998

999999
auto state1 = rope1.state_dict();
1000-
if (state1.find("sin_cache") == state1.end()) {
1001-
spdlog::error("RoPE sin_cache not found in state dict");
1002-
return false;
1003-
}
1004-
if (state1.find("cos_cache") == state1.end()) {
1005-
spdlog::error("RoPE cos_cache not found in state dict");
1006-
return false;
1007-
}
1008-
1009-
if (rope1.sin_cache()->shape() != std::vector<size_t>({2048, 64})) {
1010-
spdlog::error("RoPE sin_cache shape mismatch. Expected {{2048, 64}}");
1011-
return false;
1012-
}
1013-
1014-
if (rope1.cos_cache()->shape() != std::vector<size_t>({2048, 64})) {
1015-
spdlog::error("RoPE cos_cache shape mismatch. Expected {{2048, 64}}");
1016-
return false;
1017-
}
10181000

10191001
if (rope1.head_dim() != 128) {
10201002
spdlog::error("head_dim mismatch. Expected 128, got {}", rope1.head_dim());
@@ -1050,15 +1032,15 @@ TestResult NNModuleTest::testModuleRoPE() {
10501032

10511033
// Test 3: Different algorithms
10521034
spdlog::info("Test 3: Testing different algorithms");
1053-
infinicore::nn::RoPE rope_gptj(64, 1024, 10000.0, INFINIOP_ROPE_ALGO_GPT_J);
1054-
infinicore::nn::RoPE rope_gptneox(64, 1024, 10000.0, INFINIOP_ROPE_ALGO_GPT_NEOX);
1035+
infinicore::nn::RoPE rope_gptj(64, 1024, 10000.0, infinicore::nn::RoPEAlgo::GPT_J);
1036+
infinicore::nn::RoPE rope_gptneox(64, 1024, 10000.0, infinicore::nn::RoPEAlgo::GPT_NEOX);
10551037

1056-
if (rope_gptj.algo() != INFINIOP_ROPE_ALGO_GPT_J) {
1038+
if (rope_gptj.algo() != infinicore::nn::RoPEAlgo::GPT_J) {
10571039
spdlog::error("GPT_J algorithm not set correctly");
10581040
return false;
10591041
}
10601042

1061-
if (rope_gptneox.algo() != INFINIOP_ROPE_ALGO_GPT_NEOX) {
1043+
if (rope_gptneox.algo() != infinicore::nn::RoPEAlgo::GPT_NEOX) {
10621044
spdlog::error("GPT_NEOX algorithm not set correctly");
10631045
return false;
10641046
}
@@ -1104,27 +1086,9 @@ TestResult NNModuleTest::testModuleRoPE() {
11041086
spdlog::debug("Different theta values test passed");
11051087

11061088
// Test 5: load_state_dict
1107-
spdlog::info("Test 5: Testing load_state_dict for RoPE");
1108-
auto new_sin_cache = infinicore::Tensor::ones({2048, 64}, infinicore::DataType::F32, infinicore::Device());
1109-
auto new_cos_cache = infinicore::Tensor::ones({2048, 64}, infinicore::DataType::F32, infinicore::Device());
1110-
11111089
std::unordered_map<std::string, infinicore::Tensor> new_state;
1112-
new_state.emplace("sin_cache", new_sin_cache);
1113-
new_state.emplace("cos_cache", new_cos_cache);
1114-
11151090
rope1.load_state_dict(new_state);
1116-
1117-
if (!tensorsAllClose(rope1.sin_cache(), new_sin_cache, 1e-7, 1e-7)) {
1118-
spdlog::error("RoPE sin_cache not loaded correctly");
1119-
return false;
1120-
}
1121-
1122-
if (!tensorsAllClose(rope1.cos_cache(), new_cos_cache, 1e-7, 1e-7)) {
1123-
spdlog::error("RoPE cos_cache not loaded correctly");
1124-
return false;
1125-
}
1126-
1127-
spdlog::debug("load_state_dict for RoPE passed");
1091+
spdlog::debug("load_state_dict for RoPE passed (no parameters to load)");
11281092

11291093
// Test 6: extra_repr
11301094
spdlog::info("Test 6: Testing extra_repr");

src/infinicore/nn/functional.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "infinicore/nn/functional.hpp"
2+
#include "infinicore/ops.hpp"
3+
#include <spdlog/spdlog.h>
4+
5+
namespace infinicore::nn {
6+
namespace functional {
7+
8+
Tensor silu(const Tensor &input) {
9+
// Delegate to InfiniCore op (backed by InfiniRT/InfiniOP)
10+
// Validation is handled by the op layer
11+
// output = input * sigmoid(input)
12+
return op::silu(input);
13+
}
14+
15+
Tensor swiglu(const Tensor &up, const Tensor &gate) {
16+
// Delegate to InfiniCore op (backed by InfiniRT/InfiniOP)
17+
// Validation is handled by the op layer
18+
// output = up * gate * sigmoid(gate)
19+
return op::swiglu(up, gate);
20+
}
21+
22+
} // namespace functional
23+
} // namespace infinicore::nn

src/infinicore/nn/module.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ Tensor Module::register_parameter(const std::string &name, Parameter param) {
5555
return param;
5656
}
5757

58+
Tensor Module::register_buffer(const std::string &name, Parameter buffer) {
59+
buffers_[name] = buffer;
60+
return buffer;
61+
}
62+
5863
void Module::collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix) const {
5964
// Add direct parameters with the given prefix
6065
for (const auto &[param_name, param] : parameters_) {

src/infinicore/nn/rope.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace infinicore::nn {
1212
RoPE::RoPE(size_t head_dim,
1313
size_t max_seq_len,
1414
double theta,
15-
infiniopRoPEAlgo_t algo,
15+
RoPEAlgo algo,
1616
const DataType &dtype,
1717
const Device &device)
1818
: head_dim_(head_dim),
@@ -38,8 +38,8 @@ void RoPE::initialize_cache() {
3838
size_t cache_dim = head_dim_ / 2;
3939

4040
// Create sin and cos cache tables: [max_seq_len, cache_dim]
41-
INFINICORE_NN_PARAMETER_INIT(sin_cache, ({max_seq_len_, cache_dim}, dtype_, device_));
42-
INFINICORE_NN_PARAMETER_INIT(cos_cache, ({max_seq_len_, cache_dim}, dtype_, device_));
41+
INFINICORE_NN_BUFFER_INIT(sin_cache, ({max_seq_len_, cache_dim}, dtype_, device_));
42+
INFINICORE_NN_BUFFER_INIT(cos_cache, ({max_seq_len_, cache_dim}, dtype_, device_));
4343

4444
// Pre-compute sin and cos values
4545
// The frequency calculation differs based on algorithm:
@@ -58,11 +58,11 @@ void RoPE::initialize_cache() {
5858
// Compute inverse frequency based on algorithm
5959
double inv_freq;
6060

61-
if (algo_ == INFINIOP_ROPE_ALGO_GPT_J) {
61+
if (algo_ == RoPEAlgo::GPT_J) {
6262
// GPT_J: pairs are (2j, 2j+1) for cache entry j
6363
// Frequency for pair j: theta^(-2j/head_dim)
6464
inv_freq = 1.0 / std::pow(theta_, 2.0 * static_cast<double>(j) / static_cast<double>(head_dim_));
65-
} else if (algo_ == INFINIOP_ROPE_ALGO_GPT_NEOX) {
65+
} else if (algo_ == RoPEAlgo::GPT_NEOX) {
6666
// GPT_NEOX: pairs are (j, j+head_dim/2) for cache entry j
6767
// Frequency for pair j (corresponding to dimension j): theta^(-j/head_dim)
6868
inv_freq = 1.0 / std::pow(theta_, static_cast<double>(j) / static_cast<double>(head_dim_));
@@ -107,7 +107,7 @@ Tensor RoPE::forward(const Tensor &x, const Tensor &pos) const {
107107
}
108108

109109
std::string RoPE::extra_repr() const {
110-
std::string algo_str = (algo_ == INFINIOP_ROPE_ALGO_GPT_J) ? "GPT_J" : "GPT_NEOX";
110+
std::string algo_str = (algo_ == RoPEAlgo::GPT_J) ? "GPT_J" : "GPT_NEOX";
111111
return "RoPE(head_dim=" + std::to_string(head_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
112112
}
113113

src/infinicore/nn/swiglu.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
#include "infinicore/nn/swiglu.hpp"
2-
#include "infinicore/ops.hpp"
2+
#include "infinicore/nn/functional.hpp"
33
#include <spdlog/spdlog.h>
44
#include <stdexcept>
55

66
namespace infinicore::nn {
77

88
Tensor SwiGLU::forward(const Tensor &up, const Tensor &gate) const {
9-
// Delegate to InfiniCore op (backed by InfiniRT/InfiniOP)
10-
// Validation is handled by the op layer
11-
// output = up * gate * sigmoid(gate)
12-
return op::swiglu(up, gate);
9+
// Delegate to functional::swiglu
10+
return functional::swiglu(up, gate);
1311
}
1412

1513
std::string SwiGLU::extra_repr() const {

0 commit comments

Comments
 (0)