@@ -84,24 +84,20 @@ template <typename T> T ExtractNumberField(const std::string &content, const std
8484}
8585} // namespace
8686
87- void Checkpoint::Save (const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer,
88- const TrainerState &state, const CheckpointOptions &options) {
89- CHECK (options.format == " bin" || options.format == " ckpt" ) << " Unsupported checkpoint format: " << options.format ;
87+ void Checkpoint::Save (const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer *optimizer,
88+ const TrainerState &state, bool no_save_optim) {
9089 std::filesystem::create_directories (checkpoint_dir);
91- LOG (ERROR) << " [CKPT] Save begin: dir=" << checkpoint_dir << " , format=" << options.format
92- << " , global_step=" << state.global_step ;
90+ LOG (ERROR) << " [CKPT] Save begin: dir=" << checkpoint_dir << " , global_step=" << state.global_step ;
9391
94- const auto model_path = checkpoint_dir / (options.format == " ckpt" ? " model.ckpt" : " model.bin" );
95- if (options.format == " bin" && options.model_bin_writer ) {
96- options.model_bin_writer (model, model_path);
97- } else {
98- SaveStateDictBinary (model_path, model.StateDict ());
99- }
92+ const auto model_path = checkpoint_dir / (" model.ckpt" );
10093
101- if (options.no_save_optim ) {
102- auto opt_state = optimizer.StateDict ();
94+ SaveStateDictBinary (model_path, model.StateDict ());
95+
96+ if (!no_save_optim) {
97+ CHECK (optimizer != nullptr ) << " Optimizer pointer is null, cannot save optimizer state." ;
98+ auto opt_state = optimizer->StateDict ();
10399 if (!opt_state.empty ()) {
104- const auto opt_path = checkpoint_dir / (options. format == " ckpt " ? " optimizer.ckpt" : " optimizer.bin " ) ;
100+ const auto opt_path = checkpoint_dir / " optimizer.ckpt" ;
105101 SaveStateDictBinary (opt_path, opt_state);
106102 }
107103 }
@@ -110,48 +106,32 @@ void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Mod
110106 LOG (ERROR) << " [CKPT] Save done: dir=" << checkpoint_dir;
111107}
112108
113- void Checkpoint::Load (const std::filesystem::path &checkpoint_dir, nn::Module *model, Optimizer *optimizer,
114- TrainerState *state, const CheckpointLoadOptions &options) {
115- CHECK (model != nullptr );
116- CHECK (state != nullptr );
117-
118- const std::string format = InferFormat (checkpoint_dir);
119- const auto model_path = checkpoint_dir / (format == " ckpt" ? " model.ckpt" : " model.bin" );
120- LOG (ERROR) << " [CKPT] Load begin: dir=" << checkpoint_dir << " , format=" << format;
109+ void Checkpoint::Load (const std::filesystem::path &checkpoint_dir, nn::Module &model, Optimizer *optimizer,
110+ TrainerState &state, bool load_optimizer_state) {
111+ const auto model_path = checkpoint_dir / " model.ckpt" ;
121112 LOG (ERROR) << " [CKPT] Loading model: " << model_path;
122- if (format == " bin" && options.model_bin_loader ) {
123- const uint32_t magic = PeekMagic (model_path);
124- if (magic == kCkptMagic ) {
125- LOG (ERROR) << " [CKPT] Model format detected: native checkpoint binary." ;
126- model->LoadStateDict (LoadStateDictBinary (model_path));
127- } else {
128- LOG (ERROR) << " [CKPT] Model format detected: external model.bin (magic=" << magic
129- << " ), use model_bin_loader callback." ;
130- options.model_bin_loader (model, model_path);
131- }
132- } else {
133- model->LoadStateDict (LoadStateDictBinary (model_path));
134- }
135113
136- if (optimizer != nullptr && options.load_optimizer_state ) {
137- const auto opt_path = checkpoint_dir / (format == " ckpt" ? " optimizer.ckpt" : " optimizer.bin" );
114+ model.LoadStateDict (LoadStateDictBinary (model_path));
115+
116+ if (optimizer == nullptr ) {
117+ LOG (ERROR) << " [CKPT] No optimizer instance, skip optimizer state loading." ;
118+ } else if (load_optimizer_state) {
119+ const auto opt_path = checkpoint_dir / " optimizer.ckpt" ;
138120 if (std::filesystem::exists (opt_path)) {
139121 LOG (ERROR) << " [CKPT] Loading optimizer: " << opt_path;
140122 optimizer->LoadStateDict (LoadStateDictBinary (opt_path));
141123 } else {
142124 LOG (ERROR) << " [CKPT] Optimizer state not found, skip: " << opt_path;
143125 }
144- } else if (optimizer == nullptr ) {
145- LOG (ERROR) << " [CKPT] No optimizer instance, skip optimizer state loading." ;
146126 } else {
147127 LOG (ERROR) << " [CKPT] load_optimizer_state=false, skip optimizer state loading." ;
148128 }
149129
150- * state = LoadTrainerState (checkpoint_dir / " trainer_state.json" );
151- LOG (ERROR) << " [CKPT] Load done: global_step=" << state-> global_step << " , data_batch_idx=" << state-> data_batch_idx
152- << " , data_batch_stride=" << state-> data_batch_stride << " , last_lr=" << state-> last_lr
153- << " , optimizer_type=" << state-> optimizer_type << " , topology(ddp,tp,sp,pp)=(" << state-> ddp_size << " ,"
154- << state-> tp_size << " ," << state-> sp_size << " ," << state-> pp_size << " )" ;
130+ state = LoadTrainerState (checkpoint_dir / " trainer_state.json" );
131+ LOG (ERROR) << " [CKPT] Load done: global_step=" << state. global_step << " , data_batch_idx=" << state. data_batch_idx
132+ << " , data_batch_stride=" << state. data_batch_stride << " , last_lr=" << state. last_lr
133+ << " , optimizer_type=" << state. optimizer_type << " , topology(ddp,tp,sp,pp)=(" << state. ddp_size << " ,"
134+ << state. tp_size << " ," << state. sp_size << " ," << state. pp_size << " )" ;
155135}
156136
157137void Checkpoint::SaveStateDictBinary (const std::filesystem::path &path,
@@ -233,7 +213,6 @@ void Checkpoint::SaveTrainerState(const std::filesystem::path &path, const Train
233213 ofs << " \" data_batch_stride\" : " << state.data_batch_stride << " ,\n " ;
234214 ofs << " \" last_lr\" : " << state.last_lr << " ,\n " ;
235215 ofs << " \" optimizer_type\" : \" " << state.optimizer_type << " \" ,\n " ;
236- ofs << " \" checkpoint_file_format\" : \" " << state.checkpoint_file_format << " \" ,\n " ;
237216 ofs << " \" ddp_size\" : " << state.ddp_size << " ,\n " ;
238217 ofs << " \" tp_size\" : " << state.tp_size << " ,\n " ;
239218 ofs << " \" sp_size\" : " << state.sp_size << " ,\n " ;
@@ -252,23 +231,10 @@ TrainerState Checkpoint::LoadTrainerState(const std::filesystem::path &path) {
252231 state.data_batch_stride = ExtractNumberField<int64_t >(content, " data_batch_stride" , 1 );
253232 state.last_lr = ExtractNumberField<double >(content, " last_lr" , 0.0 );
254233 state.optimizer_type = ExtractStringField (content, " optimizer_type" , " unknown" );
255- state.checkpoint_file_format = ExtractStringField (content, " checkpoint_file_format" , " bin" );
256234 state.ddp_size = ExtractNumberField<int >(content, " ddp_size" , 1 );
257235 state.tp_size = ExtractNumberField<int >(content, " tp_size" , 1 );
258236 state.sp_size = ExtractNumberField<int >(content, " sp_size" , 1 );
259237 state.pp_size = ExtractNumberField<int >(content, " pp_size" , 1 );
260238 return state;
261239}
262-
263- std::string Checkpoint::InferFormat (const std::filesystem::path &checkpoint_dir) {
264- if (std::filesystem::exists (checkpoint_dir / " model.ckpt" )) {
265- return " ckpt" ;
266- }
267- if (std::filesystem::exists (checkpoint_dir / " model.bin" )) {
268- return " bin" ;
269- }
270- LOG (FATAL) << " Failed to infer checkpoint format from path: " << checkpoint_dir;
271- return " bin" ;
272- }
273-
274240} // namespace infini_train
0 commit comments