@@ -32,12 +32,6 @@ constexpr int kRandomSeed = 42;
3232static std::mt19937 gen{kRandomSeed };
3333} // namespace
3434
35- std::shared_ptr<GPT2> GPT2::FromPretrained (ModelType model_type) {
36- // TODO(dcj): implement this later
37- LOG (FATAL) << " Not implemented yet" ;
38- return nullptr ;
39- }
40-
4135namespace {
4236constexpr int32_t kHeaderMagic = 20240326 ;
4337constexpr int32_t kHeaderFP32Version = 3 ;
@@ -58,7 +52,7 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
5852}
5953} // namespace
6054
61- std::shared_ptr<GPT2> GPT2::FromLLMC (const std::string &filepath) {
55+ std::shared_ptr<DecoderOnlyTransformer> DecoderOnlyTransformer::FromLLMC_GPT2 (const std::string &filepath) {
6256 if (!std::filesystem::exists (filepath)) {
6357 LOG (FATAL) << " File not found: " << filepath;
6458 }
@@ -89,7 +83,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
8983 gpt2_config.n_layer = n_layer;
9084 gpt2_config.n_head = n_head;
9185 gpt2_config.n_embd = n_embd;
92- auto local_gpt2 = std::make_shared<GPT2 >(gpt2_config);
86+ auto local_gpt2 = std::make_shared<DecoderOnlyTransformer >(gpt2_config);
9387
9488 LOG (INFO) << " magic: " << magic << " version: " << version << " block_size: " << block_size
9589 << " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
@@ -135,7 +129,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
135129 // local: (vocab_size_per_partition, n_embd)
136130 if (is_first_stage) {
137131 auto &transformer_wte_weight
138- = state_dict[std::format (" {}.{}.{}" , GPT2:: kTransformerModelName , nn::TransformerFirstStage::kWTELayerName ,
132+ = state_dict[std::format (" {}.{}.{}" , kTransformerModelName , nn::TransformerFirstStage::kWTELayerName ,
139133 nn::parallel::VocabParallelEmbedding::kParamWeightName )];
140134 ReadMatrixRowShardFloat (ifs, static_cast <float *>(transformer_wte_weight->DataPtr ()), model_vocab_size, n_embd,
141135 v_start, vpp);
@@ -157,7 +151,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
157151 if (is_first_stage) {
158152 // transformer.wpe.weight
159153 auto &transformer_wpe_weight
160- = state_dict[std::format (" {}.{}.{}" , GPT2:: kTransformerModelName , nn::TransformerFirstStage::kWPELayerName ,
154+ = state_dict[std::format (" {}.{}.{}" , kTransformerModelName , nn::TransformerFirstStage::kWPELayerName ,
161155 nn::Embedding::kParamWeightName )];
162156 ReadMatrixAllFloat (ifs, static_cast <float *>(transformer_wpe_weight->DataPtr ()), block_size, n_embd);
163157 } else {
@@ -170,9 +164,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
170164 for (int idx = 0 ; idx < n_layer; ++idx) {
171165 if (owned_layers[idx]) {
172166 auto &tensor
173- = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2:: kTransformerModelName ,
174- nn::TransformerChunk:: kHLayerName , std::to_string (local_layer_index),
175- nn::TransformerLayer:: kLn1LayerName , nn:: LayerNorm::kParamWeightName )];
167+ = state_dict[std::format (" {}.{}.{}.{}.{}" , kTransformerModelName , nn::TransformerChunk:: kHLayerName ,
168+ std::to_string (local_layer_index), nn::TransformerLayer:: kLn1LayerName ,
169+ nn::LayerNorm::kParamWeightName )];
176170 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
177171 ++local_layer_index;
178172 } else {
@@ -185,7 +179,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
185179 local_layer_index = 0 ;
186180 for (int idx = 0 ; idx < n_layer; ++idx) {
187181 if (owned_layers[idx]) {
188- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2:: kTransformerModelName ,
182+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , kTransformerModelName ,
189183 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
190184 nn::TransformerLayer::kLn1LayerName , nn::LayerNorm::kParamBiasName )];
191185 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
@@ -201,7 +195,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
201195 for (int idx = 0 ; idx < n_layer; ++idx) {
202196 if (owned_layers[idx]) {
203197 auto &tensor = state_dict[std::format (
204- " {}.{}.{}.{}.{}.{}" , GPT2:: kTransformerModelName , nn::TransformerChunk::kHLayerName ,
198+ " {}.{}.{}.{}.{}.{}" , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
205199 std::to_string (local_layer_index), nn::TransformerLayer::kAttnLayerName ,
206200 nn::CausalSelfAttention::kCAttnLayerName , nn::parallel::ColumnParallelLinear::kParamWeightName )];
207201 // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim,
@@ -244,7 +238,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
244238 for (int idx = 0 ; idx < n_layer; ++idx) {
245239 if (owned_layers[idx]) {
246240 auto &tensor = state_dict[std::format (
247- " {}.{}.{}.{}.{}.{}" , GPT2:: kTransformerModelName , nn::TransformerChunk::kHLayerName ,
241+ " {}.{}.{}.{}.{}.{}" , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
248242 std::to_string (local_layer_index), nn::TransformerLayer::kAttnLayerName ,
249243 nn::CausalSelfAttention::kCAttnLayerName , nn::parallel::ColumnParallelLinear::kParamBiasName )];
250244 // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated
@@ -286,7 +280,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
286280 for (int idx = 0 ; idx < n_layer; ++idx) {
287281 if (owned_layers[idx]) {
288282 auto &tensor = state_dict[std::format (
289- " {}.{}.{}.{}.{}.{}" , GPT2:: kTransformerModelName , nn::TransformerChunk::kHLayerName ,
283+ " {}.{}.{}.{}.{}.{}" , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
290284 std::to_string (local_layer_index), nn::TransformerLayer::kAttnLayerName ,
291285 nn::CausalSelfAttention::kCProjLayerName , nn::parallel::RowParallelLinear::kParamWeightName )];
292286 ReadMatrixColShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd, n_embd, tp_rank * in_pp,
@@ -303,7 +297,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
303297 for (int idx = 0 ; idx < n_layer; ++idx) {
304298 if (owned_layers[idx]) {
305299 auto &tensor = state_dict[std::format (
306- " {}.{}.{}.{}.{}.{}" , GPT2:: kTransformerModelName , nn::TransformerChunk::kHLayerName ,
300+ " {}.{}.{}.{}.{}.{}" , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
307301 std::to_string (local_layer_index), nn::TransformerLayer::kAttnLayerName ,
308302 nn::CausalSelfAttention::kCProjLayerName , nn::parallel::RowParallelLinear::kParamBiasName )];
309303 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
@@ -319,9 +313,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
319313 for (int idx = 0 ; idx < n_layer; ++idx) {
320314 if (owned_layers[idx]) {
321315 auto &tensor
322- = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2:: kTransformerModelName ,
323- nn::TransformerChunk:: kHLayerName , std::to_string (local_layer_index),
324- nn::TransformerLayer:: kLn2LayerName , nn:: LayerNorm::kParamWeightName )];
316+ = state_dict[std::format (" {}.{}.{}.{}.{}" , kTransformerModelName , nn::TransformerChunk:: kHLayerName ,
317+ std::to_string (local_layer_index), nn::TransformerLayer:: kLn2LayerName ,
318+ nn::LayerNorm::kParamWeightName )];
325319 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
326320 ++local_layer_index;
327321 } else {
@@ -334,7 +328,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
334328 local_layer_index = 0 ;
335329 for (int idx = 0 ; idx < n_layer; ++idx) {
336330 if (owned_layers[idx]) {
337- auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , GPT2:: kTransformerModelName ,
331+ auto &tensor = state_dict[std::format (" {}.{}.{}.{}.{}" , kTransformerModelName ,
338332 nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index),
339333 nn::TransformerLayer::kLn2LayerName , nn::LayerNorm::kParamBiasName )];
340334 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
@@ -349,10 +343,10 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
349343 local_layer_index = 0 ;
350344 for (int idx = 0 ; idx < n_layer; ++idx) {
351345 if (owned_layers[idx]) {
352- auto &tensor = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , GPT2:: kTransformerModelName ,
353- nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index) ,
354- nn::TransformerLayer:: kMlpLayerName , nn::MLP:: kCFcLayerName ,
355- nn::parallel::ColumnParallelLinear::kParamWeightName )];
346+ auto &tensor
347+ = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
348+ std::to_string (local_layer_index) , nn::TransformerLayer:: kMlpLayerName ,
349+ nn::MLP:: kCFcLayerName , nn::parallel::ColumnParallelLinear::kParamWeightName )];
356350 ReadMatrixRowShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()), fc_out, n_embd, fc_start, fc_pp);
357351 ++local_layer_index;
358352 } else {
@@ -365,10 +359,10 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
365359 local_layer_index = 0 ;
366360 for (int idx = 0 ; idx < n_layer; ++idx) {
367361 if (owned_layers[idx]) {
368- auto &tensor = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , GPT2:: kTransformerModelName ,
369- nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index) ,
370- nn::TransformerLayer:: kMlpLayerName , nn::MLP:: kCFcLayerName ,
371- nn::parallel::ColumnParallelLinear::kParamBiasName )];
362+ auto &tensor
363+ = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
364+ std::to_string (local_layer_index) , nn::TransformerLayer:: kMlpLayerName ,
365+ nn::MLP:: kCFcLayerName , nn::parallel::ColumnParallelLinear::kParamBiasName )];
372366 ReadVectorShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()), fc_out, fc_start, fc_pp);
373367 ++local_layer_index;
374368 } else {
@@ -381,10 +375,10 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
381375 local_layer_index = 0 ;
382376 for (int idx = 0 ; idx < n_layer; ++idx) {
383377 if (owned_layers[idx]) {
384- auto &tensor = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , GPT2:: kTransformerModelName ,
385- nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index) ,
386- nn::TransformerLayer:: kMlpLayerName , nn::MLP:: kCProjLayerName ,
387- nn::parallel::RowParallelLinear::kParamWeightName )];
378+ auto &tensor
379+ = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
380+ std::to_string (local_layer_index) , nn::TransformerLayer:: kMlpLayerName ,
381+ nn::MLP:: kCProjLayerName , nn::parallel::RowParallelLinear::kParamWeightName )];
388382 ReadMatrixColShardFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd, fc_out, tp_rank * in4_pp,
389383 in4_pp);
390384 ++local_layer_index;
@@ -398,10 +392,10 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
398392 local_layer_index = 0 ;
399393 for (int idx = 0 ; idx < n_layer; ++idx) {
400394 if (owned_layers[idx]) {
401- auto &tensor = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , GPT2:: kTransformerModelName ,
402- nn::TransformerChunk::kHLayerName , std::to_string (local_layer_index) ,
403- nn::TransformerLayer:: kMlpLayerName , nn::MLP:: kCProjLayerName ,
404- nn::parallel::RowParallelLinear::kParamBiasName )];
395+ auto &tensor
396+ = state_dict[ std::format ( " {}.{}.{}.{}.{}.{} " , kTransformerModelName , nn::TransformerChunk::kHLayerName ,
397+ std::to_string (local_layer_index) , nn::TransformerLayer:: kMlpLayerName ,
398+ nn::MLP:: kCProjLayerName , nn::parallel::RowParallelLinear::kParamBiasName )];
405399 ReadVectorAllFloat (ifs, static_cast <float *>(tensor->DataPtr ()), n_embd);
406400 ++local_layer_index;
407401 } else {
@@ -413,13 +407,12 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
413407 if (is_last_stage) {
414408 // transformer.ln_f.weight
415409 auto &transformer_ln_f_weight
416- = state_dict[std::format (" {}.{}.{}" , GPT2:: kTransformerModelName , nn::TransformerLastStage::kLnFLayerName ,
410+ = state_dict[std::format (" {}.{}.{}" , kTransformerModelName , nn::TransformerLastStage::kLnFLayerName ,
417411 nn::LayerNorm::kParamWeightName )];
418412 ReadVectorAllFloat (ifs, static_cast <float *>(transformer_ln_f_weight->DataPtr ()), n_embd);
419413 // transformer.ln_f.bias
420- auto &transformer_ln_f_bias
421- = state_dict[std::format (" {}.{}.{}" , GPT2::kTransformerModelName , nn::TransformerLastStage::kLnFLayerName ,
422- nn::LayerNorm::kParamBiasName )];
414+ auto &transformer_ln_f_bias = state_dict[std::format (
415+ " {}.{}.{}" , kTransformerModelName , nn::TransformerLastStage::kLnFLayerName , nn::LayerNorm::kParamBiasName )];
423416 ReadVectorAllFloat (ifs, static_cast <float *>(transformer_ln_f_bias->DataPtr ()), n_embd);
424417 } else {
425418 size_t ln_f_w_bytes = n_embd * sizeof (float );
@@ -428,5 +421,3 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
428421 }
429422 return local_gpt2;
430423}
431-
432- int GPT2::GetChunkSize () const { return stage_info_.layer_ranges_per_chunk .size (); }
0 commit comments