Skip to content

Commit fbb9d25

Browse files
committed
refractor module
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent ba89edc commit fbb9d25

2 files changed

Lines changed: 45 additions & 58 deletions

File tree

include/infinicore/nn/module.hpp

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
#include <unordered_map>
66
#include <type_traits>
7+
#include <vector>
78

89
namespace infinicore::nn {
910
class Module {
1011
public:
11-
Module() : path_("") {}
12+
Module() = default;
1213

1314
const std::unordered_map<std::string, Parameter> &state_dict() const;
1415

@@ -18,11 +19,9 @@ class Module {
1819

1920
void load_parameter_from_blob(const std::string &name, const void *data);
2021

22+
protected:
2123
Tensor register_parameter(const std::string &name, Parameter param);
2224

23-
// Get the full path of this module in the hierarchy
24-
const std::string &path() const { return path_; }
25-
2625
// Add an existing submodule to this module's hierarchy
2726
// Template parameter M must be a type derived from Module
2827
// Returns the submodule for convenience (allows method chaining)
@@ -32,14 +31,6 @@ class Module {
3231
static_assert(std::is_base_of<Module, M>::value,
3332
"Template parameter M must be derived from infinicore::nn::Module");
3433

35-
// Set the submodule's path based on this module's path
36-
std::string submodule_path = path_.empty() ? name : path_ + "." + name;
37-
submodule->path_ = submodule_path;
38-
39-
// Recursively update paths of all descendants
40-
// This handles the case where submodules are added in any order
41-
submodule->update_submodule_paths();
42-
4334
// Store in the submodules map (std::shared_ptr<M> automatically converts to std::shared_ptr<Module>)
4435
submodules_[name] = submodule;
4536

@@ -65,61 +56,68 @@ class Module {
6556
// Each submodule is named as "name.0", "name.1", etc.
6657
// Template parameter M must be a type derived from Module
6758
template <typename M, typename... Args>
68-
std::vector<std::shared_ptr<M>> register_modules(size_t layers, const std::string &name, Args &&...args) {
59+
std::vector<std::shared_ptr<M>> register_modules(size_t count, const std::string &name, Args &&...args) {
6960
static_assert(std::is_base_of<Module, M>::value,
7061
"Template parameter M must be derived from infinicore::nn::Module");
7162

72-
auto submodules = std::vector<std::shared_ptr<M>>(layers);
73-
for (size_t i = 0; i < layers; i++) {
74-
submodules[i] = register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...);
63+
std::vector<std::shared_ptr<M>> modules;
64+
modules.reserve(count);
65+
for (size_t i = 0; i < count; i++) {
66+
modules.push_back(register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...));
7567
}
76-
return submodules;
68+
return modules;
7769
}
7870

79-
private:
80-
void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params) const;
81-
void update_submodule_paths();
82-
8371
protected:
84-
std::string path_; // Full path of this module in the hierarchy (e.g., "layer1.sublayer")
8572
Device device_;
8673
std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
8774
std::unordered_map<std::string, Parameter> parameters_;
75+
76+
private:
77+
void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = "") const;
8878
};
8979

9080
// ============================================================================
9181
// PyTorch-like Macros for Convenient Module Registration
9282
// ============================================================================
9383

9484
/**
95-
* @brief Register a submodule with automatic name inference from variable name
85+
* @brief Register submodules with automatic name inference from variable name
9686
*
9787
* Usage:
9888
* @code
9989
* class MyModel : public Module {
100-
* INFINI_MODULE(Linear, layer1);
101-
* INFINI_MODULE(Linear, layer2);
90+
* protected:
91+
* INFINICORE_NN_MODULE(Linear, layer1);
92+
* INFINICORE_NN_MODULE(Linear, layer2);
93+
* INFINICORE_NN_MODULE_VEC(Linear, layers);
10294
*
10395
* public:
10496
* MyModel() {
105-
* INFINI_INIT_MODULE(layer1, 128, 64);
106-
* INFINI_INIT_MODULE(layer2, 64, 32);
97+
* INFINICORE_NN_MODULE_INIT(layer1, 128, 64);
98+
* INFINICORE_NN_MODULE_INIT(layer2, 64, 32);
99+
* layers_ = register_modules<Linear>(3, "layers", 32, 16);
107100
* }
108101
* };
109102
* @endcode
110103
*/
111104

112-
// Declare a module member variable
113-
#define INFINI_MODULE(ModuleType, name) \
105+
// Declare a single module member variable
106+
#define INFINICORE_NN_MODULE(ModuleType, name) \
114107
std::shared_ptr<ModuleType> name##_
115108

116-
// Initialize a module (use in constructor)
117-
#define INFINI_INIT_MODULE(name, ...) \
118-
name##_ = register_module<std::remove_reference<decltype(*name##_)>::type>(#name, ##__VA_ARGS__)
109+
// Declare a vector of modules member variable
110+
#define INFINICORE_NN_MODULE_VEC(ModuleType, name) \
111+
std::vector<std::shared_ptr<ModuleType>> name##_
112+
113+
// Initialize a module in constructor
114+
#define INFINICORE_NN_MODULE_INIT(name, ...) \
115+
name##_ = this->register_module<std::remove_reference<decltype(*name##_)>::type>(#name, ##__VA_ARGS__)
119116

120-
// Alternative: Combined declaration and initialization helper
121-
// Useful when you want to initialize inline (though less flexible)
122-
#define INFINI_SUBMODULE(ModuleType, name, ...) \
123-
name##_ = register_module<ModuleType>(#name, ##__VA_ARGS__)
117+
// Initialize a vector of modules in constructor
118+
// Usage: INFINICORE_NN_MODULE_VEC_INIT(layers, count, ModuleType, ctor_args...)
119+
// Example: INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 128, 64)
120+
#define INFINICORE_NN_MODULE_VEC_INIT(name, count, ModuleType, ...) \
121+
name##_ = this->register_modules<ModuleType>(count, #name, ##__VA_ARGS__)
124122

125123
} // namespace infinicore::nn

src/infinicore/nn/module.cc

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
55
static std::unordered_map<std::string, Parameter> result;
66
result.clear();
77

8-
collect_all_parameters(result);
8+
collect_all_parameters(result, "");
99

1010
return result;
1111
}
1212

