Skip to content

Commit 8174510

Browse files
committed
refactor: remove module registry and use direct construction
- Remove ModuleRegistry and INFINI_TRAIN_REGISTER_MODULE macros - Replace BuildModule() with direct constructor calls - Simplify module instantiation in MLP, CausalSelfAttention, and TransformerLayer
1 parent d283738 commit 8174510

25 files changed

Lines changed: 898 additions & 1147 deletions

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,5 +204,5 @@ link_infini_train_exe(test_precision_check)
204204
add_executable(test_lora test/lora/test_lora.cc)
205205
link_infini_train_exe(test_lora)
206206

207-
# add_executable(test_transformer_spec test/transformer_spec/test_transformer_spec.cc)
208-
# link_infini_train_exe(test_transformer_spec)
207+
add_executable(test_transformer_architecture test/transformer/test_transformer_architecture.cc)
208+
link_infini_train_exe(test_transformer_architecture)

example/gpt2/checkpoint_loader.cc

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "infini_train/include/nn/modules/normalization.h"
1818
#include "infini_train/include/nn/modules/sparse.h"
1919
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
20-
#include "infini_train/include/nn/modules/transformer/layer_specs.h"
2120
#include "infini_train/include/nn/modules/transformer/mlp.h"
2221
#include "infini_train/include/nn/modules/transformer/transformer.h"
2322
#include "infini_train/include/nn/parallel/global.h"
@@ -56,14 +55,6 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
5655
} // namespace
5756

5857
namespace gpt2 {
59-
int GetChunkSize() {
60-
nn::TransformerConfig gpt2_config = GPT2Config();
61-
62-
auto stage_info = nn::parallel::PipelineParallel::GetStageInfo(
63-
gpt2_config.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank,
64-
nn::parallel::global::GetVirtualPipelineParallelSize());
65-
return stage_info.layer_ranges_per_chunk.size();
66-
}
6758

6859
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
6960
if (!std::filesystem::exists(filepath)) {
@@ -96,10 +87,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
9687
gpt2_config.n_layer = n_layer;
9788
gpt2_config.n_head = n_head;
9889
gpt2_config.n_embd = n_embd;
99-
auto local_gpt2 = std::make_shared<nn::TransformerModel>(
100-
gpt2_config,
101-
nn::BuildTransformerSpec(gpt2_config, nn::BuildFirstStageSpec(gpt2_config),
102-
nn::BuildTransformerLayerSpec(gpt2_config), nn::BuildLastStageSpec(gpt2_config)));
90+
auto local_gpt2 = std::make_shared<nn::TransformerModel>(gpt2_config);
10391

10492
LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
10593
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
@@ -436,6 +424,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
436424
size_t ln_f_b_bytes = n_embd * sizeof(float);
437425
ifs.seekg(ln_f_w_bytes + ln_f_b_bytes, std::ios::cur);
438426
}
427+
439428
return local_gpt2;
440429
}
441430
} // namespace gpt2

example/gpt2/checkpoint_loader.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55

66
namespace infini_train::nn {
77
class TransformerModel;
8-
enum class ModelType;
98
} // namespace infini_train::nn
109

1110
namespace gpt2 {
12-
int GetChunkSize();
1311
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
14-
std::shared_ptr<infini_train::nn::TransformerModel> FromPretrained(infini_train::nn::ModelType model_type);
1512
} // namespace gpt2

