Skip to content

Commit ba89edc

Browse files
committed
fix module hierarchy traverse && impl torch-like MACROs
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent ed586f8 commit ba89edc

8 files changed

Lines changed: 563 additions & 216 deletions

File tree

include/infinicore/nn/module.hpp

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
#include "parameter.hpp"
44

55
#include <unordered_map>
6+
#include <type_traits>
67

78
namespace infinicore::nn {
89
class Module {
910
public:
11+
Module() : path_("") {}
12+
1013
const std::unordered_map<std::string, Parameter> &state_dict() const;
1114

1215
void load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict);
@@ -17,37 +20,106 @@ class Module {
1720

1821
Tensor register_parameter(const std::string &name, Parameter param);
1922

20-
// Create a Linear submodule-like parameter set (weight and optional bias)
21-
// Mirrors torch.nn.Linear shapes: weight [out_features, in_features], bias [out_features]
22-
void linear(const std::string &name, size_t in_features, size_t out_features, bool bias = true);
23+
// Get the full path of this module in the hierarchy
24+
const std::string &path() const { return path_; }
2325

26+
// Add an existing submodule to this module's hierarchy
27+
// Template parameter M must be a type derived from Module
28+
// Returns the submodule for convenience (allows method chaining)
2429
template <typename M>
2530
std::shared_ptr<M> add_module(const std::string &name, std::shared_ptr<M> submodule) {
31+
// Ensure M is derived from Module (compile-time check)
32+
static_assert(std::is_base_of<Module, M>::value,
33+
"Template parameter M must be derived from infinicore::nn::Module");
34+
35+
// Set the submodule's path based on this module's path
36+
std::string submodule_path = path_.empty() ? name : path_ + "." + name;
37+
submodule->path_ = submodule_path;
38+
39+
// Recursively update paths of all descendants
40+
// This handles the case where submodules are added in any order
41+
submodule->update_submodule_paths();
42+
43+
// Store in the submodules map (std::shared_ptr<M> automatically converts to std::shared_ptr<Module>)
2644
submodules_[name] = submodule;
45+
2746
return submodule;
2847
}
2948

49+
// Create and register a new submodule by constructing it with the given arguments
50+
// Template parameter M must be a type derived from Module
51+
// Args are forwarded to M's constructor
3052
template <typename M, typename... Args>
3153
std::shared_ptr<M> register_module(const std::string &name, Args &&...args) {
54+
// Ensure M is derived from Module (compile-time check)
55+
static_assert(std::is_base_of<Module, M>::value,
56+
"Template parameter M must be derived from infinicore::nn::Module");
57+
58+
// Construct the submodule
3259
auto submodule = std::make_shared<M>(std::forward<Args>(args)...);
60+
3361
return add_module(name, submodule);
3462
}
3563

64+
// Create and register multiple submodules of the same type
65+
// Each submodule is named as "name.0", "name.1", etc.
66+
// Template parameter M must be a type derived from Module
3667
template <typename M, typename... Args>
3768
std::vector<std::shared_ptr<M>> register_modules(size_t layers, const std::string &name, Args &&...args) {
69+
static_assert(std::is_base_of<Module, M>::value,
70+
"Template parameter M must be derived from infinicore::nn::Module");
71+
3872
auto submodules = std::vector<std::shared_ptr<M>>(layers);
3973
for (size_t i = 0; i < layers; i++) {
40-
register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...);
74+
submodules[i] = register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...);
4175
}
4276
return submodules;
4377
}
4478

4579
private:
46-
void collect_all_parameters(const std::string &prefix, std::unordered_map<std::string, Parameter> &all_params) const;
80+
void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params) const;
81+
void update_submodule_paths();
4782

