@@ -1287,141 +1287,6 @@ TestResult NNModuleTest::testModuleRoPE() {
12871287 });
12881288}
12891289
1290- // Test 7.75: SwiGLU module test
1291- TestResult NNModuleTest::testModuleSwiGLU () {
1292- return measureTime (" ModuleSwiGLU" , [this ]() {
1293- try {
1294- spdlog::info (" ==========================================" );
1295- spdlog::info (" Testing SwiGLU module implementation" );
1296- spdlog::info (" ==========================================" );
1297-
1298- // Test 1: Basic SwiGLU creation
1299- spdlog::info (" Test 1: Basic SwiGLU creation" );
1300- infinicore::nn::SwiGLU swiglu1;
1301-
1302- auto state1 = swiglu1.state_dict ();
1303- if (state1.size () != 0 ) {
1304- spdlog::error (" SwiGLU should have no parameters, got {}" , state1.size ());
1305- return false ;
1306- }
1307-
1308- spdlog::debug (" Basic SwiGLU creation passed" );
1309-
1310- // Test 2: Forward pass - 2D input [batch, hidden]
1311- spdlog::info (" Test 2: Forward pass with 2D input [batch, hidden]" );
1312- auto up_2d = infinicore::Tensor::ones ({4 , 128 }, infinicore::DataType::F32, infinicore::Device ());
1313- auto gate_2d = infinicore::Tensor::ones ({4 , 128 }, infinicore::DataType::F32, infinicore::Device ());
1314- auto output_2d = swiglu1.forward (up_2d, gate_2d);
1315-
1316- if (output_2d->shape () != std::vector<size_t >({4 , 128 })) {
1317- spdlog::error (" 2D output shape mismatch. Expected {{4, 128}}" );
1318- return false ;
1319- }
1320-
1321- spdlog::debug (" 2D forward pass passed. Output shape: {{4, 128}}" );
1322-
1323- // Test 3: Forward pass - 3D input [batch, seq_len, hidden]
1324- spdlog::info (" Test 3: Forward pass with 3D input [batch, seq_len, hidden]" );
1325- auto up_3d = infinicore::Tensor::ones ({2 , 10 , 128 }, infinicore::DataType::F32, infinicore::Device ());
1326- auto gate_3d = infinicore::Tensor::ones ({2 , 10 , 128 }, infinicore::DataType::F32, infinicore::Device ());
1327- auto output_3d = swiglu1.forward (up_3d, gate_3d);
1328-
1329- if (output_3d->shape () != std::vector<size_t >({2 , 10 , 128 })) {
1330- spdlog::error (" 3D output shape mismatch. Expected {{2, 10, 128}}" );
1331- return false ;
1332- }
1333-
1334- spdlog::debug (" 3D forward pass passed. Output shape: {{2, 10, 128}}" );
1335-
1336- // Test 4: Different tensor values
1337- spdlog::info (" Test 4: Testing with different tensor values" );
1338- auto up_test = infinicore::Tensor::ones ({2 , 64 }, infinicore::DataType::F32, infinicore::Device ());
1339- auto gate_test = infinicore::Tensor::zeros ({2 , 64 }, infinicore::DataType::F32, infinicore::Device ());
1340- auto output_test = swiglu1.forward (up_test, gate_test);
1341-
1342- if (output_test->shape () != up_test->shape ()) {
1343- spdlog::error (" Output shape doesn't match input shape" );
1344- return false ;
1345- }
1346-
1347- spdlog::debug (" Different tensor values test passed" );
1348-
1349- // Test 5: Shape mismatch validation (op layer handles this)
1350- spdlog::info (" Test 5: Testing shape mismatch validation" );
1351- auto up_shape = infinicore::Tensor::ones ({4 , 128 }, infinicore::DataType::F32, infinicore::Device ());
1352- auto gate_shape = infinicore::Tensor::ones ({4 , 64 }, infinicore::DataType::F32, infinicore::Device ());
1353-
1354- try {
1355- swiglu1.forward (up_shape, gate_shape);
1356- spdlog::error (" Should have thrown exception for shape mismatch" );
1357- return false ;
1358- } catch (const std::exception &e) {
1359- spdlog::debug (" Correctly caught exception for shape mismatch (handled by op layer): {}" , e.what ());
1360- } catch (...) {
1361- spdlog::error (" Caught unexpected exception type" );
1362- return false ;
1363- }
1364-
1365- spdlog::debug (" Shape mismatch validation test passed" );
1366-
1367- // Test 6: Dtype mismatch validation (op layer handles this)
1368- spdlog::info (" Test 6: Testing dtype mismatch validation" );
1369- auto up_dtype = infinicore::Tensor::ones ({4 , 128 }, infinicore::DataType::F32, infinicore::Device ());
1370- auto gate_dtype = infinicore::Tensor::ones ({4 , 128 }, infinicore::DataType::F16, infinicore::Device ());
1371-
1372- try {
1373- swiglu1.forward (up_dtype, gate_dtype);
1374- } catch (...) {
1375- spdlog::error (" Caught unexpected exception type" );
1376- return false ;
1377- }
1378-
1379- spdlog::debug (" Dtype mismatch validation test passed" );
1380-
1381- // Test 7: extra_repr
1382- spdlog::info (" Test 7: Testing extra_repr" );
1383- std::string repr = swiglu1.extra_repr ();
1384- spdlog::debug (" SwiGLU repr: {}" , repr);
1385-
1386- if (repr.find (" SwiGLU" ) == std::string::npos) {
1387- spdlog::error (" extra_repr should contain SwiGLU" );
1388- return false ;
1389- }
1390-
1391- spdlog::debug (" extra_repr test passed" );
1392-
1393- // Test 8: Different hidden sizes
1394- spdlog::info (" Test 8: Testing different hidden sizes" );
1395- auto up_small = infinicore::Tensor::ones ({2 , 64 }, infinicore::DataType::F32, infinicore::Device ());
1396- auto gate_small = infinicore::Tensor::ones ({2 , 64 }, infinicore::DataType::F32, infinicore::Device ());
1397- auto output_small = swiglu1.forward (up_small, gate_small);
1398-
1399- auto up_large = infinicore::Tensor::ones ({2 , 4096 }, infinicore::DataType::F32, infinicore::Device ());
1400- auto gate_large = infinicore::Tensor::ones ({2 , 4096 }, infinicore::DataType::F32, infinicore::Device ());
1401- auto output_large = swiglu1.forward (up_large, gate_large);
1402-
1403- if (output_small->shape () != std::vector<size_t >({2 , 64 })) {
1404- spdlog::error (" Small SwiGLU output shape mismatch" );
1405- return false ;
1406- }
1407-
1408- if (output_large->shape () != std::vector<size_t >({2 , 4096 })) {
1409- spdlog::error (" Large SwiGLU output shape mismatch" );
1410- return false ;
1411- }
1412-
1413- spdlog::debug (" Different hidden sizes test passed" );
1414-
1415- spdlog::info (" All SwiGLU module tests passed!" );
1416- return true ;
1417-
1418- } catch (const std::exception &e) {
1419- spdlog::error (" Exception in testModuleSwiGLU: {}" , e.what ());
1420- return false ;
1421- }
1422- });
1423- }
1424-
14251290// Test 8: Dtype assertion test
14261291TestResult NNModuleTest::testDtypeAssertion () {
14271292 return measureTime (" DtypeAssertionTest" , [this ]() {
@@ -1851,7 +1716,6 @@ TestResult NNModuleTest::run() {
18511716 results.push_back (testModuleEmbedding ()); // Embedding module test
18521717 results.push_back (testModuleRMSNorm ()); // RMSNorm module test
18531718 results.push_back (testModuleRoPE ()); // RoPE module test
1854- results.push_back (testModuleSwiGLU ()); // SwiGLU module test
18551719 results.push_back (testDtypeAssertion ()); // Dtype assertion test
18561720 results.push_back (testTinyLlamaConstruction ()); // Comprehensive: TinyLlama model test
18571721
0 commit comments