Skip to content

Commit dcb6dc5

Browse files
committed
chore: Optimize the trainer output information and PSNR.
1 parent 42f4c74 commit dcb6dc5

3 files changed

Lines changed: 25 additions & 22 deletions

File tree

iddm/model/trainers/autoencoder.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from iddm.utils.utils import check_and_create_dir, save_images
4040
from iddm.utils.dataset import get_dataset, post_image
4141
from iddm.utils.logger import get_logger
42+
from iddm.utils.metrics import compute_psnr
4243

4344
logger = get_logger(name=__name__)
4445

@@ -185,10 +186,9 @@ def train_in_iter(self):
185186
train_loss_list.append(train_loss.item())
186187
# Loss per epoch
187188
self.avg_train_loss = sum(train_loss_list) / len(train_loss_list)
188-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Train loss",
189-
scalar_value=self.avg_train_loss,
190-
global_step=self.epoch)
191-
logger.info(f"Train loss: {self.avg_train_loss}")
189+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg train loss({self.loss_func})",
190+
scalar_value=self.avg_train_loss, global_step=self.epoch)
191+
logger.info(f"[{self.device}]: Train loss: {self.avg_train_loss}")
192192
logger.info(msg="Finish train mode.")
193193

194194
# Val
@@ -211,9 +211,9 @@ def train_in_iter(self):
211211
global_step=self.epoch * self.len_val_dataloader + i)
212212
val_loss_list.append(val_loss.item())
213213

214-
# TODO: Metric
215-
score = 0
216-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Score({self.loss_func})", scalar_value=score,
214+
# Metric PSNR
215+
score = compute_psnr(mse=val_loss.item())
216+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Score(PSNR)", scalar_value=score,
217217
global_step=self.epoch * self.len_val_dataloader + i)
218218
score_list.append(score)
219219

@@ -233,12 +233,11 @@ def train_in_iter(self):
233233
# Loss, score per epoch
234234
self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
235235
self.avg_score = sum(score_list) / len(score_list)
236-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss", scalar_value=self.avg_val_loss,
237-
global_step=self.epoch)
238-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg score", scalar_value=self.avg_score,
236+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg val loss({self.loss_func})",
237+
scalar_value=self.avg_val_loss, global_step=self.epoch)
238+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg score(PSNR)", scalar_value=self.avg_score,
239239
global_step=self.epoch)
240-
logger.info(f"Val loss: {self.avg_val_loss}, Score: {self.avg_score}")
241-
self.model.train()
240+
logger.info(f"[{self.device}]: Val loss: {self.avg_val_loss}, Score(PSNR): {self.avg_score}")
242241
logger.info(msg="Finish val mode.")
243242

244243
def after_iter(self):

iddm/model/trainers/dm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def train_in_iter(self):
220220
"""
221221
# Initialize images and labels
222222
images, labels, loss_list = None, None, []
223+
# Train mode
223224
for i, (images, labels) in enumerate(self.pbar):
224225
# The images are all resized in dataloader
225226
images = images.to(self.device)
@@ -266,12 +267,15 @@ def train_in_iter(self):
266267

267268
# TensorBoard logging
268269
self.pbar.set_postfix(MSE=loss.item())
269-
self.tb_logger.add_scalar(tag=f"[{self.device}]: MSE", scalar_value=loss.item(),
270+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Train loss({self.loss_func})", scalar_value=loss.item(),
270271
global_step=self.epoch * self.len_dataloader + i)
271272
loss_list.append(loss.item())
272273
# Loss per epoch
273-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Loss", scalar_value=sum(loss_list) / len(loss_list),
274+
avg_train_loss = sum(loss_list) / len(loss_list)
275+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg train loss", scalar_value=avg_train_loss,
274276
global_step=self.epoch)
277+
logger.info(msg=f"[{self.device}]: Train loss: {avg_train_loss}.")
278+
logger.info(msg="Finish train mode.")
275279

276280
def after_iter(self):
277281
"""

iddm/model/trainers/sr.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,10 @@ def train_in_iter(self):
209209
global_step=self.epoch * self.len_train_dataloader + i)
210210
train_loss_list.append(train_loss.item())
211211
# Loss per epoch
212-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Train loss",
213-
scalar_value=sum(train_loss_list) / len(train_loss_list),
212+
avg_train_loss = sum(train_loss_list) / len(train_loss_list)
213+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg train loss({self.loss_func})", scalar_value=avg_train_loss,
214214
global_step=self.epoch)
215+
logger.info(msg=f"[{self.device}]: Train loss:{avg_train_loss}")
215216
logger.info(msg="Finish train mode.")
216217

217218
# Val
@@ -240,9 +241,9 @@ def train_in_iter(self):
240241
# Metric
241242
ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
242243
psnr_res = compute_psnr(mse=val_loss.item())
243-
self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res,
244+
self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM", scalar_value=ssim_res,
244245
global_step=self.epoch * self.len_val_dataloader + i)
245-
self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res,
246+
self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR", scalar_value=psnr_res,
246247
global_step=self.epoch * self.len_val_dataloader + i)
247248
ssim_list.append(ssim_res)
248249
psnr_list.append(psnr_res)
@@ -268,12 +269,11 @@ def train_in_iter(self):
268269
self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
269270
self.avg_ssim = sum(ssim_list) / len(ssim_list)
270271
self.avg_psnr = sum(psnr_list) / len(psnr_list)
271-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss", scalar_value=self.avg_val_loss,
272-
global_step=self.epoch)
272+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val avg loss({self.loss_func})",
273+
scalar_value=self.avg_val_loss, global_step=self.epoch)
273274
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg ssim", scalar_value=self.avg_ssim, global_step=self.epoch)
274275
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg psnr", scalar_value=self.avg_psnr, global_step=self.epoch)
275-
logger.info(f"Val loss: {self.avg_val_loss}, SSIM: {self.avg_ssim}, PSNR: {self.avg_psnr}")
276-
self.model.train()
276+
logger.info(f"[{self.device}]: Val loss: {self.avg_val_loss}, SSIM: {self.avg_ssim}, PSNR: {self.avg_psnr}")
277277
logger.info(msg="Finish val mode.")
278278

279279
def after_iter(self):

0 commit comments

Comments
 (0)