Skip to content

Commit c9bd609

Browse files
committed
EffNet Spectre + Dual added
1 parent 3cdcd83 commit c9bd609

5 files changed

Lines changed: 385 additions & 28 deletions

File tree

classification_models_1D/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def round_repeats(repeats):
372372
x = layers.GlobalAveragePooling1D(name='avg_pool')(x)
373373
if dropout_rate > 0:
374374
x = layers.Dropout(dropout_rate, name='top_dropout')(x)
375-
imagenet_utils.validate_activation(classifier_activation, weights)
375+
# imagenet_utils.validate_activation(classifier_activation, weights)
376376
x = layers.Dense(
377377
classes,
378378
activation=classifier_activation,
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# pylint: disable=invalid-name
16+
# pylint: disable=missing-docstring
17+
"""EfficientNet models for Keras.
18+
19+
Reference:
20+
- [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](
21+
https://arxiv.org/abs/1905.11946) (ICML 2019)
22+
"""
23+
24+
from .. import get_submodules_from_kwargs
25+
from ..weights import load_model_weights
26+
import tensorflow.compat.v2 as tf
27+
28+
import os
29+
import copy
30+
import math
31+
32+
from keras import backend
33+
from keras.applications import imagenet_utils
34+
from keras.applications.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, \
35+
EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
36+
from kapre.composed import get_perfectly_reconstructing_stft_istft
37+
from kapre import Magnitude, MagnitudeToDecibel
38+
39+
from keras.engine import training
40+
from keras.layers import VersionAwareLayers
41+
from keras.utils import data_utils
42+
from keras.utils import layer_utils
43+
from tensorflow.python.util.tf_export import keras_export
44+
45+
46+
backend = None
47+
layers = None
48+
models = None
49+
keras_utils = None
50+
51+
layers = VersionAwareLayers()
52+
53+
DENSE_KERNEL_INITIALIZER = {
54+
'class_name': 'VarianceScaling',
55+
'config': {
56+
'scale': 1. / 3.,
57+
'mode': 'fan_out',
58+
'distribution': 'uniform'
59+
}
60+
}
61+
62+
BASE_DOCSTRING = """Instantiates the {name} architecture.
63+
64+
Reference:
65+
- [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](
66+
https://arxiv.org/abs/1905.11946) (ICML 2019)
67+
68+
This function returns a Keras image classification model,
69+
optionally loaded with weights pre-trained on ImageNet.
70+
71+
For image classification use cases, see
72+
[this page for detailed examples](
73+
https://keras.io/api/applications/#usage-examples-for-image-classification-models).
74+
75+
For transfer learning use cases, make sure to read the
76+
[guide to transfer learning & fine-tuning](
77+
https://keras.io/guides/transfer_learning/).
78+
79+
Note: each Keras Application expects a specific kind of input preprocessing.
80+
For EfficientNet, input preprocessing is included as part of the model
81+
(as a `Rescaling` layer), and thus
82+
`tf.keras.applications.efficientnet.preprocess_input` is actually a
83+
pass-through function. EfficientNet models expect their inputs to be float
84+
tensors of pixels with values in the [0-255] range.
85+
86+
Args:
87+
include_top: Whether to include the fully-connected
88+
layer at the top of the network. Defaults to True.
89+
weights: One of `None` (random initialization),
90+
'imagenet' (pre-training on ImageNet),
91+
or the path to the weights file to be loaded. Defaults to 'imagenet'.
92+
input_tensor: Optional Keras tensor
93+
(i.e. output of `layers.Input()`)
94+
to use as image input for the model.
95+
input_shape: Optional shape tuple, only to be specified
96+
if `include_top` is False.
97+
It should have exactly 3 inputs channels.
98+
pooling: Optional pooling mode for feature extraction
99+
when `include_top` is `False`. Defaults to None.
100+
- `None` means that the output of the model will be
101+
the 4D tensor output of the
102+
last convolutional layer.
103+
- `avg` means that global average pooling
104+
will be applied to the output of the
105+
last convolutional layer, and thus
106+
the output of the model will be a 2D tensor.
107+
- `max` means that global max pooling will
108+
be applied.
109+
classes: Optional number of classes to classify images
110+
into, only to be specified if `include_top` is True, and
111+
if no `weights` argument is specified. Defaults to 1000 (number of
112+
ImageNet classes).
113+
classifier_activation: A `str` or callable. The activation function to use
114+
on the "top" layer. Ignored unless `include_top=True`. Set
115+
`classifier_activation=None` to return the logits of the "top" layer.
116+
Defaults to 'softmax'.
117+
When loading pretrained weights, `classifier_activation` can only
118+
be `None` or `"softmax"`.
119+
120+
Returns:
121+
A `keras.Model` instance.
122+
"""
123+
124+
125+
def EfficientNet_dual(
126+
type=0,
127+
model_name='efficientnet',
128+
include_top=True,
129+
weights='imagenet',
130+
input_shape=None,
131+
pooling=None,
132+
classes=527,
133+
win_length=2048,
134+
hop_length=1024,
135+
n_fft=1024,
136+
align_32=False,
137+
dropout_val=0.0,
138+
classifier_activation='softmax',
139+
**kwargs
140+
):
141+
global backend, layers, models, keras_utils
142+
from .efficientnet import EfficientNetB0, EfficientNetB1,EfficientNetB2, EfficientNetB3, \
143+
EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
144+
from .efficientnet_spectre import EfficientNetB0_spectre, EfficientNetB1_spectre, EfficientNetB2_spectre, \
145+
EfficientNetB3_spectre, EfficientNetB4_spectre, EfficientNetB5_spectre, EfficientNetB6_spectre, \
146+
EfficientNetB7_spectre
147+
148+
backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
149+
150+
inp = layers.Input(input_shape)
151+
152+
effnet_1D = [EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3,
153+
EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7]
154+
effnet_2D = [EfficientNetB0_spectre, EfficientNetB1_spectre, EfficientNetB2_spectre, EfficientNetB3_spectre,
155+
EfficientNetB4_spectre, EfficientNetB5_spectre, EfficientNetB6_spectre, EfficientNetB7_spectre]
156+
157+
x1 = effnet_1D[type](
158+
include_top=False,
159+
weights='audioset',
160+
input_shape=input_shape,
161+
pooling=pooling,
162+
**kwargs,
163+
)(inp)
164+
165+
x2 = effnet_2D[type](
166+
include_top=False,
167+
weights='audioset',
168+
input_shape=input_shape,
169+
pooling=pooling,
170+
**kwargs,
171+
)(inp)
172+
173+
x = layers.concatenate([x1, x2])
174+
175+
if include_top:
176+
if dropout_val > 0:
177+
x = layers.Dropout(dropout_val, name='top_dropout')(x)
178+
imagenet_utils.validate_activation(classifier_activation, weights)
179+
x = layers.Dense(
180+
classes,
181+
activation=classifier_activation,
182+
kernel_initializer=DENSE_KERNEL_INITIALIZER,
183+
name='predictions'
184+
)(x)
185+
186+
model = models.Model(inputs=inp, outputs=x, name=model_name)
187+
return model
188+
189+
190+
def EfficientNetB0_dual(
191+
**kwargs
192+
):
193+
return EfficientNet_dual(
194+
type=0,
195+
model_name='EfficientNetB0_dual',
196+
**kwargs
197+
)
198+
199+
200+
def EfficientNetB1_dual(
201+
**kwargs
202+
):
203+
return EfficientNet_dual(
204+
type=1,
205+
model_name='EfficientNetB1_dual',
206+
**kwargs
207+
)
208+
209+
210+
def EfficientNetB2_dual(
211+
**kwargs
212+
):
213+
return EfficientNet_dual(
214+
type=2,
215+
model_name='EfficientNetB2_dual',
216+
**kwargs
217+
)
218+
219+
220+
def EfficientNetB3_dual(
221+
**kwargs
222+
):
223+
return EfficientNet_dual(
224+
type=3,
225+
model_name='EfficientNetB3_dual',
226+
**kwargs
227+
)
228+
229+
230+
def EfficientNetB4_dual(
231+
**kwargs
232+
):
233+
return EfficientNet_dual(
234+
type=4,
235+
model_name='EfficientNetB4_dual',
236+
**kwargs
237+
)
238+
239+
240+
def EfficientNetB5_dual(
241+
**kwargs
242+
):
243+
return EfficientNet_dual(
244+
type=5,
245+
model_name='EfficientNetB5_dual',
246+
**kwargs
247+
)
248+
249+
250+
def EfficientNetB6_dual(
251+
**kwargs
252+
):
253+
return EfficientNet_dual(
254+
type=6,
255+
model_name='EfficientNetB6_dual',
256+
**kwargs
257+
)
258+
259+
def EfficientNetB7_dual(
260+
**kwargs
261+
):
262+
return EfficientNet_dual(
263+
type=7,
264+
model_name='EfficientNetB7_dual',
265+
**kwargs
266+
)
267+
268+
269+
EfficientNetB0_dual.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB0_dual')
270+
EfficientNetB1_dual.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB1_dual')
271+
EfficientNetB2_dual.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB2_dual')
272+
EfficientNetB3_dual.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB3_dual')
273+
EfficientNetB4_dual.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB4_dual')
274+
EfficientNetB5_dual.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB5_dual')
275+
EfficientNetB6_dual.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB6_dual')
276+
EfficientNetB7_dual.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB7_dual')
277+
278+
279+
280+
def preprocess_input(x, data_format=None, **kwargs): # pylint: disable=unused-argument
281+
"""A placeholder method for backward compatibility.
282+
283+
The preprocessing logic has been included in the efficientnet model
284+
implementation. Users are no longer required to call this method to normalize
285+
the input data. This method does nothing and only kept as a placeholder to
286+
align the API surface between old and new version of model.
287+
288+
Args:
289+
x: A floating point `numpy.array` or a `tf.Tensor`.
290+
data_format: Optional data format of the image tensor/array. Defaults to
291+
None, in which case the global setting
292+
`tf.keras.backend.image_data_format()` is used (unless you changed it,
293+
it defaults to "channels_last").{mode}
294+
295+
Returns:
296+
Unchanged `numpy.array` or `tf.Tensor`.
297+
"""
298+
return x
299+
300+
301+
def decode_predictions(preds, top=5, **kwargs):
302+
return imagenet_utils.decode_predictions(preds, top=top)
303+
304+
305+
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

0 commit comments

Comments
 (0)