@@ -394,7 +394,7 @@ TestResult NNModuleTest::testModuleLinear() {
394394 try {
395395 // Test with bias
396396 spdlog::info (" Testing Linear module with bias (8->4 features)" );
397- infinicore::nn::Linear m1 (8 , 4 , true , infinicore::Device () );
397+ infinicore::nn::Linear m1 (8 , 4 , true );
398398 auto sd1 = m1.state_dict ();
399399 if (sd1.find (" weight" ) == sd1.end ()) {
400400 spdlog::error (" weight missing" );
@@ -440,7 +440,7 @@ TestResult NNModuleTest::testModuleLinear() {
440440
441441 // Test without bias
442442 spdlog::info (" Testing Linear module without bias (16->3 features)" );
443- infinicore::nn::Linear m2 (16 , 3 , false , infinicore::Device () );
443+ infinicore::nn::Linear m2 (16 , 3 , false );
444444 auto sd2 = m2.state_dict ();
445445 if (sd2.find (" weight" ) == sd2.end ()) {
446446 spdlog::error (" weight missing (no-bias)" );
@@ -834,7 +834,7 @@ TestResult NNModuleTest::testModuleRMSNorm() {
834834
835835 // Test 1: Basic RMSNorm creation
836836 spdlog::info (" Test 1: Basic RMSNorm creation (hidden_size=768)" );
837- infinicore::nn::RMSNorm norm1 (768 , 1e-6 , infinicore::Device () );
837+ infinicore::nn::RMSNorm norm1 (768 );
838838
839839 auto state1 = norm1.state_dict ();
840840 if (state1.find (" weight" ) == state1.end ()) {
@@ -925,8 +925,8 @@ TestResult NNModuleTest::testModuleRMSNorm() {
925925
926926 // Test 7: Different hidden sizes
927927 spdlog::info (" Test 7: Testing different hidden sizes" );
928- infinicore::nn::RMSNorm norm_small (128 , 1e-5 , infinicore::Device () );
929- infinicore::nn::RMSNorm norm_large (4096 , 1e-6 , infinicore::Device () );
928+ infinicore::nn::RMSNorm norm_small (128 , 1e-5 );
929+ infinicore::nn::RMSNorm norm_large (4096 );
930930
931931 auto input_small = infinicore::Tensor::ones ({2 , 128 }, infinicore::DataType::F32 , infinicore::Device ());
932932 auto output_small = norm_small.forward (input_small);
@@ -956,7 +956,130 @@ TestResult NNModuleTest::testModuleRMSNorm() {
956956 });
957957}
958958
959- // Test 8: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
959+ // Test 8: Dtype assertion test
960+ TestResult NNModuleTest::testDtypeAssertion () {
961+ return measureTime (" DtypeAssertionTest" , [this ]() {
962+ try {
963+ spdlog::info (" Testing dtype assertions when loading parameters" );
964+
965+ // Test 1: Successful load with matching dtype
966+ spdlog::info (" Test 1: Successful load with matching dtype (F32)" );
967+ infinicore::nn::Linear linear1 (8 , 4 , true );
968+ auto matching_weight = infinicore::Tensor::ones ({4 , 8 }, infinicore::DataType::F32 , infinicore::Device ());
969+ auto matching_bias = infinicore::Tensor::ones ({4 }, infinicore::DataType::F32 , infinicore::Device ());
970+
971+ std::unordered_map<std::string, infinicore::Tensor> matching_state;
972+ matching_state.emplace (" weight" , matching_weight);
973+ matching_state.emplace (" bias" , matching_bias);
974+
975+ // This should succeed without throwing
976+ linear1.load_state_dict (matching_state);
977+ spdlog::debug (" ✓ Matching dtype load succeeded" );
978+
979+ // Test 2: Failed load with mismatched dtype (load_parameter)
980+ spdlog::info (" Test 2: Failed load_parameter with mismatched dtype" );
981+ infinicore::nn::Linear linear2 (8 , 4 , true );
982+ auto mismatched_weight = infinicore::Tensor::ones ({4 , 8 }, infinicore::DataType::BF16 , infinicore::Device ());
983+
984+ bool exception_thrown = false ;
985+ try {
986+ linear2.load_parameter (" weight" , mismatched_weight);
987+ } catch (const std::runtime_error &e) {
988+ exception_thrown = true ;
989+ std::string error_msg = e.what ();
990+ if (error_msg.find (" dtype mismatch" ) == std::string::npos) {
991+ spdlog::error (" Exception message doesn't contain 'dtype mismatch'" );
992+ return false ;
993+ }
994+ spdlog::debug (" ✓ Mismatched dtype exception caught: {}" , error_msg);
995+ }
996+
997+ if (!exception_thrown) {
998+ spdlog::error (" Expected exception for dtype mismatch in load_parameter" );
999+ return false ;
1000+ }
1001+
1002+ // Test 3: Failed load with mismatched dtype (load_state_dict)
1003+ spdlog::info (" Test 3: Failed load_state_dict with mismatched dtype" );
1004+ infinicore::nn::Embedding embedding1 (100 , 64 );
1005+ auto mismatched_embed_weight = infinicore::Tensor::ones ({100 , 64 }, infinicore::DataType::BF16 , infinicore::Device ());
1006+
1007+ std::unordered_map<std::string, infinicore::Tensor> mismatched_state;
1008+ mismatched_state.emplace (" weight" , mismatched_embed_weight);
1009+
1010+ exception_thrown = false ;
1011+ try {
1012+ embedding1.load_state_dict (mismatched_state);
1013+ } catch (const std::runtime_error &e) {
1014+ exception_thrown = true ;
1015+ std::string error_msg = e.what ();
1016+ if (error_msg.find (" dtype mismatch" ) == std::string::npos) {
1017+ spdlog::error (" Exception message doesn't contain 'dtype mismatch'" );
1018+ return false ;
1019+ }
1020+ if (error_msg.find (" weight" ) == std::string::npos) {
1021+ spdlog::error (" Exception message doesn't contain parameter name 'weight'" );
1022+ return false ;
1023+ }
1024+ spdlog::debug (" ✓ Mismatched dtype exception caught: {}" , error_msg);
1025+ }
1026+
1027+ if (!exception_thrown) {
1028+ spdlog::error (" Expected exception for dtype mismatch in load_state_dict" );
1029+ return false ;
1030+ }
1031+
1032+ // Test 4: Failed load with mismatched dtype (RMSNorm)
1033+ spdlog::info (" Test 4: Failed load_state_dict with mismatched dtype (RMSNorm)" );
1034+ infinicore::nn::RMSNorm norm1 (768 );
1035+ auto mismatched_norm_weight = infinicore::Tensor::ones ({768 }, infinicore::DataType::BF16 , infinicore::Device ());
1036+
1037+ std::unordered_map<std::string, infinicore::Tensor> mismatched_norm_state;
1038+ mismatched_norm_state.emplace (" weight" , mismatched_norm_weight);
1039+
1040+ exception_thrown = false ;
1041+ try {
1042+ norm1.load_state_dict (mismatched_norm_state);
1043+ } catch (const std::runtime_error &e) {
1044+ exception_thrown = true ;
1045+ std::string error_msg = e.what ();
1046+ if (error_msg.find (" dtype mismatch" ) == std::string::npos) {
1047+ spdlog::error (" Exception message doesn't contain 'dtype mismatch'" );
1048+ return false ;
1049+ }
1050+ spdlog::debug (" ✓ Mismatched dtype exception caught for RMSNorm: {}" , error_msg);
1051+ }
1052+
1053+ if (!exception_thrown) {
1054+ spdlog::error (" Expected exception for dtype mismatch in RMSNorm load_state_dict" );
1055+ return false ;
1056+ }
1057+
1058+ // Test 5: Successful load with different module dtypes
1059+ spdlog::info (" Test 5: Successful load with BF16 dtype (module created with BF16)" );
1060+ infinicore::nn::Linear linear3 (8 , 4 , true , infinicore::DataType::BF16 );
1061+ auto bf16_weight = infinicore::Tensor::ones ({4 , 8 }, infinicore::DataType::BF16 , infinicore::Device ());
1062+ auto bf16_bias = infinicore::Tensor::ones ({4 }, infinicore::DataType::BF16 , infinicore::Device ());
1063+
1064+ std::unordered_map<std::string, infinicore::Tensor> bf16_state;
1065+ bf16_state.emplace (" weight" , bf16_weight);
1066+ bf16_state.emplace (" bias" , bf16_bias);
1067+
1068+ // This should succeed
1069+ linear3.load_state_dict (bf16_state);
1070+ spdlog::debug (" ✓ BF16 dtype load succeeded" );
1071+
1072+ spdlog::info (" All dtype assertion tests passed!" );
1073+ return true ;
1074+
1075+ } catch (const std::exception &e) {
1076+ spdlog::error (" Exception in testDtypeAssertion: {}" , e.what ());
1077+ return false ;
1078+ }
1079+ });
1080+ }
1081+
1082+ // Test 9: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
9601083TestResult NNModuleTest::testTinyLlamaConstruction () {
9611084 return measureTime (" TinyLlamaModelTest" , [this ]() {
9621085 try {
@@ -1007,10 +1130,10 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
10071130 INFINICORE_NN_MODULE (infinicore::nn::Linear, o_proj);
10081131
10091132 SelfAttn (size_t hidden_size, size_t kv_dim, const infinicore::Device &device) {
1010- INFINICORE_NN_MODULE_INIT (q_proj, hidden_size, hidden_size, false , device);
1011- INFINICORE_NN_MODULE_INIT (k_proj, hidden_size, kv_dim, false , device);
1012- INFINICORE_NN_MODULE_INIT (v_proj, hidden_size, kv_dim, false , device);
1013- INFINICORE_NN_MODULE_INIT (o_proj, hidden_size, hidden_size, false , device);
1133+ INFINICORE_NN_MODULE_INIT (q_proj, hidden_size, hidden_size, false , infinicore::DataType:: F32 , device);
1134+ INFINICORE_NN_MODULE_INIT (k_proj, hidden_size, kv_dim, false , infinicore::DataType:: F32 , device);
1135+ INFINICORE_NN_MODULE_INIT (v_proj, hidden_size, kv_dim, false , infinicore::DataType:: F32 , device);
1136+ INFINICORE_NN_MODULE_INIT (o_proj, hidden_size, hidden_size, false , infinicore::DataType:: F32 , device);
10141137 }
10151138 };
10161139
@@ -1021,9 +1144,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
10211144 INFINICORE_NN_MODULE (infinicore::nn::Linear, down_proj);
10221145
10231146 MLP (size_t hidden_size, size_t intermediate_size, const infinicore::Device &device) {
1024- INFINICORE_NN_MODULE_INIT (gate_proj, hidden_size, intermediate_size, false , device);
1025- INFINICORE_NN_MODULE_INIT (up_proj, hidden_size, intermediate_size, false , device);
1026- INFINICORE_NN_MODULE_INIT (down_proj, intermediate_size, hidden_size, false , device);
1147+ INFINICORE_NN_MODULE_INIT (gate_proj, hidden_size, intermediate_size, false , infinicore::DataType:: F32 , device);
1148+ INFINICORE_NN_MODULE_INIT (up_proj, hidden_size, intermediate_size, false , infinicore::DataType:: F32 , device);
1149+ INFINICORE_NN_MODULE_INIT (down_proj, intermediate_size, hidden_size, false , infinicore::DataType:: F32 , device);
10271150 }
10281151 };
10291152
@@ -1036,9 +1159,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
10361159
10371160 Block (const TinyLlamaConfig &cfg, const infinicore::Device &device) {
10381161 size_t kv_dim = cfg.hidden_size * cfg.num_key_value_heads / cfg.num_attention_heads ;
1039- INFINICORE_NN_MODULE_INIT (input_layernorm, cfg.hidden_size , cfg.rms_norm_eps , device);
1162+ INFINICORE_NN_MODULE_INIT (input_layernorm, cfg.hidden_size , cfg.rms_norm_eps , infinicore::DataType:: F32 , device);
10401163 INFINICORE_NN_MODULE_INIT (self_attn, cfg.hidden_size , kv_dim, device);
1041- INFINICORE_NN_MODULE_INIT (post_attention_layernorm, cfg.hidden_size , cfg.rms_norm_eps , device);
1164+ INFINICORE_NN_MODULE_INIT (post_attention_layernorm, cfg.hidden_size , cfg.rms_norm_eps , infinicore::DataType:: F32 , device);
10421165 INFINICORE_NN_MODULE_INIT (mlp, cfg.hidden_size , cfg.intermediate_size , device);
10431166 }
10441167 };
@@ -1051,7 +1174,7 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
10511174 TinyLlamaModel (const TinyLlamaConfig &config, const infinicore::Device &device) {
10521175 INFINICORE_NN_MODULE_INIT (embed_tokens, config.vocab_size , config.hidden_size , std::nullopt , infinicore::DataType::F32 , device);
10531176 INFINICORE_NN_MODULE_VEC_INIT (layers, config.num_hidden_layers , Block, config, device);
1054- INFINICORE_NN_MODULE_INIT (norm, config.hidden_size , config.rms_norm_eps , device);
1177+ INFINICORE_NN_MODULE_INIT (norm, config.hidden_size , config.rms_norm_eps , infinicore::DataType:: F32 , device);
10551178 }
10561179 };
10571180
@@ -1259,6 +1382,7 @@ TestResult NNModuleTest::run() {
12591382 results.push_back (testModuleLinear ()); // Linear module comprehensive test
12601383 results.push_back (testModuleEmbedding ()); // Embedding module test
12611384 results.push_back (testModuleRMSNorm ()); // RMSNorm module test
1385+ results.push_back (testDtypeAssertion ()); // Dtype assertion test
12621386 results.push_back (testTinyLlamaConstruction ()); // Comprehensive: TinyLlama model test
12631387
12641388 // Check if all tests passed
0 commit comments