Skip to content

Commit 9a8cce4

Browse files
committed
refactor: remove module registry and use direct construction
- Remove ModuleRegistry and INFINI_TRAIN_REGISTER_MODULE macros - Replace BuildModule() with direct constructor calls - Simplify module instantiation in MLP, CausalSelfAttention, and TransformerLayer
1 parent 52eca96 commit 9a8cce4

File tree

19 files changed

+859
-1121
lines changed

19 files changed

+859
-1121
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,5 +204,5 @@ link_infini_train_exe(test_precision_check)
204204
add_executable(test_lora test/lora/test_lora.cc)
205205
link_infini_train_exe(test_lora)
206206

207-
# add_executable(test_transformer_spec test/transformer_spec/test_transformer_spec.cc)
208-
# link_infini_train_exe(test_transformer_spec)
207+
add_executable(test_transformer_architecture test/transformer/test_transformer_architecture.cc)
208+
link_infini_train_exe(test_transformer_architecture)

example/gpt2/checkpoint_loader.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "infini_train/include/nn/modules/normalization.h"
1818
#include "infini_train/include/nn/modules/sparse.h"
1919
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
20-
#include "infini_train/include/nn/modules/transformer/layer_specs.h"
2120
#include "infini_train/include/nn/modules/transformer/mlp.h"
2221
#include "infini_train/include/nn/modules/transformer/transformer.h"
2322
#include "infini_train/include/nn/parallel/global.h"
@@ -96,10 +95,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
9695
gpt2_config.n_layer = n_layer;
9796
gpt2_config.n_head = n_head;
9897
gpt2_config.n_embd = n_embd;
99-
auto local_gpt2 = std::make_shared<nn::TransformerModel>(
100-
gpt2_config,
101-
nn::BuildTransformerSpec(gpt2_config, nn::BuildFirstStageSpec(gpt2_config),
102-
nn::BuildTransformerLayerSpec(gpt2_config), nn::BuildLastStageSpec(gpt2_config)));
98+
auto local_gpt2 = std::make_shared<nn::TransformerModel>(gpt2_config);
10399

104100
LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
105101
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
@@ -436,6 +432,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
436432
size_t ln_f_b_bytes = n_embd * sizeof(float);
437433
ifs.seekg(ln_f_w_bytes + ln_f_b_bytes, std::ios::cur);
438434
}
435+
439436
return local_gpt2;
440437
}
441438
} // namespace gpt2

example/gpt2/checkpoint_loader.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55

66
namespace infini_train::nn {
77
class TransformerModel;
8-
enum class ModelType;
98
} // namespace infini_train::nn
109

1110
namespace gpt2 {
1211
int GetChunkSize();
1312
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
14-
std::shared_ptr<infini_train::nn::TransformerModel> FromPretrained(infini_train::nn::ModelType model_type);
1513
} // namespace gpt2

