@@ -218,53 +218,52 @@ def train_in_iter(self):
218218 self .model .eval ()
219219 logger .info (msg = "Start val mode." )
220220 val_pbar = tqdm (self .val_dataloader )
221- for i , ( lr_images , hr_images ) in enumerate ( val_pbar ):
222- # The images are all resized in val dataloader
223- lr_images = lr_images . to ( self . device )
224- hr_images = hr_images .to (self .device )
225- # Enable Automatic mixed precision training
226- # Automatic mixed precision training
227- with torch . no_grad ():
221+ with torch . no_grad ( ):
222+ for i , ( lr_images , hr_images ) in enumerate ( val_pbar ):
223+ # The images are all resized in val dataloader
224+ lr_images = lr_images .to (self .device )
225+ hr_images = hr_images . to ( self . device )
226+ # Enable Automatic mixed precision training
227+ # Automatic mixed precision training
228228 output = self .model (lr_images )
229229 # To calculate the MSE loss
230230 # You need to use the standard normal distribution of x at time t and the predicted noise
231231 val_loss = self .loss_func (output , hr_images )
232- # The optimizer clears the gradient of the model parameters
233- self .optimizer .zero_grad ()
234232
235- # TensorBoard logging
236- val_pbar .set_postfix (MSE = val_loss .item ())
237- self .tb_logger .add_scalar (tag = f"[{ self .device } ]: Val loss({ self .loss_func } )" , scalar_value = val_loss .item (),
238- global_step = self .epoch * self .len_val_dataloader + i )
239- val_loss_list .append (val_loss .item ())
233+ # TensorBoard logging
234+ val_pbar .set_postfix (MSE = val_loss .item ())
235+ self .tb_logger .add_scalar (tag = f"[{ self .device } ]: Val loss({ self .loss_func } )" ,
236+ scalar_value = val_loss .item (),
237+ global_step = self .epoch * self .len_val_dataloader + i )
238+ val_loss_list .append (val_loss .item ())
240239
241- # Metric
242- ssim_res = compute_ssim (image_outputs = output , image_sources = hr_images )
243- psnr_res = compute_psnr (mse = val_loss .item ())
244- self .tb_logger .add_scalar (tag = f"[{ self .device } ]: SSIM({ self .loss_func } )" , scalar_value = ssim_res ,
245- global_step = self .epoch * self .len_val_dataloader + i )
246- self .tb_logger .add_scalar (tag = f"[{ self .device } ]: PSNR({ self .loss_func } )" , scalar_value = psnr_res ,
247- global_step = self .epoch * self .len_val_dataloader + i )
248- ssim_list .append (ssim_res )
249- psnr_list .append (psnr_res )
240+ # Metric
241+ ssim_res = compute_ssim (image_outputs = output , image_sources = hr_images )
242+ psnr_res = compute_psnr (mse = val_loss .item ())
243+ self .tb_logger .add_scalar (tag = f"[{ self .device } ]: SSIM({ self .loss_func } )" , scalar_value = ssim_res ,
244+ global_step = self .epoch * self .len_val_dataloader + i )
245+ self .tb_logger .add_scalar (tag = f"[{ self .device } ]: PSNR({ self .loss_func } )" , scalar_value = psnr_res ,
246+ global_step = self .epoch * self .len_val_dataloader + i )
247+ ssim_list .append (ssim_res )
248+ psnr_list .append (psnr_res )
250249
251- # Save super resolution image and high resolution image
252- lr_images = post_image (lr_images , device = self .device )
253- sr_images = post_image (output , device = self .device )
254- hr_images = post_image (hr_images , device = self .device )
255- image_name = time .time ()
256- for lr_index , lr_image in enumerate (lr_images ):
257- save_images (images = lr_image ,
258- path = os .path .join (self .save_val_vis_dir ,
259- f"{ i } _{ image_name } _{ lr_index } _lr.{ self .image_format } " ))
260- for sr_index , sr_image in enumerate (sr_images ):
261- save_images (images = sr_image ,
262- path = os .path .join (self .save_val_vis_dir ,
263- f"{ i } _{ image_name } _{ sr_index } _sr.{ self .image_format } " ))
264- for hr_index , hr_image in enumerate (hr_images ):
265- save_images (images = hr_image ,
266- path = os .path .join (self .save_val_vis_dir ,
267- f"{ i } _{ image_name } _{ hr_index } _hr.{ self .image_format } " ))
250+ # Save super resolution image and high resolution image
251+ lr_images = post_image (lr_images , device = self .device )
252+ sr_images = post_image (output , device = self .device )
253+ hr_images = post_image (hr_images , device = self .device )
254+ image_name = time .time ()
255+ for lr_index , lr_image in enumerate (lr_images ):
256+ save_images (images = lr_image ,
257+ path = os .path .join (self .save_val_vis_dir ,
258+ f"{ i } _{ image_name } _{ lr_index } _lr.{ self .image_format } " ))
259+ for sr_index , sr_image in enumerate (sr_images ):
260+ save_images (images = sr_image ,
261+ path = os .path .join (self .save_val_vis_dir ,
262+ f"{ i } _{ image_name } _{ sr_index } _sr.{ self .image_format } " ))
263+ for hr_index , hr_image in enumerate (hr_images ):
264+ save_images (images = hr_image ,
265+ path = os .path .join (self .save_val_vis_dir ,
266+ f"{ i } _{ image_name } _{ hr_index } _hr.{ self .image_format } " ))
268267 # Loss, ssim and psnr per epoch
269268 self .avg_val_loss = sum (val_loss_list ) / len (val_loss_list )
270269 self .avg_ssim = sum (ssim_list ) / len (ssim_list )
@@ -274,6 +273,7 @@ def train_in_iter(self):
274273 self .tb_logger .add_scalar (tag = f"[{ self .device } ]: Avg ssim" , scalar_value = self .avg_ssim , global_step = self .epoch )
275274 self .tb_logger .add_scalar (tag = f"[{ self .device } ]: Avg psnr" , scalar_value = self .avg_psnr , global_step = self .epoch )
276275 logger .info (f"Val loss: { self .avg_val_loss } , SSIM: { self .avg_ssim } , PSNR: { self .avg_psnr } " )
276+ self .model .train ()
277277 logger .info (msg = "Finish val mode." )
278278
279279 def after_iter (self ):
0 commit comments