Skip to content

Commit cf6dff1

Browse files
Merge pull request #546 from InfiniTensor/issue/545
issue/545 nn::module::Rope and nn::module::Swiglu
2 parents 48b3afd + 289d400 commit cf6dff1

13 files changed

Lines changed: 698 additions & 41 deletions

File tree

include/infinicore/nn/module.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class Module {
2323
protected:
2424
Tensor register_parameter(const std::string &name, Parameter param);
2525

26+
Tensor register_buffer(const std::string &name, Parameter buffer);
27+
2628
// Add an existing submodule to this module's hierarchy
2729
// Template parameter M must be a type derived from Module
2830
// Returns the submodule for convenience (allows method chaining)
@@ -72,6 +74,7 @@ class Module {
7274
protected:
7375
Device device_;
7476
std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
77+
std::unordered_map<std::string, Parameter> buffers_;
7578
std::unordered_map<std::string, Parameter> parameters_;
7679

7780
private:
@@ -134,4 +137,15 @@ class Module {
134137
name##_ = infinicore::nn::Parameter args; \
135138
this->register_parameter(#name, name##_)
136139

140+
// Declare a buffer member variable
141+
#define INFINICORE_NN_BUFFER(name) \
142+
infinicore::nn::Parameter name##_
143+
144+
// Initialize a buffer in constructor
145+
// Usage: INFINICORE_NN_BUFFER_INIT(name, (shape, dtype, device))
146+
// Example: INFINICORE_NN_BUFFER_INIT(cache, ({max_seq_len, head_dim}, DataType::F32, device))
147+
#define INFINICORE_NN_BUFFER_INIT(name, args) \
148+
name##_ = infinicore::nn::Parameter args; \
149+
this->register_buffer(#name, name##_)
150+
137151
} // namespace infinicore::nn

include/infinicore/nn/rope.hpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../context/context.hpp"
5+
#include "../tensor.hpp"
6+
#include <memory>
7+
8+
namespace infinicore::nn {
9+
10+
class RoPE : public Module {
11+
public:
12+
/**
13+
* @brief RoPE algorithm type
14+
*/
15+
enum class Algo {
16+
GPT_J = 0, // GPT-J style RoPE algorithm (Interleave even and odd dimensions)
17+
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
18+
};
19+
20+
/**
21+
* @brief Construct a RoPE layer
22+
*
23+
* @param head_dim Dimension of each attention head (must be even)
24+
* @param max_seq_len Maximum sequence length for pre-computed cache
25+
* @param theta Base frequency for rotary embeddings (default: 10000.0)
26+
* @param algo RoPE algorithm type (default: Algo::GPT_J)
27+
* @param dtype Data type for sin/cos cache (default: DataType::F32)
28+
* @param device Device to create the cache on
29+
*/
30+
RoPE(size_t head_dim,
31+
size_t max_seq_len,
32+
double theta = 10000.0,
33+
Algo algo = Algo::GPT_J,
34+
const DataType &dtype = DataType::F32,
35+
const Device &device = Device());
36+
37+
/**
38+
* @brief Forward pass: apply RoPE to a tensor
39+
*
40+
* @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions
41+
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
42+
* @return Rotated tensor with same shape as input
43+
*
44+
* Applies rotary position embeddings to the input tensor.
45+
* For attention mechanisms, call this method separately for query and key tensors.
46+
*
47+
* Common input shapes:
48+
* - [batch, num_heads, seq_len, head_dim]
49+
* - [batch, seq_len, num_heads, head_dim]
50+
* - [seq_len, head_dim]
51+
*/
52+
Tensor forward(const Tensor &x, const Tensor &pos) const;
53+
54+
// Module information
55+
size_t head_dim() const { return head_dim_; }
56+
size_t max_seq_len() const { return max_seq_len_; }
57+
double theta() const { return theta_; }
58+
Algo algo() const { return algo_; }
59+
DataType dtype() const { return dtype_; }
60+
61+
// String representation
62+
std::string extra_repr() const;
63+
64+
protected:
65+
// Buffers (sin and cos cache tables) - not exposed in state_dict
66+
INFINICORE_NN_BUFFER(sin_cache);
67+
INFINICORE_NN_BUFFER(cos_cache);
68+
69+
private:
70+
void initialize_cache();
71+
72+
size_t head_dim_; // Dimension of each attention head
73+
size_t max_seq_len_; // Maximum sequence length
74+
double theta_; // Base frequency for rotary embeddings
75+
Algo algo_; // RoPE algorithm type
76+
DataType dtype_; // Data type for cache tables
77+
};
78+
79+
} // namespace infinicore::nn

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
#include "ops/ones.hpp"
88
#include "ops/rearrange.hpp"
99
#include "ops/rms_norm.hpp"
10+
#include "ops/rope.hpp"
1011
#include "ops/silu.hpp"
1112
#include "ops/swiglu.hpp"

include/infinicore/ops/rope.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../tensor.hpp"
5+
#include "../nn/rope.hpp"
6+
#include "common/op.hpp"
7+
8+
namespace infinicore::op {
9+
class RoPE {
10+
public:
11+
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
12+
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
13+
static common::OpDispatcher<schema> &dispatcher();
14+
};
15+
16+
// Internal function
17+
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
18+
19+
// Public API that uses infinicore::nn::RoPE::Algo
20+
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
21+
} // namespace infinicore::op

src/infinicore-test/main.cc

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,8 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
141141

142142
int main(int argc, char *argv[]) {
143143
try {
144-
// Initialize spdlog for debugging
145-
spdlog::set_level(spdlog::level::debug);
146-
spdlog::info("Starting InfiniCore Memory Management Test Suite");
147-
148144
ParsedArgs args = parseArgs(argc, argv);
149-
spdlog::debug("Arguments parsed successfully");
145+
spdlog::info("Arguments parsed successfully");
150146

151147
std::cout << "==============================================\n"
152148
<< "InfiniCore Memory Management Test Suite\n"
@@ -156,31 +152,25 @@ int main(int argc, char *argv[]) {
156152
<< "Iterations: " << args.iterations << "\n"
157153
<< "==============================================" << std::endl;
158154

159-
spdlog::debug("About to initialize InfiniCore context");
155+
spdlog::info("About to initialize InfiniCore context");
160156
// Initialize InfiniCore context
161157
infinicore::context::setDevice(infinicore::Device(static_cast<infinicore::Device::Type>(args.device_type), 0));
162-
spdlog::debug("InfiniCore context initialized successfully");
158+
spdlog::info("InfiniCore context initialized successfully");
163159

164-
spdlog::debug("Creating test runner");
160+
spdlog::info("Creating test runner");
165161
// Create test runner
166162
infinicore::test::InfiniCoreTestRunner runner;
167-
spdlog::debug("Test runner created successfully");
163+
spdlog::info("Test runner created successfully");
168164

169165
// Add tests based on arguments
170166
if (args.run_basic) {
171-
spdlog::debug("Adding BasicMemoryTest");
172167
runner.addTest(std::make_unique<infinicore::test::BasicMemoryTest>());
173-
spdlog::debug("BasicMemoryTest added successfully");
174168

175-
spdlog::debug("Adding TensorDestructorTest");
176169
runner.addTest(std::make_unique<infinicore::test::TensorDestructorTest>());
177-
spdlog::debug("TensorDestructorTest added successfully");
178170
}
179171

180172
if (args.run_module) {
181-
spdlog::debug("Adding NNModuleTest");
182173
runner.addTest(std::make_unique<infinicore::test::NNModuleTest>());
183-
spdlog::debug("NNModuleTest added successfully");
184174
}
185175

186176
if (args.run_concurrency) {
@@ -203,10 +193,10 @@ int main(int argc, char *argv[]) {
203193
runner.addTest(std::make_unique<infinicore::test::StressTest>());
204194
}
205195

206-
spdlog::debug("About to run all tests");
196+
spdlog::info("About to run all tests");
207197
// Run all tests
208198
auto results = runner.runAllTests();
209-
spdlog::debug("All tests completed");
199+
spdlog::info("All tests completed");
210200

211201
// Count results and collect failed tests
212202
size_t passed = 0, failed = 0;

0 commit comments

Comments
 (0)