Skip to content

Commit 44ac936

Browse files
committed
added weights
1 parent ecd2546 commit 44ac936

3 files changed

Lines changed: 152 additions & 0 deletions

File tree

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
# -----------------------------
6+
# Residual Conv Block
7+
# -----------------------------
8+
9+
class ResidualBlock(nn.Module):
10+
11+
def __init__(self, in_ch, out_ch):
12+
13+
super().__init__()
14+
15+
self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1)
16+
self.norm1 = nn.InstanceNorm3d(out_ch)
17+
self.relu = nn.LeakyReLU(0.01, inplace=True)
18+
19+
self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1)
20+
self.norm2 = nn.InstanceNorm3d(out_ch)
21+
22+
self.skip = None
23+
24+
if in_ch != out_ch:
25+
self.skip = nn.Conv3d(in_ch, out_ch, 1)
26+
27+
def forward(self, x):
28+
29+
identity = x
30+
31+
out = self.conv1(x)
32+
out = self.norm1(out)
33+
out = self.relu(out)
34+
35+
out = self.conv2(out)
36+
out = self.norm2(out)
37+
38+
if self.skip is not None:
39+
identity = self.skip(identity)
40+
41+
out += identity
42+
out = self.relu(out)
43+
44+
return out
45+
46+
47+
# -----------------------------
48+
# Encoder Block
49+
# -----------------------------
50+
51+
class EncoderBlock(nn.Module):
52+
53+
def __init__(self, in_ch, out_ch):
54+
55+
super().__init__()
56+
57+
self.block = ResidualBlock(in_ch, out_ch)
58+
self.pool = nn.MaxPool3d(2)
59+
60+
def forward(self, x):
61+
62+
x = self.block(x)
63+
p = self.pool(x)
64+
65+
return x, p
66+
67+
68+
# -----------------------------
69+
# Decoder Block
70+
# -----------------------------
71+
72+
class DecoderBlock(nn.Module):
73+
74+
def __init__(self, in_ch, out_ch):
75+
76+
super().__init__()
77+
78+
self.up = nn.ConvTranspose3d(in_ch, out_ch, 2, stride=2)
79+
80+
self.block = ResidualBlock(in_ch, out_ch)
81+
82+
def forward(self, x, skip):
83+
84+
x = self.up(x)
85+
86+
x = torch.cat([x, skip], dim=1)
87+
88+
x = self.block(x)
89+
90+
return x
91+
92+
93+
# -----------------------------
94+
# UNet
95+
# -----------------------------
96+
97+
class UNet3D(nn.Module):
98+
99+
def __init__(self, in_channels=4, out_channels=1):
100+
101+
super().__init__()
102+
103+
# Encoder
104+
self.enc1 = EncoderBlock(in_channels, 32)
105+
self.enc2 = EncoderBlock(32, 64)
106+
self.enc3 = EncoderBlock(64, 128)
107+
self.enc4 = EncoderBlock(128, 256)
108+
109+
# Bottleneck
110+
self.bottleneck = nn.Sequential(
111+
ResidualBlock(256, 512),
112+
nn.Dropout3d(0.2)
113+
)
114+
115+
# Decoder
116+
self.dec4 = DecoderBlock(512, 256)
117+
self.dec3 = DecoderBlock(256, 128)
118+
self.dec2 = DecoderBlock(128, 64)
119+
self.dec1 = DecoderBlock(64, 32)
120+
121+
# Output
122+
self.out_conv = nn.Conv3d(32, out_channels, 1)
123+
124+
def forward(self, x):
125+
126+
s1, p1 = self.enc1(x)
127+
s2, p2 = self.enc2(p1)
128+
s3, p3 = self.enc3(p2)
129+
s4, p4 = self.enc4(p3)
130+
131+
b = self.bottleneck(p4)
132+
133+
d4 = self.dec4(b, s4)
134+
d3 = self.dec3(d4, s3)
135+
d2 = self.dec2(d3, s2)
136+
d1 = self.dec1(d2, s1)
137+
138+
out = self.out_conv(d1)
139+
140+
return out
141+
142+
143+
# -----------------------------
144+
# Builder
145+
# -----------------------------
146+
147+
def build_model():
148+
149+
return UNet3D(
150+
in_channels=1, # only PET as input
151+
out_channels=1
152+
)
File renamed without changes.
87.5 MB
Binary file not shown.

0 commit comments

Comments
 (0)