Skip to content

Commit 998b5b4

Browse files
authored
Merge pull request #186 from chairc/dev
chore: Optimize the trainer training logic.
2 parents 4576f6a + 42f4c74 commit 998b5b4

3 files changed

Lines changed: 71 additions & 72 deletions

File tree

iddm/model/trainers/autoencoder.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -195,43 +195,41 @@ def train_in_iter(self):
195195
self.model.eval()
196196
logger.info(msg="Start val mode.")
197197
val_pbar = tqdm(self.val_dataloader)
198-
for i, (images, _) in enumerate(val_pbar):
199-
# Input images [B, C, H, W]
200-
images = images.to(self.device)
198+
with torch.no_grad():
199+
for i, (images, _) in enumerate(val_pbar):
200+
# Input images [B, C, H, W]
201+
images = images.to(self.device)
201202

202-
with autocast(device_type="cuda", enabled=self.amp):
203203
recon_images = self.model(images)
204204
# To calculate the MSE loss
205205
val_loss = self.loss_func(recon_images, images)
206206

207-
# The optimizer clears the gradient of the model parameters
208-
self.optimizer.zero_grad()
207+
# TensorBoard logging
208+
val_pbar.set_postfix(MSE=val_loss.item())
209+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})",
210+
scalar_value=val_loss.item(),
211+
global_step=self.epoch * self.len_val_dataloader + i)
212+
val_loss_list.append(val_loss.item())
209213

210-
# TensorBoard logging
211-
val_pbar.set_postfix(MSE=val_loss.item())
212-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", scalar_value=val_loss.item(),
213-
global_step=self.epoch * self.len_val_dataloader + i)
214-
val_loss_list.append(val_loss.item())
215-
216-
# TODO: Metric
217-
score = 0
218-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Score({self.loss_func})", scalar_value=score,
219-
global_step=self.epoch * self.len_val_dataloader + i)
220-
score_list.append(score)
214+
# TODO: Metric
215+
score = 0
216+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Score({self.loss_func})", scalar_value=score,
217+
global_step=self.epoch * self.len_val_dataloader + i)
218+
score_list.append(score)
221219

222-
images = post_image(images=images, device=self.device)
223-
if self.loss_name == "mse_kl":
224-
recon_images = recon_images[0]
225-
recon_images = post_image(images=recon_images, device=self.device)
226-
image_name = time.time()
227-
for index, image in enumerate(images):
228-
save_images(images=image,
229-
path=os.path.join(self.save_val_vis_dir,
230-
f"{i}_{image_name}_{index}_origin.{self.image_format}"))
231-
for recon_index, recon_image in enumerate(recon_images):
232-
save_images(images=recon_image,
233-
path=os.path.join(self.save_val_vis_dir,
234-
f"{i}_{image_name}_{recon_index}_recon.{self.image_format}"))
220+
images = post_image(images=images, device=self.device)
221+
if self.loss_name == "mse_kl":
222+
recon_images = recon_images[0]
223+
recon_images = post_image(images=recon_images, device=self.device)
224+
image_name = time.time()
225+
for index, image in enumerate(images):
226+
save_images(images=image,
227+
path=os.path.join(self.save_val_vis_dir,
228+
f"{i}_{image_name}_{index}_origin.{self.image_format}"))
229+
for recon_index, recon_image in enumerate(recon_images):
230+
save_images(images=recon_image,
231+
path=os.path.join(self.save_val_vis_dir,
232+
f"{i}_{image_name}_{recon_index}_recon.{self.image_format}"))
235233
# Loss, score per epoch
236234
self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
237235
self.avg_score = sum(score_list) / len(score_list)
@@ -240,6 +238,7 @@ def train_in_iter(self):
240238
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg score", scalar_value=self.avg_score,
241239
global_step=self.epoch)
242240
logger.info(f"Val loss: {self.avg_val_loss}, Score: {self.avg_score}")
241+
self.model.train()
243242
logger.info(msg="Finish val mode.")
244243

245244
def after_iter(self):

iddm/model/trainers/sr.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -218,53 +218,52 @@ def train_in_iter(self):
218218
self.model.eval()
219219
logger.info(msg="Start val mode.")
220220
val_pbar = tqdm(self.val_dataloader)
221-
for i, (lr_images, hr_images) in enumerate(val_pbar):
222-
# The images are all resized in val dataloader
223-
lr_images = lr_images.to(self.device)
224-
hr_images = hr_images.to(self.device)
225-
# Enable Automatic mixed precision training
226-
# Automatic mixed precision training
227-
with torch.no_grad():
221+
with torch.no_grad():
222+
for i, (lr_images, hr_images) in enumerate(val_pbar):
223+
# The images are all resized in val dataloader
224+
lr_images = lr_images.to(self.device)
225+
hr_images = hr_images.to(self.device)
226+
# Enable Automatic mixed precision training
227+
# Automatic mixed precision training
228228
output = self.model(lr_images)
229229
# To calculate the MSE loss
230230
# You need to use the standard normal distribution of x at time t and the predicted noise
231231
val_loss = self.loss_func(output, hr_images)
232-
# The optimizer clears the gradient of the model parameters
233-
self.optimizer.zero_grad()
234232

