1212
1313#include " example/common/utils.h"
1414#include " infini_train/include/core/models/decode_only_transformer/model.h"
15- #include " infini_train/include/core/transformer/transformer_block.h"
15+ #include " infini_train/include/core/transformer/attention/causal_self_attention.h"
16+ #include " infini_train/include/core/transformer/mlp.h"
1617#include " infini_train/include/core/transformer/transformer_config.h"
18+ #include " infini_train/include/core/transformer/transformer_layer.h"
1719#include " infini_train/include/nn/modules/normalization.h"
1820#include " infini_train/include/nn/modules/sparse.h"
1921#include " infini_train/include/nn/parallel/global.h"
@@ -133,7 +135,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
133135 // local: (vocab_size_per_partition, n_embd)
134136 if (is_first_stage) {
135137 auto &transformer_wte_weight
136- = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerLayerName , nn::TransformerFirstStage::kWTELayerName ,
138+ = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerFirstStage::kWTELayerName ,
137139 nn::parallel::VocabParallelEmbedding::kParamWeightName )];
138140 ReadMatrixRowShardFloat (ifs, static_cast <float *>(transformer_wte_weight->DataPtr ()), model_vocab_size, n_embd,
139141 v_start, vpp);
@@ -155,7 +157,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
155157 if (is_first_stage) {
156158 // transformer.wpe.weight
157159 auto &transformer_wpe_weight
158- = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerLayerName , nn::TransformerFirstStage::kWPELayerName ,
160+ = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerFirstStage::kWPELayerName ,
159161 nn::Embedding::kParamWeightName )];
160162 ReadMatrixAllFloat (ifs, static_cast <float *>(transformer_wpe_weight->DataPtr ()), block_size, n_embd);
161163 } else {
@@ -168,9 +170,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
168170 for (int idx = 0 ; idx < n_layer; ++idx) {
169171 if (owned_layers[idx]) {
170172 auto &tensor
171- = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2::kTransformerLayerName ,
173+ = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2::kTransformerModelName ,
172174 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
173- nn::TransformerBlock ::kLn1LayerName , nn::LayerNorm::kParamWeightName )];
175+ nn::TransformerLayer ::kLn1LayerName , nn::LayerNorm::kParamWeightName )];
174176 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
175177 ++local_layer_index;
176178 } else {
@@ -183,9 +185,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
183185 local_layer_index = 0 ;
184186 for (int idx = 0 ; idx < n_layer; ++idx) {
185187 if (owned_layers[idx]) {
186- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2::kTransformerLayerName ,
188+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2::kTransformerModelName ,
187189 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
188- nn::TransformerBlock ::kLn1LayerName , nn::LayerNorm::kParamBiasName )];
190+ nn::TransformerLayer ::kLn1LayerName , nn::LayerNorm::kParamBiasName )];
189191 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
190192 ++local_layer_index;
191193 } else {
@@ -199,8 +201,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
199201 for (int idx = 0 ; idx < n_layer; ++idx) {
200202 if (owned_layers[idx]) {
201203 auto &tensor = state_dict[std::format (
202- " {}.{}.{}.{}.{}.{}" , GPT2::kTransformerLayerName , nn::TransformerChunk::kHLayerName ,
203- std::to_string (local_layer_index), nn::TransformerBlock ::kAttnLayerName ,
204+ " {}.{}.{}.{}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerChunk::kHLayerName ,
205+ std::to_string (local_layer_index), nn::TransformerLayer ::kAttnLayerName ,
204206 nn::CausalSelfAttention::kCAttnLayerName , nn::parallel::ColumnParallelLinear::kParamWeightName )];
205207 // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim,
206208 // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T
@@ -242,8 +244,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
242244 for (int idx = 0 ; idx < n_layer; ++idx) {
243245 if (owned_layers[idx]) {
244246 auto &tensor = state_dict[std::format (
245- " {}.{}.{}.{}.{}.{}" , GPT2::kTransformerLayerName , nn::TransformerChunk::kHLayerName ,
246- std::to_string (local_layer_index), nn::TransformerBlock ::kAttnLayerName ,
247+ " {}.{}.{}.{}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerChunk::kHLayerName ,
248+ std::to_string (local_layer_index), nn::TransformerLayer ::kAttnLayerName ,
247249 nn::CausalSelfAttention::kCAttnLayerName , nn::parallel::ColumnParallelLinear::kParamBiasName )];
248250 // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated
249251 // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn]
@@ -284,8 +286,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
284286 for (int idx = 0 ; idx < n_layer; ++idx) {
285287 if (owned_layers[idx]) {
286288 auto &tensor = state_dict[std::format (
287- " {}.{}.{}.{}.{}.{}" , GPT2::kTransformerLayerName , nn::TransformerChunk::kHLayerName ,
288- std::to_string (local_layer_index), nn::TransformerBlock ::kAttnLayerName ,
289+ " {}.{}.{}.{}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerChunk::kHLayerName ,
290+ std::to_string (local_layer_index), nn::TransformerLayer ::kAttnLayerName ,
289291 nn::CausalSelfAttention::kCProjLayerName , nn::parallel::RowParallelLinear::kParamWeightName )];
290292 ReadMatrixColShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd, n_embd, tp_rank * in_pp,
291293 in_pp);
@@ -301,8 +303,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
301303 for (int idx = 0 ; idx < n_layer; ++idx) {
302304 if (owned_layers[idx]) {
303305 auto &tensor = state_dict[std::format (
304- " {}.{}.{}.{}.{}.{}" , GPT2::kTransformerLayerName , nn::TransformerChunk::kHLayerName ,
305- std::to_string (local_layer_index), nn::TransformerBlock ::kAttnLayerName ,
306+ " {}.{}.{}.{}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerChunk::kHLayerName ,
307+ std::to_string (local_layer_index), nn::TransformerLayer ::kAttnLayerName ,
306308 nn::CausalSelfAttention::kCProjLayerName , nn::parallel::RowParallelLinear::kParamBiasName )];
307309 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
308310 ++local_layer_index;
@@ -317,9 +319,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
317319 for (int idx = 0 ; idx < n_layer; ++idx) {
318320 if (owned_layers[idx]) {
319321 auto &tensor
320- = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2::kTransformerLayerName ,
322+ = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2::kTransformerModelName ,
321323 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
322- nn::TransformerBlock ::kLn2LayerName , nn::LayerNorm::kParamWeightName )];
324+ nn::TransformerLayer ::kLn2LayerName , nn::LayerNorm::kParamWeightName )];
323325 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
324326 ++local_layer_index;
325327 } else {
@@ -332,9 +334,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
332334 local_layer_index = 0 ;
333335 for (int idx = 0 ; idx < n_layer; ++idx) {
334336 if (owned_layers[idx]) {
335- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2::kTransformerLayerName ,
337+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2::kTransformerModelName ,
336338 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
337- nn::TransformerBlock ::kLn2LayerName , nn::LayerNorm::kParamBiasName )];
339+ nn::TransformerLayer ::kLn2LayerName , nn::LayerNorm::kParamBiasName )];
338340 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
339341 ++local_layer_index;
340342 } else {
@@ -347,9 +349,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
347349 local_layer_index = 0 ;
348350 for (int idx = 0 ; idx < n_layer; ++idx) {
349351 if (owned_layers[idx]) {
350- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , GPT2::kTransformerLayerName ,
352+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , GPT2::kTransformerModelName ,
351353 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
352- nn::TransformerBlock ::kMlpLayerName , nn::MLP::kCFcLayerName ,
354+ nn::TransformerLayer ::kMlpLayerName , nn::MLP::kCFcLayerName ,
353355 nn::parallel::ColumnParallelLinear::kParamWeightName )];
354356 ReadMatrixRowShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()), fc_out, n_embd, fc_start, fc_pp);
355357 ++local_layer_index;
@@ -363,9 +365,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
363365 local_layer_index = 0 ;
364366 for (int idx = 0 ; idx < n_layer; ++idx) {
365367 if (owned_layers[idx]) {
366- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , GPT2::kTransformerLayerName ,
368+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , GPT2::kTransformerModelName ,
367369 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
368- nn::TransformerBlock ::kMlpLayerName , nn::MLP::kCFcLayerName ,
370+ nn::TransformerLayer ::kMlpLayerName , nn::MLP::kCFcLayerName ,
369371 nn::parallel::ColumnParallelLinear::kParamBiasName )];
370372 ReadVectorShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()), fc_out, fc_start, fc_pp);
371373 ++local_layer_index;
@@ -379,9 +381,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
379381 local_layer_index = 0 ;
380382 for (int idx = 0 ; idx < n_layer; ++idx) {
381383 if (owned_layers[idx]) {
382- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , GPT2::kTransformerLayerName ,
384+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , GPT2::kTransformerModelName ,
383385 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
384- nn::TransformerBlock ::kMlpLayerName , nn::MLP::kCProjLayerName ,
386+ nn::TransformerLayer ::kMlpLayerName , nn::MLP::kCProjLayerName ,
385387 nn::parallel::RowParallelLinear::kParamWeightName )];
386388 ReadMatrixColShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd, fc_out, tp_rank * in4_pp,
387389 in4_pp);
@@ -396,9 +398,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
396398 local_layer_index = 0 ;
397399 for (int idx = 0 ; idx < n_layer; ++idx) {
398400 if (owned_layers[idx]) {
399- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , GPT2::kTransformerLayerName ,
401+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}.{}" , GPT2::kTransformerModelName ,
400402 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
401- nn::TransformerBlock ::kMlpLayerName , nn::MLP::kCProjLayerName ,
403+ nn::TransformerLayer ::kMlpLayerName , nn::MLP::kCProjLayerName ,
402404 nn::parallel::RowParallelLinear::kParamBiasName )];
403405 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
404406 ++local_layer_index;
@@ -411,12 +413,12 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
411413 if (is_last_stage) {
412414 // transformer.ln_f.weight
413415 auto &transformer_ln_f_weight
414- = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerLayerName , nn::TransformerLastStage::kLnFLayerName ,
416+ = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerLastStage::kLnFLayerName ,
415417 nn::LayerNorm::kParamWeightName )];
416418 ReadVectorAllFloat (ifs, static_cast <float *>(transformer_ln_f_weight->DataPtr ()), n_embd);
417419 // transformer.ln_f.bias
418420 auto &transformer_ln_f_bias
419- = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerLayerName , nn::TransformerLastStage::kLnFLayerName ,
421+ = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerLastStage::kLnFLayerName ,
420422 nn::LayerNorm::kParamBiasName )];
421423 ReadVectorAllFloat (ifs, static_cast <float *>(transformer_ln_f_bias->DataPtr ()), n_embd);
422424 } else {
0 commit comments