|
17 | 17 | #include "infini_train/include/nn/modules/normalization.h" |
18 | 18 | #include "infini_train/include/nn/modules/sparse.h" |
19 | 19 | #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" |
20 | | -#include "infini_train/include/nn/modules/transformer/layer_specs.h" |
21 | 20 | #include "infini_train/include/nn/modules/transformer/mlp.h" |
22 | 21 | #include "infini_train/include/nn/modules/transformer/transformer.h" |
23 | 22 | #include "infini_train/include/nn/parallel/global.h" |
@@ -56,14 +55,6 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std:: |
56 | 55 | } // namespace |
57 | 56 |
|
58 | 57 | namespace gpt2 { |
59 | | -int GetChunkSize() { |
60 | | - nn::TransformerConfig gpt2_config = GPT2Config(); |
61 | | - |
62 | | - auto stage_info = nn::parallel::PipelineParallel::GetStageInfo( |
63 | | - gpt2_config.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, |
64 | | - nn::parallel::global::GetVirtualPipelineParallelSize()); |
65 | | - return stage_info.layer_ranges_per_chunk.size(); |
66 | | -} |
67 | 58 |
|
68 | 59 | std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) { |
69 | 60 | if (!std::filesystem::exists(filepath)) { |
@@ -96,10 +87,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) |
96 | 87 | gpt2_config.n_layer = n_layer; |
97 | 88 | gpt2_config.n_head = n_head; |
98 | 89 | gpt2_config.n_embd = n_embd; |
99 | | - auto local_gpt2 = std::make_shared<nn::TransformerModel>( |
100 | | - gpt2_config, |
101 | | - nn::BuildTransformerSpec(gpt2_config, nn::BuildFirstStageSpec(gpt2_config), |
102 | | - nn::BuildTransformerLayerSpec(gpt2_config), nn::BuildLastStageSpec(gpt2_config))); |
| 90 | + auto local_gpt2 = std::make_shared<nn::TransformerModel>(gpt2_config); |
103 | 91 |
|
104 | 92 | LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size |
105 | 93 | << " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head |
@@ -436,6 +424,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) |
436 | 424 | size_t ln_f_b_bytes = n_embd * sizeof(float); |
437 | 425 | ifs.seekg(ln_f_w_bytes + ln_f_b_bytes, std::ios::cur); |
438 | 426 | } |
| 427 | + |
439 | 428 | return local_gpt2; |
440 | 429 | } |
441 | 430 | } // namespace gpt2 |
0 commit comments