-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
113 lines (88 loc) · 3.45 KB
/
train.py
File metadata and controls
113 lines (88 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from model import Critic, Generator, initialize_weights
from utils import gradient_penalty
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
z_dim = 100
image_dim = 64
batch_size = 64
num_epochs = 50
disc_features = 128
gen_features = 128
critic_iterations = 5
Lambda_GP = 10
critic = Critic(disc_features).to(device)
generator = Generator(z_dim, gen_features).to(device)
if os.path.exists("Generator.pth"):
generator.load_state_dict(torch.load("Generator.pth", map_location=device))
generator.to(device)
else:
initialize_weights(generator)
if os.path.exists("Critic.pth"):
critic.load_state_dict(torch.load("Critic.pth", map_location=device))
critic.to(device)
else:
initialize_weights(critic)
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = ImageFolder(root='dataset/', transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_critic = optim.Adam(critic.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_gen = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
fixed_noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
critic.train()
generator.train()
sample_iter = 0
os.makedirs('samples', exist_ok=True)
for epoch in range(num_epochs):
for i, (real, _) in enumerate(loader):
real_image = real.to(device)
current_batch_size,_,_,_ = real_image.shape
for _ in range(critic_iterations):
z_noise = torch.randn(current_batch_size, z_dim, 1, 1).to(device)
fake_image = generator(z_noise)
critic_real = critic(real_image).reshape(-1)
critic_fake = critic(fake_image.detach()).reshape(-1)
gp = gradient_penalty(critic, real_image, fake_image, device=device)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ (Lambda_GP * gp)
)
critic.zero_grad()
loss_critic.backward()
opt_critic.step()
### Train Generator: min -E[critic(gen_fake)] ###
z_noise = torch.randn(current_batch_size, z_dim, 1, 1).to(device)
fake_image = generator(z_noise)
output = critic(fake_image).reshape(-1)
loss_gen = -torch.mean(output)
generator.zero_grad()
loss_gen.backward()
opt_gen.step()
if i % 100 == 0:
print("saved model for epoch :", epoch+1)
torch.save(generator.state_dict(), "Generator.pth")
torch.save(critic.state_dict(), "Critic.pth")
if i % 1 == 0:
print(f" Generator Loss: {loss_gen.item()}, Critic Loss: {loss_critic.item()}")
if i % 200 == 0:
generator.eval()
with torch.no_grad():
sample_iter += 1
print(f"saved sample {sample_iter}")
fake = generator(fixed_noise).detach().cpu()
fake = fake[0, :, :, :].permute(1, 2, 0) * 0.5 + 0.5
plt.imsave(f"samples/fake_images_epoch_{sample_iter}.png", fake.numpy())
#plt.imshow(fake)
#plt.show()
generator.train()