|
13 | 13 |
|
14 | 14 | if has_tf: |
15 | 15 | from tensorflow.keras.models import Model |
16 | | - from tensorflow.keras.layers import Dense, Flatten, Dropout, Input |
| 16 | + from tensorflow.keras.layers import Dense, Flatten, Dropout, Input, Reshape |
17 | 17 | from tensorflow.keras.layers import MaxPooling1D, Conv1D |
18 | 18 | from tensorflow.keras.layers import LSTM, Bidirectional |
19 | 19 | from tensorflow.keras.layers import ( |
@@ -309,22 +309,23 @@ def cnn_resnet(input_shape, n_classes, output_activation="softmax", verbose=Fals |
309 | 309 | return model |
310 | 310 |
|
311 | 311 |
|
312 | | -def autoencoder(input_shape, n_hidden=None, verbose=False): |
| 312 | +def autoencoder(input_shape, encoding_dim=None, verbose=False): |
313 | 313 | """ |
314 | 314 | Fully connected autoencoder |
315 | 315 | """ |
316 | 316 | ip = Input(shape=input_shape) |
317 | | - if n_hidden is None: |
318 | | - n_hidden = input_shape[0] // 8 |
| 317 | + if encoding_dim is None: |
| 318 | + encoding_dim = (input_shape[0] * input_shape[1]) // 8 |
319 | 319 |
|
320 | 320 | fc = Flatten()(ip) |
321 | | - fc = Dense(input_shape, activation="relu")(fc) |
322 | | - fc = Dense(input_shape // 2, activation="relu")(fc) |
323 | | - fc = Dense(input_shape // 4, activation="relu")(fc) |
324 | | - fc = Dense(n_hidden, activation="relu")(fc) |
325 | | - fc = Dense(input_shape // 4, activation="relu")(fc) |
326 | | - fc = Dense(input_shape // 2, activation="relu")(fc) |
327 | | - out = Dense(input_shape)(fc) |
| 321 | + fc = Dense(input_shape[0] * input_shape[1], activation="relu")(fc) |
| 322 | + fc = Dense((input_shape[0] * input_shape[1]) // 2, activation="relu")(fc) |
| 323 | + fc = Dense((input_shape[0] * input_shape[1]) // 4, activation="relu")(fc) |
| 324 | + fc = Dense(encoding_dim, activation="relu")(fc) |
| 325 | + fc = Dense((input_shape[0] * input_shape[1]) // 4, activation="relu")(fc) |
| 326 | + fc = Dense((input_shape[0] * input_shape[1]) // 2, activation="relu")(fc) |
| 327 | + pre_out = Dense((input_shape[0] * input_shape[1]))(fc) |
| 328 | + out = Reshape(input_shape)(pre_out) |
328 | 329 |
|
329 | 330 | model = Model([ip], [out]) |
330 | 331 | if verbose: |
|
0 commit comments