Skip to content

Commit d6fc7fa

Browse files
committed
apply macro at module definition
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 7dc1f6c commit d6fc7fa

4 files changed

Lines changed: 17 additions & 12 deletions

File tree

include/infinicore/nn/embedding.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class Embedding : public Module {
7575

7676
protected:
7777
// Parameters
78-
Parameter weight_;
78+
INFINICORE_NN_PARAMETER(weight);
7979

8080
private:
8181
size_t num_embeddings_; // Vocabulary size

include/infinicore/nn/linear.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class Linear : public Module {
3131

3232
protected:
3333
// Parameters
34-
Parameter weight_;
35-
Parameter bias_;
34+
INFINICORE_NN_PARAMETER(weight);
35+
INFINICORE_NN_PARAMETER(bias);
3636

3737
private:
3838
// Helper method for common forward computation

include/infinicore/nn/rmsnorm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class RMSNorm : public Module {
7070

7171
protected:
7272
// Parameters
73-
Parameter weight_;
73+
INFINICORE_NN_PARAMETER(weight);
7474

7575
private:
7676
size_t normalized_shape_; // Size of the feature dimension

src/infinicore-test/test_nn_module.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,21 @@ namespace infinicore::test {
2121
// Simple test module that mimics torch.nn.Linear
2222
class MockLinearModule : public infinicore::nn::Module {
2323
public:
24+
// Declare parameters using macros (torch-like style)
25+
INFINICORE_NN_PARAMETER(weight);
26+
INFINICORE_NN_PARAMETER(bias);
27+
2428
MockLinearModule(int input_size, int output_size, const infinicore::Device &device)
2529
: input_size_(input_size), output_size_(output_size), device_(device) {
26-
27-
// Initialize weight parameter (similar to torch.nn.Linear.weight)
28-
register_parameter("weight",
29-
infinicore::nn::Parameter({static_cast<size_t>(output_size), static_cast<size_t>(input_size)}, infinicore::DataType::F32, device));
30-
31-
// Initialize bias parameter (similar to torch.nn.Linear.bias)
32-
register_parameter("bias",
33-
infinicore::nn::Parameter({static_cast<size_t>(output_size)}, infinicore::DataType::F32, device));
30+
// Initialize parameters using macros
31+
INFINICORE_NN_PARAMETER_INIT(weight,
32+
({static_cast<size_t>(output_size), static_cast<size_t>(input_size)},
33+
infinicore::DataType::F32,
34+
device));
35+
INFINICORE_NN_PARAMETER_INIT(bias,
36+
({static_cast<size_t>(output_size)},
37+
infinicore::DataType::F32,
38+
device));
3439
}
3540

3641
// Simple forward pass (conceptual - would need actual matrix operations)

0 commit comments

Comments
 (0)