11#pragma once
22
33#include " parameter.hpp"
4+ #include " ../tensor.hpp"
45
56#include < unordered_map>
7+ #include < type_traits>
8+ #include < vector>
69
710namespace infinicore ::nn {
811class Module {
912public:
13+ Module () = default ;
14+
1015 const std::unordered_map<std::string, Parameter> &state_dict () const ;
1116
1217 void load_state_dict (const std::unordered_map<std::string, Tensor> &_state_dict);
@@ -15,35 +20,118 @@ class Module {
1520
1621 void load_parameter_from_blob (const std::string &name, const void *data);
1722
23+ protected:
1824 Tensor register_parameter (const std::string &name, Parameter param);
1925
26+ // Add an existing submodule to this module's hierarchy
27+ // Template parameter M must be a type derived from Module
28+ // Returns the submodule for convenience (allows method chaining)
2029 template <typename M>
2130 std::shared_ptr<M> add_module (const std::string &name, std::shared_ptr<M> submodule) {
31+ // Ensure M is derived from Module (compile-time check)
32+ static_assert (std::is_base_of<Module, M>::value,
33+ " Template parameter M must be derived from infinicore::nn::Module" );
34+
35+ // Store in the submodules map (std::shared_ptr<M> automatically converts to std::shared_ptr<Module>)
2236 submodules_[name] = submodule;
37+
2338 return submodule;
2439 }
2540
41+ // Create and register a new submodule by constructing it with the given arguments
42+ // Template parameter M must be a type derived from Module
43+ // Args are forwarded to M's constructor
2644 template <typename M, typename ... Args>
2745 std::shared_ptr<M> register_module (const std::string &name, Args &&...args) {
46+ // Ensure M is derived from Module (compile-time check)
47+ static_assert (std::is_base_of<Module, M>::value,
48+ " Template parameter M must be derived from infinicore::nn::Module" );
49+
50+ // Construct the submodule
2851 auto submodule = std::make_shared<M>(std::forward<Args>(args)...);
52+
2953 return add_module (name, submodule);
3054 }
3155
56+ // Create and register multiple submodules of the same type
57+ // Each submodule is named as "name.0", "name.1", etc.
58+ // Template parameter M must be a type derived from Module
3259 template <typename M, typename ... Args>
33- std::vector<std::shared_ptr<M>> register_modules (size_t layers, const std::string &name, Args &&...args) {
34- auto submodules = std::vector<std::shared_ptr<M>>(layers);
35- for (size_t i = 0 ; i < layers; i++) {
36- register_module<M>(name + " ." + std::to_string (i), std::forward<Args>(args)...);
60+ std::vector<std::shared_ptr<M>> register_modules (size_t count, const std::string &name, Args &&...args) {
61+ static_assert (std::is_base_of<Module, M>::value,
62+ " Template parameter M must be derived from infinicore::nn::Module" );
63+
64+ std::vector<std::shared_ptr<M>> modules;
65+ modules.reserve (count);
66+ for (size_t i = 0 ; i < count; i++) {
67+ modules.push_back (register_module<M>(name + " ." + std::to_string (i), std::forward<Args>(args)...));
3768 }
38- return submodules ;
69+ return modules ;
3970 }
4071
41- private:
42- void collect_all_parameters (const std::string &prefix, std::unordered_map<std::string, Parameter> &all_params) const ;
43-
4472protected:
4573 Device device_;
4674 std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
4775 std::unordered_map<std::string, Parameter> parameters_;
76+
77+ private:
78+ void collect_all_parameters (std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = " " ) const ;
4879};
80+
81+ // ============================================================================
82+ // PyTorch-like Macros for Convenient Module Registration
83+ // ============================================================================
84+
85+ /* *
86+ * @brief Register submodules with automatic name inference from variable name
87+ *
88+ * Usage:
89+ * @code
90+ * class MyModel : public Module {
91+ * protected:
92+ * INFINICORE_NN_MODULE(Linear, layer1);
93+ * INFINICORE_NN_MODULE(Linear, layer2);
94+ * INFINICORE_NN_MODULE_VEC(Linear, layers);
95+ * INFINICORE_NN_PARAMETER(scaling_factor);
96+ *
97+ * public:
98+ * MyModel() {
99+ * INFINICORE_NN_MODULE_INIT(layer1, 128, 64);
100+ * INFINICORE_NN_MODULE_INIT(layer2, 64, 32);
101+ * INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 32, 16);
102+ * INFINICORE_NN_PARAMETER_INIT(scaling_factor, ({1}, DataType::F32, Device()));
103+ * }
104+ * };
105+ * @endcode
106+ */
107+
108+ // Declare a single module member variable
109+ #define INFINICORE_NN_MODULE (ModuleType, name ) \
110+ std::shared_ptr<ModuleType> name##_
111+
112+ // Declare a vector of modules member variable
113+ #define INFINICORE_NN_MODULE_VEC (ModuleType, name ) \
114+ std::vector<std::shared_ptr<ModuleType>> name##_
115+
116+ // Initialize a module in constructor
117+ #define INFINICORE_NN_MODULE_INIT (name, ...) \
118+ name##_ = this ->register_module<std::remove_reference<decltype (*name##_)>::type>(#name, ##__VA_ARGS__)
119+
120+ // Initialize a vector of modules in constructor
121+ // Usage: INFINICORE_NN_MODULE_VEC_INIT(layers, count, ModuleType, ctor_args...)
122+ // Example: INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 128, 64)
123+ #define INFINICORE_NN_MODULE_VEC_INIT (name, count, ModuleType, ...) \
124+ name##_ = this ->register_modules<ModuleType>(count, #name, ##__VA_ARGS__)
125+
126+ // Declare a parameter member variable
127+ #define INFINICORE_NN_PARAMETER (name ) \
128+ Parameter name##_
129+
130+ // Initialize a parameter in constructor
131+ // Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
132+ // Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
133+ #define INFINICORE_NN_PARAMETER_INIT (name, args ) \
134+ name##_ = Parameter args; \
135+ this ->register_parameter (#name, name##_)
136+
49137} // namespace infinicore::nn
0 commit comments