-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
174 lines (139 loc) · 5.8 KB
/
models.py
File metadata and controls
174 lines (139 loc) · 5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from wcode.net.CNN.VNet.VNet import VNet
class MeanTeacher(nn.Module):
def __init__(self, params, alpha=0.99):
super(MeanTeacher, self).__init__()
self.student = VNet(params)
self.teacher = VNet(params)
# initialize the params of teacher the same as the student's.
self.init_teacher_with_student()
self.teacher.requires_grad_(False)
self.check_init_weight()
self.alpha = alpha
def forward(self, x, train_flag=False, weak_aug=False):
if train_flag:
if weak_aug:
with torch.no_grad():
rotate_time = random.randint(1, 3)
flip_axis = [np.random.choice(list(range(x.ndim))[-2:])]
x_s = x.clone().detach()
x_s = self.rotate_img(x_s, rotate_time)
x_s = self.flip_img(x_s, flip_axis)
if torch.rand(1) < 0.5:
x_s = self.gaussian_blur(x_s)
else:
x_s = self.sharpen(x_s)
student_out = self.student(x_s)
student_out["pred"] = self.flip_img(student_out["pred"], flip_axis)
student_out["pred"] = self.rotate_img(
student_out["pred"], 4 - rotate_time
)
if "feature" in student_out.keys():
student_out["feature"] = self.flip_img(
student_out["feature"], flip_axis
)
student_out["feature"] = self.rotate_img(
student_out["feature"], 4 - rotate_time
)
else:
student_out = self.student(x)
else:
student_out = None
with torch.no_grad():
teacher_out = self.teacher(x)
return {
"student_out": student_out,
"teacher_out": teacher_out,
"pred": teacher_out["pred"],
}
def update_ema_variables(self, alpha=None):
if alpha is None:
alpha = self.alpha
for ema_param, param in zip(
self.teacher.parameters(), self.student.parameters()
):
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
def update_alpha(self, new_value):
assert 0 <= new_value <= 1
self.alpha = new_value
def init_teacher_with_student(self):
self.teacher.load_state_dict(self.student.state_dict())
def exchange_params_teacher_and_student(self):
params = self.student.state_dict().copy()
self.student.load_state_dict(self.teacher.state_dict())
self.teacher.load_state_dict(params)
def check_init_weight(self):
for student_param, teacher_param in zip(
self.student.parameters(), self.teacher.parameters()
):
assert torch.equal(student_param.data, teacher_param.data)
def rotate_img(self, img, rotate_time):
if isinstance(img, (tuple, list)):
dim = len(img[0].shape)
return [
torch.rot90(i, k=rotate_time, dims=(2, 3) if dim == 4 else (3, 4))
for i in img
]
else:
dim = len(img.shape)
return torch.rot90(img, k=rotate_time, dims=(2, 3) if dim == 4 else (3, 4))
def flip_img(self, img, filp_axis):
if isinstance(img, (tuple, list)):
return [torch.flip(i, dims=filp_axis) for i in img]
else:
return torch.flip(img, dims=filp_axis)
def gaussian_blur(self, img):
dim = img.ndim - 2
def gaussian_kernel(dim, kernel_size=3, sigma=2.0):
if isinstance(kernel_size, int):
size = [int(kernel_size) // 2 for _ in range(2)]
elif isinstance(kernel_size, (list, tuple)):
size = [i // 2 for i in list(kernel_size)]
else:
raise ValueError("Unsupport type of kernel_size:", type(kernel_size))
x, y = np.mgrid[-size[0] : size[0] + 1, -size[1] : size[1] + 1]
g = np.exp(-(x**2 + y**2) / (2 * sigma**2))
return g / g.sum(), size + size
if dim == 3:
# b, c, z, y, x -> b, z, c, y, x -> b*z, c, y, x
b, c, z, y, x = img.shape
img = img.transpose(1, 2).reshape(b * z, c, y, x)
elif dim == 2:
b, c, y, x = img.shape
kernelsize = random.randint(3, 7)
kernel, padding_size = gaussian_kernel(dim, kernel_size=kernelsize)
kernel_tensor = (
torch.from_numpy(kernel)
.float()
.to(img.device)
.unsqueeze(0)
.unsqueeze(0)
.repeat(c, 1, 1, 1)
) # (c, 1, kernel_size, kernel_size)
blurred_img = F.conv2d(
F.pad(img, pad=padding_size, mode="reflect"), kernel_tensor, groups=c
)
if dim == 3:
blurred_img = blurred_img.reshape(b, z, c, y, x).transpose(1, 2)
return blurred_img
def sharpen(self, img):
dim = img.ndim - 2
sharpen_kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
if dim == 3:
# b, c, z, y, x -> b, z, c, y, x -> b*z, c, y, x
b, c, z, y, x = img.shape
img = img.transpose(1, 2).reshape(b * z, c, y, x)
elif dim == 2:
b, c, y, x = img.shape
kernel_tensor = torch.from_numpy(sharpen_kernel).float().to(img.device)
kernel_tensor = kernel_tensor.unsqueeze(0).unsqueeze(0).repeat(c, 1, 1, 1)
sharpened_img = F.conv2d(
F.pad(img, pad=[1, 1, 1, 1], mode="reflect"), kernel_tensor, groups=c
)
if dim == 3:
sharpened_img = sharpened_img.reshape(b, z, c, y, x).transpose(1, 2)
return sharpened_img