-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscnp_example.py
More file actions
58 lines (44 loc) · 1.98 KB
/
Copy pathscnp_example.py
File metadata and controls
58 lines (44 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
##########################
#### vvvvv SCNP vvvvv ####
##########################
import torch
# Copy-paste this in your code (yes, you only need this :))
class SCNP(torch.nn.Module):
def __init__(self, dimensions, neighborhood_size=3):
super(SCNP, self).__init__()
if dimensions == "2D":
self.mp = torch.nn.functional.max_pool2d
self.ns = (neighborhood_size, neighborhood_size)
self.st = (1, 1)
self.pad = (neighborhood_size//2, neighborhood_size//2)
elif dimensions == "3D":
self.mp = torch.nn.functional.max_pool3d
self.ns = (neighborhood_size, neighborhood_size, neighborhood_size)
self.st = (1, 1, 1)
self.pad = (neighborhood_size//2, neighborhood_size//2, neighborhood_size//2)
else:
raise Exception("`dimensions` parameters must be either '2D' or '3D'")
def forward(self, logits, target):
assert logits.shape == target.shape, "`target` should be one-hot encoded and have the same shape as `logits`"
# MinPooling in the foreground
t1 = -self.mp(-(logits*target+9999*(1-target)), self.ns, self.st, self.pad)
# MaxPooling in the background
t2 = self.mp((logits*(1-target)-9999*target), self.ns, self.st, self.pad)
z_tilde = t1*target + t2*(1-target)
return z_tilde
#############################
#### vvvvv Example vvvvv ####
#############################
# Data
X # torch.Tensor of size (B, C, H, W, (D))
# One-hot encoded labels, with N = number of classes
Y # torch.Tensor of size (B, N, H, W, (D))
# 2D version of SCNP, with a neighborhood size of 3x3
scnp = SCNP("2D", 3)
# 3D version of SCNP, with a neighborhood size of 5x5x5
# scnp = SCNP("3D", 5)
z_logits = model(X)
scnp_logits = scnp(z_logits, Y)
loss = CrossEntropyDiceLoss(SoftmaxOrSigmoid(scnp_logits), Y)
# loss = CrossEntropyDiceLoss(SoftmaxOrSigmoid(scnp_logits), Y) + CrossEntropyDiceLoss(SoftmaxOrSigmoid(z_logits), Y)
loss.backward()...