Skip to content

Commit 374e2ea

Browse files
committed
issue/545 nn::module::Rope and nn::module::Swiglu
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent a5e20fc commit 374e2ea

12 files changed

Lines changed: 847 additions & 31 deletions

File tree

include/infinicore/nn/rope.hpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../context/context.hpp"
5+
#include "../tensor.hpp"
6+
#include <infiniop.h>
7+
#include <memory>
8+
9+
namespace infinicore::nn {
10+
11+
12+
class RoPE : public Module {
13+
public:
14+
/**
15+
* @brief Construct a RoPE layer
16+
*
17+
* @param head_dim Dimension of each attention head (must be even)
18+
* @param max_seq_len Maximum sequence length for pre-computed cache
19+
* @param theta Base frequency for rotary embeddings (default: 10000.0)
20+
* @param algo RoPE algorithm type (default: INFINIOP_ROPE_ALGO_GPT_J)
21+
* @param dtype Data type for sin/cos cache (default: DataType::F32)
22+
* @param device Device to create the cache on
23+
*/
24+
RoPE(size_t head_dim,
25+
size_t max_seq_len,
26+
double theta = 10000.0,
27+
infiniopRoPEAlgo_t algo = INFINIOP_ROPE_ALGO_GPT_J,
28+
const DataType &dtype = DataType::F32,
29+
const Device &device = Device());
30+
31+
/**
32+
* @brief Forward pass: apply RoPE to a tensor
33+
*
34+
* @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions
35+
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
36+
* @return Rotated tensor with same shape as input
37+
*
38+
* Applies rotary position embeddings to the input tensor.
39+
* For attention mechanisms, call this method separately for query and key tensors.
40+
*
41+
* Common input shapes:
42+
* - [batch, num_heads, seq_len, head_dim]
43+
* - [batch, seq_len, num_heads, head_dim]
44+
* - [seq_len, head_dim]
45+
*/
46+
Tensor forward(const Tensor &x, const Tensor &pos) const;
47+
48+
// Module information
49+
size_t head_dim() const { return head_dim_; }
50+
size_t max_seq_len() const { return max_seq_len_; }
51+
double theta() const { return theta_; }
52+
infiniopRoPEAlgo_t algo() const { return algo_; }
53+
DataType dtype() const { return dtype_; }
54+
55+
// String representation
56+
std::string extra_repr() const;
57+
58+
// Accessors for parameters
59+
Tensor sin_cache() const { return sin_cache_; }
60+
Tensor cos_cache() const { return cos_cache_; }
61+
62+
protected:
63+
// Parameters (sin and cos cache tables)
64+
INFINICORE_NN_PARAMETER(sin_cache);
65+
INFINICORE_NN_PARAMETER(cos_cache);
66+
67+
private:
68+
void initialize_cache();
69+
70+
size_t head_dim_; // Dimension of each attention head
71+
size_t max_seq_len_; // Maximum sequence length
72+
double theta_; // Base frequency for rotary embeddings
73+
infiniopRoPEAlgo_t algo_; // RoPE algorithm type
74+
DataType dtype_; // Data type for cache tables
75+
};
76+
77+
} // namespace infinicore::nn

include/infinicore/nn/swiglu.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../ops.hpp"
5+
6+
namespace infinicore::nn {
7+
8+
class SwiGLU : public Module {
9+
public:
10+
/**
11+
* @brief Construct a SwiGLU module
12+
*
13+
* SwiGLU is a stateless activation function, so no parameters are needed.
14+
*/
15+
SwiGLU() = default;
16+
17+
/**
18+
* @brief Forward pass: apply SwiGLU activation
19+
*
20+
* @param up The "up" projection tensor
21+
* @param gate The "gate" projection tensor
22+
* @return Result tensor: up * gate * sigmoid(gate)
23+
*
24+
* Both input tensors must have the same shape and dtype.
25+
* Common usage:
26+
* - Input: up from linear_up layer, gate from linear_gate layer
27+
* - Shapes: typically [batch, seq_len, hidden_size] or [batch, hidden_size]
28+
*/
29+
Tensor forward(const Tensor &up, const Tensor &gate) const;
30+
31+
// String representation
32+
std::string extra_repr() const;
33+
};
34+
35+
} // namespace infinicore::nn

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
#include "ops/ones.hpp"
88
#include "ops/rearrange.hpp"
99
#include "ops/rms_norm.hpp"
10+
#include "ops/rope.hpp"
1011
#include "ops/silu.hpp"
1112
#include "ops/swiglu.hpp"