example/gpt2/config.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ inline nn::TransformerConfig GPT2Config() {
1515
.attention_type = nn::AttentionType::kStandard,
1616
.activation_type = nn::MLPType::kGELU,
1717
.norm_type = nn::NormType::kLayerNorm,
18-
.use_bias = true,
18+
.add_bias_linear = true,
19+
.add_bias_lm_head = false,
1920
.tie_weights = true,
2021
.ffn_expansion_ratio = 4.0f,
2122
.ffn_dim_multiplier = std::nullopt,

example/gpt2/main.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ void Train(const nn::parallel::Rank &rank) {
250250
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};
251251

252252
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
253-
pp_rank, device, gpt2::GetChunkSize());
253+
pp_rank, device, model_config.GetChunkSize());
254254
if (ddp_world_size > 1) {
255255
auto ddp_config
256256
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
@@ -370,6 +370,7 @@ void Train(const nn::parallel::Rank &rank) {
370370
y = std::make_shared<Tensor>(y->To(device));
371371

372372
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward";
373+
373374
// (bs, seq_len, vocab_size)
374375
auto logits = (*model)({x, y})[0];
375376
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";

example/llama3/checkpoint_loader.cc

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "example/llama3/config.h"
1717
#include "infini_train/include/nn/modules/normalization.h"
1818
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
19-
#include "infini_train/include/nn/modules/transformer/layer_specs.h"
2019
#include "infini_train/include/nn/modules/transformer/mlp.h"
2120
#include "infini_train/include/nn/modules/transformer/transformer.h"
2221
#include "infini_train/include/nn/parallel/global.h"
@@ -40,15 +39,6 @@ constexpr int32_t kLLaMA3FP32Version = 3;
4039

4140
namespace llama3 {
4241

43-
int GetChunkSize() {
44-
nn::TransformerConfig llama3_config = llama3::LLaMA3Config();
45-
46-
auto stage_info = nn::parallel::PipelineParallel::GetStageInfo(
47-
llama3_config.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank,
48-
nn::parallel::global::GetVirtualPipelineParallelSize());
49-
return stage_info.layer_ranges_per_chunk.size();
50-
}
51-
5242
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
5343
if (!std::filesystem::exists(filepath)) {
5444
LOG(FATAL) << "File not found: " << filepath;
@@ -90,10 +80,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
9080
llama3_config.use_scaled_rope = static_cast<bool>(use_scaled_rope);
9181
llama3_config.norm_eps = norm_eps;
9282
llama3_config.max_gen_batch_size = max_gen_bs;
93-
auto llama3 = std::make_shared<nn::TransformerModel>(
94-
llama3_config,
95-
nn::BuildTransformerSpec(llama3_config, nn::BuildFirstStageSpec(llama3_config),
96-
nn::BuildTransformerLayerSpec(llama3_config), nn::BuildLastStageSpec(llama3_config)));
83+
auto llama3 = std::make_shared<nn::TransformerModel>(llama3_config);
9784

9885
// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
9986
int pp_size = nn::parallel::global::GetPipelineParallelSize();

example/llama3/checkpoint_loader.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,5 @@ class TransformerModel;
88
} // namespace infini_train::nn
99

1010
namespace llama3 {
11-
int GetChunkSize();
1211
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
1312
} // namespace llama3

example/llama3/config.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ inline nn::TransformerConfig LLaMA3Config() {
1515
.attention_type = nn::AttentionType::kRoPE,
1616
.activation_type = nn::MLPType::kSwiGLU,
1717
.norm_type = nn::NormType::kRMSNorm,
18-
.use_bias = false,
18+
.add_bias_linear = false,
19+
.add_bias_lm_head = false,
1920
.tie_weights = false,
2021
.ffn_expansion_ratio = 4.0f,
2122
.ffn_dim_multiplier = 1.5f,

example/llama3/main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ void Train(const nn::parallel::Rank &rank) {
220220
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};
221221

222222
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
223-
pp_rank, device, llama3::GetChunkSize());
223+
pp_rank, device, model_config.GetChunkSize());
224224
if (ddp_world_size > 1) {
225225
auto ddp_config
226226
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};

infini_train/include/nn/modules/transformer/causal_self_attention.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <vector>
66

77
#include "infini_train/include/nn/modules/module.h"
8-
#include "infini_train/include/nn/modules/transformer/spec_utils.h"
98
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
109

1110
namespace infini_train::nn {
@@ -18,7 +17,7 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfA
1817

1918
static constexpr char kParamBiasName[] = "bias";
2019

21-
explicit CausalSelfAttention(const TransformerConfig &config, const ModuleSpec &spec = {});
20+
explicit CausalSelfAttention(const TransformerConfig &config);
2221

2322
std::vector<std::shared_ptr<infini_train::Tensor>>
2423
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

0 commit comments

Comments
 (0)