-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathtest_reg_loss_integration.py
More file actions
122 lines (99 loc) · 4.68 KB
/
test_reg_loss_integration.py
File metadata and controls
122 lines (99 loc) · 4.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
import torch
import torch.nn as nn
import torch.optim as optim
from parameterized import parameterized
from monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from monai.utils import set_determinism
TEST_CASES = [
[BendingEnergyLoss, {}, ["pred"], 3],
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "rectangular"}, ["pred", "target"]],
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 5, "kernel_type": "triangular"}, ["pred", "target"]],
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 3, "kernel_type": "gaussian"}, ["pred", "target"]],
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "gaussian"}, ["pred", "target"]],
[GlobalMutualInformationLoss, {"num_bins": 10}, ["pred", "target"]],
[GlobalMutualInformationLoss, {"kernel_type": "b-spline", "num_bins": 10}, ["pred", "target"]],
]
class TestRegLossIntegration(unittest.TestCase):
def setUp(self):
set_determinism(0)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
def tearDown(self):
set_determinism(None)
@parameterized.expand(TEST_CASES)
def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1):
"""
The goal of this test is to assess if the gradient of the loss function
is correct by testing if we can train a one layer neural network
to segment one image.
We verify that the loss is decreasing in almost all SGD steps.
"""
learning_rate = 0.001
max_iter = 100
# define a simple 3d example
target = torch.rand((1, 1, 5, 5, 5), device=self.device)
image = 12 * target + 27
image = image.to(device=self.device)
# define a one layer model
class OnelayerNet(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv3d(in_channels=1, out_channels=pred_channels, kernel_size=3, padding=1),
)
def forward(self, x):
return self.layer(x)
# initialise the network
net = OnelayerNet().to(self.device)
# initialize the loss
loss = loss_type(**loss_args).to(self.device)
# initialize a SGD optimizer
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
# declare first for pylint
init_loss = None
# train the network
for it in range(max_iter):
# set the gradient to zero
optimizer.zero_grad()
# forward pass
output = net(image)
loss_input = {"pred": output, "target": target}
loss_val = loss(**{k: loss_input[k] for k in forward_args})
if it == 0:
init_loss = loss_val
# backward pass
loss_val.backward()
optimizer.step()
self.assertGreater(init_loss, loss_val, "loss did not decrease")
def test_lncc_gaussian_kernel_gt3_identical_images(self):
"""
Regression test for make_gaussian_kernel truncated parameter bug.
LNCC on identical inputs must be close to -1.0 for gaussian kernel_size > 3.
"""
for kernel_size in [5, 7]:
with self.subTest(kernel_size=kernel_size):
loss_fn = LocalNormalizedCrossCorrelationLoss(
spatial_dims=2, kernel_size=kernel_size, kernel_type="gaussian"
).to(self.device)
x = torch.rand(2, 1, 32, 32, device=self.device)
y = x.clone()
loss = loss_fn(x, y)
self.assertTrue(
torch.allclose(loss, torch.tensor(-1.0, device=self.device, dtype=loss.dtype), atol=1e-3),
f"LNCC of identical images should be -1.0, got {loss.item():.6f} (kernel_size={kernel_size})",
)
if __name__ == "__main__":
unittest.main()