Skip to content

Commit 304fe8a

Browse files
committed
feat: implement neural network module system with PyTorch-like API
- Implement core modules: Linear, Embedding, RMSNorm - Add PyTorch-like macros for module and parameter definition - INFINICORE_NN_MODULE for single module declaration - INFINICORE_NN_MODULE_VEC for module vectors - INFINICORE_NN_PARAMETER for parameter declaration - Corresponding INIT macros for initialization - Implement hierarchical module system with dynamic path generation - Add state_dict() and load_state_dict() support - Refactor module design: protected registration methods, removed path_ member - Add comprehensive test suite including TinyLlama integration - All parameters are protected with public accessors Files changed: 16 files, +1876/-264 lines
1 parent ba8258f commit 304fe8a

16 files changed

Lines changed: 1876 additions & 264 deletions

File tree

include/infinicore.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
22

3+
#include "infinicore/nn.hpp"
34
#include "infinicore/ops.hpp"
45
#include "infinicore/tensor.hpp"

include/infinicore/nn.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include "nn/embedding.hpp"
4+
#include "nn/linear.hpp"
5+
#include "nn/rmsnorm.hpp"
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../ops.hpp"
5+
#include <optional>
6+
7+
namespace infinicore::nn {
8+
9+
/**
10+
* @brief Embedding layer that maps indices to dense vectors
11+
*
12+
* A simple lookup table that stores embeddings of a fixed dictionary and size.
13+
* This module is often used to store word embeddings and retrieve them using indices.
14+
* The input to the module is a tensor of indices, and the output is the corresponding
15+
* embedding vectors.
16+
*
17+
* Similar to PyTorch's nn.Embedding:
18+
* https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
19+
*
20+
* Example:
21+
* @code
22+
* // Create embedding: 10000 words, 300-dimensional embeddings
23+
* auto embedding = Embedding(10000, 300);
24+
*
25+
* // Input: tensor of indices [batch_size, seq_len]
26+
* auto indices = Tensor::from_data({2, 5}, {3, 5, 12, 8, 99, 0, 1, 45, 67, 23});
27+
*
28+
* // Output: [batch_size, seq_len, embedding_dim] = [2, 5, 300]
29+
* auto embeddings = embedding.forward(indices);
30+
* @endcode
31+
*/
32+
class Embedding : public Module {
33+
public:
34+
/**
35+
* @brief Construct an Embedding layer
36+
*
37+
* @param num_embeddings Size of the dictionary of embeddings (vocabulary size)
38+
* @param embedding_dim The size of each embedding vector
39+
* @param padding_idx If specified, the entries at padding_idx do not contribute to gradient
40+
* and the embedding vector at padding_idx is not updated during training
41+
* @param device Device to create the embedding weight on
42+
*/
43+
Embedding(size_t num_embeddings,
44+
size_t embedding_dim,
45+
std::optional<int64_t> padding_idx = std::nullopt,
46+
const Device &device = Device());
47+
48+
/**
49+
* @brief Forward pass: lookup embeddings for given indices
50+
*
51+
* @param indices Tensor containing indices into the embedding matrix.
52+
* Can be any shape (*), typically [batch_size] or [batch_size, seq_len]
53+
* @return Tensor containing the embedding vectors.
54+
* Shape: (*, embedding_dim) where * matches the input shape
55+
*
56+
* Example:
57+
* Input shape: [2, 3] -> Output shape: [2, 3, embedding_dim]
58+
* Input shape: [10] -> Output shape: [10, embedding_dim]
59+
*/
60+
Tensor forward(const Tensor &indices) const;
61+
62+
// Module information
63+
size_t num_embeddings() const { return num_embeddings_; }
64+
size_t embedding_dim() const { return embedding_dim_; }
65+
std::optional<int64_t> padding_idx() const { return padding_idx_; }
66+
67+
// String representation
68+
std::string extra_repr() const;
69+
70+
// Accessors for parameters
71+
Tensor weight() const { return weight_; }
72+
73+
protected:
74+
// Parameters
75+
Parameter weight_;
76+
77+
private:
78+
size_t num_embeddings_; // Vocabulary size
79+
size_t embedding_dim_; // Embedding dimension
80+
std::optional<int64_t> padding_idx_; // Optional padding index
81+
};
82+
83+
} // namespace infinicore::nn

