Skip to content

Commit 2e4611c

Browse files
committed
refactor: flatten transformer spec and extract model presets
1 parent 388b4bb commit 2e4611c

File tree

18 files changed

+144
-163
lines changed

18 files changed

+144
-163
lines changed

example/gpt2/config.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include "infini_train/include/core/transformer/transformer_config.h"
4+
5+
namespace infini_train::nn::gpt2 {
6+
inline nn::TransformerConfig GPT2Config() {
7+
return {.model_type = ModelType::kGPT2,
8+
.block_size = 1024,
9+
.vocab_size = 50304,
10+
.original_vocab_size = 50257,
11+
.n_layer = 12,
12+
.n_head = 12,
13+
.n_kv_head = 12,
14+
.n_embd = 768,
15+
.attention_type = nn::AttentionType::kStandard,
16+
.activation_type = nn::MLPType::kGELU,
17+
.norm_type = nn::NormType::kLayerNorm,
18+
.use_bias = true,
19+
.tie_weights = true,
20+
.ffn_expansion_ratio = 4.0f,
21+
.ffn_dim_multiplier = std::nullopt,
22+
.multiple_of = 1};
23+
}
24+
25+
} // namespace infini_train::nn::gpt2

example/gpt2/main.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "infini_train/include/autocast.h"
1313
#include "infini_train/include/core/models/decode_only_transformer/model.h"
1414
#include "infini_train/include/core/runtime/device_guard.h"
15-
#include "infini_train/include/core/transformer/transformer_config.h"
1615
#include "infini_train/include/dataloader.h"
1716
#include "infini_train/include/device.h"
1817
#include "infini_train/include/nn/lora/lora_utils.h"
@@ -37,6 +36,7 @@
3736

3837
#include "example/common/tiny_shakespeare_dataset.h"
3938
#include "example/common/tokenizer.h"
39+
#include "example/gpt2/config.h"
4040

