Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions iddm/model/trainers/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,43 +195,41 @@ def train_in_iter(self):
self.model.eval()
logger.info(msg="Start val mode.")
val_pbar = tqdm(self.val_dataloader)
for i, (images, _) in enumerate(val_pbar):
# Input images [B, C, H, W]
images = images.to(self.device)
with torch.no_grad():
for i, (images, _) in enumerate(val_pbar):
# Input images [B, C, H, W]
images = images.to(self.device)

with autocast(device_type="cuda", enabled=self.amp):
recon_images = self.model(images)
# To calculate the MSE loss
val_loss = self.loss_func(recon_images, images)

# The optimizer clears the gradient of the model parameters
self.optimizer.zero_grad()
# TensorBoard logging
val_pbar.set_postfix(MSE=val_loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})",
scalar_value=val_loss.item(),
global_step=self.epoch * self.len_val_dataloader + i)
val_loss_list.append(val_loss.item())

# TensorBoard logging
val_pbar.set_postfix(MSE=val_loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", scalar_value=val_loss.item(),
global_step=self.epoch * self.len_val_dataloader + i)
val_loss_list.append(val_loss.item())

# TODO: Metric
score = 0
self.tb_logger.add_scalar(tag=f"[{self.device}]: Score({self.loss_func})", scalar_value=score,
global_step=self.epoch * self.len_val_dataloader + i)
score_list.append(score)
# TODO: Metric
score = 0
self.tb_logger.add_scalar(tag=f"[{self.device}]: Score({self.loss_func})", scalar_value=score,
global_step=self.epoch * self.len_val_dataloader + i)
score_list.append(score)

images = post_image(images=images, device=self.device)
if self.loss_name == "mse_kl":
recon_images = recon_images[0]
recon_images = post_image(images=recon_images, device=self.device)
image_name = time.time()
for index, image in enumerate(images):
save_images(images=image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{index}_origin.{self.image_format}"))
for recon_index, recon_image in enumerate(recon_images):
save_images(images=recon_image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{recon_index}_recon.{self.image_format}"))
images = post_image(images=images, device=self.device)
if self.loss_name == "mse_kl":
recon_images = recon_images[0]
recon_images = post_image(images=recon_images, device=self.device)
image_name = time.time()
for index, image in enumerate(images):
save_images(images=image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{index}_origin.{self.image_format}"))
for recon_index, recon_image in enumerate(recon_images):
save_images(images=recon_image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{recon_index}_recon.{self.image_format}"))
# Loss, score per epoch
self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
self.avg_score = sum(score_list) / len(score_list)
Expand All @@ -240,6 +238,7 @@ def train_in_iter(self):
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg score", scalar_value=self.avg_score,
global_step=self.epoch)
logger.info(f"Val loss: {self.avg_val_loss}, Score: {self.avg_score}")
self.model.train()
logger.info(msg="Finish val mode.")

def after_iter(self):
Expand Down
80 changes: 40 additions & 40 deletions iddm/model/trainers/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,53 +218,52 @@ def train_in_iter(self):
self.model.eval()
logger.info(msg="Start val mode.")
val_pbar = tqdm(self.val_dataloader)
for i, (lr_images, hr_images) in enumerate(val_pbar):
# The images are all resized in val dataloader
lr_images = lr_images.to(self.device)
hr_images = hr_images.to(self.device)
# Enable Automatic mixed precision training
# Automatic mixed precision training
with torch.no_grad():
with torch.no_grad():
for i, (lr_images, hr_images) in enumerate(val_pbar):
# The images are all resized in val dataloader
lr_images = lr_images.to(self.device)
hr_images = hr_images.to(self.device)
# Enable Automatic mixed precision training
# Automatic mixed precision training
output = self.model(lr_images)
# To calculate the MSE loss
# You need to use the standard normal distribution of x at time t and the predicted noise
val_loss = self.loss_func(output, hr_images)
# The optimizer clears the gradient of the model parameters
self.optimizer.zero_grad()

# TensorBoard logging
val_pbar.set_postfix(MSE=val_loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", scalar_value=val_loss.item(),
global_step=self.epoch * self.len_val_dataloader + i)
val_loss_list.append(val_loss.item())
# TensorBoard logging
val_pbar.set_postfix(MSE=val_loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})",
scalar_value=val_loss.item(),
global_step=self.epoch * self.len_val_dataloader + i)
val_loss_list.append(val_loss.item())

# Metric
ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
psnr_res = compute_psnr(mse=val_loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res,
global_step=self.epoch * self.len_val_dataloader + i)
self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res,
global_step=self.epoch * self.len_val_dataloader + i)
ssim_list.append(ssim_res)
psnr_list.append(psnr_res)
# Metric
ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
psnr_res = compute_psnr(mse=val_loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res,
global_step=self.epoch * self.len_val_dataloader + i)
self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res,
global_step=self.epoch * self.len_val_dataloader + i)
ssim_list.append(ssim_res)
psnr_list.append(psnr_res)

# Save super resolution image and high resolution image
lr_images = post_image(lr_images, device=self.device)
sr_images = post_image(output, device=self.device)
hr_images = post_image(hr_images, device=self.device)
image_name = time.time()
for lr_index, lr_image in enumerate(lr_images):
save_images(images=lr_image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{lr_index}_lr.{self.image_format}"))
for sr_index, sr_image in enumerate(sr_images):
save_images(images=sr_image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{sr_index}_sr.{self.image_format}"))
for hr_index, hr_image in enumerate(hr_images):
save_images(images=hr_image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{hr_index}_hr.{self.image_format}"))
# Save super resolution image and high resolution image
lr_images = post_image(lr_images, device=self.device)
sr_images = post_image(output, device=self.device)
hr_images = post_image(hr_images, device=self.device)
image_name = time.time()
for lr_index, lr_image in enumerate(lr_images):
save_images(images=lr_image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{lr_index}_lr.{self.image_format}"))
for sr_index, sr_image in enumerate(sr_images):
save_images(images=sr_image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{sr_index}_sr.{self.image_format}"))
for hr_index, hr_image in enumerate(hr_images):
save_images(images=hr_image,
path=os.path.join(self.save_val_vis_dir,
f"{i}_{image_name}_{hr_index}_hr.{self.image_format}"))
# Loss, ssim and psnr per epoch
self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
self.avg_ssim = sum(ssim_list) / len(ssim_list)
Expand All @@ -274,6 +273,7 @@ def train_in_iter(self):
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg ssim", scalar_value=self.avg_ssim, global_step=self.epoch)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg psnr", scalar_value=self.avg_psnr, global_step=self.epoch)
logger.info(f"Val loss: {self.avg_val_loss}, SSIM: {self.avg_ssim}, PSNR: {self.avg_psnr}")
self.model.train()
logger.info(msg="Finish val mode.")

def after_iter(self):
Expand Down
4 changes: 2 additions & 2 deletions webui/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def train(self, seed, conditional, sample, network, run_name, epochs, batch_size
29: cfg_scale
"""
gradio.Info(message="Start training...")
logger = CustomLogger(name=__name__, is_webui=True, is_save_log=True,
log_path=os.path.join(result_path, run_name))
logger = WebUILogger(name=__name__, is_save_log=True, log_path=str(os.path.join(result_path, run_name)),
rank=main_gpu)
self.KILL_FLAG = False
yield logger.info(msg="[Note]: Start parameters setting.")
yield logger.info(
Expand Down