From 42f4c746dd38a84264f916a7c69c0b96276b54b4 Mon Sep 17 00:00:00 2001 From: cheny Date: Wed, 31 Dec 2025 14:31:46 +0800 Subject: [PATCH] chore: Optimize the trainer training logic. --- iddm/model/trainers/autoencoder.py | 59 +++++++++++----------- iddm/model/trainers/sr.py | 80 +++++++++++++++--------------- webui/web.py | 4 +- 3 files changed, 71 insertions(+), 72 deletions(-) diff --git a/iddm/model/trainers/autoencoder.py b/iddm/model/trainers/autoencoder.py index 00eb4d5..65b1a79 100644 --- a/iddm/model/trainers/autoencoder.py +++ b/iddm/model/trainers/autoencoder.py @@ -195,43 +195,41 @@ def train_in_iter(self): self.model.eval() logger.info(msg="Start val mode.") val_pbar = tqdm(self.val_dataloader) - for i, (images, _) in enumerate(val_pbar): - # Input images [B, C, H, W] - images = images.to(self.device) + with torch.no_grad(): + for i, (images, _) in enumerate(val_pbar): + # Input images [B, C, H, W] + images = images.to(self.device) - with autocast(device_type="cuda", enabled=self.amp): recon_images = self.model(images) # To calculate the MSE loss val_loss = self.loss_func(recon_images, images) - # The optimizer clears the gradient of the model parameters - self.optimizer.zero_grad() + # TensorBoard logging + val_pbar.set_postfix(MSE=val_loss.item()) + self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", + scalar_value=val_loss.item(), + global_step=self.epoch * self.len_val_dataloader + i) + val_loss_list.append(val_loss.item()) - # TensorBoard logging - val_pbar.set_postfix(MSE=val_loss.item()) - self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", scalar_value=val_loss.item(), - global_step=self.epoch * self.len_val_dataloader + i) - val_loss_list.append(val_loss.item()) - - # TODO: Metric - score = 0 - self.tb_logger.add_scalar(tag=f"[{self.device}]: Score({self.loss_func})", scalar_value=score, - global_step=self.epoch * self.len_val_dataloader + i) - score_list.append(score) + # TODO: Metric + score = 0 + self.tb_logger.add_scalar(tag=f"[{self.device}]: Score({self.loss_func})", scalar_value=score, + global_step=self.epoch * self.len_val_dataloader + i) + score_list.append(score) - images = post_image(images=images, device=self.device) - if self.loss_name == "mse_kl": - recon_images = recon_images[0] - recon_images = post_image(images=recon_images, device=self.device) - image_name = time.time() - for index, image in enumerate(images): - save_images(images=image, - path=os.path.join(self.save_val_vis_dir, - f"{i}_{image_name}_{index}_origin.{self.image_format}")) - for recon_index, recon_image in enumerate(recon_images): - save_images(images=recon_image, - path=os.path.join(self.save_val_vis_dir, - f"{i}_{image_name}_{recon_index}_recon.{self.image_format}")) + images = post_image(images=images, device=self.device) + if self.loss_name == "mse_kl": + recon_images = recon_images[0] + recon_images = post_image(images=recon_images, device=self.device) + image_name = time.time() + for index, image in enumerate(images): + save_images(images=image, + path=os.path.join(self.save_val_vis_dir, + f"{i}_{image_name}_{index}_origin.{self.image_format}")) + for recon_index, recon_image in enumerate(recon_images): + save_images(images=recon_image, + path=os.path.join(self.save_val_vis_dir, + f"{i}_{image_name}_{recon_index}_recon.{self.image_format}")) # Loss, score per epoch self.avg_val_loss = sum(val_loss_list) / len(val_loss_list) self.avg_score = sum(score_list) / len(score_list) @@ -240,6 +238,7 @@ def train_in_iter(self): self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg score", scalar_value=self.avg_score, global_step=self.epoch) logger.info(f"Val loss: {self.avg_val_loss}, Score: {self.avg_score}") + self.model.train() logger.info(msg="Finish val mode.") def after_iter(self): diff --git a/iddm/model/trainers/sr.py b/iddm/model/trainers/sr.py index 229a59c..e8e7519 100644 --- a/iddm/model/trainers/sr.py +++ b/iddm/model/trainers/sr.py @@ -218,53 +218,52 @@ def train_in_iter(self): self.model.eval() logger.info(msg="Start val mode.") val_pbar = tqdm(self.val_dataloader) - for i, (lr_images, hr_images) in enumerate(val_pbar): - # The images are all resized in val dataloader - lr_images = lr_images.to(self.device) - hr_images = hr_images.to(self.device) - # Enable Automatic mixed precision training - # Automatic mixed precision training - with torch.no_grad(): + with torch.no_grad(): + for i, (lr_images, hr_images) in enumerate(val_pbar): + # The images are all resized in val dataloader + lr_images = lr_images.to(self.device) + hr_images = hr_images.to(self.device) + # Enable Automatic mixed precision training + # Automatic mixed precision training output = self.model(lr_images) # To calculate the MSE loss # You need to use the standard normal distribution of x at time t and the predicted noise val_loss = self.loss_func(output, hr_images) - # The optimizer clears the gradient of the model parameters - self.optimizer.zero_grad() - # TensorBoard logging - val_pbar.set_postfix(MSE=val_loss.item()) - self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", scalar_value=val_loss.item(), - global_step=self.epoch * self.len_val_dataloader + i) - val_loss_list.append(val_loss.item()) + # TensorBoard logging + val_pbar.set_postfix(MSE=val_loss.item()) + self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", + scalar_value=val_loss.item(), + global_step=self.epoch * self.len_val_dataloader + i) + val_loss_list.append(val_loss.item()) - # Metric - ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images) - psnr_res = compute_psnr(mse=val_loss.item()) - self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res, - global_step=self.epoch * self.len_val_dataloader + i) - self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res, - global_step=self.epoch * self.len_val_dataloader + i) - ssim_list.append(ssim_res) - psnr_list.append(psnr_res) + # Metric + ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images) + psnr_res = compute_psnr(mse=val_loss.item()) + self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res, + global_step=self.epoch * self.len_val_dataloader + i) + self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res, + global_step=self.epoch * self.len_val_dataloader + i) + ssim_list.append(ssim_res) + psnr_list.append(psnr_res) - # Save super resolution image and high resolution image - lr_images = post_image(lr_images, device=self.device) - sr_images = post_image(output, device=self.device) - hr_images = post_image(hr_images, device=self.device) - image_name = time.time() - for lr_index, lr_image in enumerate(lr_images): - save_images(images=lr_image, - path=os.path.join(self.save_val_vis_dir, - f"{i}_{image_name}_{lr_index}_lr.{self.image_format}")) - for sr_index, sr_image in enumerate(sr_images): - save_images(images=sr_image, - path=os.path.join(self.save_val_vis_dir, - f"{i}_{image_name}_{sr_index}_sr.{self.image_format}")) - for hr_index, hr_image in enumerate(hr_images): - save_images(images=hr_image, - path=os.path.join(self.save_val_vis_dir, - f"{i}_{image_name}_{hr_index}_hr.{self.image_format}")) + # Save super resolution image and high resolution image + lr_images = post_image(lr_images, device=self.device) + sr_images = post_image(output, device=self.device) + hr_images = post_image(hr_images, device=self.device) + image_name = time.time() + for lr_index, lr_image in enumerate(lr_images): + save_images(images=lr_image, + path=os.path.join(self.save_val_vis_dir, + f"{i}_{image_name}_{lr_index}_lr.{self.image_format}")) + for sr_index, sr_image in enumerate(sr_images): + save_images(images=sr_image, + path=os.path.join(self.save_val_vis_dir, + f"{i}_{image_name}_{sr_index}_sr.{self.image_format}")) + for hr_index, hr_image in enumerate(hr_images): + save_images(images=hr_image, + path=os.path.join(self.save_val_vis_dir, + f"{i}_{image_name}_{hr_index}_hr.{self.image_format}")) # Loss, ssim and psnr per epoch self.avg_val_loss = sum(val_loss_list) / len(val_loss_list) self.avg_ssim = sum(ssim_list) / len(ssim_list) @@ -274,6 +273,7 @@ def train_in_iter(self): self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg ssim", scalar_value=self.avg_ssim, global_step=self.epoch) self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg psnr", scalar_value=self.avg_psnr, global_step=self.epoch) logger.info(f"Val loss: {self.avg_val_loss}, SSIM: {self.avg_ssim}, PSNR: {self.avg_psnr}") + self.model.train() logger.info(msg="Finish val mode.") def after_iter(self): diff --git a/webui/web.py b/webui/web.py index 41b213e..2c518cb 100644 --- a/webui/web.py +++ b/webui/web.py @@ -90,8 +90,8 @@ def train(self, seed, conditional, sample, network, run_name, epochs, batch_size 29: cfg_scale """ gradio.Info(message="Start training...") - logger = CustomLogger(name=__name__, is_webui=True, is_save_log=True, - log_path=os.path.join(result_path, run_name)) + logger = WebUILogger(name=__name__, is_save_log=True, log_path=str(os.path.join(result_path, run_name)), + rank=main_gpu) self.KILL_FLAG = False yield logger.info(msg="[Note]: Start parameters setting.") yield logger.info(