1313
void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict) {
1414
// Collect all parameters from this module and its submodules with their full hierarchical names
1515
std::unordered_map<std::string, Parameter> all_params;
16-
collect_all_parameters(all_params);
16+
collect_all_parameters(all_params, "");
1717

1818
// For each parameter in this module hierarchy, load from the state dict
1919
for (auto &[param_full_name, param] : all_params) {
@@ -39,28 +39,17 @@ Tensor Module::register_parameter(const std::string &name, Parameter param) {
3939
return param;
4040
}
4141

42-
void Module::collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params) const {
43-
// Add direct parameters with this module's path as prefix
44-
for (const auto &p : parameters_) {
45-
std::string param_name = path_.empty() ? p.first : path_ + "." + p.first;
46-
all_params[param_name] = p.second;
42+
void Module::collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix) const {
43+
// Add direct parameters with the given prefix
44+
for (const auto &[param_name, param] : parameters_) {
45+
std::string full_name = prefix.empty() ? param_name : prefix + "." + param_name;
46+
all_params[full_name] = param;
4747
}
4848

49-
// Recursively collect parameters from submodules
50-
// Each submodule already knows its own path, so no need to pass prefix
51-
for (const auto &sub : submodules_) {
52-
sub.second->collect_all_parameters(all_params);
53-
}
54-
}
55-
56-
void Module::update_submodule_paths() {
57-
// Recursively update the paths of all submodules based on this module's current path
58-
for (auto &[name, submodule] : submodules_) {
59-
// Update submodule's path based on this module's current path
60-
std::string submodule_path = path_.empty() ? name : path_ + "." + name;
61-
submodule->path_ = submodule_path;
62-
// Recursively update descendants
63-
submodule->update_submodule_paths();
49+
// Recursively collect parameters from submodules with extended prefix
50+
for (const auto &[sub_name, submodule] : submodules_) {
51+
std::string sub_prefix = prefix.empty() ? sub_name : prefix + "." + sub_name;
52+
submodule->collect_all_parameters(all_params, sub_prefix);
6453
}
6554
}
6655

0 commit comments

Comments
 (0)