-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathepi_irv2_augment_l1.py
More file actions
64 lines (56 loc) · 2.28 KB
/
epi_irv2_augment_l1.py
File metadata and controls
64 lines (56 loc) · 2.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from tensorflow.keras import preprocessing as preproc
from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras.models import Sequential as seq, load_model as load
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense
from tensorflow.keras.optimizers import Adadelta
from tensorflow.keras.callbacks import ModelCheckpoint, TerminateOnNaN
from tensorflow.keras.metrics import Accuracy, AUC, Precision, Recall, SpecificityAtSensitivity
from pandas import DataFrame as df
datagen = preproc.image.ImageDataGenerator(
validation_split=.18,
rescale=1./255,
brightness_range=[25.5, 65.5],
#shear_range=0.3,
zoom_range=0.2,
#horizontal_flip=True,
)
train = datagen.flow_from_directory(
directory="data/plot_epi/train/",
class_mode="categorical",
color_mode="rgb",
target_size=(299, 299),
shuffle=True,
interpolation="bilinear",
seed=42,
subset="training",
)
val = datagen.flow_from_directory(
directory="data/plot_epi/train/",
class_mode="categorical",
color_mode="rgb",
target_size=(299, 299),
shuffle=False,
interpolation="bilinear",
seed=42,
subset="validation",
)
epi_InceptionResNetV2_model = seq([
Input(shape=(299, 299, 3)),
InceptionResNetV2(include_top=False, weights="imagenet", input_shape=(299, 299, 3)),
GlobalAveragePooling2D(),
Dense(2, activation="softmax", kernel_regularizer="l1")
], name="EPI_InceptionResNetV2")
epi_InceptionResNetV2_model.compile(
loss="categorical_crossentropy",
optimizer=Adadelta(learning_rate=1e-2),
metrics=["acc", Precision(.51), Recall(.51), SpecificityAtSensitivity(.5), AUC()]
)
callbacks = [
ModelCheckpoint(filepath="ckpt/checkpoint-augment-l1-inceptionresnetv2-epi-{epoch:02d}-{val_acc:.3f}.h5", monitor="val_acc", save_best_only=True, mode="max"),
TerminateOnNaN()
]
epi_InceptionResNetV2_model_result = epi_InceptionResNetV2_model.fit(
x=train, validation_data=val, epochs=30, callbacks=callbacks)
epi_InceptionResNetV2_model.save("model/augment_l1_epi_InceptionResNetV2_model.h5")
epi_InceptionResNetV2_model.save_weights("model/augment_l1_epi_InceptionResNetV2_weights.h5")
df.from_dict(epi_InceptionResNetV2_model_result.history).to_csv('result/augment_l1_epi_InceptionResNetV2_model_result.csv', index=False)