44
55#include < unordered_map>
66#include < type_traits>
7+ #include < vector>
78
89namespace infinicore ::nn {
910class Module {
1011public:
11- Module () : path_( " " ) {}
12+ Module () = default ;
1213
1314 const std::unordered_map<std::string, Parameter> &state_dict () const ;
1415
@@ -18,11 +19,9 @@ class Module {
1819
1920 void load_parameter_from_blob (const std::string &name, const void *data);
2021
22+ protected:
2123 Tensor register_parameter (const std::string &name, Parameter param);
2224
23- // Get the full path of this module in the hierarchy
24- const std::string &path () const { return path_; }
25-
2625 // Add an existing submodule to this module's hierarchy
2726 // Template parameter M must be a type derived from Module
2827 // Returns the submodule for convenience (allows method chaining)
@@ -32,14 +31,6 @@ class Module {
3231 static_assert (std::is_base_of<Module, M>::value,
3332 " Template parameter M must be derived from infinicore::nn::Module" );
3433
35- // Set the submodule's path based on this module's path
36- std::string submodule_path = path_.empty () ? name : path_ + " ." + name;
37- submodule->path_ = submodule_path;
38-
39- // Recursively update paths of all descendants
40- // This handles the case where submodules are added in any order
41- submodule->update_submodule_paths ();
42-
4334 // Store in the submodules map (std::shared_ptr<M> automatically converts to std::shared_ptr<Module>)
4435 submodules_[name] = submodule;
4536
@@ -65,61 +56,68 @@ class Module {
6556 // Each submodule is named as "name.0", "name.1", etc.
6657 // Template parameter M must be a type derived from Module
6758 template <typename M, typename ... Args>
68- std::vector<std::shared_ptr<M>> register_modules (size_t layers , const std::string &name, Args &&...args) {
59+ std::vector<std::shared_ptr<M>> register_modules (size_t count , const std::string &name, Args &&...args) {
6960 static_assert (std::is_base_of<Module, M>::value,
7061 " Template parameter M must be derived from infinicore::nn::Module" );
7162
72- auto submodules = std::vector<std::shared_ptr<M>>(layers);
73- for (size_t i = 0 ; i < layers; i++) {
74- submodules[i] = register_module<M>(name + " ." + std::to_string (i), std::forward<Args>(args)...);
63+ std::vector<std::shared_ptr<M>> modules;
64+ modules.reserve (count);
65+ for (size_t i = 0 ; i < count; i++) {
66+ modules.push_back (register_module<M>(name + " ." + std::to_string (i), std::forward<Args>(args)...));
7567 }
76- return submodules ;
68+ return modules ;
7769 }
7870
79- private:
80- void collect_all_parameters (std::unordered_map<std::string, Parameter> &all_params) const ;
81- void update_submodule_paths ();
82-
8371protected:
84- std::string path_; // Full path of this module in the hierarchy (e.g., "layer1.sublayer")
8572 Device device_;
8673 std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
8774 std::unordered_map<std::string, Parameter> parameters_;
75+
76+ private:
77+ void collect_all_parameters (std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = " " ) const ;
8878};
8979
9080// ============================================================================
9181// PyTorch-like Macros for Convenient Module Registration
9282// ============================================================================
9383
9484/* *
95- * @brief Register a submodule with automatic name inference from variable name
85+ * @brief Register submodules with automatic name inference from variable name
9686 *
9787 * Usage:
9888 * @code
9989 * class MyModel : public Module {
100- * INFINI_MODULE(Linear, layer1);
101- * INFINI_MODULE(Linear, layer2);
90+ * protected:
91+ * INFINICORE_NN_MODULE(Linear, layer1);
92+ * INFINICORE_NN_MODULE(Linear, layer2);
93+ * INFINICORE_NN_MODULE_VEC(Linear, layers);
10294 *
10395 * public:
10496 * MyModel() {
105- * INFINI_INIT_MODULE(layer1, 128, 64);
106- * INFINI_INIT_MODULE(layer2, 64, 32);
97+ * INFINICORE_NN_MODULE_INIT(layer1, 128, 64);
98+ * INFINICORE_NN_MODULE_INIT(layer2, 64, 32);
99+ * layers_ = register_modules<Linear>(3, "layers", 32, 16);
107100 * }
108101 * };
109102 * @endcode
110103 */
111104
112- // Declare a module member variable
113- #define INFINI_MODULE (ModuleType, name ) \
105+ // Declare a single module member variable
106+ #define INFINICORE_NN_MODULE (ModuleType, name ) \
114107 std::shared_ptr<ModuleType> name##_
115108
116- // Initialize a module (use in constructor)
117- #define INFINI_INIT_MODULE (name, ...) \
118- name##_ = register_module<std::remove_reference<decltype (*name##_)>::type>(#name, ##__VA_ARGS__)
109+ // Declare a vector of modules member variable
110+ #define INFINICORE_NN_MODULE_VEC (ModuleType, name ) \
111+ std::vector<std::shared_ptr<ModuleType>> name##_
112+
113+ // Initialize a module in constructor
114+ #define INFINICORE_NN_MODULE_INIT (name, ...) \
115+ name##_ = this ->register_module<std::remove_reference<decltype (*name##_)>::type>(#name, ##__VA_ARGS__)
119116
120- // Alternative: Combined declaration and initialization helper
121- // Useful when you want to initialize inline (though less flexible)
122- #define INFINI_SUBMODULE (ModuleType, name, ...) \
123- name##_ = register_module<ModuleType>(#name, ##__VA_ARGS__)
117+ // Initialize a vector of modules in constructor
118+ // Usage: INFINICORE_NN_MODULE_VEC_INIT(layers, count, ModuleType, ctor_args...)
119+ // Example: INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 128, 64)
120+ #define INFINICORE_NN_MODULE_VEC_INIT (name, count, ModuleType, ...) \
121+ name##_ = this ->register_modules<ModuleType>(count, #name, ##__VA_ARGS__)
124122
125123} // namespace infinicore::nn
0 commit comments