235-
# TensorBoard logging
236-
val_pbar.set_postfix(MSE=val_loss.item())
237-
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})", scalar_value=val_loss.item(),
238-
global_step=self.epoch * self.len_val_dataloader + i)
239-
val_loss_list.append(val_loss.item())
233+
# TensorBoard logging
234+
val_pbar.set_postfix(MSE=val_loss.item())
235+
self.tb_logger.add_scalar(tag=f"[{self.device}]: Val loss({self.loss_func})",
236+
scalar_value=val_loss.item(),
237+
global_step=self.epoch * self.len_val_dataloader + i)
238+
val_loss_list.append(val_loss.item())
240239

241-
# Metric
242-
ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
243-
psnr_res = compute_psnr(mse=val_loss.item())
244-
self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res,
245-
global_step=self.epoch * self.len_val_dataloader + i)
246-
self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res,
247-
global_step=self.epoch * self.len_val_dataloader + i)
248-
ssim_list.append(ssim_res)
249-
psnr_list.append(psnr_res)
240+
# Metric
241+
ssim_res = compute_ssim(image_outputs=output, image_sources=hr_images)
242+
psnr_res = compute_psnr(mse=val_loss.item())
243+
self.tb_logger.add_scalar(tag=f"[{self.device}]: SSIM({self.loss_func})", scalar_value=ssim_res,
244+
global_step=self.epoch * self.len_val_dataloader + i)
245+
self.tb_logger.add_scalar(tag=f"[{self.device}]: PSNR({self.loss_func})", scalar_value=psnr_res,
246+
global_step=self.epoch * self.len_val_dataloader + i)
247+
ssim_list.append(ssim_res)
248+
psnr_list.append(psnr_res)
250249

251-
# Save super resolution image and high resolution image
252-
lr_images = post_image(lr_images, device=self.device)
253-
sr_images = post_image(output, device=self.device)
254-
hr_images = post_image(hr_images, device=self.device)
255-
image_name = time.time()
256-
for lr_index, lr_image in enumerate(lr_images):
257-
save_images(images=lr_image,
258-
path=os.path.join(self.save_val_vis_dir,
259-
f"{i}_{image_name}_{lr_index}_lr.{self.image_format}"))
260-
for sr_index, sr_image in enumerate(sr_images):
261-
save_images(images=sr_image,
262-
path=os.path.join(self.save_val_vis_dir,
263-
f"{i}_{image_name}_{sr_index}_sr.{self.image_format}"))
264-
for hr_index, hr_image in enumerate(hr_images):
265-
save_images(images=hr_image,
266-
path=os.path.join(self.save_val_vis_dir,
267-
f"{i}_{image_name}_{hr_index}_hr.{self.image_format}"))
250+
# Save super resolution image and high resolution image
251+
lr_images = post_image(lr_images, device=self.device)
252+
sr_images = post_image(output, device=self.device)
253+
hr_images = post_image(hr_images, device=self.device)
254+
image_name = time.time()
255+
for lr_index, lr_image in enumerate(lr_images):
256+
save_images(images=lr_image,
257+
path=os.path.join(self.save_val_vis_dir,
258+
f"{i}_{image_name}_{lr_index}_lr.{self.image_format}"))
259+
for sr_index, sr_image in enumerate(sr_images):
260+
save_images(images=sr_image,
261+
path=os.path.join(self.save_val_vis_dir,
262+
f"{i}_{image_name}_{sr_index}_sr.{self.image_format}"))
263+
for hr_index, hr_image in enumerate(hr_images):
264+
save_images(images=hr_image,
265+
path=os.path.join(self.save_val_vis_dir,
266+
f"{i}_{image_name}_{hr_index}_hr.{self.image_format}"))
268267
# Loss, ssim and psnr per epoch
269268
self.avg_val_loss = sum(val_loss_list) / len(val_loss_list)
270269
self.avg_ssim = sum(ssim_list) / len(ssim_list)
@@ -274,6 +273,7 @@ def train_in_iter(self):
274273
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg ssim", scalar_value=self.avg_ssim, global_step=self.epoch)
275274
self.tb_logger.add_scalar(tag=f"[{self.device}]: Avg psnr", scalar_value=self.avg_psnr, global_step=self.epoch)
276275
logger.info(f"Val loss: {self.avg_val_loss}, SSIM: {self.avg_ssim}, PSNR: {self.avg_psnr}")
276+
self.model.train()
277277
logger.info(msg="Finish val mode.")
278278

279279
def after_iter(self):

webui/web.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def train(self, seed, conditional, sample, network, run_name, epochs, batch_size
9090
29: cfg_scale
9191
"""
9292
gradio.Info(message="Start training...")
93-
logger = CustomLogger(name=__name__, is_webui=True, is_save_log=True,
94-
log_path=os.path.join(result_path, run_name))
93+
logger = WebUILogger(name=__name__, is_save_log=True, log_path=str(os.path.join(result_path, run_name)),
94+
rank=main_gpu)
9595
self.KILL_FLAG = False
9696
yield logger.info(msg="[Note]: Start parameters setting.")
9797
yield logger.info(

0 commit comments

Comments
 (0)