4883
protected:
84+
std::string path_; // Full path of this module in the hierarchy (e.g., "layer1.sublayer")
4985
Device device_;
5086
std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
5187
std::unordered_map<std::string, Parameter> parameters_;
5288
};
89+
90+
// ============================================================================
91+
// PyTorch-like Macros for Convenient Module Registration
92+
// ============================================================================
93+
94+
/**
95+
* @brief Register a submodule with automatic name inference from variable name
96+
*
97+
* Usage:
98+
* @code
99+
* class MyModel : public Module {
100+
* INFINI_MODULE(Linear, layer1);
101+
* INFINI_MODULE(Linear, layer2);
102+
*
103+
* public:
104+
* MyModel() {
105+
* INFINI_INIT_MODULE(layer1, 128, 64);
106+
* INFINI_INIT_MODULE(layer2, 64, 32);
107+
* }
108+
* };
109+
* @endcode
110+
*/
111+
112+
// Declare a module member variable
113+
#define INFINI_MODULE(ModuleType, name) \
114+
std::shared_ptr<ModuleType> name##_
115+
116+
// Initialize a module (use in constructor)
117+
#define INFINI_INIT_MODULE(name, ...) \
118+
name##_ = register_module<std::remove_reference<decltype(*name##_)>::type>(#name, ##__VA_ARGS__)
119+
120+
// Alternative: Combined declaration and initialization helper
121+
// Useful when you want to initialize inline (though less flexible)
122+
#define INFINI_SUBMODULE(ModuleType, name, ...) \
123+
name##_ = register_module<ModuleType>(#name, ##__VA_ARGS__)
124+
53125
} // namespace infinicore::nn

