-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
105 lines (73 loc) · 2.79 KB
/
utils.py
File metadata and controls
105 lines (73 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import numpy as np
from dataset import Lab_Dataset
import matplotlib.pyplot as plt
import os
import warnings
from skimage.color import lab2rgb
def initilize_weights(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose1d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
def logits_to_ab(logits, pts_in_hull):
logits = torch.softmax(logits, dim=1)
logits = logits.permute(0, 2, 3, 1)
ab = torch.matmul(logits, pts_in_hull)
ab = ab.permute(0, 3, 1, 2)
return ab
def ab_to_bins(ab, mode, pts_in_hull, return_bin_index: bool = False):
#TODO make this work with copic
B, C, H, W = ab.shape
og_pts_in_hull = pts_in_hull.detach()
bin_size,_ = pts_in_hull.shape
ab = ab.view(B, C, H*W, 1).repeat(1, 1, 1, bin_size)
pts_in_hull = pts_in_hull.permute(1, 0)
pts_in_hull = pts_in_hull.unsqueeze(0).unsqueeze(2)
pts_in_hull = pts_in_hull.expand(B, -1, -1, -1)
color_space="CIELAB",
x_dist = ab[:, 0, :, :].subtract_(pts_in_hull[:, 0, :, :])
y_dist = ab[:, 1, :, :].subtract_(pts_in_hull[:, 1, :, :])
dist_sq = x_dist**2
dist_sq.addcmul_(y_dist, y_dist)
closest_idx = torch.argmin(dist_sq, dim=2)
if return_bin_index:
bins_index = closest_idx.view(B, H, W)
return bins_index
bins_ab = og_pts_in_hull[closest_idx]
bins_ab = bins_ab.permute(0, 2, 1).view(B, 2, H, W)
return bins_ab
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def save_images(fixed_l_batch, render_batch, gen_mode):
gen_mode.eval()
fig, axes = plt.subplots(render_batch[0], render_batch[1], figsize=(15, 15))
with torch.no_grad():
fake_ab = gen_mode(fixed_l_batch).detach()
L_ab = torch.cat([fixed_l_batch, fake_ab], dim=1)
L_ab = L_ab.squeeze(0).squeeze(0)
L_ab = L_ab.permute(0, 2, 3, 1).cpu().numpy()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
rgb_image = lab2rgb(L_ab)
rgb_image = np.clip(rgb_image, 0, 1)
for idx, ax in enumerate(axes.flat):
ax.imshow(rgb_image[idx, :, :, :])
ax.axis("off")
plt.tight_layout()
plt.savefig("output.png")
plt.close()
gen_mode.train()
def r1_penalty(ab, disc_score):
grad = torch.autograd.grad(
outputs=disc_score.sum(),
inputs=ab,
create_graph=True
)[0]
penalty = grad.pow(2).reshape(grad.shape[0], -1).sum(1).mean()
return penalty