include/infinicore/ops/rope.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../tensor.hpp"
5+
#include "common/op.hpp"
6+
#include <infiniop.h>
7+
8+
namespace infinicore::op {
9+
class RoPE {
10+
public:
11+
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infiniopRoPEAlgo_t);
12+
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infiniopRoPEAlgo_t algo);
13+
static common::OpDispatcher<schema> &dispatcher();
14+
};
15+
16+
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infiniopRoPEAlgo_t algo);
17+
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infiniopRoPEAlgo_t algo);
18+
} // namespace infinicore::op

src/infinicore-test/main.cc

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,8 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
141141

142142
int main(int argc, char *argv[]) {
143143
try {
144-
// Initialize spdlog for debugging
145-
spdlog::set_level(spdlog::level::debug);
146-
spdlog::info("Starting InfiniCore Memory Management Test Suite");
147-
148144
ParsedArgs args = parseArgs(argc, argv);
149-
spdlog::debug("Arguments parsed successfully");
145+
spdlog::info("Arguments parsed successfully");
150146

151147
std::cout << "==============================================\n"
152148
<< "InfiniCore Memory Management Test Suite\n"
@@ -156,31 +152,25 @@ int main(int argc, char *argv[]) {
156152
<< "Iterations: " << args.iterations << "\n"
157153
<< "==============================================" << std::endl;
158154

159-
spdlog::debug("About to initialize InfiniCore context");
155+
spdlog::info("About to initialize InfiniCore context");
160156
// Initialize InfiniCore context
161157
infinicore::context::setDevice(infinicore::Device(static_cast<infinicore::Device::Type>(args.device_type), 0));
162-
spdlog::debug("InfiniCore context initialized successfully");
158+
spdlog::info("InfiniCore context initialized successfully");
163159

164-
spdlog::debug("Creating test runner");
160+
spdlog::info("Creating test runner");
165161
// Create test runner
166162
infinicore::test::InfiniCoreTestRunner runner;
167-
spdlog::debug("Test runner created successfully");
163+
spdlog::info("Test runner created successfully");
168164

169165
// Add tests based on arguments
170166
if (args.run_basic) {
171-
spdlog::debug("Adding BasicMemoryTest");
172167
runner.addTest(std::make_unique<infinicore::test::BasicMemoryTest>());
173-
spdlog::debug("BasicMemoryTest added successfully");
174168

175-
spdlog::debug("Adding TensorDestructorTest");
176169
runner.addTest(std::make_unique<infinicore::test::TensorDestructorTest>());
177-
spdlog::debug("TensorDestructorTest added successfully");
178170
}
179171

180172
if (args.run_module) {
181-
spdlog::debug("Adding NNModuleTest");
182173
runner.addTest(std::make_unique<infinicore::test::NNModuleTest>());
183-
spdlog::debug("NNModuleTest added successfully");
184174
}
185175

186176
if (args.run_concurrency) {
@@ -203,10 +193,10 @@ int main(int argc, char *argv[]) {
203193
runner.addTest(std::make_unique<infinicore::test::StressTest>());
204194
}
205195

206-
spdlog::debug("About to run all tests");
196+
spdlog::info("About to run all tests");
207197
// Run all tests
208198
auto results = runner.runAllTests();
209-
spdlog::debug("All tests completed");
199+
spdlog::info("All tests completed");
210200

211201
// Count results and collect failed tests
212202
size_t passed = 0, failed = 0;

0 commit comments

Comments
 (0)