1+ import torch
2+ import torch .nn as nn
3+
4+
5+ # -----------------------------
6+ # Residual Conv Block
7+ # -----------------------------
8+
9+ class ResidualBlock (nn .Module ):
10+
11+ def __init__ (self , in_ch , out_ch ):
12+
13+ super ().__init__ ()
14+
15+ self .conv1 = nn .Conv3d (in_ch , out_ch , 3 , padding = 1 )
16+ self .norm1 = nn .InstanceNorm3d (out_ch )
17+ self .relu = nn .LeakyReLU (0.01 , inplace = True )
18+
19+ self .conv2 = nn .Conv3d (out_ch , out_ch , 3 , padding = 1 )
20+ self .norm2 = nn .InstanceNorm3d (out_ch )
21+
22+ self .skip = None
23+
24+ if in_ch != out_ch :
25+ self .skip = nn .Conv3d (in_ch , out_ch , 1 )
26+
27+ def forward (self , x ):
28+
29+ identity = x
30+
31+ out = self .conv1 (x )
32+ out = self .norm1 (out )
33+ out = self .relu (out )
34+
35+ out = self .conv2 (out )
36+ out = self .norm2 (out )
37+
38+ if self .skip is not None :
39+ identity = self .skip (identity )
40+
41+ out += identity
42+ out = self .relu (out )
43+
44+ return out
45+
46+
47+ # -----------------------------
48+ # Encoder Block
49+ # -----------------------------
50+
51+ class EncoderBlock (nn .Module ):
52+
53+ def __init__ (self , in_ch , out_ch ):
54+
55+ super ().__init__ ()
56+
57+ self .block = ResidualBlock (in_ch , out_ch )
58+ self .pool = nn .MaxPool3d (2 )
59+
60+ def forward (self , x ):
61+
62+ x = self .block (x )
63+ p = self .pool (x )
64+
65+ return x , p
66+
67+
68+ # -----------------------------
69+ # Decoder Block
70+ # -----------------------------
71+
72+ class DecoderBlock (nn .Module ):
73+
74+ def __init__ (self , in_ch , out_ch ):
75+
76+ super ().__init__ ()
77+
78+ self .up = nn .ConvTranspose3d (in_ch , out_ch , 2 , stride = 2 )
79+
80+ self .block = ResidualBlock (in_ch , out_ch )
81+
82+ def forward (self , x , skip ):
83+
84+ x = self .up (x )
85+
86+ x = torch .cat ([x , skip ], dim = 1 )
87+
88+ x = self .block (x )
89+
90+ return x
91+
92+
93+ # -----------------------------
94+ # UNet
95+ # -----------------------------
96+
97+ class UNet3D (nn .Module ):
98+
99+ def __init__ (self , in_channels = 4 , out_channels = 1 ):
100+
101+ super ().__init__ ()
102+
103+ # Encoder
104+ self .enc1 = EncoderBlock (in_channels , 32 )
105+ self .enc2 = EncoderBlock (32 , 64 )
106+ self .enc3 = EncoderBlock (64 , 128 )
107+ self .enc4 = EncoderBlock (128 , 256 )
108+
109+ # Bottleneck
110+ self .bottleneck = nn .Sequential (
111+ ResidualBlock (256 , 512 ),
112+ nn .Dropout3d (0.2 )
113+ )
114+
115+ # Decoder
116+ self .dec4 = DecoderBlock (512 , 256 )
117+ self .dec3 = DecoderBlock (256 , 128 )
118+ self .dec2 = DecoderBlock (128 , 64 )
119+ self .dec1 = DecoderBlock (64 , 32 )
120+
121+ # Output
122+ self .out_conv = nn .Conv3d (32 , out_channels , 1 )
123+
124+ def forward (self , x ):
125+
126+ s1 , p1 = self .enc1 (x )
127+ s2 , p2 = self .enc2 (p1 )
128+ s3 , p3 = self .enc3 (p2 )
129+ s4 , p4 = self .enc4 (p3 )
130+
131+ b = self .bottleneck (p4 )
132+
133+ d4 = self .dec4 (b , s4 )
134+ d3 = self .dec3 (d4 , s3 )
135+ d2 = self .dec2 (d3 , s2 )
136+ d1 = self .dec1 (d2 , s1 )
137+
138+ out = self .out_conv (d1 )
139+
140+ return out
141+
142+
143+ # -----------------------------
144+ # Builder
145+ # -----------------------------
146+
147+ def build_model ():
148+
149+ return UNet3D (
150+ in_channels = 1 , # only PET as input
151+ out_channels = 1
152+ )
0 commit comments