Skip to content

Commit 5033cbb

Browse files
stylegan, esrgan, srgan code
1 parent a2ee927 commit 5033cbb

34 files changed

Lines changed: 1569 additions & 0 deletions

ML/Pytorch/GANs/ESRGAN/ESRGAN.png

131 KB
Loading
64 MB
Binary file not shown.

ML/Pytorch/GANs/ESRGAN/config.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
from PIL import Image
3+
import albumentations as A
4+
from albumentations.pytorch import ToTensorV2
5+
6+
LOAD_MODEL = True
7+
SAVE_MODEL = True
8+
CHECKPOINT_GEN = "gen.pth"
9+
CHECKPOINT_DISC = "disc.pth"
10+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11+
LEARNING_RATE = 1e-4
12+
NUM_EPOCHS = 10000
13+
BATCH_SIZE = 16
14+
LAMBDA_GP = 10
15+
NUM_WORKERS = 4
16+
HIGH_RES = 128
17+
LOW_RES = HIGH_RES // 4
18+
IMG_CHANNELS = 3
19+
20+
highres_transform = A.Compose(
21+
[
22+
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
23+
ToTensorV2(),
24+
]
25+
)
26+
27+
lowres_transform = A.Compose(
28+
[
29+
A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
30+
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
31+
ToTensorV2(),
32+
]
33+
)
34+
35+
both_transforms = A.Compose(
36+
[
37+
A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
38+
A.HorizontalFlip(p=0.5),
39+
A.RandomRotate90(p=0.5),
40+
]
41+
)
42+
43+
test_transform = A.Compose(
44+
[
45+
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
46+
ToTensorV2(),
47+
]
48+
)

ML/Pytorch/GANs/ESRGAN/dataset.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from tqdm import tqdm
3+
import time
4+
import torch.nn
5+
import os
6+
from torch.utils.data import Dataset, DataLoader
7+
import numpy as np
8+
import config
9+
from PIL import Image
10+
import cv2
11+
12+
13+
class MyImageFolder(Dataset):
14+
def __init__(self, root_dir):
15+
super(MyImageFolder, self).__init__()
16+
self.data = []
17+
self.root_dir = root_dir
18+
self.class_names = os.listdir(root_dir)
19+
20+
for index, name in enumerate(self.class_names):
21+
files = os.listdir(os.path.join(root_dir, name))
22+
self.data += list(zip(files, [index] * len(files)))
23+
24+
def __len__(self):
25+
return len(self.data)
26+
27+
def __getitem__(self, index):
28+
img_file, label = self.data[index]
29+
root_and_dir = os.path.join(self.root_dir, self.class_names[label])
30+
31+
image = cv2.imread(os.path.join(root_and_dir, img_file))
32+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
33+
both_transform = config.both_transforms(image=image)["image"]
34+
low_res = config.lowres_transform(image=both_transform)["image"]
35+
high_res = config.highres_transform(image=both_transform)["image"]
36+
return low_res, high_res
37+
38+
39+
def test():
40+
dataset = MyImageFolder(root_dir="data/")
41+
loader = DataLoader(dataset, batch_size=8)
42+
43+
for low_res, high_res in loader:
44+
print(low_res.shape)
45+
print(high_res.shape)
46+
47+
48+
if __name__ == "__main__":
49+
test()

