Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions iddm/model/trainers/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from iddm.utils.utils import check_and_create_dir, save_images
from iddm.utils.dataset import get_dataset, post_image
from iddm.utils.logger import get_logger
from iddm.utils.metrics import compute_psnr

logger = get_logger(name=__name__)

Expand Down Expand Up @@ -185,10 +186,9 @@ def train_in_iter(self):
train_loss_list.append(train_loss.item())
# Loss per epoch
self.avg_train_loss = sum(train_loss_list) / len(train_loss_list)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Train loss",
scalar_value=self.avg_train_loss,
global_step=self.epoch)
logger.info(f"Train loss: {self.avg_train_loss}")
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg train loss({self.loss_func})",
scalar_value=self.avg_train_loss, global_step=self.epoch)
logger.info(f"[{self.device}]: Train loss: {self.avg_train_loss}")
logger.info(msg="Finish train mode.")

# Val
Expand All @@ -211,9 +211,9 @@ def train_in_iter(self):
global_step=self.epoch * self.len_val_dataloader + i)
val_loss_list.append(val_loss.item())

# TODO: Metric
score = 0
self.tb_logger.add_scalar(tag=f"[{self.device}]: Score({self.loss_func})", scalar_value=score,
# Metric PSNR
score = compute_psnr(mse=val_loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: Score(PSNR)", scalar_value=score,
global_step=self.epoch * self.len_val_dataloader + i)
score_list.append(score)

Expand All @@ -233,12 +233,11 @@ def train_in_iter(self):
# Loss, score per epoch
self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
self.avg_score = sum(score_list) / len(score_list)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss", scalar_value=self.avg_val_loss,
global_step=self.epoch)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg score", scalar_value=self.avg_score,
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg val loss({self.loss_func})",
scalar_value=self.avg_val_loss, global_step=self.epoch)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg score(PSNR)", scalar_value=self.avg_score,
global_step=self.epoch)
logger.info(f"Val loss: {self.avg_val_loss}, Score: {self.avg_score}")
self.model.train()
logger.info(f"[{self.device}]: Val loss: {self.avg_val_loss}, Score(PSNR): {self.avg_score}")
logger.info(msg="Finish val mode.")

def after_iter(self):
Expand Down
8 changes: 6 additions & 2 deletions iddm/model/trainers/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def train_in_iter(self):
"""
# Initialize images and labels
images, labels, loss_list = None, None, []
# Train mode
for i, (images, labels) in enumerate(self.pbar):
# The images are all resized in dataloader
images = images.to(self.device)
Expand Down Expand Up @@ -266,12 +267,15 @@ def train_in_iter(self):

# TensorBoard logging
self.pbar.set_postfix(MSE=loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: MSE", scalar_value=loss.item(),
self.tb_logger.add_scalar(tag=f"[{self.device}]: Train loss({self.loss_func})", scalar_value=loss.item(),
global_step=self.epoch * self.len_dataloader + i)
loss_list.append(loss.item())
# Loss per epoch
self.tb_logger.add_scalar(tag=f"[{self.device}]: Loss", scalar_value=sum(loss_list) / len(loss_list),
avg_train_loss = sum(loss_list) / len(loss_list)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg train loss", scalar_value=avg_train_loss,
global_step=self.epoch)
logger.info(msg=f"[{self.device}]: Train loss: {avg_train_loss}.")
logger.info(msg="Finish train mode.")

def after_iter(self):
"""
Expand Down
16 changes: 8 additions & 8 deletions iddm/model/trainers/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,10 @@ def train_in_iter(self):
global_step=self.epoch * self.len_train_dataloader + i)
train_loss_list.append(train_loss.item())
# Loss per epoch
self.tb_logger.add_scalar(tag=f"[{self.device}]: Train loss",
scalar_value=sum(train_loss_list) / len(train_loss_list),
avg_train_loss = sum(train_loss_list) / len(train_loss_list)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg train loss({self.loss_func})", scalar_value=avg_train_loss,
global_step=self.epoch)
logger.info(msg=f"[{self.device}]: Train loss:{avg_train_loss}")
logger.info(msg="Finish train mode.")

# Val
Expand Down Expand Up @@ -240,9 +241,9 @@ def train_in_iter(self):
# Metric
ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
psnr_res = compute_psnr(mse=val_loss.item())
self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res,
self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM", scalar_value=ssim_res,
global_step=self.epoch * self.len_val_dataloader + i)
self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res,
self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR", scalar_value=psnr_res,
global_step=self.epoch * self.len_val_dataloader + i)
ssim_list.append(ssim_res)
psnr_list.append(psnr_res)
Expand All @@ -268,12 +269,11 @@ def train_in_iter(self):
self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
self.avg_ssim = sum(ssim_list) / len(ssim_list)
self.avg_psnr = sum(psnr_list) / len(psnr_list)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss", scalar_value=self.avg_val_loss,
global_step=self.epoch)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val avg loss({self.loss_func})",
scalar_value=self.avg_val_loss, global_step=self.epoch)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg ssim", scalar_value=self.avg_ssim, global_step=self.epoch)
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg psnr", scalar_value=self.avg_psnr, global_step=self.epoch)
logger.info(f"Val loss: {self.avg_val_loss}, SSIM: {self.avg_ssim}, PSNR: {self.avg_psnr}")
self.model.train()
logger.info(f"[{self.device}]: Val loss: {self.avg_val_loss}, SSIM: {self.avg_ssim}, PSNR: {self.avg_psnr}")
logger.info(msg="Finish val mode.")

def after_iter(self):
Expand Down