55#include < filesystem>
66#include < fstream>
77#include < memory>
8+ #include < random>
89#include < string>
910#include < tuple>
1011#include < vector>
@@ -28,25 +29,35 @@ using namespace infini_train;
2829namespace nn = infini_train::nn;
2930
3031namespace {
31- constexpr int32_t kGPT2Magic = 20240326 ;
32- constexpr int32_t kGPT2FP32Version = 3 ;
33- constexpr int32_t kGPT2BF16Version = 5 ;
32+ constexpr int kRandomSeed = 42 ;
3433
35- std::tuple<int32_t , DataType> DetermineAndCheckVersion (const std::vector<uint8_t > &header, size_t offset) {
34+ // TODO(dcj): make this rng generator compatible with torch later
35+ static std::mt19937 gen{kRandomSeed };
36+ } // namespace
37+
38+ namespace {
39+ constexpr int32_t kHeaderMagic = 20240326 ;
40+ constexpr int32_t kHeaderFP32Version = 3 ;
41+ constexpr int32_t kHeaderBF16Version = 5 ;
42+
43+ std::tuple<int32_t , infini_train::DataType> DetermineAndCheckVersion (const std::vector<uint8_t > &header,
44+ size_t offset) {
3645 const auto version = BytesToType<uint32_t >(header, offset);
3746 switch (version) {
38- case kGPT2BF16Version :
39- return {version, DataType::kBFLOAT16 };
40- case kGPT2FP32Version :
41- return {version, DataType::kFLOAT32 };
47+ case kHeaderBF16Version :
48+ return {version, infini_train:: DataType::kBFLOAT16 };
49+ case kHeaderFP32Version :
50+ return {version, infini_train:: DataType::kFLOAT32 };
4251 default :
4352 LOG (FATAL ) << " Unsupported version: " << version << " at " << __FILE__ << " :" << __LINE__;
4453 return {}; // Unreachable, but keeps compiler happy
4554 }
4655}
4756} // namespace
4857
49- std::shared_ptr<nn::TransformerModel> gpt2::LoadFromLLMC (const std::string &filepath) {
58+ namespace gpt2 {
59+
60+ std::shared_ptr<nn::TransformerModel> LoadFromLLMC (const std::string &filepath) {
5061 if (!std::filesystem::exists (filepath)) {
5162 LOG (FATAL ) << " File not found: " << filepath;
5263 }
@@ -55,9 +66,9 @@ std::shared_ptr<nn::TransformerModel> gpt2::LoadFromLLMC(const std::string &file
5566 const auto header = ReadSeveralBytesFromIfstream (256 * sizeof (int32_t ), &ifs);
5667
5768 const auto magic = BytesToType<uint32_t >(header, 0 );
58- CHECK_EQ (magic, kGPT2Magic );
69+ CHECK_EQ (magic, kHeaderMagic );
5970 auto [version, dtype] = DetermineAndCheckVersion (header, 4 );
60- CHECK_EQ (version, kGPT2FP32Version );
71+ CHECK_EQ (version, kHeaderFP32Version );
6172
6273 auto tp_size = nn::parallel::global::GetTensorParallelSize ();
6374
@@ -418,127 +429,4 @@ std::shared_ptr<nn::TransformerModel> gpt2::LoadFromLLMC(const std::string &file
418429
419430 return local_gpt2;
420431}
421-
422- void gpt2::SaveAsLLMC (const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath) {
423- CHECK_EQ (nn::parallel::global::GetTensorParallelSize (), 1 ) << " SaveAsLLMC currently supports TP=1 only." ;
424- CHECK_EQ (nn::parallel::global::GetPipelineParallelSize (), 1 ) << " SaveAsLLMC currently supports PP=1 only." ;
425-
426- std::ofstream ofs (filepath, std::ios::binary);
427- CHECK (ofs.is_open ()) << " Failed to open model file for write: " << filepath;
428-
429- auto config = model->Config ();
430- std::vector<int32_t > header (256 , 0 );
431- header[0 ] = kGPT2Magic ;
432- header[1 ] = kGPT2FP32Version ;
433- header[2 ] = static_cast <int32_t >(config.block_size );
434- header[3 ] = static_cast <int32_t >(config.original_vocab_size );
435- header[4 ] = static_cast <int32_t >(config.n_layer );
436- header[5 ] = static_cast <int32_t >(config.n_head );
437- header[6 ] = static_cast <int32_t >(config.n_embd );
438- header[7 ] = static_cast <int32_t >(config.vocab_size );
439- ofs.write (reinterpret_cast <const char *>(header.data ()),
440- static_cast <std::streamsize>(header.size () * sizeof (int32_t )));
441-
442- const auto state_dict = model->StateDict ();
443- auto get_tensor = [&](const std::string &name) -> std::shared_ptr<Tensor> {
444- CHECK (state_dict.contains (name)) << " Missing tensor in GPT2 state_dict: " << name;
445- return state_dict.at (name);
446- };
447-
448- auto write_tensor_fp32 = [&](const std::shared_ptr<Tensor> &tensor) {
449- Tensor cpu = tensor->To (Device ());
450- if (cpu.Dtype () != DataType::kFLOAT32 ) {
451- cpu = cpu.To (DataType::kFLOAT32 );
452- }
453- const auto bytes = static_cast <std::streamsize>(cpu.SizeInBytes ());
454- ofs.write (reinterpret_cast <const char *>(cpu.DataPtr ()), bytes);
455- };
456-
457- // transformer.wte.weight
458- write_tensor_fp32 (get_tensor (std::format (" {}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
459- nn::TransformerFirstStage::kWTELayerName ,
460- nn::parallel::VocabParallelEmbedding::kParamWeightName )));
461-
462- // transformer.wpe.weight
463- write_tensor_fp32 (
464- get_tensor (std::format (" {}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
465- nn::TransformerFirstStage::kWPELayerName , nn::Embedding::kParamWeightName )));
466-
467- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
468- write_tensor_fp32 (get_tensor (std::format (
469- " {}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName , nn::TransformerChunk::kHLayerName , idx,
470- nn::TransformerLayer::kLn1LayerName , nn::LayerNorm::kParamWeightName )));
471- }
472- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
473- write_tensor_fp32 (get_tensor (std::format (" {}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
474- nn::TransformerChunk::kHLayerName , idx,
475- nn::TransformerLayer::kLn1LayerName , nn::LayerNorm::kParamBiasName )));
476- }
477- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
478- write_tensor_fp32 (get_tensor (std::format (
479- " {}.{}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName , nn::TransformerChunk::kHLayerName , idx,
480- nn::TransformerLayer::kAttnLayerName , nn::CausalSelfAttention::kCAttnLayerName ,
481- nn::parallel::ColumnParallelLinear::kParamWeightName )));
482- }
483- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
484- write_tensor_fp32 (get_tensor (
485- std::format (" {}.{}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
486- nn::TransformerChunk::kHLayerName , idx, nn::TransformerLayer::kAttnLayerName ,
487- nn::CausalSelfAttention::kCAttnLayerName , nn::parallel::ColumnParallelLinear::kParamBiasName )));
488- }
489- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
490- write_tensor_fp32 (get_tensor (
491- std::format (" {}.{}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
492- nn::TransformerChunk::kHLayerName , idx, nn::TransformerLayer::kAttnLayerName ,
493- nn::CausalSelfAttention::kCProjLayerName , nn::parallel::RowParallelLinear::kParamWeightName )));
494- }
495- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
496- write_tensor_fp32 (get_tensor (
497- std::format (" {}.{}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
498- nn::TransformerChunk::kHLayerName , idx, nn::TransformerLayer::kAttnLayerName ,
499- nn::CausalSelfAttention::kCProjLayerName , nn::parallel::RowParallelLinear::kParamBiasName )));
500- }
501- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
502- write_tensor_fp32 (get_tensor (std::format (
503- " {}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName , nn::TransformerChunk::kHLayerName , idx,
504- nn::TransformerLayer::kLn2LayerName , nn::LayerNorm::kParamWeightName )));
505- }
506- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
507- write_tensor_fp32 (get_tensor (std::format (" {}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
508- nn::TransformerChunk::kHLayerName , idx,
509- nn::TransformerLayer::kLn2LayerName , nn::LayerNorm::kParamBiasName )));
510- }
511- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
512- write_tensor_fp32 (
513- get_tensor (std::format (" {}.{}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
514- nn::TransformerChunk::kHLayerName , idx, nn::TransformerLayer::kMlpLayerName ,
515- nn::MLP ::kCFcLayerName , nn::parallel::ColumnParallelLinear::kParamWeightName )));
516- }
517- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
518- write_tensor_fp32 (
519- get_tensor (std::format (" {}.{}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
520- nn::TransformerChunk::kHLayerName , idx, nn::TransformerLayer::kMlpLayerName ,
521- nn::MLP ::kCFcLayerName , nn::parallel::ColumnParallelLinear::kParamBiasName )));
522- }
523- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
524- write_tensor_fp32 (
525- get_tensor (std::format (" {}.{}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
526- nn::TransformerChunk::kHLayerName , idx, nn::TransformerLayer::kMlpLayerName ,
527- nn::MLP ::kCProjLayerName , nn::parallel::RowParallelLinear::kParamWeightName )));
528- }
529- for (int idx = 0 ; idx < config.n_layer ; ++idx) {
530- write_tensor_fp32 (
531- get_tensor (std::format (" {}.{}.{}.{}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
532- nn::TransformerChunk::kHLayerName , idx, nn::TransformerLayer::kMlpLayerName ,
533- nn::MLP ::kCProjLayerName , nn::parallel::RowParallelLinear::kParamBiasName )));
534- }
535-
536- write_tensor_fp32 (
537- get_tensor (std::format (" {}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
538- nn::TransformerLastStage::kLnFLayerName , nn::LayerNorm::kParamWeightName )));
539- write_tensor_fp32 (get_tensor (std::format (" {}.{}.{}" , nn::TransformerModel::kTransformerModelName ,
540- nn::TransformerLastStage::kLnFLayerName , nn::LayerNorm::kParamBiasName )));
541-
542- ofs.flush ();
543- CHECK (ofs.good ()) << " Failed to flush model file: " << filepath;
544- }
432+ } // namespace gpt2
0 commit comments