include/infinicore/nn/linear.hpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../ops.hpp"
5+
6+
namespace infinicore::nn {
7+
8+
class Linear : public Module {
9+
public:
10+
Linear(size_t in_features, size_t out_features, bool bias = true, const Device &device = Device());
11+
12+
// Forward pass: output = input @ weight.T + bias
13+
Tensor forward(Tensor &input) const;
14+
15+
// Forward pass with residual connection (InfiniLM-style)
16+
// output = input @ weight.T + bias + residual
17+
Tensor forward(Tensor &input, Tensor &residual) const;
18+
19+
// Module information
20+
size_t in_features() const { return in_features_; }
21+
size_t out_features() const { return out_features_; }
22+
bool has_bias() const { return has_bias_; }
23+
24+
// String representation
25+
std::string extra_repr() const;
26+
27+
// Accessors for parameters
28+
Tensor weight() const { return weight_; }
29+
Tensor bias() const { return bias_; }
30+
31+
protected:
32+
// Parameters
33+
Parameter weight_;
34+
Parameter bias_;
35+
36+
private:
37+
// Helper method for common forward computation
38+
Tensor compute_linear(Tensor &input) const;
39+
40+
size_t in_features_;
41+
size_t out_features_;
42+
bool has_bias_;
43+
};
44+
45+
} // namespace infinicore::nn

include/infinicore/nn/module.hpp

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
#pragma once
22

33
#include "parameter.hpp"
4+
#include "../tensor.hpp"
45

56
#include <unordered_map>
7+
#include <type_traits>
8+
#include <vector>
69

710
namespace infinicore::nn {
811
class Module {
912
public:
13+
Module() = default;
14+
1015
const std::unordered_map<std::string, Parameter> &state_dict() const;
1116

1217
void load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict);
@@ -15,35 +20,118 @@ class Module {
1520

1621
void load_parameter_from_blob(const std::string &name, const void *data);
1722

23+
protected:
1824
Tensor register_parameter(const std::string &name, Parameter param);
1925

26+
// Add an existing submodule to this module's hierarchy
27+
// Template parameter M must be a type derived from Module
28+
// Returns the submodule for convenience (allows method chaining)
2029
template <typename M>
2130
std::shared_ptr<M> add_module(const std::string &name, std::shared_ptr<M> submodule) {
31+
// Ensure M is derived from Module (compile-time check)
32+
static_assert(std::is_base_of<Module, M>::value,
33+
"Template parameter M must be derived from infinicore::nn::Module");
34+
35+
// Store in the submodules map (std::shared_ptr<M> automatically converts to std::shared_ptr<Module>)
2236
submodules_[name] = submodule;
37+
2338
return submodule;
2439
}
2540

41+
// Create and register a new submodule by constructing it with the given arguments
42+
// Template parameter M must be a type derived from Module
43+
// Args are forwarded to M's constructor
2644
template <typename M, typename... Args>
2745
std::shared_ptr<M> register_module(const std::string &name, Args &&...args) {
46+
// Ensure M is derived from Module (compile-time check)
47+
static_assert(std::is_base_of<Module, M>::value,
48+
"Template parameter M must be derived from infinicore::nn::Module");
49+
50+
// Construct the submodule
2851
auto submodule = std::make_shared<M>(std::forward<Args>(args)...);
52+
2953
return add_module(name, submodule);
3054
}
3155

56+
// Create and register multiple submodules of the same type
57+
// Each submodule is named as "name.0", "name.1", etc.
58+
// Template parameter M must be a type derived from Module
3259
template <typename M, typename... Args>
33-
std::vector<std::shared_ptr<M>> register_modules(size_t layers, const std::string &name, Args &&...args) {
34-
auto submodules = std::vector<std::shared_ptr<M>>(layers);
35-
for (size_t i = 0; i < layers; i++) {
36-
register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...);
60+
std::vector<std::shared_ptr<M>> register_modules(size_t count, const std::string &name, Args &&...args) {
61+
static_assert(std::is_base_of<Module, M>::value,
62+
"Template parameter M must be derived from infinicore::nn::Module");
63+
64+
std::vector<std::shared_ptr<M>> modules;
65+
modules.reserve(count);
66+
for (size_t i = 0; i < count; i++) {
67+
modules.push_back(register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...));
3768
}
38-
return submodules;
69+
return modules;
3970
}
4071

41-
private:
42-
void collect_all_parameters(const std::string &prefix, std::unordered_map<std::string, Parameter> &all_params) const;
43-
4472
protected:
4573
Device device_;
4674
std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
4775
std::unordered_map<std::string, Parameter> parameters_;
76+
77+
private:
78+
void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = "") const;
4879
};
80+
81+
// ============================================================================
82+
// PyTorch-like Macros for Convenient Module Registration
83+
// ============================================================================
84+
85+
/**
86+
* @brief Register submodules with automatic name inference from variable name
87+
*
88+
* Usage:
89+
* @code
90+
* class MyModel : public Module {
91+
* protected:
92+
* INFINICORE_NN_MODULE(Linear, layer1);
93+
* INFINICORE_NN_MODULE(Linear, layer2);
94+
* INFINICORE_NN_MODULE_VEC(Linear, layers);
95+
* INFINICORE_NN_PARAMETER(scaling_factor);
96+
*
97+
* public:
98+
* MyModel() {
99+
* INFINICORE_NN_MODULE_INIT(layer1, 128, 64);
100+
* INFINICORE_NN_MODULE_INIT(layer2, 64, 32);
101+
* INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 32, 16);
102+
* INFINICORE_NN_PARAMETER_INIT(scaling_factor, ({1}, DataType::F32, Device()));
103+
* }
104+
* };
105+
* @endcode
106+
*/
107+
108+
// Declare a single module member variable
109+
#define INFINICORE_NN_MODULE(ModuleType, name) \
110+
std::shared_ptr<ModuleType> name##_
111+
112+
// Declare a vector of modules member variable
113+
#define INFINICORE_NN_MODULE_VEC(ModuleType, name) \
114+
std::vector<std::shared_ptr<ModuleType>> name##_
115+
116+
// Initialize a module in constructor
117+
#define INFINICORE_NN_MODULE_INIT(name, ...) \
118+
name##_ = this->register_module<std::remove_reference<decltype(*name##_)>::type>(#name, ##__VA_ARGS__)
119+
120+
// Initialize a vector of modules in constructor
121+
// Usage: INFINICORE_NN_MODULE_VEC_INIT(layers, count, ModuleType, ctor_args...)
122+
// Example: INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 128, 64)
123+
#define INFINICORE_NN_MODULE_VEC_INIT(name, count, ModuleType, ...) \
124+
name##_ = this->register_modules<ModuleType>(count, #name, ##__VA_ARGS__)
125+
126+
// Declare a parameter member variable
127+
#define INFINICORE_NN_PARAMETER(name) \
128+
Parameter name##_
129+
130+
// Initialize a parameter in constructor
131+
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
132+
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
133+
#define INFINICORE_NN_PARAMETER_INIT(name, args) \
134+
name##_ = Parameter args; \
135+
this->register_parameter(#name, name##_)
136+
49137
} // namespace infinicore::nn

