Skip to content

Commit 777b323

Browse files
committed
do assertion at load_parameter && update Module definition with macros
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 69c1c35 commit 777b323

9 files changed

Lines changed: 202 additions & 51 deletions

File tree

include/infinicore/nn/embedding.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class Embedding : public Module {
7575

7676
protected:
7777
// Parameters
78-
Parameter weight_;
78+
INFINICORE_NN_PARAMETER(weight);
7979

8080
private:
8181
size_t num_embeddings_; // Vocabulary size

include/infinicore/nn/linear.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace infinicore::nn {
77

88
class Linear : public Module {
99
public:
10-
Linear(size_t in_features, size_t out_features, bool bias = true, const Device &device = Device());
10+
Linear(size_t in_features, size_t out_features, bool bias = true, const DataType &dtype = DataType::F32, const Device &device = Device());
1111

1212
// Forward pass: output = input @ weight.T + bias
1313
Tensor forward(Tensor &input) const;
@@ -20,6 +20,7 @@ class Linear : public Module {
2020
size_t in_features() const { return in_features_; }
2121
size_t out_features() const { return out_features_; }
2222
bool has_bias() const { return has_bias_; }
23+
DataType dtype() const { return dtype_; }
2324

2425
// String representation
2526
std::string extra_repr() const;
@@ -30,8 +31,8 @@ class Linear : public Module {
3031

3132
protected:
3233
// Parameters
33-
Parameter weight_;
34-
Parameter bias_;
34+
INFINICORE_NN_PARAMETER(weight);
35+
INFINICORE_NN_PARAMETER(bias);
3536

3637
private:
3738
// Helper method for common forward computation
@@ -40,6 +41,7 @@ class Linear : public Module {
4041
size_t in_features_;
4142
size_t out_features_;
4243
bool has_bias_;
44+
DataType dtype_;
4345
};
4446

4547
} // namespace infinicore::nn

include/infinicore/nn/module.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ class Module {
125125

126126
// Declare a parameter member variable
127127
#define INFINICORE_NN_PARAMETER(name) \
128-
Parameter name##_
128+
infinicore::nn::Parameter name##_
129129

130130
// Initialize a parameter in constructor
131131
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
132132
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
133133
#define INFINICORE_NN_PARAMETER_INIT(name, args) \
134-
name##_ = Parameter args; \
134+
name##_ = infinicore::nn::Parameter args; \
135135
this->register_parameter(#name, name##_)
136136

137137
} // namespace infinicore::nn

include/infinicore/nn/rmsnorm.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ class RMSNorm : public Module {
3636
*
3737
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
3838
* @param eps Small constant for numerical stability (default: 1e-6)
39+
* @param dtype Data type for the weight (default: DataType::F32)
3940
* @param device Device to create the weight on
4041
*/
4142
RMSNorm(size_t normalized_shape,
4243
double eps = 1e-6,
44+
const DataType &dtype = DataType::F32,
4345
const Device &device = Device());
4446

4547
/**
@@ -58,6 +60,7 @@ class RMSNorm : public Module {
5860
// Module information
5961
size_t normalized_shape() const { return normalized_shape_; }
6062
double eps() const { return eps_; }
63+
DataType dtype() const { return dtype_; }
6164

6265
// String representation
6366
std::string extra_repr() const;
@@ -67,11 +70,12 @@ class RMSNorm : public Module {
6770

6871
protected:
6972
// Parameters
70-
Parameter weight_;
73+
INFINICORE_NN_PARAMETER(weight);
7174

7275
private:
7376
size_t normalized_shape_; // Size of the feature dimension
7477
double eps_; // Epsilon for numerical stability
78+
DataType dtype_; // Data type for weight
7579
};
7680

7781
} // namespace infinicore::nn

src/infinicore-test/test_nn_module.cc

Lines changed: 140 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ TestResult NNModuleTest::testModuleLinear() {
394394
try {
395395
// Test with bias
396396
spdlog::info("Testing Linear module with bias (8->4 features)");
397-
infinicore::nn::Linear m1(8, 4, true, infinicore::Device());
397+
infinicore::nn::Linear m1(8, 4, true);
398398
auto sd1 = m1.state_dict();
399399
if (sd1.find("weight") == sd1.end()) {
400400
spdlog::error("weight missing");
@@ -440,7 +440,7 @@ TestResult NNModuleTest::testModuleLinear() {
440440

441441
// Test without bias
442442
spdlog::info("Testing Linear module without bias (16->3 features)");
443-
infinicore::nn::Linear m2(16, 3, false, infinicore::Device());
443+
infinicore::nn::Linear m2(16, 3, false);
444444
auto sd2 = m2.state_dict();
445445
if (sd2.find("weight") == sd2.end()) {
446446
spdlog::error("weight missing (no-bias)");
@@ -834,7 +834,7 @@ TestResult NNModuleTest::testModuleRMSNorm() {
834834

835835
// Test 1: Basic RMSNorm creation
836836
spdlog::info("Test 1: Basic RMSNorm creation (hidden_size=768)");
837-
infinicore::nn::RMSNorm norm1(768, 1e-6, infinicore::Device());
837+
infinicore::nn::RMSNorm norm1(768);
838838

839839
auto state1 = norm1.state_dict();
840840
if (state1.find("weight") == state1.end()) {
@@ -925,8 +925,8 @@ TestResult NNModuleTest::testModuleRMSNorm() {
925925

926926
// Test 7: Different hidden sizes
927927
spdlog::info("Test 7: Testing different hidden sizes");
928-
infinicore::nn::RMSNorm norm_small(128, 1e-5, infinicore::Device());
929-
infinicore::nn::RMSNorm norm_large(4096, 1e-6, infinicore::Device());
928+
infinicore::nn::RMSNorm norm_small(128, 1e-5);
929+
infinicore::nn::RMSNorm norm_large(4096);
930930

931931
auto input_small = infinicore::Tensor::ones({2, 128}, infinicore::DataType::F32, infinicore::Device());
932932
auto output_small = norm_small.forward(input_small);
@@ -956,7 +956,130 @@ TestResult NNModuleTest::testModuleRMSNorm() {
956956
});
957957
}
958958

959-
// Test 8: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
959+
// Test 8: Dtype assertion test
960+
TestResult NNModuleTest::testDtypeAssertion() {
961+
return measureTime("DtypeAssertionTest", [this]() {
962+
try {
963+
spdlog::info("Testing dtype assertions when loading parameters");
964+
965+
// Test 1: Successful load with matching dtype
966+
spdlog::info("Test 1: Successful load with matching dtype (F32)");
967+
infinicore::nn::Linear linear1(8, 4, true);
968+
auto matching_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::F32, infinicore::Device());
969+
auto matching_bias = infinicore::Tensor::ones({4}, infinicore::DataType::F32, infinicore::Device());
970+
971+
std::unordered_map<std::string, infinicore::Tensor> matching_state;
972+
matching_state.emplace("weight", matching_weight);
973+
matching_state.emplace("bias", matching_bias);
974+
975+
// This should succeed without throwing
976+
linear1.load_state_dict(matching_state);
977+
spdlog::debug("✓ Matching dtype load succeeded");
978+
979+
// Test 2: Failed load with mismatched dtype (load_parameter)
980+
spdlog::info("Test 2: Failed load_parameter with mismatched dtype");
981+
infinicore::nn::Linear linear2(8, 4, true);
982+
auto mismatched_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device());
983+
984+
bool exception_thrown = false;
985+
try {
986+
linear2.load_parameter("weight", mismatched_weight);
987+
} catch (const std::runtime_error &e) {
988+
exception_thrown = true;
989+
std::string error_msg = e.what();
990+
if (error_msg.find("dtype mismatch") == std::string::npos) {
991+
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
992+
return false;
993+
}
994+
spdlog::debug("✓ Mismatched dtype exception caught: {}", error_msg);
995+
}
996+
997+
if (!exception_thrown) {
998+
spdlog::error("Expected exception for dtype mismatch in load_parameter");
999+
return false;
1000+
}
1001+
1002+
// Test 3: Failed load with mismatched dtype (load_state_dict)
1003+
spdlog::info("Test 3: Failed load_state_dict with mismatched dtype");
1004+
infinicore::nn::Embedding embedding1(100, 64);
1005+
auto mismatched_embed_weight = infinicore::Tensor::ones({100, 64}, infinicore::DataType::BF16, infinicore::Device());
1006+
1007+
std::unordered_map<std::string, infinicore::Tensor> mismatched_state;
1008+
mismatched_state.emplace("weight", mismatched_embed_weight);
1009+
1010+
exception_thrown = false;
1011+
try {
1012+
embedding1.load_state_dict(mismatched_state);
1013+
} catch (const std::runtime_error &e) {
1014+
exception_thrown = true;
1015+
std::string error_msg = e.what();
1016+
if (error_msg.find("dtype mismatch") == std::string::npos) {
1017+
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
1018+
return false;
1019+
}
1020+
if (error_msg.find("weight") == std::string::npos) {
1021+
spdlog::error("Exception message doesn't contain parameter name 'weight'");
1022+
return false;
1023+
}
1024+
spdlog::debug("✓ Mismatched dtype exception caught: {}", error_msg);
1025+
}
1026+
1027+
if (!exception_thrown) {
1028+
spdlog::error("Expected exception for dtype mismatch in load_state_dict");
1029+
return false;
1030+
}
1031+
1032+
// Test 4: Failed load with mismatched dtype (RMSNorm)
1033+
spdlog::info("Test 4: Failed load_state_dict with mismatched dtype (RMSNorm)");
1034+
infinicore::nn::RMSNorm norm1(768);
1035+
auto mismatched_norm_weight = infinicore::Tensor::ones({768}, infinicore::DataType::BF16, infinicore::Device());
1036+
1037+
std::unordered_map<std::string, infinicore::Tensor> mismatched_norm_state;
1038+
mismatched_norm_state.emplace("weight", mismatched_norm_weight);
1039+
1040+
exception_thrown = false;
1041+
try {
1042+
norm1.load_state_dict(mismatched_norm_state);
1043+
} catch (const std::runtime_error &e) {
1044+
exception_thrown = true;
1045+
std::string error_msg = e.what();
1046+
if (error_msg.find("dtype mismatch") == std::string::npos) {
1047+
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
1048+
return false;
1049+
}
1050+
spdlog::debug("✓ Mismatched dtype exception caught for RMSNorm: {}", error_msg);
1051+
}
1052+
1053+
if (!exception_thrown) {
1054+
spdlog::error("Expected exception for dtype mismatch in RMSNorm load_state_dict");
1055+
return false;
1056+
}
1057+
1058+
// Test 5: Successful load with different module dtypes
1059+
spdlog::info("Test 5: Successful load with BF16 dtype (module created with BF16)");
1060+
infinicore::nn::Linear linear3(8, 4, true, infinicore::DataType::BF16);
1061+
auto bf16_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device());
1062+
auto bf16_bias = infinicore::Tensor::ones({4}, infinicore::DataType::BF16, infinicore::Device());
1063+
1064+
std::unordered_map<std::string, infinicore::Tensor> bf16_state;
1065+
bf16_state.emplace("weight", bf16_weight);
1066+
bf16_state.emplace("bias", bf16_bias);
1067+
1068+
// This should succeed
1069+
linear3.load_state_dict(bf16_state);
1070+
spdlog::debug("✓ BF16 dtype load succeeded");
1071+
1072+
spdlog::info("All dtype assertion tests passed!");
1073+
return true;
1074+
1075+
} catch (const std::exception &e) {
1076+
spdlog::error("Exception in testDtypeAssertion: {}", e.what());
1077+
return false;
1078+
}
1079+
});
1080+
}
1081+
1082+
// Test 9: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
9601083
TestResult NNModuleTest::testTinyLlamaConstruction() {
9611084
return measureTime("TinyLlamaModelTest", [this]() {
9621085
try {
@@ -1007,10 +1130,10 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
10071130
INFINICORE_NN_MODULE(infinicore::nn::Linear, o_proj);
10081131

10091132
SelfAttn(size_t hidden_size, size_t kv_dim, const infinicore::Device &device) {
1010-
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size, hidden_size, false, device);
1011-
INFINICORE_NN_MODULE_INIT(k_proj, hidden_size, kv_dim, false, device);
1012-
INFINICORE_NN_MODULE_INIT(v_proj, hidden_size, kv_dim, false, device);
1013-
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size, hidden_size, false, device);
1133+
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size, hidden_size, false, infinicore::DataType::F32, device);
1134+
INFINICORE_NN_MODULE_INIT(k_proj, hidden_size, kv_dim, false, infinicore::DataType::F32, device);
1135+
INFINICORE_NN_MODULE_INIT(v_proj, hidden_size, kv_dim, false, infinicore::DataType::F32, device);
1136+
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size, hidden_size, false, infinicore::DataType::F32, device);
10141137
}
10151138
};
10161139

@@ -1021,9 +1144,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
10211144
INFINICORE_NN_MODULE(infinicore::nn::Linear, down_proj);
10221145

10231146
MLP(size_t hidden_size, size_t intermediate_size, const infinicore::Device &device) {
1024-
INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size, intermediate_size, false, device);
1025-
INFINICORE_NN_MODULE_INIT(up_proj, hidden_size, intermediate_size, false, device);
1026-
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size, hidden_size, false, device);
1147+
INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size, intermediate_size, false, infinicore::DataType::F32, device);
1148+
INFINICORE_NN_MODULE_INIT(up_proj, hidden_size, intermediate_size, false, infinicore::DataType::F32, device);
1149+
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size, hidden_size, false, infinicore::DataType::F32, device);
10271150
}
10281151
};
10291152

@@ -1036,9 +1159,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
10361159

10371160
Block(const TinyLlamaConfig &cfg, const infinicore::Device &device) {
10381161
size_t kv_dim = cfg.hidden_size * cfg.num_key_value_heads / cfg.num_attention_heads;
1039-
INFINICORE_NN_MODULE_INIT(input_layernorm, cfg.hidden_size, cfg.rms_norm_eps, device);
1162+
INFINICORE_NN_MODULE_INIT(input_layernorm, cfg.hidden_size, cfg.rms_norm_eps, infinicore::DataType::F32, device);
10401163
INFINICORE_NN_MODULE_INIT(self_attn, cfg.hidden_size, kv_dim, device);
1041-
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, cfg.hidden_size, cfg.rms_norm_eps, device);
1164+
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, cfg.hidden_size, cfg.rms_norm_eps, infinicore::DataType::F32, device);
10421165
INFINICORE_NN_MODULE_INIT(mlp, cfg.hidden_size, cfg.intermediate_size, device);
10431166
}
10441167
};
@@ -1051,7 +1174,7 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
10511174
TinyLlamaModel(const TinyLlamaConfig &config, const infinicore::Device &device) {
10521175
INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, std::nullopt, infinicore::DataType::F32, device);
10531176
INFINICORE_NN_MODULE_VEC_INIT(layers, config.num_hidden_layers, Block, config, device);
1054-
INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps, device);
1177+
INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps, infinicore::DataType::F32, device);
10551178
}
10561179
};
10571180

@@ -1259,6 +1382,7 @@ TestResult NNModuleTest::run() {
12591382
results.push_back(testModuleLinear()); // Linear module comprehensive test
12601383
results.push_back(testModuleEmbedding()); // Embedding module test
12611384
results.push_back(testModuleRMSNorm()); // RMSNorm module test
1385+
results.push_back(testDtypeAssertion()); // Dtype assertion test
12621386
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
12631387

12641388
// Check if all tests passed

src/infinicore-test/test_nn_module.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,21 @@ namespace infinicore::test {
2121
// Simple test module that mimics torch.nn.Linear
2222
class MockLinearModule : public infinicore::nn::Module {
2323
public:
24+
// Declare parameters using macros (torch-like style)
25+
INFINICORE_NN_PARAMETER(weight);
26+
INFINICORE_NN_PARAMETER(bias);
27+
2428
MockLinearModule(int input_size, int output_size, const infinicore::Device &device)
2529
: input_size_(input_size), output_size_(output_size), device_(device) {
26-
27-
// Initialize weight parameter (similar to torch.nn.Linear.weight)
28-
register_parameter("weight",
29-
infinicore::nn::Parameter({static_cast<size_t>(output_size), static_cast<size_t>(input_size)}, infinicore::DataType::F32, device));
30-
31-
// Initialize bias parameter (similar to torch.nn.Linear.bias)
32-
register_parameter("bias",
33-
infinicore::nn::Parameter({static_cast<size_t>(output_size)}, infinicore::DataType::F32, device));
30+
// Initialize parameters using macros
31+
INFINICORE_NN_PARAMETER_INIT(weight,
32+
({static_cast<size_t>(output_size), static_cast<size_t>(input_size)},
33+
infinicore::DataType::F32,
34+
device));
35+
INFINICORE_NN_PARAMETER_INIT(bias,
36+
({static_cast<size_t>(output_size)},
37+
infinicore::DataType::F32,
38+
device));
3439
}
3540

3641
// Simple forward pass (conceptual - would need actual matrix operations)
@@ -77,6 +82,7 @@ class NNModuleTest : public TestFramework {
7782
TestResult testModuleLinear(); // Comprehensive Linear module test
7883
TestResult testModuleEmbedding(); // Embedding module test
7984
TestResult testModuleRMSNorm(); // RMSNorm module test
85+
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
8086
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
8187
};
8288

0 commit comments

Comments
 (0)