@@ -184,29 +184,209 @@ TestResult NNModuleTest::testStateDict() {
184184TestResult NNModuleTest::testLoadStateDict () {
185185 return measureTime (" LoadStateDict" , [this ]() {
186186 try {
187+ spdlog::info (" Testing Module::load_state_dict functionality" );
188+
187189 MockLinearModule module (4 , 2 , infinicore::Device ());
188190
189- // Create new tensors
190- infinicore::Tensor new_weight = infinicore::Tensor::empty ({2 , 4 }, infinicore::DataType::F32, infinicore::Device ());
191- infinicore::Tensor new_bias = infinicore::Tensor::empty ({2 }, infinicore::DataType::F32, infinicore::Device ());
191+ // Test 1: Load parameters using load_parameter
192+ spdlog::info (" Test 1: Loading individual parameters with load_parameter" );
193+ infinicore::Tensor new_weight = infinicore::Tensor::ones ({2 , 4 }, infinicore::DataType::F32, infinicore::Device ());
194+ infinicore::Tensor new_bias = infinicore::Tensor::zeros ({2 }, infinicore::DataType::F32, infinicore::Device ());
192195
193- // Load individual parameters
194196 module .load_parameter (" weight" , new_weight);
195197 module .load_parameter (" bias" , new_bias);
196198
197- std::cout << " Successfully loaded parameters" << std::endl ;
199+ spdlog::debug ( " Successfully loaded parameters using load_parameter " ) ;
198200
199201 // Verify the parameters were updated
200202 auto updated_state_dict = module .state_dict ();
201203 if (updated_state_dict.size () != 2 ) {
202- std::cout << " Error: State dict size mismatch after loading" << std::endl;
204+ spdlog::error (" State dict size mismatch after loading. Expected 2, got {}" , updated_state_dict.size ());
205+ return false ;
206+ }
207+
208+ // Verify parameter values
209+ if (!tensorsAllClose (updated_state_dict.at (" weight" ), new_weight, 1e-6 , 1e-6 )) {
210+ spdlog::error (" Weight parameter values do not match after load_parameter" );
211+ return false ;
212+ }
213+ if (!tensorsAllClose (updated_state_dict.at (" bias" ), new_bias, 1e-6 , 1e-6 )) {
214+ spdlog::error (" Bias parameter values do not match after load_parameter" );
215+ return false ;
216+ }
217+ spdlog::debug (" load_parameter verification passed" );
218+
219+ // Test 2: Load entire state dict using load_state_dict
220+ spdlog::info (" Test 2: Loading entire state dict with load_state_dict" );
221+
222+ // Create custom weight and bias tensors with known values
223+ // Just use ones for simplicity - all values will be 1.0
224+ auto custom_weight = infinicore::Tensor::ones ({2 , 4 }, infinicore::DataType::F32, infinicore::Device ());
225+ auto custom_bias = infinicore::Tensor::ones ({2 }, infinicore::DataType::F32, infinicore::Device ());
226+
227+ // Create state dict
228+ std::unordered_map<std::string, infinicore::Tensor> new_state_dict;
229+ new_state_dict.emplace (" weight" , custom_weight);
230+ new_state_dict.emplace (" bias" , custom_bias);
231+
232+ // Load the entire state dict
233+ module .load_state_dict (new_state_dict);
234+ spdlog::debug (" Successfully loaded state dict using load_state_dict" );
235+
236+ // Verify that parameters were loaded correctly
237+ auto final_state_dict = module .state_dict ();
238+
239+ if (final_state_dict.size () != 2 ) {
240+ spdlog::error (" State dict size mismatch after load_state_dict. Expected 2, got {}" , final_state_dict.size ());
241+ return false ;
242+ }
243+
244+ if (final_state_dict.at (" weight" )->shape () != std::vector<size_t >({2 , 4 })) {
245+ spdlog::error (" Loaded weight shape mismatch" );
246+ return false ;
247+ }
248+ if (final_state_dict.at (" bias" )->shape () != std::vector<size_t >({2 })) {
249+ spdlog::error (" Loaded bias shape mismatch" );
250+ return false ;
251+ }
252+
253+ spdlog::debug (" load_state_dict verification passed - shapes are correct" );
254+ spdlog::info (" Skipping value comparison for now - test focuses on load mechanism" );
255+
256+ // Test 3: Test with Linear module to verify field synchronization
257+ spdlog::info (" Test 3: Testing load_state_dict with Linear module (field synchronization)" );
258+ infinicore::nn::Linear linear_module (4 , 2 , true , infinicore::Device ());
259+
260+ // Create known parameter values - just use ones for simplicity
261+ auto linear_weight = infinicore::Tensor::ones ({2 , 4 }, infinicore::DataType::F32, infinicore::Device ());
262+ auto linear_bias = infinicore::Tensor::ones ({2 }, infinicore::DataType::F32, infinicore::Device ());
263+
264+ std::unordered_map<std::string, infinicore::Tensor> linear_state_dict;
265+ linear_state_dict.emplace (" weight" , linear_weight);
266+ linear_state_dict.emplace (" bias" , linear_bias);
267+
268+ // Load state dict into Linear module
269+ linear_module.load_state_dict (linear_state_dict);
270+
271+ // Verify shapes using both state_dict() and direct field access
272+ auto loaded_via_state_dict_weight = linear_module.state_dict ().at (" weight" );
273+ auto loaded_via_field_weight = linear_module.weight ();
274+ auto loaded_via_field_bias = linear_module.bias ();
275+
276+ if (loaded_via_state_dict_weight->shape () != std::vector<size_t >({2 , 4 })) {
277+ spdlog::error (" Linear weight shape mismatch via state_dict" );
278+ return false ;
279+ }
280+ if (loaded_via_field_weight->shape () != std::vector<size_t >({2 , 4 })) {
281+ spdlog::error (" Linear weight field shape mismatch" );
282+ return false ;
283+ }
284+ if (loaded_via_field_bias->shape () != std::vector<size_t >({2 })) {
285+ spdlog::error (" Linear bias field shape mismatch" );
286+ return false ;
287+ }
288+
289+ spdlog::debug (" Linear module load_state_dict verification passed - field shapes synchronized" );
290+ spdlog::info (" Skipping value comparison - test focuses on field synchronization mechanism" );
291+
292+ // Test 4: Deep nesting (2-level hierarchy)
293+ spdlog::info (" Test 4: Testing load_state_dict with 2-level deep nesting" );
294+
295+ // Create parent -> child -> grandchild hierarchy
296+ MockLinearModule deep_parent (10 , 8 , infinicore::Device ());
297+ auto deep_child = std::make_shared<MockLinearModule>(8 , 6 , infinicore::Device ());
298+ auto deep_grandchild = std::make_shared<MockLinearModule>(6 , 4 , infinicore::Device ());
299+
300+ // Build hierarchy: parent -> layer1 -> sublayer
301+ deep_child->add_module (" sublayer" , deep_grandchild);
302+ deep_parent.add_module (" layer1" , deep_child);
303+
304+ // Verify initial state dict includes all 2-level hierarchical parameters
305+ auto deep_initial_state = deep_parent.state_dict ();
306+ spdlog::debug (" Deep hierarchical state dict has {} parameters" , deep_initial_state.size ());
307+
308+ // Expected parameters:
309+ // parent: weight, bias (2)
310+ // layer1: layer1.weight, layer1.bias (2)
311+ // sublayer: layer1.sublayer.weight, layer1.sublayer.bias (2)
312+ // Total: 6 parameters
313+ if (deep_initial_state.size () < 6 ) {
314+ spdlog::error (" Deep hierarchy state dict size mismatch. Expected at least 6, got {}" ,
315+ deep_initial_state.size ());
316+ return false ;
317+ }
318+
319+ // Verify 2-level parameter names exist
320+ bool has_sublayer_weight = deep_initial_state.find (" layer1.sublayer.weight" ) != deep_initial_state.end ();
321+ bool has_sublayer_bias = deep_initial_state.find (" layer1.sublayer.bias" ) != deep_initial_state.end ();
322+
323+ if (!has_sublayer_weight || !has_sublayer_bias) {
324+ spdlog::error (" 2-level nested parameters missing from state dict" );
325+ return false ;
326+ }
327+ spdlog::debug (" All 2-level hierarchical parameter names verified" );
328+
329+ // Create state dict for 2-level hierarchy with all 1.0 values
330+ std::unordered_map<std::string, infinicore::Tensor> deep_state_dict;
331+ deep_state_dict.emplace (" weight" , infinicore::Tensor::ones ({8 , 10 }, infinicore::DataType::F32, infinicore::Device ()));
332+ deep_state_dict.emplace (" bias" , infinicore::Tensor::ones ({8 }, infinicore::DataType::F32, infinicore::Device ()));
333+ deep_state_dict.emplace (" layer1.weight" , infinicore::Tensor::ones ({6 , 8 }, infinicore::DataType::F32, infinicore::Device ()));
334+ deep_state_dict.emplace (" layer1.bias" , infinicore::Tensor::ones ({6 }, infinicore::DataType::F32, infinicore::Device ()));
335+ deep_state_dict.emplace (" layer1.sublayer.weight" , infinicore::Tensor::ones ({4 , 6 }, infinicore::DataType::F32, infinicore::Device ()));
336+ deep_state_dict.emplace (" layer1.sublayer.bias" , infinicore::Tensor::ones ({4 }, infinicore::DataType::F32, infinicore::Device ()));
337+
338+ // Load the deep hierarchical state dict
339+ deep_parent.load_state_dict (deep_state_dict);
340+ spdlog::debug (" Successfully loaded 2-level deep hierarchical state dict" );
341+
342+ // Verify all parameters were loaded correctly
343+ auto deep_loaded_state = deep_parent.state_dict ();
344+
345+ // Verify shapes at all levels
346+ if (deep_loaded_state.at (" weight" )->shape () != std::vector<size_t >({8 , 10 })) {
347+ spdlog::error (" Deep parent weight shape mismatch" );
348+ return false ;
349+ }
350+ if (deep_loaded_state.at (" layer1.weight" )->shape () != std::vector<size_t >({6 , 8 })) {
351+ spdlog::error (" Deep layer1 weight shape mismatch" );
352+ return false ;
353+ }
354+ if (deep_loaded_state.at (" layer1.sublayer.weight" )->shape () != std::vector<size_t >({4 , 6 })) {
355+ spdlog::error (" Deep sublayer weight shape mismatch" );
203356 return false ;
204357 }
358+ spdlog::debug (" All 2-level deep parameter shapes verified" );
205359
206- std::cout << " Load state dict test passed" << std::endl;
360+ // Verify actual weight loading correctness by checking that loaded parameters
361+ // match what we provided in the state dict (use the original tensors)
362+ spdlog::info (" Verifying weight loading correctness by direct comparison" );
363+
364+ // Get the tensors we loaded from the state dict
365+ auto loaded_parent_weight = deep_loaded_state.at (" weight" );
366+ auto loaded_layer1_weight = deep_loaded_state.at (" layer1.weight" );
367+ auto loaded_sublayer_weight = deep_loaded_state.at (" layer1.sublayer.weight" );
368+
369+ // Compare with the original tensors we put in the state dict
370+ if (!tensorsAllClose (loaded_parent_weight, deep_state_dict.at (" weight" ), 1e-5 , 1e-5 )) {
371+ spdlog::error (" Deep parent weight not preserved after loading" );
372+ return false ;
373+ }
374+ if (!tensorsAllClose (loaded_layer1_weight, deep_state_dict.at (" layer1.weight" ), 1e-5 , 1e-5 )) {
375+ spdlog::error (" Deep layer1 weight not preserved after loading" );
376+ return false ;
377+ }
378+ if (!tensorsAllClose (loaded_sublayer_weight, deep_state_dict.at (" layer1.sublayer.weight" ), 1e-5 , 1e-5 )) {
379+ spdlog::error (" Deep sublayer weight not preserved after loading" );
380+ return false ;
381+ }
382+
383+ spdlog::info (" Weight loading correctness verified - loaded values match input state dict" );
384+ spdlog::info (" 2-level deep hierarchy load_state_dict verification passed" );
385+
386+ spdlog::info (" All load_state_dict tests passed (including deep hierarchy)" );
207387 return true ;
208388 } catch (const std::exception &e) {
209- std::cout << " Exception in testLoadStateDict: " << e.what () << std::endl ;
389+ spdlog::error ( " Exception in testLoadStateDict: {} " , e.what ()) ;
210390 return false ;
211391 }
212392 });
@@ -432,6 +612,25 @@ TestResult NNModuleTest::testModuleLinear() {
432612 }
433613 spdlog::debug (" Linear computation without bias passed. Input shape: {{1, 16}}, Output shape: {{1, 3}}" );
434614
615+ // Test load_state_dict for m2 (without bias)
616+ spdlog::info (" Testing load_state_dict on Linear without bias" );
617+ auto m2_load_weight = infinicore::Tensor::ones ({3 , 16 }, infinicore::DataType::F32, infinicore::Device ());
618+ std::unordered_map<std::string, infinicore::Tensor> m2_state_dict;
619+ m2_state_dict.emplace (" weight" , m2_load_weight);
620+ // Note: no bias parameter
621+ m2.load_state_dict (m2_state_dict);
622+
623+ // Verify via state_dict() and direct access
624+ if (!tensorsAllClose (m2.state_dict ().at (" weight" ), m2_load_weight, 1e-5 , 1e-5 )) {
625+ spdlog::error (" m2 weight not loaded correctly" );
626+ return false ;
627+ }
628+ if (!tensorsAllClose (m2.weight (), m2_load_weight, 1e-5 , 1e-5 )) {
629+ spdlog::error (" m2 weight field not synchronized" );
630+ return false ;
631+ }
632+ spdlog::debug (" m2 load_state_dict verified - weight loaded correctly (no bias)" );
633+
435634 // Test batch processing
436635 spdlog::info (" Testing batch linear computation (batch size 3)" );
437636 auto input3 = infinicore::Tensor::ones ({3 , 8 }, infinicore::DataType::F32, infinicore::Device ());
@@ -455,6 +654,30 @@ TestResult NNModuleTest::testModuleLinear() {
455654 return false ;
456655 }
457656
657+ // Test load_state_dict for m1 (with bias)
658+ spdlog::info (" Testing load_state_dict on Linear with bias" );
659+ auto m1_load_weight = infinicore::Tensor::ones ({4 , 8 }, infinicore::DataType::F32, infinicore::Device ());
660+ auto m1_load_bias = infinicore::Tensor::ones ({4 }, infinicore::DataType::F32, infinicore::Device ());
661+ std::unordered_map<std::string, infinicore::Tensor> m1_state_dict;
662+ m1_state_dict.emplace (" weight" , m1_load_weight);
663+ m1_state_dict.emplace (" bias" , m1_load_bias);
664+ m1.load_state_dict (m1_state_dict);
665+
666+ // Verify via state_dict() and direct access
667+ if (!tensorsAllClose (m1.state_dict ().at (" weight" ), m1_load_weight, 1e-5 , 1e-5 )) {
668+ spdlog::error (" m1 weight not loaded correctly" );
669+ return false ;
670+ }
671+ if (!tensorsAllClose (m1.weight (), m1_load_weight, 1e-5 , 1e-5 )) {
672+ spdlog::error (" m1 weight field not synchronized" );
673+ return false ;
674+ }
675+ if (!tensorsAllClose (m1.bias (), m1_load_bias, 1e-5 , 1e-5 )) {
676+ spdlog::error (" m1 bias field not synchronized" );
677+ return false ;
678+ }
679+ spdlog::debug (" m1 load_state_dict verified - parameters and fields synchronized" );
680+
458681 // Test extra_repr
459682 std::string repr = m1.extra_repr ();
460683 spdlog::debug (" Linear module representation: {}" , repr);
@@ -575,7 +798,7 @@ TestResult NNModuleTest::testModuleLinear() {
575798 spdlog::debug (" Basic forward computation correctness test passed - both implementations produce identical results" );
576799 spdlog::debug (" Basic InfiniCore output shape: {{2, 4}}, Basic naive output shape: {{2, 4}}" );
577800
578- spdlog::info (" Linear module test with computation verification passed " );
801+ spdlog::info (" All Linear module tests passed ( with/without bias, load_state_dict, computation verification) " );
579802 return true ;
580803 } catch (const std::exception &e) {
581804 spdlog::error (" Exception in testModuleLinear: {}" , e.what ());
0 commit comments