Skip to content

Commit 51711b5

Browse files
committed
remove nn::functional
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 8d83173 commit 51711b5

6 files changed

Lines changed: 0 additions & 258 deletions

File tree

include/infinicore/nn/functional.hpp

Lines changed: 0 additions & 45 deletions
This file was deleted.

include/infinicore/nn/swiglu.hpp

Lines changed: 0 additions & 35 deletions
This file was deleted.

src/infinicore-test/test_nn_module.cc

Lines changed: 0 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,141 +1287,6 @@ TestResult NNModuleTest::testModuleRoPE() {
12871287
});
12881288
}
12891289

1290-
// Test 7.75: SwiGLU module test
1291-
TestResult NNModuleTest::testModuleSwiGLU() {
1292-
return measureTime("ModuleSwiGLU", [this]() {
1293-
try {
1294-
spdlog::info("==========================================");
1295-
spdlog::info("Testing SwiGLU module implementation");
1296-
spdlog::info("==========================================");
1297-
1298-
// Test 1: Basic SwiGLU creation
1299-
spdlog::info("Test 1: Basic SwiGLU creation");
1300-
infinicore::nn::SwiGLU swiglu1;
1301-
1302-
auto state1 = swiglu1.state_dict();
1303-
if (state1.size() != 0) {
1304-
spdlog::error("SwiGLU should have no parameters, got {}", state1.size());
1305-
return false;
1306-
}
1307-
1308-
spdlog::debug("Basic SwiGLU creation passed");
1309-
1310-
// Test 2: Forward pass - 2D input [batch, hidden]
1311-
spdlog::info("Test 2: Forward pass with 2D input [batch, hidden]");
1312-
auto up_2d = infinicore::Tensor::ones({4, 128}, infinicore::DataType::F32, infinicore::Device());
1313-
auto gate_2d = infinicore::Tensor::ones({4, 128}, infinicore::DataType::F32, infinicore::Device());
1314-
auto output_2d = swiglu1.forward(up_2d, gate_2d);
1315-
1316-
if (output_2d->shape() != std::vector<size_t>({4, 128})) {
1317-
spdlog::error("2D output shape mismatch. Expected {{4, 128}}");
1318-
return false;
1319-
}
1320-
1321-
spdlog::debug("2D forward pass passed. Output shape: {{4, 128}}");
1322-
1323-
// Test 3: Forward pass - 3D input [batch, seq_len, hidden]
1324-
spdlog::info("Test 3: Forward pass with 3D input [batch, seq_len, hidden]");
1325-
auto up_3d = infinicore::Tensor::ones({2, 10, 128}, infinicore::DataType::F32, infinicore::Device());
1326-
auto gate_3d = infinicore::Tensor::ones({2, 10, 128}, infinicore::DataType::F32, infinicore::Device());
1327-
auto output_3d = swiglu1.forward(up_3d, gate_3d);
1328-
1329-
if (output_3d->shape() != std::vector<size_t>({2, 10, 128})) {
1330-
spdlog::error("3D output shape mismatch. Expected {{2, 10, 128}}");
1331-
return false;
1332-
}
1333-
1334-
spdlog::debug("3D forward pass passed. Output shape: {{2, 10, 128}}");
1335-
1336-
// Test 4: Different tensor values
1337-
spdlog::info("Test 4: Testing with different tensor values");
1338-
auto up_test = infinicore::Tensor::ones({2, 64}, infinicore::DataType::F32, infinicore::Device());
1339-
auto gate_test = infinicore::Tensor::zeros({2, 64}, infinicore::DataType::F32, infinicore::Device());
1340-
auto output_test = swiglu1.forward(up_test, gate_test);
1341-
1342-
if (output_test->shape() != up_test->shape()) {
1343-
spdlog::error("Output shape doesn't match input shape");
1344-
return false;
1345-
}
1346-
1347-
spdlog::debug("Different tensor values test passed");
1348-
1349-
// Test 5: Shape mismatch validation (op layer handles this)
1350-
spdlog::info("Test 5: Testing shape mismatch validation");
1351-
auto up_shape = infinicore::Tensor::ones({4, 128}, infinicore::DataType::F32, infinicore::Device());
1352-
auto gate_shape = infinicore::Tensor::ones({4, 64}, infinicore::DataType::F32, infinicore::Device());
1353-
1354-
try {
1355-
swiglu1.forward(up_shape, gate_shape);
1356-
spdlog::error("Should have thrown exception for shape mismatch");
1357-
return false;
1358-
} catch (const std::exception &e) {
1359-
spdlog::debug("Correctly caught exception for shape mismatch (handled by op layer): {}", e.what());
1360-
} catch (...) {
1361-
spdlog::error("Caught unexpected exception type");
1362-
return false;
1363-
}
1364-
1365-
spdlog::debug("Shape mismatch validation test passed");
1366-
1367-
// Test 6: Dtype mismatch validation (op layer handles this)
1368-
spdlog::info("Test 6: Testing dtype mismatch validation");
1369-
auto up_dtype = infinicore::Tensor::ones({4, 128}, infinicore::DataType::F32, infinicore::Device());
1370-
auto gate_dtype = infinicore::Tensor::ones({4, 128}, infinicore::DataType::F16, infinicore::Device());
1371-
1372-
try {
1373-
swiglu1.forward(up_dtype, gate_dtype);
1374-
} catch (...) {
1375-
spdlog::error("Caught unexpected exception type");
1376-
return false;
1377-
}
1378-
1379-
spdlog::debug("Dtype mismatch validation test passed");
1380-
1381-
// Test 7: extra_repr
1382-
spdlog::info("Test 7: Testing extra_repr");
1383-
std::string repr = swiglu1.extra_repr();
1384-
spdlog::debug("SwiGLU repr: {}", repr);
1385-
1386-
if (repr.find("SwiGLU") == std::string::npos) {
1387-
spdlog::error("extra_repr should contain SwiGLU");
1388-
return false;
1389-
}
1390-
1391-
spdlog::debug("extra_repr test passed");
1392-
1393-
// Test 8: Different hidden sizes
1394-
spdlog::info("Test 8: Testing different hidden sizes");
1395-
auto up_small = infinicore::Tensor::ones({2, 64}, infinicore::DataType::F32, infinicore::Device());
1396-
auto gate_small = infinicore::Tensor::ones({2, 64}, infinicore::DataType::F32, infinicore::Device());
1397-
auto output_small = swiglu1.forward(up_small, gate_small);
1398-
1399-
auto up_large = infinicore::Tensor::ones({2, 4096}, infinicore::DataType::F32, infinicore::Device());
1400-
auto gate_large = infinicore::Tensor::ones({2, 4096}, infinicore::DataType::F32, infinicore::Device());
1401-
auto output_large = swiglu1.forward(up_large, gate_large);
1402-
1403-
if (output_small->shape() != std::vector<size_t>({2, 64})) {
1404-
spdlog::error("Small SwiGLU output shape mismatch");
1405-
return false;
1406-
}
1407-
1408-
if (output_large->shape() != std::vector<size_t>({2, 4096})) {
1409-
spdlog::error("Large SwiGLU output shape mismatch");
1410-
return false;
1411-
}
1412-
1413-
spdlog::debug("Different hidden sizes test passed");
1414-
1415-
spdlog::info("All SwiGLU module tests passed!");
1416-
return true;
1417-
1418-
} catch (const std::exception &e) {
1419-
spdlog::error("Exception in testModuleSwiGLU: {}", e.what());
1420-
return false;
1421-
}
1422-
});
1423-
}
1424-
14251290
// Test 8: Dtype assertion test
14261291
TestResult NNModuleTest::testDtypeAssertion() {
14271292
return measureTime("DtypeAssertionTest", [this]() {
@@ -1851,7 +1716,6 @@ TestResult NNModuleTest::run() {
18511716
results.push_back(testModuleEmbedding()); // Embedding module test
18521717
results.push_back(testModuleRMSNorm()); // RMSNorm module test
18531718
results.push_back(testModuleRoPE()); // RoPE module test
1854-
results.push_back(testModuleSwiGLU()); // SwiGLU module test
18551719
results.push_back(testDtypeAssertion()); // Dtype assertion test
18561720
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
18571721

src/infinicore-test/test_nn_module.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "infinicore/nn/parameter.hpp"
88
#include "infinicore/nn/rmsnorm.hpp"
99
#include "infinicore/nn/rope.hpp"
10-
#include "infinicore/nn/swiglu.hpp"
1110
#include "test_runner.h"
1211
#include <algorithm>
1312
#include <cmath>
@@ -85,7 +84,6 @@ class NNModuleTest : public TestFramework {
8584
TestResult testModuleEmbedding(); // Embedding module test
8685
TestResult testModuleRMSNorm(); // RMSNorm module test
8786
TestResult testModuleRoPE(); // RoPE module test
88-
TestResult testModuleSwiGLU(); // SwiGLU module test
8987
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
9088
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
9189
};

src/infinicore/nn/functional.cc

Lines changed: 0 additions & 23 deletions
This file was deleted.

src/infinicore/nn/swiglu.cc

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)