33#include " parameter.hpp"
44
55#include < unordered_map>
6+ #include < type_traits>
67
78namespace infinicore ::nn {
89class Module {
910public:
11+ Module () : path_(" " ) {}
12+
1013 const std::unordered_map<std::string, Parameter> &state_dict () const ;
1114
1215 void load_state_dict (const std::unordered_map<std::string, Tensor> &_state_dict);
@@ -17,37 +20,106 @@ class Module {
1720
1821 Tensor register_parameter (const std::string &name, Parameter param);
1922
20- // Create a Linear submodule-like parameter set (weight and optional bias)
21- // Mirrors torch.nn.Linear shapes: weight [out_features, in_features], bias [out_features]
22- void linear (const std::string &name, size_t in_features, size_t out_features, bool bias = true );
23+ // Get the full path of this module in the hierarchy
24+ const std::string &path () const { return path_; }
2325
26+ // Add an existing submodule to this module's hierarchy
27+ // Template parameter M must be a type derived from Module
28+ // Returns the submodule for convenience (allows method chaining)
2429 template <typename M>
2530 std::shared_ptr<M> add_module (const std::string &name, std::shared_ptr<M> submodule) {
31+ // Ensure M is derived from Module (compile-time check)
32+ static_assert (std::is_base_of<Module, M>::value,
33+ " Template parameter M must be derived from infinicore::nn::Module" );
34+
35+ // Set the submodule's path based on this module's path
36+ std::string submodule_path = path_.empty () ? name : path_ + " ." + name;
37+ submodule->path_ = submodule_path;
38+
39+ // Recursively update paths of all descendants
40+ // This handles the case where submodules are added in any order
41+ submodule->update_submodule_paths ();
42+
43+ // Store in the submodules map (std::shared_ptr<M> automatically converts to std::shared_ptr<Module>)
2644 submodules_[name] = submodule;
45+
2746 return submodule;
2847 }
2948
49+ // Create and register a new submodule by constructing it with the given arguments
50+ // Template parameter M must be a type derived from Module
51+ // Args are forwarded to M's constructor
3052 template <typename M, typename ... Args>
3153 std::shared_ptr<M> register_module (const std::string &name, Args &&...args) {
54+ // Ensure M is derived from Module (compile-time check)
55+ static_assert (std::is_base_of<Module, M>::value,
56+ " Template parameter M must be derived from infinicore::nn::Module" );
57+
58+ // Construct the submodule
3259 auto submodule = std::make_shared<M>(std::forward<Args>(args)...);
60+
3361 return add_module (name, submodule);
3462 }
3563
64+ // Create and register multiple submodules of the same type
65+ // Each submodule is named as "name.0", "name.1", etc.
66+ // Template parameter M must be a type derived from Module
3667 template <typename M, typename ... Args>
3768 std::vector<std::shared_ptr<M>> register_modules (size_t layers, const std::string &name, Args &&...args) {
69+ static_assert (std::is_base_of<Module, M>::value,
70+ " Template parameter M must be derived from infinicore::nn::Module" );
71+
3872 auto submodules = std::vector<std::shared_ptr<M>>(layers);
3973 for (size_t i = 0 ; i < layers; i++) {
40- register_module<M>(name + " ." + std::to_string (i), std::forward<Args>(args)...);
74+ submodules[i] = register_module<M>(name + " ." + std::to_string (i), std::forward<Args>(args)...);
4175 }
4276 return submodules;
4377 }
4478
4579private:
46- void collect_all_parameters (const std::string &prefix, std::unordered_map<std::string, Parameter> &all_params) const ;
80+ void collect_all_parameters (std::unordered_map<std::string, Parameter> &all_params) const ;
81+ void update_submodule_paths ();
4782
4883protected:
84+ std::string path_; // Full path of this module in the hierarchy (e.g., "layer1.sublayer")
4985 Device device_;
5086 std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
5187 std::unordered_map<std::string, Parameter> parameters_;
5288};
89+
90+ // ============================================================================
91+ // PyTorch-like Macros for Convenient Module Registration
92+ // ============================================================================
93+
94+ /* *
95+ * @brief Register a submodule with automatic name inference from variable name
96+ *
97+ * Usage:
98+ * @code
99+ * class MyModel : public Module {
100+ * INFINI_MODULE(Linear, layer1);
101+ * INFINI_MODULE(Linear, layer2);
102+ *
103+ * public:
104+ * MyModel() {
105+ * INFINI_INIT_MODULE(layer1, 128, 64);
106+ * INFINI_INIT_MODULE(layer2, 64, 32);
107+ * }
108+ * };
109+ * @endcode
110+ */
111+
112+ // Declare a module member variable
113+ #define INFINI_MODULE (ModuleType, name ) \
114+ std::shared_ptr<ModuleType> name##_
115+
116+ // Initialize a module (use in constructor)
117+ #define INFINI_INIT_MODULE (name, ...) \
118+ name##_ = register_module<std::remove_reference<decltype (*name##_)>::type>(#name, ##__VA_ARGS__)
119+
120+ // Alternative: Combined declaration and initialization helper
121+ // Useful when you want to initialize inline (though less flexible)
122+ #define INFINI_SUBMODULE (ModuleType, name, ...) \
123+ name##_ = register_module<ModuleType>(#name, ##__VA_ARGS__)
124+
53125} // namespace infinicore::nn
0 commit comments