1+ #include " example/llama3/checkpoint_loader.h"
2+
13#include < cmath>
24#include < cstdlib>
35#include < filesystem>
1214
1315#include " example/common/utils.h"
1416#include " example/llama3/config.h"
15- #include " infini_train/include/core/models/decode_only_transformer/model.h"
16- #include " infini_train/include/nn/modules/causal_self_attention.h"
17- #include " infini_train/include/nn/modules/mlp.h"
1817#include " infini_train/include/nn/modules/normalization.h"
19- #include " infini_train/include/nn/modules/transformer.h"
18+ #include " infini_train/include/nn/modules/transformer/causal_self_attention.h"
19+ #include " infini_train/include/nn/modules/transformer/layer_specs.h"
20+ #include " infini_train/include/nn/modules/transformer/mlp.h"
21+ #include " infini_train/include/nn/modules/transformer/transformer.h"
22+ #include " infini_train/include/nn/parallel/global.h"
2023#include " infini_train/include/nn/parallel/tensor_parallel.h"
2124#include " infini_train/include/tensor.h"
2225
@@ -35,7 +38,18 @@ constexpr int32_t kLLaMA3Magic = 20240803;
3538constexpr int32_t kLLaMA3FP32Version = 3 ;
3639} // namespace
3740
38- std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3 (const std::string &filepath) {
41+ namespace llama3 {
42+
43+ int GetChunkSize () {
44+ nn::TransformerConfig llama3_config = llama3::LLaMA3Config ();
45+
46+ auto stage_info = nn::parallel::PipelineParallel::GetStageInfo (
47+ llama3_config.n_layer , nn::parallel::global::GetPipelineParallelSize (), nn::parallel::pp_rank,
48+ nn::parallel::global::GetVirtualPipelineParallelSize ());
49+ return stage_info.layer_ranges_per_chunk .size ();
50+ }
51+
52+ std::shared_ptr<nn::TransformerModel> LoadFromLLMC (const std::string &filepath) {
3953 if (!std::filesystem::exists (filepath)) {
4054 LOG (FATAL) << " File not found: " << filepath;
4155 }
@@ -63,7 +77,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
6377 const auto version_major = BytesToType<int32_t >(header, 56 );
6478 const auto version_minor = BytesToType<int32_t >(header, 60 );
6579
66- nn::TransformerConfig llama3_config = nn:: llama3::LLaMA3Config ();
80+ nn::TransformerConfig llama3_config = llama3::LLaMA3Config ();
6781 llama3_config.block_size = block_size;
6882 llama3_config.vocab_size = vocab_size;
6983 llama3_config.n_layer = n_layer;
@@ -76,7 +90,10 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
7690 llama3_config.use_scaled_rope = static_cast <bool >(use_scaled_rope);
7791 llama3_config.norm_eps = norm_eps;
7892 llama3_config.max_gen_batch_size = max_gen_bs;
79- auto llama3 = std::make_shared<DecoderOnlyTransformer>(llama3_config);
93+ auto llama3 = std::make_shared<nn::TransformerModel>(
94+ llama3_config,
95+ nn::BuildTransformerSpec (llama3_config, nn::BuildFirstStageSpec (llama3_config),
96+ nn::BuildTransformerLayerSpec (llama3_config), nn::BuildLastStageSpec (llama3_config)));
8097
8198 // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
8299 int pp_size = nn::parallel::global::GetPipelineParallelSize ();
@@ -164,7 +181,8 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
164181 // ========== Read Sharded Params ==========
165182 // transformer.wte.weight : (vocab_size, n_embd) -> local tp_rank: rows of [v_start : v_start+vpp)
166183 if (is_first_stage) {
167- auto &wte = state_dict[std::format (" {}.{}.{}" , kTransformerModelName , nn::TransformerFirstStage::kWTELayerName ,
184+ auto &wte = state_dict[std::format (" {}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
185+ nn::TransformerFirstStage::kWTELayerName ,
168186 nn::parallel::VocabParallelEmbedding::kParamWeightName )];
169187 ReadMatrixRowShardFloat (ifs, static_cast <float *>(wte->DataPtr ()),
170188 /* rows=*/ vocab_size, /* cols=*/ n_embd,
@@ -178,7 +196,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
178196 int local_layer_index = 0 ;
179197 for (int i = 0 ; i < static_cast <int >(n_layer); ++i) {
180198 if (owned_layers[i]) {
181- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , kTransformerModelName ,
199+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , nn::TransformerModel:: kTransformerModelName ,
182200 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
183201 nn::TransformerLayer::kLn1LayerName , nn::RMSNorm::kParamWeightName )];
184202 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
@@ -195,7 +213,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
195213 for (int i = 0 ; i < static_cast <int >(n_layer); ++i) {
196214 if (owned_layers[i]) {
197215 auto &tensor = state_dict[std::format (
198- " {}.{}.{}.{}.{}.{}" , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
216+ " {}.{}.{}.{}.{}.{}" , nn::TransformerModel:: kTransformerModelName , nn::TransformerChunk::kHLayerName ,
199217 std::to_string (local_layer_index), nn::TransformerLayer::kAttnLayerName ,
200218 nn::CausalSelfAttention::kCAttnLayerName , nn::parallel::ColumnParallelLinear::kParamWeightName )];
201219
@@ -235,7 +253,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
235253 for (int i = 0 ; i < static_cast <int >(n_layer); ++i) {
236254 if (owned_layers[i]) {
237255 auto &tensor = state_dict[std::format (
238- " {}.{}.{}.{}.{}.{}" , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
256+ " {}.{}.{}.{}.{}.{}" , nn::TransformerModel:: kTransformerModelName , nn::TransformerChunk::kHLayerName ,
239257 std::to_string (local_layer_index), nn::TransformerLayer::kAttnLayerName ,
240258 nn::CausalSelfAttention::kCProjLayerName , nn::parallel::RowParallelLinear::kParamWeightName )];
241259 ReadMatrixColShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()),
@@ -252,7 +270,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
252270 local_layer_index = 0 ;
253271 for (int i = 0 ; i < static_cast <int >(n_layer); ++i) {
254272 if (owned_layers[i]) {
255- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , kTransformerModelName ,
273+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , nn::TransformerModel:: kTransformerModelName ,
256274 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
257275 nn::TransformerLayer::kLn2LayerName , nn::RMSNorm::kParamWeightName )];
258276 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
@@ -267,10 +285,10 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
267285 local_layer_index = 0 ;
268286 for (int i = 0 ; i < static_cast <int >(n_layer); ++i) {
269287 if (owned_layers[i]) {
270- auto &tensor
271- = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
272- std::to_string (local_layer_index) , nn::TransformerLayer:: kMlpLayerName ,
273- nn::MLP:: kCFcLayerName , nn::parallel::ColumnParallelLinear::kParamWeightName )];
288+ auto &tensor = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , nn::TransformerModel:: kTransformerModelName ,
289+ nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index) ,
290+ nn::TransformerLayer:: kMlpLayerName , nn::MLP:: kCFcLayerName ,
291+ nn::parallel::ColumnParallelLinear::kParamWeightName )];
274292 ReadMatrixRowShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()),
275293 /* rows=*/ fc_out, /* cols=*/ n_embd,
276294 /* row_start=*/ tp_rank * fc_pp, /* row_cnt=*/ fc_pp);
@@ -285,7 +303,7 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
285303 local_layer_index = 0 ;
286304 for (int i = 0 ; i < static_cast <int >(n_layer); ++i) {
287305 if (owned_layers[i]) {
288- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , kTransformerModelName ,
306+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , nn::TransformerModel:: kTransformerModelName ,
289307 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
290308 nn::TransformerLayer::kMlpLayerName , nn::MLP::kCFc2LayerName ,
291309 nn::parallel::ColumnParallelLinear::kParamWeightName )];
@@ -303,10 +321,10 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
303321 local_layer_index = 0 ;
304322 for (int i = 0 ; i < static_cast <int >(n_layer); ++i) {
305323 if (owned_layers[i]) {
306- auto &tensor
307- = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
308- std::to_string (local_layer_index) , nn::TransformerLayer:: kMlpLayerName ,
309- nn::MLP:: kCProjLayerName , nn::parallel::RowParallelLinear::kParamWeightName )];
324+ auto &tensor = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , nn::TransformerModel:: kTransformerModelName ,
325+ nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index) ,
326+ nn::TransformerLayer:: kMlpLayerName , nn::MLP:: kCProjLayerName ,
327+ nn::parallel::RowParallelLinear::kParamWeightName )];
310328 ReadMatrixColShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()),
311329 /* rows=*/ n_embd, /* cols=*/ fc_out,
312330 /* col_start=*/ tp_rank * in_fc_pp, /* col_cnt=*/ in_fc_pp);
@@ -322,8 +340,8 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
322340 {
323341 if (is_last_stage) {
324342 auto &ln_f
325- = state_dict[std::format (" {}.{}.{}" , kTransformerModelName , nn::TransformerLastStage:: kLnFLayerName ,
326- nn::RMSNorm::kParamWeightName )];
343+ = state_dict[std::format (" {}.{}.{}" , nn::TransformerModel:: kTransformerModelName ,
344+ nn::TransformerLastStage:: kLnFLayerName , nn:: RMSNorm::kParamWeightName )];
327345 auto &lm_head = state_dict[std::format (" {}.{}" , nn::TransformerLastStage::kLMHeadLayerName ,
328346 nn::parallel::ColumnParallelLinear::kParamWeightName )];
329347 ReadVectorAllFloat (ifs, static_cast <float *>(ln_f->DataPtr ()), n_embd);
@@ -339,3 +357,4 @@ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_LLaMA3(
339357
340358 return llama3;
341359}
360+ } // namespace llama3
0 commit comments