4141
// I/O
4242
DEFINE_string(input_bin, "", "input .bin to train on");
@@ -188,7 +188,7 @@ void Train(const nn::parallel::Rank &rank) {
188188
// ManualSeed(42);
189189

190190
// init the model, either from scratch or from OpenAI pretrained checkpoint
191-
nn::TransformerConfig model_config = nn::TransformerConfig::GPT2();
191+
nn::TransformerConfig model_config = nn::gpt2::GPT2Config();
192192
std::shared_ptr<nn::Module> model = nullptr;
193193

194194
if (!FLAGS_llmc_filepath.empty()) {

example/gpt2/net.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include "glog/logging.h"
1212

1313
#include "example/common/utils.h"
14+
#include "example/gpt2/config.h"
1415
#include "infini_train/include/core/models/decode_only_transformer/model.h"
15-
#include "infini_train/include/core/transformer/transformer_config.h"
1616
#include "infini_train/include/nn/modules/causal_self_attention.h"
1717
#include "infini_train/include/nn/modules/mlp.h"
1818
#include "infini_train/include/nn/modules/normalization.h"
@@ -21,6 +21,7 @@
2121
#include "infini_train/include/nn/parallel/global.h"
2222
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
2323
#include "infini_train/include/nn/parallel/tensor_parallel.h"
24+
#include "infini_train/include/tensor.h"
2425

2526
using namespace infini_train;
2627
namespace nn = infini_train::nn;
@@ -76,7 +77,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_GPT2(co
7677
// NOTE(zbl): vocab_size needs to be padded to multiple of TP size
7778
const auto model_vocab_size = tp_size > 1 ? padded_vocab_size : vocab_size;
7879

79-
nn::TransformerConfig gpt2_config = nn::TransformerConfig::GPT2();
80+
nn::TransformerConfig gpt2_config = nn::gpt2::GPT2Config();
8081
gpt2_config.block_size = block_size;
8182
gpt2_config.vocab_size = model_vocab_size;
8283
gpt2_config.original_vocab_size = vocab_size;

example/llama3/config.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include "infini_train/include/core/transformer/transformer_config.h"
4+
5+
namespace infini_train::nn::llama3 {
6+
inline nn::TransformerConfig LLaMA3Config() {
7+
return {.model_type = ModelType::kLLaMA3,
8+
.block_size = 8192,
9+
.vocab_size = 128256,
10+
.original_vocab_size = 128256,
11+
.n_layer = 16,
12+
.n_head = 32,
13+
.n_kv_head = 8,
14+
.n_embd = 2048,
15+
.attention_type = nn::AttentionType::kRoPE,
16+
.activation_type = nn::MLPType::kSwiGLU,
17+
.norm_type = nn::NormType::kRMSNorm,
18+
.use_bias = false,
19+
.tie_weights = false,
20+
.ffn_expansion_ratio = 4.0f,
21+
.ffn_dim_multiplier = 1.5f,
22+
.multiple_of = 256};
23+
}
24+
} // namespace infini_train::nn::llama3

example/llama3/main.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "infini_train/include/autocast.h"
1111
#include "infini_train/include/core/models/decode_only_transformer/model.h"
1212
#include "infini_train/include/core/runtime/device_guard.h"
13-
#include "infini_train/include/core/transformer/transformer_config.h"
1413
#include "infini_train/include/dataloader.h"
1514
#include "infini_train/include/device.h"
1615
#include "infini_train/include/nn/lora/lora_utils.h"
@@ -36,6 +35,7 @@
3635

3736
#include "example/common/tiny_shakespeare_dataset.h"
3837
#include "example/common/tokenizer.h"
38+
#include "example/llama3/config.h"
3939

4040
// I/O
4141
DEFINE_string(input_bin, "", "input .bin to train on");
@@ -168,7 +168,7 @@ void Train(const nn::parallel::Rank &rank) {
168168
// rng / reproducibility
169169
// ManualSeed(42);
170170

171-
nn::TransformerConfig model_config = nn::TransformerConfig::LLaMA3();
171+
nn::TransformerConfig model_config = nn::llama3::LLaMA3Config();
172172
std::shared_ptr<nn::Module> model = nullptr;
173173
if (!FLAGS_llmc_filepath.empty()) {
174174
model = DecoderOnlyTransformer::FromLLMC_LLaMA3(FLAGS_llmc_filepath);

example/llama3/net.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include <cstdlib>
33
#include <filesystem>
44
#include <fstream>
5-
#include <map>
65
#include <memory>
76
#include <random>
87
#include <string>
@@ -12,15 +11,14 @@
1211
#include "glog/logging.h"
1312

1413
#include "example/common/utils.h"
14+
#include "example/llama3/config.h"
1515
#include "infini_train/include/core/models/decode_only_transformer/model.h"
16-
#include "infini_train/include/core/transformer/spec_utils.h"
17-
#include "infini_train/include/core/transformer/transformer_config.h"
18-
#include "infini_train/include/device.h"
1916
#include "infini_train/include/nn/modules/causal_self_attention.h"
2017
#include "infini_train/include/nn/modules/mlp.h"
2118
#include "infini_train/include/nn/modules/normalization.h"
2219
#include "infini_train/include/nn/modules/transformer.h"
2320
#include "infini_train/include/nn/parallel/tensor_parallel.h"
21+
#include "infini_train/include/tensor.h"
2422

2523
using namespace infini_train;
2624
namespace nn = infini_train::nn;
@@ -65,7 +63,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
6563
const auto version_major = BytesToType<int32_t>(header, 56);
6664
const auto version_minor = BytesToType<int32_t>(header, 60);
6765

68-
nn::TransformerConfig llama3_config = nn::TransformerConfig::LLaMA3();
66+
nn::TransformerConfig llama3_config = nn::llama3::LLaMA3Config();
6967
llama3_config.block_size = block_size;
7068
llama3_config.vocab_size = vocab_size;
7169
llama3_config.n_layer = n_layer;

infini_train/include/core/models/decode_only_transformer/layer_specs.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
#include "infini_train/include/core/transformer/transformer_config.h"
55

66
namespace infini_train::nn {
7-
// Build GPT2 model spec: LayerNorm + GELU + standard attention
8-
ModuleSpec BuildGPT2Spec(const TransformerConfig &config);
97

10-
// Build LLaMA3 model spec: RMSNorm + SwiGLU + RoPE + GQA
11-
ModuleSpec BuildLLaMA3Spec(const TransformerConfig &config);
8+
ModuleSpec BuildDecoderOnlyTransformerSpec(const TransformerConfig &config, ModuleSpec first_stage, ModuleSpec chunk,
9+
ModuleSpec last_stage);
1210
} // namespace infini_train::nn

infini_train/include/core/models/decode_only_transformer/model.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
#include "glog/logging.h"
88

99
#include "infini_train/include/core/models/decode_only_transformer/layer_specs.h"
10-
#include "infini_train/include/core/transformer/spec_utils.h"
10+
#include "infini_train/include/core/transformer/transformer_builders.h"
1111
#include "infini_train/include/core/transformer/transformer_config.h"
1212
#include "infini_train/include/core/transformer/transformer_model.h"
1313
#include "infini_train/include/nn/parallel/global.h"
1414
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
15-
#include "infini_train/include/tensor.h"
1615

1716
using namespace infini_train;
1817
namespace nn = infini_train::nn;
@@ -37,7 +36,9 @@ class DecoderOnlyTransformer : public nn::TransformerModel {
3736
};
3837

3938
explicit DecoderOnlyTransformer(const nn::TransformerConfig &config)
40-
: TransformerModel(config, BuildModelSpec(config)),
39+
: TransformerModel(config, nn::BuildDecoderOnlyTransformerSpec(config, nn::BuildFirstStageSpec(config),
40+
nn::BuildTransformerLayerSpec(config),
41+
nn::BuildLastStageSpec(config))),
4142
stage_info_(nn::parallel::PipelineParallel::GetStageInfo(
4243
Config().n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank,
4344
nn::parallel::global::GetVirtualPipelineParallelSize())) {}
@@ -52,10 +53,5 @@ class DecoderOnlyTransformer : public nn::TransformerModel {
5253
int GetChunkSize() const;
5354

5455
private:
55-
static nn::ModuleSpec BuildModelSpec(const nn::TransformerConfig &config) {
56-
return (config.model_type == nn::TransformerConfig::kGPT2Name) ? BuildGPT2Spec(config)
57-
: BuildLLaMA3Spec(config);
58-
}
59-
6056
const infini_train::nn::parallel::StageInfo stage_info_;
6157
};

infini_train/include/core/transformer/transformer_builders.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

33
#include <cstdint>
4-
#include <memory>
54

65
#include "infini_train/include/core/transformer/spec_utils.h"
76
#include "infini_train/include/core/transformer/transformer_config.h"
@@ -35,9 +34,6 @@ ModuleSpec BuildAttentionSpec(const TransformerConfig &config);
3534
// Build MLP spec (supports GELU and SwiGLU)
3635
ModuleSpec BuildMLPSpec(const TransformerConfig &config);
3736

38-
// Build TransformerLayer spec
39-
ModuleSpec BuildTransformerLayerSpec(const TransformerConfig &config);
40-
4137
// Build VocabParallelEmbedding spec for token embeddings
4238
ModuleSpec BuildVocabEmbeddingSpec(const TransformerConfig &config);
4339

@@ -47,4 +43,13 @@ ModuleSpec BuildPositionEmbeddingSpec(int64_t num_embeddings, int64_t embedding_
4743
// Build ColumnParallelLinear spec for output projection (lm_head)
4844
ModuleSpec BuildOutputProjSpec(const TransformerConfig &config, int64_t output_size, bool use_bias);
4945

46+
// Build TransformerFirstStage spec
47+
ModuleSpec BuildFirstStageSpec(const TransformerConfig &config);
48+
49+
// Build TransformerLayer spec
50+
ModuleSpec BuildTransformerLayerSpec(const TransformerConfig &config);
51+
52+
// Build TransformerLastStage spec
53+
ModuleSpec BuildLastStageSpec(const TransformerConfig &config);
54+
5055
} // namespace infini_train::nn
Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
11
#pragma once
2+
23
#include <cstdint>
34
#include <optional>
4-
#include <string>
55

66
namespace infini_train::nn {
77

8+
enum class ModelType {
9+
kGPT2, // GPT-2
10+
kLLaMA3, // LLaMA3
11+
};
12+
813
enum class AttentionType {
9-
kStandard, // Standard attention (GPT2 style, no RoPE)
10-
kRoPE // Rotary Position Embedding (LLaMA3 style)
14+
kStandard, // Standard attention
15+
kRoPE // Rotary Position Embedding
1116
};
1217

1318
enum class MLPType {
14-
kGELU, // GELU activation (GPT2 style)
15-
kSwiGLU // SwiGLU activation (LLaMA3 style)
19+
kGELU, // GELU activation
20+
kSwiGLU // SwiGLU activation
1621
};
1722

1823
enum class NormType {
19-
kLayerNorm, // LayerNorm (GPT2 style)
20-
kRMSNorm // RMSNorm (LLaMA3 style)
24+
kLayerNorm, // LayerNorm
25+
kRMSNorm // RMSNorm
2126
};
2227

2328
struct TransformerConfig {
24-
static constexpr char kGPT2Name[] = "GPT2";
25-
static constexpr char kLLaMA3Name[] = "LLaMA3";
26-
27-
std::string model_type = "";
29+
ModelType model_type = ModelType::kGPT2;
2830

2931
int64_t block_size = 1024; // Max seq_len
3032
int64_t vocab_size = 50304; // Vocab size
@@ -59,42 +61,5 @@ struct TransformerConfig {
5961
int64_t max_gen_batch_size = 4; // max batch size during inference
6062

6163
bool UseGQA() const { return n_kv_head < n_head; }
62-
63-
static TransformerConfig GPT2() {
64-
return {.model_type = kGPT2Name,
65-
.block_size = 1024,
66-
.vocab_size = 50304,
67-
.original_vocab_size = 50257,
68-
.n_layer = 12,
69-
.n_head = 12,
70-
.n_kv_head = 12,
71-
.n_embd = 768,
72-
.attention_type = AttentionType::kStandard,
73-
.activation_type = MLPType::kGELU,
74-
.norm_type = NormType::kLayerNorm,
75-
.use_bias = true,
76-
.tie_weights = true,
77-
.ffn_expansion_ratio = 4.0f,
78-
.ffn_dim_multiplier = std::nullopt,
79-
.multiple_of = 1};
80-
}
81-
82-
static TransformerConfig LLaMA3() {
83-
return {.model_type = kLLaMA3Name,
84-
.block_size = 8192,
85-
.vocab_size = 128256,
86-
.n_layer = 16,
87-
.n_head = 32,
88-
.n_kv_head = 8,
89-
.n_embd = 2048,
90-
.attention_type = AttentionType::kRoPE,
91-
.activation_type = MLPType::kSwiGLU,
92-
.norm_type = NormType::kRMSNorm,
93-
.use_bias = false,
94-
.tie_weights = false,
95-
.ffn_expansion_ratio = 4.0f,
96-
.ffn_dim_multiplier = 1.5f,
97-
.multiple_of = 256};
98-
}
9964
};
10065
} // namespace infini_train::nn

0 commit comments

Comments
 (0)