|
| 1 | + |
| 2 | +from tensorflow.keras.models import Model |
| 3 | +from tensorflow.keras.layers import Conv2D, MaxPool2D, Dense, Concatenate, UpSampling2D, Input, Activation, add, \ |
| 4 | + BatchNormalization, Dropout, Softmax, LeakyReLU |
| 5 | + |
| 6 | + |
| 7 | +# Create a context_module |
| 8 | +def context_module(input_image, filters, kernel_size=(3, 3), padding="same", strides=1): |
| 9 | + block = Conv2D(filters, kernel_size, strides, padding, activation=LeakyReLU(alpha=0.02))(input_image) |
| 10 | + block = BatchNormalization()(block) |
| 11 | + block = Conv2D(filters, kernel_size, strides, padding, activation=LeakyReLU(alpha=0.02))(block) |
| 12 | + block = BatchNormalization()(block) |
| 13 | + block = Dropout(rate=0.3)(block) |
| 14 | + return block |
| 15 | + |
| 16 | + |
| 17 | +# Create a localization_module |
| 18 | +def localization_module(input_image, filters): |
| 19 | + block = Conv2D(filters, kernel_size=(3, 3), padding="same", strides=1, activation=LeakyReLU(alpha=0.02))(input_image) |
| 20 | + block = BatchNormalization()(block) |
| 21 | + block = Conv2D(filters, kernel_size=(1, 1), padding="same", strides=1, activation=LeakyReLU(alpha=0.02))(block) |
| 22 | + block = BatchNormalization()(block) |
| 23 | + return block |
| 24 | + |
| 25 | + |
| 26 | +# Create an upsampling_module |
| 27 | +def upsampling_module(input_image, filters): |
| 28 | + block = UpSampling2D((2, 2))(input_image) |
| 29 | + block = Conv2D(filters, kernel_size=(3, 3), strides=1, padding="same", activation=LeakyReLU(alpha=0.02))(block) |
| 30 | + block = BatchNormalization()(block) |
| 31 | + return block |
| 32 | + |
| 33 | + |
| 34 | +# improved_Unet |
| 35 | +def improved_Unet(input_image): |
| 36 | + # Contracting path |
| 37 | + enc1_1 = Conv2D(filters=16, kernel_size=(3, 3), padding="same", strides=1, activation=LeakyReLU(alpha=0.02))(input_image) |
| 38 | + enc1_1 = BatchNormalization()(enc1_1) |
| 39 | + enc1_2 = context_module(enc1_1, filters=16) |
| 40 | + enc1 = add([enc1_1, enc1_2]) |
| 41 | + |
| 42 | + enc2_1 = Conv2D(filters=32, kernel_size=(3, 3), padding="same", strides=2, activation=LeakyReLU(alpha=0.02))(enc1) |
| 43 | + enc2_1 = BatchNormalization()(enc2_1) |
| 44 | + enc2_2 = context_module(enc2_1, filters=32) |
| 45 | + enc2 = add([enc2_1, enc2_2]) |
| 46 | + |
| 47 | + enc3_1 = Conv2D(filters=64, kernel_size=(3, 3), padding="same", strides=2, activation=LeakyReLU(alpha=0.02))(enc2) |
| 48 | + enc3_1 = BatchNormalization()(enc3_1) |
| 49 | + enc3_2 = context_module(enc3_1, filters=64) |
| 50 | + enc3 = add([enc3_1, enc3_2]) |
| 51 | + |
| 52 | + enc4_1 = Conv2D(filters=128, kernel_size=(3, 3), padding="same", strides=2, activation=LeakyReLU(alpha=0.02))(enc3) |
| 53 | + enc4_1 = BatchNormalization()(enc4_1) |
| 54 | + enc4_2 = context_module(enc4_1, filters=128) |
| 55 | + enc4 = add([enc4_1, enc4_2]) |
| 56 | + |
| 57 | + enc5_1 = Conv2D(filters=256, kernel_size=(3, 3), padding="same", strides=2, activation=LeakyReLU(alpha=0.02))(enc4) |
| 58 | + enc5_1 = BatchNormalization()(enc5_1) |
| 59 | + enc5_2 = context_module(enc5_1, filters=256) |
| 60 | + enc5_3 = add([enc5_1, enc5_2]) |
| 61 | + enc5 = upsampling_module(enc5_3, filters=128) |
| 62 | + |
| 63 | + # Expansive path |
| 64 | + dec1_1 = Concatenate()([enc4, enc5]) |
| 65 | + dec1_2 = localization_module(dec1_1, filters=128) |
| 66 | + dec1 = upsampling_module(dec1_2, filters=64) |
| 67 | + |
| 68 | + dec2_1 = Concatenate()([enc3, dec1]) |
| 69 | + dec2_2 = localization_module(dec2_1, filters=64) |
| 70 | + dec2 = upsampling_module(dec2_2, filters=32) |
| 71 | + |
| 72 | + dec3_1 = Concatenate()([enc2_2, dec2]) |
| 73 | + dec3_2 = localization_module(dec3_1, filters=32) |
| 74 | + dec3 = upsampling_module(dec3_2, filters=16) |
| 75 | + |
| 76 | + dec4_1 = Concatenate()([enc1_2, dec3]) |
| 77 | + dec4_2 = Conv2D(filters=32, kernel_size=(3, 3), strides=1, padding="same", activation=LeakyReLU(alpha=0.02))(dec4_1) |
| 78 | + dec4_2 = BatchNormalization()(dec4_2) |
| 79 | + dec4 = Conv2D(filters=32, kernel_size=(1, 1), strides=1, padding="same", activation=LeakyReLU(alpha=0.02))(dec4_2) |
| 80 | + dec4 = BatchNormalization()(dec4) |
| 81 | + |
| 82 | + # Element-wise sum between segmentation layers |
| 83 | + seg1 = Conv2D(filters=32,kernel_size=(1, 1), strides=1, padding="same", activation=LeakyReLU(alpha=0.02))(dec2_2) |
| 84 | + seg1 = BatchNormalization()(seg1) |
| 85 | + seg1 = UpSampling2D((2, 2))(seg1) |
| 86 | + seg2 = Conv2D(filters=32,kernel_size=(1, 1), strides=1, padding="same", activation=LeakyReLU(alpha=0.02))(dec3_2) |
| 87 | + seg2 = BatchNormalization()(seg2) |
| 88 | + seg3 = add([seg1, seg2]) |
| 89 | + seg3 = UpSampling2D((2, 2))(seg3) |
| 90 | + seg4 = Concatenate()([dec4, seg3]) |
| 91 | + |
| 92 | + output = Conv2D(3, (1, 1), padding="same", activation="softmax")(seg4) |
| 93 | + model = Model(input_image, output) |
| 94 | + return model |
| 95 | + |
| 96 | + |
0 commit comments