src/infinicore-test/main.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,29 @@ int main(int argc, char *argv[]) {
208208
auto results = runner.runAllTests();
209209
spdlog::debug("All tests completed");
210210

211-
// Count results
211+
// Count results and collect failed tests
212212
size_t passed = 0, failed = 0;
213+
std::vector<infinicore::test::TestResult> failed_tests;
213214
for (const auto &result : results) {
214215
if (result.passed) {
215216
passed++;
216217
} else {
217218
failed++;
219+
failed_tests.push_back(result);
220+
}
221+
}
222+
223+
// Print list of failed tests if any
224+
if (!failed_tests.empty()) {
225+
std::cout << "\n==============================================\n"
226+
<< "❌ FAILED TESTS\n"
227+
<< "==============================================" << std::endl;
228+
for (const auto &test : failed_tests) {
229+
std::cout << "" << test.test_name;
230+
if (!test.error_message.empty()) {
231+
std::cout << "\n Error: " << test.error_message;
232+
}
233+
std::cout << "\n Duration: " << test.duration.count() << "μs" << std::endl;
218234
}
219235
}
220236

@@ -229,7 +245,7 @@ int main(int argc, char *argv[]) {
229245

230246
// Exit with appropriate code
231247
if (failed > 0) {
232-
std::cout << "\n❌ Some tests failed. Please review the output above." << std::endl;
248+
std::cout << "\n❌ Some tests failed. Please review the failed tests list above." << std::endl;
233249
return EXIT_FAILURE;
234250
} else {
235251
std::cout << "\n✅ All tests passed!" << std::endl;

src/infinicore-test/memory_test.h

Lines changed: 7 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,72 +2,17 @@
22
#define __INFINICORE_MEMORY_TEST_H__
33

44
#include "../infinicore/context/allocators/memory_allocator.hpp"
5+
#include "test_runner.h"
56
#include <atomic>
67
#include <cassert>
7-
#include <chrono>
8-
#include <exception>
98
#include <future>
10-
#include <infinicore.hpp>
11-
#include <iostream>
12-
#include <memory>
139
#include <mutex>
1410
#include <queue>
15-
#include <spdlog/spdlog.h>
1611
#include <thread>
1712
#include <unordered_map>
18-
#include <vector>
1913

2014
namespace infinicore::test {
2115

22-
// Test result structure
23-
struct TestResult {
24-
std::string test_name;
25-
bool passed;
26-
std::string error_message;
27-
std::chrono::microseconds duration;
28-
29-
TestResult(const std::string &name, bool pass, const std::string &error = "",
30-
std::chrono::microseconds dur = std::chrono::microseconds(0))
31-
: test_name(name), passed(pass), error_message(error), duration(dur) {}
32-
};
33-
34-
// Test framework base class
35-
class MemoryTestFramework {
36-
public:
37-
virtual ~MemoryTestFramework() = default;
38-
virtual TestResult run() = 0;
39-
virtual std::string getName() const = 0;
40-
41-
protected:
42-
void logTestStart(const std::string &test_name) {
43-
std::cout << "[TEST] Starting: " << test_name << std::endl;
44-
}
45-
46-
void logTestResult(const TestResult &result) {
47-
std::cout << "[TEST] " << (result.passed ? "PASSED" : "FAILED")
48-
<< ": " << result.test_name;
49-
if (!result.passed && !result.error_message.empty()) {
50-
std::cout << " - " << result.error_message;
51-
}
52-
std::cout << " (Duration: " << result.duration.count() << "μs)" << std::endl;
53-
}
54-
55-
template <typename Func>
56-
TestResult measureTime(const std::string &test_name, Func &&func) {
57-
auto start = std::chrono::high_resolution_clock::now();
58-
try {
59-
bool result = func();
60-
auto end = std::chrono::high_resolution_clock::now();
61-
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
62-
return TestResult(test_name, result, "", duration);
63-
} catch (const std::exception &e) {
64-
auto end = std::chrono::high_resolution_clock::now();
65-
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
66-
return TestResult(test_name, false, e.what(), duration);
67-
}
68-
}
69-
};
70-
7116
// Mock allocator for testing exception safety
7217
class MockAllocator : public infinicore::MemoryAllocator {
7318
public:
@@ -149,13 +94,13 @@ class MemoryLeakDetector {
14994
};
15095

15196
// Test categories
152-
class BasicMemoryTest : public MemoryTestFramework {
97+
class BasicMemoryTest : public TestFramework {
15398
public:
15499
TestResult run() override;
155100
std::string getName() const override { return "BasicMemoryTest"; }
156101
};
157102

158-
class ConcurrencyTest : public MemoryTestFramework {
103+
class ConcurrencyTest : public TestFramework {
159104
public:
160105
TestResult run() override;
161106
std::string getName() const override { return "ConcurrencyTest"; }
@@ -166,7 +111,7 @@ class ConcurrencyTest : public MemoryTestFramework {
166111
TestResult testMemoryAllocationRace();
167112
};
168113

169-
class ExceptionSafetyTest : public MemoryTestFramework {
114+
class ExceptionSafetyTest : public TestFramework {
170115
public:
171116
TestResult run() override;
172117
std::string getName() const override { return "ExceptionSafetyTest"; }
@@ -177,7 +122,7 @@ class ExceptionSafetyTest : public MemoryTestFramework {
177122
TestResult testContextSwitchException();
178123
};
179124

180-
class MemoryLeakTest : public MemoryTestFramework {
125+
class MemoryLeakTest : public TestFramework {
181126
public:
182127
TestResult run() override;
183128
std::string getName() const override { return "MemoryLeakTest"; }
@@ -188,7 +133,7 @@ class MemoryLeakTest : public MemoryTestFramework {
188133
TestResult testExceptionLeakDetection();
189134
};
190135

191-
class PerformanceTest : public MemoryTestFramework {
136+
class PerformanceTest : public TestFramework {
192137
public:
193138
TestResult run() override;
194139
std::string getName() const override { return "PerformanceTest"; }
@@ -199,7 +144,7 @@ class PerformanceTest : public MemoryTestFramework {
199144
TestResult testMemoryCopyPerformance();
200145
};
201146

202-
class StressTest : public MemoryTestFramework {
147+
class StressTest : public TestFramework {
203148
public:
204149
TestResult run() override;
205150
std::string getName() const override { return "StressTest"; }

0 commit comments

Comments
 (0)