Skip to content

Commit ed586f8

Browse files
committed
update implementation of load_state_dict and add test case
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 5e2bf7c commit ed586f8

4 files changed

Lines changed: 272 additions & 17 deletions

File tree

include/infinicore/nn/linear.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ class Linear : public Module {
1010
Linear(size_t in_features, size_t out_features, bool bias = true, const Device &device = Device());
1111

1212
// Forward pass: output = input @ weight.T + bias
13-
Tensor forward(const Tensor &input) const;
13+
Tensor forward(Tensor &input) const;
1414

1515
// Forward pass with residual connection (InfiniLM-style)
1616
// output = input @ weight.T + bias + residual
17-
Tensor forward(const Tensor &input, const Tensor &residual) const;
17+
Tensor forward(Tensor &input, Tensor &residual) const;
1818

1919
// Accessors for parameters
2020
Tensor weight() const { return weight_; }
@@ -34,7 +34,7 @@ class Linear : public Module {
3434

3535
private:
3636
// Helper method for common forward computation
37-
Tensor compute_linear(const Tensor &input) const;
37+
Tensor compute_linear(Tensor &input) const;
3838

3939
size_t in_features_;
4040
size_t out_features_;

src/infinicore-test/test_nn_module.cc

Lines changed: 232 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,29 +184,209 @@ TestResult NNModuleTest::testStateDict() {
184184
TestResult NNModuleTest::testLoadStateDict() {
185185
return measureTime("LoadStateDict", [this]() {
186186
try {
187+
spdlog::info("Testing Module::load_state_dict functionality");
188+
187189
MockLinearModule module(4, 2, infinicore::Device());
188190

189-
// Create new tensors
190-
infinicore::Tensor new_weight = infinicore::Tensor::empty({2, 4}, infinicore::DataType::F32, infinicore::Device());
191-
infinicore::Tensor new_bias = infinicore::Tensor::empty({2}, infinicore::DataType::F32, infinicore::Device());
191+
// Test 1: Load parameters using load_parameter
192+
spdlog::info("Test 1: Loading individual parameters with load_parameter");
193+
infinicore::Tensor new_weight = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device());
194+
infinicore::Tensor new_bias = infinicore::Tensor::zeros({2}, infinicore::DataType::F32, infinicore::Device());
192195

193-
// Load individual parameters
194196
module.load_parameter("weight", new_weight);
195197
module.load_parameter("bias", new_bias);
196198

197-
std::cout << "Successfully loaded parameters" << std::endl;
199+
spdlog::debug("Successfully loaded parameters using load_parameter");
198200

199201
// Verify the parameters were updated
200202
auto updated_state_dict = module.state_dict();
201203
if (updated_state_dict.size() != 2) {
202-
std::cout << "Error: State dict size mismatch after loading" << std::endl;
204+
spdlog::error("State dict size mismatch after loading. Expected 2, got {}", updated_state_dict.size());
205+
return false;
206+
}
207+
208+
// Verify parameter values
209+
if (!tensorsAllClose(updated_state_dict.at("weight"), new_weight, 1e-6, 1e-6)) {
210+
spdlog::error("Weight parameter values do not match after load_parameter");
211+
return false;
212+
}
213+
if (!tensorsAllClose(updated_state_dict.at("bias"), new_bias, 1e-6, 1e-6)) {
214+
spdlog::error("Bias parameter values do not match after load_parameter");
215+
return false;
216+
}
217+
spdlog::debug("load_parameter verification passed");
218+
219+
// Test 2: Load entire state dict using load_state_dict
220+
spdlog::info("Test 2: Loading entire state dict with load_state_dict");
221+
222+
// Create custom weight and bias tensors with known values
223+
// Just use ones for simplicity - all values will be 1.0
224+
auto custom_weight = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device());
225+
auto custom_bias = infinicore::Tensor::ones({2}, infinicore::DataType::F32, infinicore::Device());
226+
227+
// Create state dict
228+
std::unordered_map<std::string, infinicore::Tensor> new_state_dict;
229+
new_state_dict.emplace("weight", custom_weight);
230+
new_state_dict.emplace("bias", custom_bias);
231+
232+
// Load the entire state dict
233+
module.load_state_dict(new_state_dict);
234+
spdlog::debug("Successfully loaded state dict using load_state_dict");
235+
236+
// Verify that parameters were loaded correctly
237+
auto final_state_dict = module.state_dict();
238+
239+
if (final_state_dict.size() != 2) {
240+
spdlog::error("State dict size mismatch after load_state_dict. Expected 2, got {}", final_state_dict.size());
241+
return false;
242+
}
243+
244+
if (final_state_dict.at("weight")->shape() != std::vector<size_t>({2, 4})) {
245+
spdlog::error("Loaded weight shape mismatch");
246+
return false;
247+
}
248+
if (final_state_dict.at("bias")->shape() != std::vector<size_t>({2})) {
249+
spdlog::error("Loaded bias shape mismatch");
250+
return false;
251+
}
252+
253+
spdlog::debug("load_state_dict verification passed - shapes are correct");
254+
spdlog::info("Skipping value comparison for now - test focuses on load mechanism");
255+
256+
// Test 3: Test with Linear module to verify field synchronization
257+
spdlog::info("Test 3: Testing load_state_dict with Linear module (field synchronization)");
258+
infinicore::nn::Linear linear_module(4, 2, true, infinicore::Device());
259+
260+
// Create known parameter values - just use ones for simplicity
261+
auto linear_weight = infinicore::Tensor::ones({2, 4}, infinicore::DataType::F32, infinicore::Device());
262+
auto linear_bias = infinicore::Tensor::ones({2}, infinicore::DataType::F32, infinicore::Device());
263+
264+
std::unordered_map<std::string, infinicore::Tensor> linear_state_dict;
265+
linear_state_dict.emplace("weight", linear_weight);
266+
linear_state_dict.emplace("bias", linear_bias);
267+
268+
// Load state dict into Linear module
269+
linear_module.load_state_dict(linear_state_dict);
270+
271+
// Verify shapes using both state_dict() and direct field access
272+
auto loaded_via_state_dict_weight = linear_module.state_dict().at("weight");
273+
auto loaded_via_field_weight = linear_module.weight();
274+
auto loaded_via_field_bias = linear_module.bias();
275+
276+
if (loaded_via_state_dict_weight->shape() != std::vector<size_t>({2, 4})) {
277+
spdlog::error("Linear weight shape mismatch via state_dict");
278+
return false;
279+
}
280+
if (loaded_via_field_weight->shape() != std::vector<size_t>({2, 4})) {
281+
spdlog::error("Linear weight field shape mismatch");
282+
return false;
283+
}
284+
if (loaded_via_field_bias->shape() != std::vector<size_t>({2})) {
285+
spdlog::error("Linear bias field shape mismatch");
286+
return false;
287+
}
288+
289+
spdlog::debug("Linear module load_state_dict verification passed - field shapes synchronized");
290+
spdlog::info("Skipping value comparison - test focuses on field synchronization mechanism");
291+
292+
// Test 4: Deep nesting (2-level hierarchy)
293+
spdlog::info("Test 4: Testing load_state_dict with 2-level deep nesting");
294+
295+
// Create parent -> child -> grandchild hierarchy
296+
MockLinearModule deep_parent(10, 8, infinicore::Device());
297+
auto deep_child = std::make_shared<MockLinearModule>(8, 6, infinicore::Device());
298+
auto deep_grandchild = std::make_shared<MockLinearModule>(6, 4, infinicore::Device());
299+
300+
// Build hierarchy: parent -> layer1 -> sublayer
301+
deep_child->add_module("sublayer", deep_grandchild);
302+
deep_parent.add_module("layer1", deep_child);
303+
304+
// Verify initial state dict includes all 2-level hierarchical parameters
305+
auto deep_initial_state = deep_parent.state_dict();
306+
spdlog::debug("Deep hierarchical state dict has {} parameters", deep_initial_state.size());
307+
308+
// Expected parameters:
309+
// parent: weight, bias (2)
310+
// layer1: layer1.weight, layer1.bias (2)
311+
// sublayer: layer1.sublayer.weight, layer1.sublayer.bias (2)
312+
// Total: 6 parameters
313+
if (deep_initial_state.size() < 6) {
314+
spdlog::error("Deep hierarchy state dict size mismatch. Expected at least 6, got {}",
315+
deep_initial_state.size());
316+
return false;
317+
}
318+
319+
// Verify 2-level parameter names exist
320+
bool has_sublayer_weight = deep_initial_state.find("layer1.sublayer.weight") != deep_initial_state.end();
321+
bool has_sublayer_bias = deep_initial_state.find("layer1.sublayer.bias") != deep_initial_state.end();
322+
323+
if (!has_sublayer_weight || !has_sublayer_bias) {
324+
spdlog::error("2-level nested parameters missing from state dict");
325+
return false;
326+
}
327+
spdlog::debug("All 2-level hierarchical parameter names verified");
328+
329+
// Create state dict for 2-level hierarchy with all 1.0 values
330+
std::unordered_map<std::string, infinicore::Tensor> deep_state_dict;
331+
deep_state_dict.emplace("weight", infinicore::Tensor::ones({8, 10}, infinicore::DataType::F32, infinicore::Device()));
332+
deep_state_dict.emplace("bias", infinicore::Tensor::ones({8}, infinicore::DataType::F32, infinicore::Device()));
333+
deep_state_dict.emplace("layer1.weight", infinicore::Tensor::ones({6, 8}, infinicore::DataType::F32, infinicore::Device()));
334+
deep_state_dict.emplace("layer1.bias", infinicore::Tensor::ones({6}, infinicore::DataType::F32, infinicore::Device()));
335+
deep_state_dict.emplace("layer1.sublayer.weight", infinicore::Tensor::ones({4, 6}, infinicore::DataType::F32, infinicore::Device()));
336+
deep_state_dict.emplace("layer1.sublayer.bias", infinicore::Tensor::ones({4}, infinicore::DataType::F32, infinicore::Device()));
337+
338+
// Load the deep hierarchical state dict
339+
deep_parent.load_state_dict(deep_state_dict);
340+
spdlog::debug("Successfully loaded 2-level deep hierarchical state dict");
341+
342+
// Verify all parameters were loaded correctly
343+
auto deep_loaded_state = deep_parent.state_dict();
344+
345+
// Verify shapes at all levels
346+
if (deep_loaded_state.at("weight")->shape() != std::vector<size_t>({8, 10})) {
347+
spdlog::error("Deep parent weight shape mismatch");
348+
return false;
349+
}
350+
if (deep_loaded_state.at("layer1.weight")->shape() != std::vector<size_t>({6, 8})) {
351+
spdlog::error("Deep layer1 weight shape mismatch");
352+
return false;
353+
}
354+
if (deep_loaded_state.at("layer1.sublayer.weight")->shape() != std::vector<size_t>({4, 6})) {
355+
spdlog::error("Deep sublayer weight shape mismatch");
203356
return false;
204357
}
358+
spdlog::debug("All 2-level deep parameter shapes verified");
205359

206-
std::cout << "Load state dict test passed" << std::endl;
360+
// Verify actual weight loading correctness by checking that loaded parameters
361+
// match what we provided in the state dict (use the original tensors)
362+
spdlog::info("Verifying weight loading correctness by direct comparison");
363+
364+
// Get the tensors we loaded from the state dict
365+
auto loaded_parent_weight = deep_loaded_state.at("weight");
366+
auto loaded_layer1_weight = deep_loaded_state.at("layer1.weight");
367+
auto loaded_sublayer_weight = deep_loaded_state.at("layer1.sublayer.weight");
368+
369+
// Compare with the original tensors we put in the state dict
370+
if (!tensorsAllClose(loaded_parent_weight, deep_state_dict.at("weight"), 1e-5, 1e-5)) {
371+
spdlog::error("Deep parent weight not preserved after loading");
372+
return false;
373+
}
374+
if (!tensorsAllClose(loaded_layer1_weight, deep_state_dict.at("layer1.weight"), 1e-5, 1e-5)) {
375+
spdlog::error("Deep layer1 weight not preserved after loading");
376+
return false;
377+
}
378+
if (!tensorsAllClose(loaded_sublayer_weight, deep_state_dict.at("layer1.sublayer.weight"), 1e-5, 1e-5)) {
379+
spdlog::error("Deep sublayer weight not preserved after loading");
380+
return false;
381+
}
382+
383+
spdlog::info("Weight loading correctness verified - loaded values match input state dict");
384+
spdlog::info("2-level deep hierarchy load_state_dict verification passed");
385+
386+
spdlog::info("All load_state_dict tests passed (including deep hierarchy)");
207387
return true;
208388
} catch (const std::exception &e) {
209-
std::cout << "Exception in testLoadStateDict: " << e.what() << std::endl;
389+
spdlog::error("Exception in testLoadStateDict: {}", e.what());
210390
return false;
211391
}
212392
});
@@ -432,6 +612,25 @@ TestResult NNModuleTest::testModuleLinear() {
432612
}
433613
spdlog::debug("Linear computation without bias passed. Input shape: {{1, 16}}, Output shape: {{1, 3}}");
434614

615+
// Test load_state_dict for m2 (without bias)
616+
spdlog::info("Testing load_state_dict on Linear without bias");
617+
auto m2_load_weight = infinicore::Tensor::ones({3, 16}, infinicore::DataType::F32, infinicore::Device());
618+
std::unordered_map<std::string, infinicore::Tensor> m2_state_dict;
619+
m2_state_dict.emplace("weight", m2_load_weight);
620+
// Note: no bias parameter
621+
m2.load_state_dict(m2_state_dict);
622+
623+
// Verify via state_dict() and direct access
624+
if (!tensorsAllClose(m2.state_dict().at("weight"), m2_load_weight, 1e-5, 1e-5)) {
625+
spdlog::error("m2 weight not loaded correctly");
626+
return false;
627+
}
628+
if (!tensorsAllClose(m2.weight(), m2_load_weight, 1e-5, 1e-5)) {
629+
spdlog::error("m2 weight field not synchronized");
630+
return false;
631+
}
632+
spdlog::debug("m2 load_state_dict verified - weight loaded correctly (no bias)");
633+
435634
// Test batch processing
436635
spdlog::info("Testing batch linear computation (batch size 3)");
437636
auto input3 = infinicore::Tensor::ones({3, 8}, infinicore::DataType::F32, infinicore::Device());
@@ -455,6 +654,30 @@ TestResult NNModuleTest::testModuleLinear() {
455654
return false;
456655
}
457656

657+
// Test load_state_dict for m1 (with bias)
658+
spdlog::info("Testing load_state_dict on Linear with bias");
659+
auto m1_load_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::F32, infinicore::Device());
660+
auto m1_load_bias = infinicore::Tensor::ones({4}, infinicore::DataType::F32, infinicore::Device());
661+
std::unordered_map<std::string, infinicore::Tensor> m1_state_dict;
662+
m1_state_dict.emplace("weight", m1_load_weight);
663+
m1_state_dict.emplace("bias", m1_load_bias);
664+
m1.load_state_dict(m1_state_dict);
665+
666+
// Verify via state_dict() and direct access
667+
if (!tensorsAllClose(m1.state_dict().at("weight"), m1_load_weight, 1e-5, 1e-5)) {
668+
spdlog::error("m1 weight not loaded correctly");
669+
return false;
670+
}
671+
if (!tensorsAllClose(m1.weight(), m1_load_weight, 1e-5, 1e-5)) {
672+
spdlog::error("m1 weight field not synchronized");
673+
return false;
674+
}
675+
if (!tensorsAllClose(m1.bias(), m1_load_bias, 1e-5, 1e-5)) {
676+
spdlog::error("m1 bias field not synchronized");
677+
return false;
678+
}
679+
spdlog::debug("m1 load_state_dict verified - parameters and fields synchronized");
680+
458681
// Test extra_repr
459682
std::string repr = m1.extra_repr();
460683
spdlog::debug("Linear module representation: {}", repr);
@@ -575,7 +798,7 @@ TestResult NNModuleTest::testModuleLinear() {
575798
spdlog::debug("Basic forward computation correctness test passed - both implementations produce identical results");
576799
spdlog::debug("Basic InfiniCore output shape: {{2, 4}}, Basic naive output shape: {{2, 4}}");
577800

578-
spdlog::info("Linear module test with computation verification passed");
801+
spdlog::info("All Linear module tests passed (with/without bias, load_state_dict, computation verification)");
579802
return true;
580803
} catch (const std::exception &e) {
581804
spdlog::error("Exception in testModuleLinear: {}", e.what());

src/infinicore/nn/linear.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Linear::Linear(size_t in_features, size_t out_features, bool bias, const Device
2525
in_features, out_features, bias);
2626
}
2727

28-
Tensor Linear::compute_linear(const Tensor &input) const {
28+
Tensor Linear::compute_linear(Tensor &input) const {
2929
// Create output tensor with shape [batch_size, out_features]
3030
auto output_shape = input->shape();
3131
output_shape[output_shape.size() - 1] = out_features_;
@@ -55,11 +55,11 @@ Tensor Linear::compute_linear(const Tensor &input) const {
5555
return output;
5656
}
5757

58-
Tensor Linear::forward(const Tensor &input) const {
58+
Tensor Linear::forward(Tensor &input) const {
5959
return compute_linear(input);
6060
}
6161

62-
Tensor Linear::forward(const Tensor &input, const Tensor &residual) const {
62+
Tensor Linear::forward(Tensor &input, Tensor &residual) const {
6363
auto output = compute_linear(input);
6464

6565
// Add residual: output = output + residual

src/infinicore/nn/module.cc

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,40 @@ const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
1717
}
1818

1919
void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict) {
20-
for (auto &p : parameters_) {
21-
load_parameter(p.first, p.second);
20+
// Collect all parameters from this module and its submodules
21+
std::unordered_map<std::string, Parameter> all_params;
22+
collect_all_parameters("", all_params);
23+
24+
// For each parameter in this module hierarchy, load from the state dict
25+
for (const auto &[param_full_name, param] : all_params) {
26+
// Look up the corresponding tensor in the input state dict
27+
auto it = _state_dict.find(param_full_name);
28+
if (it != _state_dict.end()) {
29+
// Navigate to the correct module by following the path
30+
Module *target_module = this;
31+
std::string remaining_path = param_full_name;
32+
33+
// Split the name by dots and traverse the module hierarchy
34+
size_t pos = 0;
35+
while ((pos = remaining_path.find('.')) != std::string::npos) {
36+
std::string submodule_name = remaining_path.substr(0, pos);
37+
remaining_path = remaining_path.substr(pos + 1);
38+
39+
// Navigate to the submodule
40+
auto sub_it = target_module->submodules_.find(submodule_name);
41+
if (sub_it != target_module->submodules_.end()) {
42+
target_module = sub_it->second.get();
43+
} else {
44+
target_module = nullptr;
45+
break;
46+
}
47+
}
48+
49+
// Load the parameter into the target module
50+
if (target_module != nullptr) {
51+
target_module->load_parameter(remaining_path, it->second);
52+
}
53+
}
2254
}
2355
}
2456

0 commit comments

Comments
 (0)