@@ -270,7 +270,7 @@ def get_group_ids(self):
270270
271271
272272class ShardingIO :
273- def __init__ (self , args , model , optimizer = None , hcg = None , remap_parameter_name = False ):
273+ def __init__ (self , args , model , optimizer = None , hcg = None , remap_parameter_name = False , is_ema = False ):
274274 self .args = args
275275 self .model = model
276276 self .optimizer = optimizer
@@ -282,6 +282,7 @@ def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=F
282282
283283 self .remap_parameter_name = remap_parameter_name
284284 self .remapper = None
285+ self .is_ema = is_ema
285286
286287 def _get_remapper (self , checkpoint ):
287288 if not self .remap_parameter_name :
@@ -395,24 +396,33 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
395396 """
396397 load state_dict of one shard from_checkpoint, Only load model state dict.
397398 """
399+ if self .is_ema :
400+ base_weight_name = base_weight_name .replace ("model_state" , "ema" ).replace ("pdparams" , "pdopt" )
398401 file_path = os .path .join (resume_from_checkpoint , _add_variant (base_weight_name , weight_name_suffix ))
399402 if not os .path .isfile (file_path ):
400403 raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } , no { file_path } " )
401404
402405 logger .info (f"Loading model from { resume_from_checkpoint } ." )
403406 # We load the model state dict on the CPU to avoid an OOM error.
404407 state_dict = paddle .load (file_path , return_numpy = True )
408+ if self .is_ema :
409+ state_dict .pop ("master_weights" , None )
405410 state_dict = self ._remap_parameter_name (resume_from_checkpoint , state_dict , is_opt = False )
406411 return state_dict
407412
408413 def _load_optimizer_state_of_one_shard (self , checkpoint , base_opt_name , optimizer_name_suffix , group_getter = None ):
414+ if self .is_ema :
415+ base_opt_name = base_opt_name .replace ("optimizer" , "ema" )
409416 optimizer_name = _add_variant (base_opt_name , optimizer_name_suffix )
410417 path = os .path .join (checkpoint , optimizer_name )
411418 logger .info (f"load optimizer state from { path } " )
412419 if os .path .isfile (path ):
420+ opt_state = paddlenlp_load (path , map_location = "cpu" )
421+ if self .is_ema :
422+ opt_state = {"master_weights" : opt_state .get ("master_weights" , {})}
413423 return self ._remap_parameter_name (
414424 checkpoint ,
415- self ._modify_ckpt_for_compatibility (paddlenlp_load ( path , map_location = "cpu" ) ),
425+ self ._modify_ckpt_for_compatibility (opt_state ),
416426 is_opt = True ,
417427 )
418428 logger .info (f"{ path } not exists" )
0 commit comments