@@ -173,10 +173,11 @@ def ema_reset(self):
173173 self .ema_buffer_modele_params = None
174174
175175 @imperative_base .no_grad ()
176- def ema_accumulate (self ):
176+ def ema_accumulate (self , global_step , loss , zcc_ema_loss_threshold ):
177177 """
178178 perform ema update : ` \a lpha * EMA + (1-\a lpha) + model`
179- build `self.ema_buffer` if necessary
179+ buid `self.ema_buffer` if necessary
180+ when loss < threshold, do ema update
180181 """
181182 # logger.info(f'[ZCC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
182183 # do update: ema = alpha * ema + (1-alpha) * model
@@ -185,14 +186,19 @@ def ema_accumulate(self):
185186 cpu_master_weights = self .optimizer_fusion_storage_helper .cpu_buffer ._slice (
186187 self .master_min_offset , self .master_max_offset
187188 ).cpu ()
188- self .ema_buffer = self .ema_coef * self .ema_buffer + (1 - self .ema_coef ) * cpu_master_weights
189- # logger.info(f'[ZCC EMA2] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
190- for index , ema_buf in self .ema_buffer_model_params .items ():
191- _ , cpu_buf = self .param_fusion_storage_helper .inited_buffers [index ]
192- updated_ema = self .ema_coef * ema_buf + (1 - self .ema_coef ) * cpu_buf
193- self .ema_buffer_model_params [index ] = updated_ema
194-
195- logger .info (f"[ZCC EMA] accumulating, buffer type:{ self .ema_buffer .place } { self .ema_buffer .dtype } , done" )
189+ if zcc_ema_loss_threshold is None or loss < zcc_ema_loss_threshold :
190+ self .ema_buffer = self .ema_coef * self .ema_buffer + (1 - self .ema_coef ) * cpu_master_weights
191+ for index , ema_buf in self .ema_buffer_model_params .items ():
192+ _ , cpu_buf = self .param_fusion_storage_helper .inited_buffers [index ]
193+ updated_ema = self .ema_coef * ema_buf + (1 - self .ema_coef ) * cpu_buf
194+ self .ema_buffer_model_params [index ] = updated_ema
195+ logger .info (
196+ f"[ZCC EMA] accmulating, buffer type:{ self .ema_buffer .place } { self .ema_buffer .dtype } , done"
197+ )
198+ else :
199+ logger .info (
200+ f"[ZCC EMA] accmulating SKIP for global_step:{ global_step } , because loss:{ loss } > threshold:{ zcc_ema_loss_threshold } "
201+ )
196202
197203 @imperative_base .no_grad ()
198204 def ema_state_dict (self ):
@@ -790,7 +796,11 @@ def process_offload_task(self, dump, global_step):
790796 self .global_step .value = global_step
791797
792798 if self .ema_coef is not None :
793- self .zcc_ema_processor .ema_accumulate ()
799+ self .zcc_ema_processor .ema_accumulate (
800+ self .trainer_state .global_step ,
801+ self .trainer_state .loss ,
802+ self .training_args_content .zcc_ema_loss_threshold ,
803+ )
794804
795805 # continue to process dumping task at the last chunk
796806 if self .offloaded_numels == self .all_numel :
0 commit comments