example/gpt2/main.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ void Train(const nn::parallel::Rank &rank) {
370370
y = std::make_shared<Tensor>(y->To(device));
371371

372372
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward";
373+
373374
// (bs, seq_len, vocab_size)
374375
auto logits = (*model)({x, y})[0];
375376
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";

example/llama3/checkpoint_loader.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "example/llama3/config.h"
1717
#include "infini_train/include/nn/modules/normalization.h"
1818
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
19-
#include "infini_train/include/nn/modules/transformer/layer_specs.h"
2019
#include "infini_train/include/nn/modules/transformer/mlp.h"
2120
#include "infini_train/include/nn/modules/transformer/transformer.h"
2221
#include "infini_train/include/nn/parallel/global.h"
@@ -90,10 +89,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
9089
llama3_config.use_scaled_rope = static_cast<bool>(use_scaled_rope);
9190
llama3_config.norm_eps = norm_eps;
9291
llama3_config.max_gen_batch_size = max_gen_bs;
93-
auto llama3 = std::make_shared<nn::TransformerModel>(
94-
llama3_config,
95-
nn::BuildTransformerSpec(llama3_config, nn::BuildFirstStageSpec(llama3_config),
96-
nn::BuildTransformerLayerSpec(llama3_config), nn::BuildLastStageSpec(llama3_config)));
92+
auto llama3 = std::make_shared<nn::TransformerModel>(llama3_config);
9793

9894
// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
9995
int pp_size = nn::parallel::global::GetPipelineParallelSize();

infini_train/include/nn/modules/transformer/causal_self_attention.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <vector>
66

77
#include "infini_train/include/nn/modules/module.h"
8-
#include "infini_train/include/nn/modules/transformer/spec_utils.h"
98
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
109

1110
namespace infini_train::nn {
@@ -18,7 +17,7 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfA
1817

1918
static constexpr char kParamBiasName[] = "bias";
2019

21-
explicit CausalSelfAttention(const TransformerConfig &config, const ModuleSpec &spec = {});
20+
explicit CausalSelfAttention(const TransformerConfig &config);
2221

2322
std::vector<std::shared_ptr<infini_train::Tensor>>
2423
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

infini_train/include/nn/modules/transformer/layer_specs.h

Lines changed: 0 additions & 55 deletions
This file was deleted.

infini_train/include/nn/modules/transformer/mlp.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <vector>
44

55
#include "infini_train/include/nn/modules/module.h"
6-
#include "infini_train/include/nn/modules/transformer/spec_utils.h"
76
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
87

98
namespace infini_train::nn {
@@ -17,7 +16,7 @@ class MLP : public infini_train::nn::CloneableModule<MLP> {
1716
static constexpr char kCFc2LayerName[] = "c_fc2";
1817
static constexpr char kSiluLayerName[] = "silu";
1918

20-
explicit MLP(const TransformerConfig &config, const ModuleSpec &spec = {});
19+
explicit MLP(const TransformerConfig &config);
2120

2221
std::vector<std::shared_ptr<infini_train::Tensor>>
2322
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

infini_train/include/nn/modules/transformer/spec_utils.h

Lines changed: 0 additions & 93 deletions
This file was deleted.

infini_train/include/nn/modules/transformer/transformer.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <vector>
44

55
#include "infini_train/include/nn/modules/module.h"
6-
#include "infini_train/include/nn/modules/transformer/spec_utils.h"
76
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
87
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
98

@@ -16,7 +15,7 @@ class TransformerLayer : public CloneableModule<TransformerLayer> {
1615
static constexpr char kLn2LayerName[] = "ln_2";
1716
static constexpr char kMlpLayerName[] = "mlp";
1817

19-
explicit TransformerLayer(const TransformerConfig &config, const ModuleSpec &spec = {});
18+
explicit TransformerLayer(const TransformerConfig &config);
2019

2120
std::vector<std::shared_ptr<infini_train::Tensor>>
2221
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
@@ -28,14 +27,13 @@ class TransformerFirstStage : public CloneableModule<TransformerFirstStage> {
2827
static constexpr char kWTELayerName[] = "wte";
2928
static constexpr char kWPELayerName[] = "wpe";
3029

31-
explicit TransformerFirstStage(const TransformerConfig &config, const ModuleSpec &spec = {});
30+
explicit TransformerFirstStage(const TransformerConfig &config);
3231

3332
std::vector<std::shared_ptr<infini_train::Tensor>>
3433
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
3534

3635
private:
3736
TransformerConfig config_;
38-
ModuleSpec spec_;
3937
};
4038

4139
class TransformerChunk : public CloneableModule<TransformerChunk> {
@@ -44,14 +42,13 @@ class TransformerChunk : public CloneableModule<TransformerChunk> {
4442
static constexpr char kHLayerName[] = "h";
4543
static constexpr char kFreqsCisName[] = "freqs_cis";
4644

47-
TransformerChunk(const TransformerConfig &config, int start_layer, int end_layer, const ModuleSpec &spec = {});
45+
TransformerChunk(const TransformerConfig &config, int start_layer, int end_layer);
4846

4947
std::vector<std::shared_ptr<infini_train::Tensor>>
5048
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
5149

5250
private:
5351
const TransformerConfig config_;
54-
ModuleSpec spec_;
5552
};
5653

5754
class TransformerLastStage : public CloneableModule<TransformerLastStage> {
@@ -60,22 +57,21 @@ class TransformerLastStage : public CloneableModule<TransformerLastStage> {
6057
static constexpr char kLnFLayerName[] = "ln_f";
6158
static constexpr char kLMHeadLayerName[] = "lm_head";
6259

63-
explicit TransformerLastStage(const TransformerConfig &config, const ModuleSpec &spec = {});
60+
explicit TransformerLastStage(const TransformerConfig &config);
6461

6562
std::vector<std::shared_ptr<infini_train::Tensor>>
6663
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
6764

6865
private:
6966
const TransformerConfig config_;
70-
ModuleSpec spec_;
7167
};
7268

7369
class TransformerModel : public CloneableModule<TransformerModel> {
7470
public:
7571
static constexpr char kType[] = "Transformer";
7672
static constexpr char kTransformerModelName[] = "transformer";
7773

78-
explicit TransformerModel(const TransformerConfig config, const ModuleSpec &spec = {});
74+
explicit TransformerModel(const TransformerConfig config);
7975

8076
std::vector<std::shared_ptr<infini_train::Tensor>>
8177
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
@@ -85,7 +81,6 @@ class TransformerModel : public CloneableModule<TransformerModel> {
8581
private:
8682
const TransformerConfig config_;
8783
const infini_train::nn::parallel::StageInfo stage_info_;
88-
ModuleSpec spec_;
8984
};
9085

9186
} // namespace infini_train::nn

0 commit comments

Comments
 (0)