Skip to content

Commit 5cec43f

Browse files
committed
refactor: flatten transformer spec and extract model presets
1 parent 388b4bb commit 5cec43f

File tree

14 files changed

+140
-143
lines changed

14 files changed

+140
-143
lines changed

example/gpt2/config.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include "infini_train/include/core/transformer/transformer_config.h"
4+
5+
namespace nn = infini_train::nn;
6+
7+
namespace infini_train::nn::gpt2 {
8+
inline nn::TransformerConfig GPT2Config() {
9+
return {.model_type = ModelType::kGPT2,
10+
.block_size = 1024,
11+
.vocab_size = 50304,
12+
.original_vocab_size = 50257,
13+
.n_layer = 12,
14+
.n_head = 12,
15+
.n_kv_head = 12,
16+
.n_embd = 768,
17+
.attention_type = nn::AttentionType::kStandard,
18+
.activation_type = nn::MLPType::kGELU,
19+
.norm_type = nn::NormType::kLayerNorm,
20+
.use_bias = true,
21+
.tie_weights = true,
22+
.ffn_expansion_ratio = 4.0f,
23+
.ffn_dim_multiplier = std::nullopt,
24+
.multiple_of = 1};
25+
}
26+
27+
} // namespace infini_train::nn::gpt2

example/gpt2/main.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "infini_train/include/autocast.h"
1313
#include "infini_train/include/core/models/decode_only_transformer/model.h"
1414
#include "infini_train/include/core/runtime/device_guard.h"
15-
#include "infini_train/include/core/transformer/transformer_config.h"
1615
#include "infini_train/include/dataloader.h"
1716
#include "infini_train/include/device.h"
1817
#include "infini_train/include/nn/lora/lora_utils.h"
@@ -37,6 +36,7 @@
3736

3837
#include "example/common/tiny_shakespeare_dataset.h"
3938
#include "example/common/tokenizer.h"
39+
#include "example/gpt2/config.h"
4040

