-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
79 lines (65 loc) · 2.02 KB
/
train.py
File metadata and controls
79 lines (65 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from options import TrainOptions
from dataset import dataset_unpair
from model import DRIT
from saver import Saver
def main():
# parse options
parser = TrainOptions()
opts = parser.parse()
# daita loader
print('\n--- load dataset ---')
dataset = dataset_unpair(opts)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.nThreads)
# model
print('\n--- load model ---')
torch.autograd.set_detect_anomaly(True)
model = DRIT(opts)
model.setgpu(opts.gpu)
if opts.resume is None:
model.initialize()
ep0 = -1
total_it = 0
else:
ep0, total_it = model.resume(opts.resume)
model.set_scheduler(opts, last_ep=ep0)
ep0 += 1
print('start the training at epoch %d'%(ep0))
# saver for display and output
saver = Saver(opts)
# train
print('\n--- train ---')
max_it = 500000
for ep in range(ep0, opts.n_ep):
for it, (images_a, images_b) in enumerate(train_loader):
if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size:
continue
# input data
images_a = images_a.cuda(opts.gpu).detach()
images_b = images_b.cuda(opts.gpu).detach()
# update model
if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
model.update_D_content(images_a, images_b)
continue
else:
model.update_D(images_a, images_b)
model.update_EG()
# save to display file
# if not opts.no_display_img:
# saver.write_display(total_it, model)
print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
total_it += 1
if total_it >= max_it:
saver.write_img(-1, model)
saver.write_model(-1, model)
break
# decay learning rate
if opts.n_ep_decay > -1:
model.update_lr()
# save result image
saver.write_img(ep, model)
# Save network weights
saver.write_model(ep, total_it, model)
return
if __name__ == '__main__':
main()