ML/Pytorch/GANs/ESRGAN/loss.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch.nn as nn
2+
from torchvision.models import vgg19
3+
import config
4+
5+
6+
class VGGLoss(nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.vgg = vgg19(pretrained=True).features[:35].eval().to(config.DEVICE)
10+
11+
for param in self.vgg.parameters():
12+
param.requires_grad = False
13+
14+
self.loss = nn.MSELoss()
15+
16+
def forward(self, input, target):
17+
vgg_input_features = self.vgg(input)
18+
vgg_target_features = self.vgg(target)
19+
return self.loss(vgg_input_features, vgg_target_features)

ML/Pytorch/GANs/ESRGAN/model.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class ConvBlock(nn.Module):
6+
def __init__(self, in_channels, out_channels, use_act, **kwargs):
7+
super().__init__()
8+
self.cnn = nn.Conv2d(
9+
in_channels,
10+
out_channels,
11+
**kwargs,
12+
bias=True,
13+
)
14+
self.act = nn.LeakyReLU(0.2, inplace=True) if use_act else nn.Identity()
15+
16+
def forward(self, x):
17+
return self.act(self.cnn(x))
18+
19+
20+
class UpsampleBlock(nn.Module):
21+
def __init__(self, in_c, scale_factor=2):
22+
super().__init__()
23+
self.upsample = nn.Upsample(scale_factor=scale_factor, mode="nearest")
24+
self.conv = nn.Conv2d(in_c, in_c, 3, 1, 1, bias=True)
25+
self.act = nn.LeakyReLU(0.2, inplace=True)
26+
27+
def forward(self, x):
28+
return self.act(self.conv(self.upsample(x)))
29+
30+
31+
class DenseResidualBlock(nn.Module):
32+
def __init__(self, in_channels, channels=32, residual_beta=0.2):
33+
super().__init__()
34+
self.residual_beta = residual_beta
35+
self.blocks = nn.ModuleList()
36+
37+
for i in range(5):
38+
self.blocks.append(
39+
ConvBlock(
40+
in_channels + channels * i,
41+
channels if i <= 3 else in_channels,
42+
kernel_size=3,
43+
stride=1,
44+
padding=1,
45+
use_act=True if i <= 3 else False,
46+
)
47+
)
48+
49+
def forward(self, x):
50+
new_inputs = x
51+
for block in self.blocks:
52+
out = block(new_inputs)
53+
new_inputs = torch.cat([new_inputs, out], dim=1)
54+
return self.residual_beta * out + x
55+
56+
57+
class RRDB(nn.Module):
58+
def __init__(self, in_channels, residual_beta=0.2):
59+
super().__init__()
60+
self.residual_beta = residual_beta
61+
self.rrdb = nn.Sequential(*[DenseResidualBlock(in_channels) for _ in range(3)])
62+
63+
def forward(self, x):
64+
return self.rrdb(x) * self.residual_beta + x
65+
66+
67+
class Generator(nn.Module):
68+
def __init__(self, in_channels=3, num_channels=64, num_blocks=23):
69+
super().__init__()
70+
self.initial = nn.Conv2d(
71+
in_channels,
72+
num_channels,
73+
kernel_size=3,
74+
stride=1,
75+
padding=1,
76+
bias=True,
77+
)
78+
self.residuals = nn.Sequential(*[RRDB(num_channels) for _ in range(num_blocks)])
79+
self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
80+
self.upsamples = nn.Sequential(
81+
UpsampleBlock(num_channels), UpsampleBlock(num_channels),
82+
)
83+
self.final = nn.Sequential(
84+
nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=True),
85+
nn.LeakyReLU(0.2, inplace=True),
86+
nn.Conv2d(num_channels, in_channels, 3, 1, 1, bias=True),
87+
)
88+
89+
def forward(self, x):
90+
initial = self.initial(x)
91+
x = self.conv(self.residuals(initial)) + initial
92+
x = self.upsamples(x)
93+
return self.final(x)
94+
95+
96+
class Discriminator(nn.Module):
97+
def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
98+
super().__init__()
99+
blocks = []
100+
for idx, feature in enumerate(features):
101+
blocks.append(
102+
ConvBlock(
103+
in_channels,
104+
feature,
105+
kernel_size=3,
106+
stride=1 + idx % 2,
107+
padding=1,
108+
use_act=True,
109+
),
110+
)
111+
in_channels = feature
112+
113+
self.blocks = nn.Sequential(*blocks)
114+
self.classifier = nn.Sequential(
115+
nn.AdaptiveAvgPool2d((6, 6)),
116+
nn.Flatten(),
117+
nn.Linear(512 * 6 * 6, 1024),
118+
nn.LeakyReLU(0.2, inplace=True),
119+
nn.Linear(1024, 1),
120+
)
121+
122+
def forward(self, x):
123+
x = self.blocks(x)
124+
return self.classifier(x)
125+
126+
def initialize_weights(model, scale=0.1):
127+
for m in model.modules():
128+
if isinstance(m, nn.Conv2d):
129+
nn.init.kaiming_normal_(m.weight.data)
130+
m.weight.data *= scale
131+
132+
elif isinstance(m, nn.Linear):
133+
nn.init.kaiming_normal_(m.weight.data)
134+
m.weight.data *= scale
135+
136+
137+
def test():
138+
gen = Generator()
139+
disc = Discriminator()
140+
low_res = 24
141+
x = torch.randn((5, 3, low_res, low_res))
142+
gen_out = gen(x)
143+
disc_out = disc(gen_out)
144+
145+
print(gen_out.shape)
146+
print(disc_out.shape)
147+
148+
if __name__ == "__main__":
149+
test()
150+
151+
152+
153+
154+
543 KB
Loading
375 KB
Loading
124 KB
Loading
199 KB
Loading

0 commit comments

Comments
 (0)