4141
// I/O
4242
DEFINE_string(input_bin, "", "input .bin to train on");
@@ -188,7 +188,7 @@ void Train(const nn::parallel::Rank &rank) {
188188
// ManualSeed(42);
189189

190190
// init the model, either from scratch or from OpenAI pretrained checkpoint
191-
nn::TransformerConfig model_config = nn::TransformerConfig::GPT2();
191+
nn::TransformerConfig model_config = nn::gpt2::GPT2Config();
192192
std::shared_ptr<nn::Module> model = nullptr;
193193

194194
if (!FLAGS_llmc_filepath.empty()) {

example/gpt2/net.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include "glog/logging.h"
1212

1313
#include "example/common/utils.h"
14+
#include "example/gpt2/config.h"
1415
#include "infini_train/include/core/models/decode_only_transformer/model.h"
15-
#include "infini_train/include/core/transformer/transformer_config.h"
1616
#include "infini_train/include/nn/modules/causal_self_attention.h"
1717
#include "infini_train/include/nn/modules/mlp.h"
1818
#include "infini_train/include/nn/modules/normalization.h"
@@ -76,7 +76,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_GPT2(co
7676
// NOTE(zbl): vocab_size needs to be padded to multiple of TP size
7777
const auto model_vocab_size = tp_size > 1 ? padded_vocab_size : vocab_size;
7878

79-
nn::TransformerConfig gpt2_config = nn::TransformerConfig::GPT2();
79+
nn::TransformerConfig gpt2_config = nn::gpt2::GPT2Config();
8080
gpt2_config.block_size = block_size;
8181
gpt2_config.vocab_size = model_vocab_size;
8282
gpt2_config.original_vocab_size = vocab_size;

example/llama3/config.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include "infini_train/include/core/transformer/transformer_config.h"
4+
5+
namespace nn = infini_train::nn;
6+
7+
namespace infini_train::nn::llama3 {
8+
inline nn::TransformerConfig LLaMA3Config() {
9+
return {.model_type = ModelType::kLLaMA3,
10+
.block_size = 8192,
11+
.vocab_size = 128256,
12+
.original_vocab_size = 128256,
13+
.n_layer = 16,
14+
.n_head = 32,
15+
.n_kv_head = 8,
16+
.n_embd = 2048,
17+
.attention_type = nn::AttentionType::kRoPE,
18+
.activation_type = nn::MLPType::kSwiGLU,
19+
.norm_type = nn::NormType::kRMSNorm,
20+
.use_bias = false,
21+
.tie_weights = false,
22+
.ffn_expansion_ratio = 4.0f,
23+
.ffn_dim_multiplier = 1.5f,
24+
.multiple_of = 256};
25+
}
26+
} // namespace infini_train::nn::llama3

example/llama3/main.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "infini_train/include/autocast.h"
1111
#include "infini_train/include/core/models/decode_only_transformer/model.h"
1212
#include "infini_train/include/core/runtime/device_guard.h"
13-
#include "infini_train/include/core/transformer/transformer_config.h"
1413
#include "infini_train/include/dataloader.h"
1514
#include "infini_train/include/device.h"
1615
#include "infini_train/include/nn/lora/lora_utils.h"
@@ -36,6 +35,7 @@
3635

3736
#include "example/common/tiny_shakespeare_dataset.h"
3837
#include "example/common/tokenizer.h"
38+
#include "example/llama3/config.h"
3939

4040
// I/O
4141
DEFINE_string(input_bin, "", "input .bin to train on");
@@ -168,7 +168,7 @@ void Train(const nn::parallel::Rank &rank) {
168168
// rng / reproducibility
169169
// ManualSeed(42);
170170

171-
nn::TransformerConfig model_config = nn::TransformerConfig::LLaMA3();
171+
nn::TransformerConfig model_config = nn::llama3::LLaMA3Config();
172172
std::shared_ptr<nn::Module> model = nullptr;
173173
if (!FLAGS_llmc_filepath.empty()) {
174174
model = DecoderOnlyTransformer::FromLLMC_LLaMA3(FLAGS_llmc_filepath);

example/llama3/net.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#include "glog/logging.h"
1313

1414
#include "example/common/utils.h"
15+
#include "example/llama3/config.h"
1516
#include "infini_train/include/core/models/decode_only_transformer/model.h"
1617
#include "infini_train/include/core/transformer/spec_utils.h"
17-
#include "infini_train/include/core/transformer/transformer_config.h"
1818
#include "infini_train/include/device.h"
1919
#include "infini_train/include/nn/modules/causal_self_attention.h"
2020
#include "infini_train/include/nn/modules/mlp.h"
@@ -65,7 +65,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
6565
const auto version_major = BytesToType<int32_t>(header, 56);
6666
const auto version_minor = BytesToType<int32_t>(header, 60);
6767

68-
nn::TransformerConfig llama3_config = nn::TransformerConfig::LLaMA3();
68+
nn::TransformerConfig llama3_config = nn::llama3::LLaMA3Config();
6969
llama3_config.block_size = block_size;
7070
llama3_config.vocab_size = vocab_size;
7171
llama3_config.n_layer = n_layer;

infini_train/include/core/models/decode_only_transformer/layer_specs.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
#include "infini_train/include/core/transformer/transformer_config.h"
55

66
namespace infini_train::nn {
7-
// Build GPT2 model spec: LayerNorm + GELU + standard attention
8-
ModuleSpec BuildGPT2Spec(const TransformerConfig &config);
97

10-
// Build LLaMA3 model spec: RMSNorm + SwiGLU + RoPE + GQA
11-
ModuleSpec BuildLLaMA3Spec(const TransformerConfig &config);
8+
ModuleSpec BuildDecoderOnlyTransformerSpec(const TransformerConfig &config, ModuleSpec first_stage, ModuleSpec chunk,
9+
ModuleSpec last_stage);
1210
} // namespace infini_train::nn

infini_train/include/core/models/decode_only_transformer/model.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "infini_train/include/core/models/decode_only_transformer/layer_specs.h"
1010
#include "infini_train/include/core/transformer/spec_utils.h"
11+
#include "infini_train/include/core/transformer/transformer_builders.h"
1112
#include "infini_train/include/core/transformer/transformer_config.h"
1213
#include "infini_train/include/core/transformer/transformer_model.h"
1314
#include "infini_train/include/nn/parallel/global.h"
@@ -37,7 +38,9 @@ class DecoderOnlyTransformer : public nn::TransformerModel {
3738
};
3839

3940
explicit DecoderOnlyTransformer(const nn::TransformerConfig &config)
40-
: TransformerModel(config, BuildModelSpec(config)),
41+
: TransformerModel(config, nn::BuildDecoderOnlyTransformerSpec(config, nn::BuildFirstStageSpec(config),
42+
nn::BuildTransformerLayerSpec(config),
43+
nn::BuildLastStageSpec(config))),
4144
stage_info_(nn::parallel::PipelineParallel::GetStageInfo(
4245
Config().n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank,
4346
nn::parallel::global::GetVirtualPipelineParallelSize())) {}
@@ -52,10 +55,5 @@ class DecoderOnlyTransformer : public nn::TransformerModel {
5255
int GetChunkSize() const;
5356

5457
private:
55-
static nn::ModuleSpec BuildModelSpec(const nn::TransformerConfig &config) {
56-
return (config.model_type == nn::TransformerConfig::kGPT2Name) ? BuildGPT2Spec(config)
57-
: BuildLLaMA3Spec(config);
58-
}
59-
6058
const infini_train::nn::parallel::StageInfo stage_info_;
6159
};

infini_train/include/core/transformer/transformer_builders.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,10 @@ ModuleSpec BuildPositionEmbeddingSpec(int64_t num_embeddings, int64_t embedding_
4747
// Build ColumnParallelLinear spec for output projection (lm_head)
4848
ModuleSpec BuildOutputProjSpec(const TransformerConfig &config, int64_t output_size, bool use_bias);
4949

50+
ModuleSpec BuildFirstStageSpec(const TransformerConfig &config);
51+
52+
ModuleSpec BuildChunkSpec(const TransformerConfig &config, int start_layer, int end_layer);
53+
54+
ModuleSpec BuildLastStageSpec(const TransformerConfig &config);
55+
5056
} // namespace infini_train::nn

infini_train/include/core/transformer/transformer_config.h

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,28 @@
55

66
namespace infini_train::nn {
77

8+
enum class ModelType {
9+
kGPT2, // GPT-2
10+
kLLaMA3, // LLaMA3
11+
};
12+
813
enum class AttentionType {
9-
kStandard, // Standard attention (GPT2 style, no RoPE)
10-
kRoPE // Rotary Position Embedding (LLaMA3 style)
14+
kStandard, // Standard attention
15+
kRoPE // Rotary Position Embedding
1116
};
1217

1318
enum class MLPType {
14-
kGELU, // GELU activation (GPT2 style)
15-
kSwiGLU // SwiGLU activation (LLaMA3 style)
19+
kGELU, // GELU activation
20+
kSwiGLU // SwiGLU activation
1621
};
1722

1823
enum class NormType {
19-
kLayerNorm, // LayerNorm (GPT2 style)
20-
kRMSNorm // RMSNorm (LLaMA3 style)
24+
kLayerNorm, // LayerNorm
25+
kRMSNorm // RMSNorm
2126
};
2227

2328
struct TransformerConfig {
24-
static constexpr char kGPT2Name[] = "GPT2";
25-
static constexpr char kLLaMA3Name[] = "LLaMA3";
26-
27-
std::string model_type = "";
29+
ModelType model_type = ModelType::kGPT2;
2830

2931
int64_t block_size = 1024; // Max seq_len
3032
int64_t vocab_size = 50304; // Vocab size
@@ -59,42 +61,5 @@ struct TransformerConfig {
5961
int64_t max_gen_batch_size = 4; // max batch size during inference
6062

6163
bool UseGQA() const { return n_kv_head < n_head; }
62-
63-
static TransformerConfig GPT2() {
64-
return {.model_type = kGPT2Name,
65-
.block_size = 1024,
66-
.vocab_size = 50304,
67-
.original_vocab_size = 50257,
68-
.n_layer = 12,
69-
.n_head = 12,
70-
.n_kv_head = 12,
71-
.n_embd = 768,
72-
.attention_type = AttentionType::kStandard,
73-
.activation_type = MLPType::kGELU,
74-
.norm_type = NormType::kLayerNorm,
75-
.use_bias = true,
76-
.tie_weights = true,
77-
.ffn_expansion_ratio = 4.0f,
78-
.ffn_dim_multiplier = std::nullopt,
79-
.multiple_of = 1};
80-
}
81-
82-
static TransformerConfig LLaMA3() {
83-
return {.model_type = kLLaMA3Name,
84-
.block_size = 8192,
85-
.vocab_size = 128256,
86-
.n_layer = 16,
87-
.n_head = 32,
88-
.n_kv_head = 8,
89-
.n_embd = 2048,
90-
.attention_type = AttentionType::kRoPE,
91-
.activation_type = MLPType::kSwiGLU,
92-
.norm_type = NormType::kRMSNorm,
93-
.use_bias = false,
94-
.tie_weights = false,
95-
.ffn_expansion_ratio = 4.0f,
96-
.ffn_dim_multiplier = 1.5f,
97-
.multiple_of = 256};
98-
}
9964
};
10065
} // namespace infini_train::nn

0 commit comments

Comments
 (0)