forked from multimediaeval/2022-Medico-Multimedia
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
114 lines (88 loc) · 3.61 KB
/
model.py
File metadata and controls
114 lines (88 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.layers import AveragePooling2D, GlobalAveragePooling2D, UpSampling2D, Reshape, Dense, LayerNormalization, Dropout, Attention
def SqueezeAndExcite(inputs, ratio=8):
init = inputs
filters = init.shape[-1]
se_shape = (1, 1, filters)
se = GlobalAveragePooling2D()(init)
se = Reshape(se_shape)(se)
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
x = init * se
return x
def ASPP(inputs):
""" Image Pooling """
shape = inputs.shape
y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(inputs)
y1 = Conv2D(256, 1, padding="same", use_bias=False)(y1)
y1 = BatchNormalization()(y1)
y1 = Activation("relu")(y1)
y1 = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y1)
y1 = Attention()([y1, y1])
""" 1x1 conv """
y2 = Conv2D(256, 1, padding="same", use_bias=False)(inputs)
y2 = BatchNormalization()(y2)
y2 = Activation("relu")(y2)
#y2 = reshape(y2, y2.shape[1:])
y2 = Attention()([y2, y2])
""" 3x3 conv rate=6 """
y3 = Conv2D(256, 3, padding="same", use_bias=False, dilation_rate=6)(inputs)
y3 = BatchNormalization()(y3)
y3 = Activation("relu")(y3)
#y3 = reshape(y3, y3.shape[1:])
y3 = Attention()([y3, y3])
""" 3x3 conv rate=12 """
y4 = Conv2D(256, 3, padding="same", use_bias=False, dilation_rate=12)(inputs)
y4 = BatchNormalization()(y4)
y4 = Activation("relu")(y4)
#y4 = reshape(y4, y4.shape[1:])
y4 = Attention()([y4, y4])
""" 3x3 conv rate=18 """
y5 = Conv2D(256, 3, padding="same", use_bias=False, dilation_rate=18)(inputs)
y5 = BatchNormalization()(y5)
y5 = Activation("relu")(y5)
#y5 = reshape(y5, y5.shape[1:])
y5 = Attention()([y5, y5])
y = Concatenate()([y1, y2, y3, y4, y5])
y = Conv2D(256, 1, padding="same", use_bias=False)(y)
y = BatchNormalization()(y)
y = Activation("relu")(y)
return y
def build_DLV3SA(shape):
""" Input """
inputs = Input(shape)
""" Encoder """
encoder = ResNet50(weights="imagenet", include_top=False, input_tensor=inputs)
image_features = encoder.get_layer("conv4_block6_out").output
x_a = ASPP(image_features)
x_a = UpSampling2D((4, 4), interpolation="bilinear")(x_a)
x_b = encoder.get_layer("conv2_block2_out").output
x_b = Conv2D(filters=48, kernel_size=1, padding='same', use_bias=False)(x_b)
x_b = BatchNormalization()(x_b)
x_b = Activation('relu')(x_b)
# cr_a_b = Attention()([x_a, x_b, x_a])
# print(cr_a_b.shape)
# cr_b_a = Attention()([x_b, x_a, x_b])
# print(cr_b_a.shape)
x = Concatenate()([x_a, x_b])
# x = Concatenate()([cr_a_b, cr_b_a])
x = SqueezeAndExcite(x)
x = Conv2D(filters=256, kernel_size=3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=256, kernel_size=3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# x = Attention()([x, x])
x = SqueezeAndExcite(x)
x = UpSampling2D((4, 4), interpolation="bilinear")(x)
x = Conv2D(1, 1)(x)
x = Activation("sigmoid")(x)
model = Model(inputs, x)
return model
if __name__ == "__main__":
model = build_DLV3SA((256, 256, 3))