include/infinicore/nn/rmsnorm.hpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../ops.hpp"
5+
6+
namespace infinicore::nn {
7+
8+
/**
9+
* @brief Root Mean Square Layer Normalization (RMSNorm)
10+
*
11+
* Applies Root Mean Square Layer Normalization over the last dimension.
12+
* Unlike LayerNorm, RMSNorm doesn't subtract mean and doesn't use bias.
13+
*
14+
* Formula: y = (x / RMS(x)) * weight
15+
* where RMS(x) = sqrt(mean(x^2) + eps)
16+
*
17+
* Used in LLaMA, Galactica, and other modern language models as a
18+
* simpler and faster alternative to LayerNorm.
19+
*
20+
* Example:
21+
* @code
22+
* // Create RMSNorm for hidden size 4096
23+
* auto norm = RMSNorm(4096);
24+
*
25+
* // Input: [batch, seq_len, hidden_size]
26+
* auto input = Tensor::randn({2, 10, 4096});
27+
*
28+
* // Output: [batch, seq_len, hidden_size]
29+
* auto output = norm.forward(input);
30+
* @endcode
31+
*/
32+
class RMSNorm : public Module {
33+
public:
34+
/**
35+
* @brief Construct a RMSNorm layer
36+
*
37+
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
38+
* @param eps Small constant for numerical stability (default: 1e-6)
39+
* @param device Device to create the weight on
40+
*/
41+
RMSNorm(size_t normalized_shape,
42+
double eps = 1e-6,
43+
const Device &device = Device());
44+
45+
/**
46+
* @brief Forward pass: apply RMSNorm
47+
*
48+
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions
49+
* @return Normalized tensor with same shape as input
50+
*
51+
* The normalization is applied over the last dimension.
52+
* For example:
53+
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
54+
* Input: [batch, hidden_size] -> normalize over hidden_size
55+
*/
56+
Tensor forward(const Tensor &x) const;
57+
58+
// Module information
59+
size_t normalized_shape() const { return normalized_shape_; }
60+
double eps() const { return eps_; }
61+
62+
// String representation
63+
std::string extra_repr() const;
64+
65+
// Accessors for parameters
66+
Tensor weight() const { return weight_; }
67+
68+
protected:
69+
// Parameters
70+
Parameter weight_;
71+
72+
private:
73+
size_t normalized_shape_; // Size of the feature dimension
74+
double eps_; // Epsilon for numerical stability
75+
};
76+
77+
} // namespace infinicore::nn

0 commit comments

Comments
 (0)