Skip to content
This repository was archived by the owner on Feb 8, 2025. It is now read-only.

Commit d8fe72c

Browse files
committed
A bit more progress on autoencoder.
1 parent 1b1ad34 commit d8fe72c

1 file changed

Lines changed: 12 additions & 11 deletions

File tree

afqinsight/nn/tf_models.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
if has_tf:
1515
from tensorflow.keras.models import Model
16-
from tensorflow.keras.layers import Dense, Flatten, Dropout, Input
16+
from tensorflow.keras.layers import Dense, Flatten, Dropout, Input, Reshape
1717
from tensorflow.keras.layers import MaxPooling1D, Conv1D
1818
from tensorflow.keras.layers import LSTM, Bidirectional
1919
from tensorflow.keras.layers import (
@@ -309,22 +309,23 @@ def cnn_resnet(input_shape, n_classes, output_activation="softmax", verbose=Fals
309309
return model
310310

311311

312-
def autoencoder(input_shape, n_hidden=None, verbose=False):
312+
def autoencoder(input_shape, encoding_dim=None, verbose=False):
313313
"""
314314
Fully connected autoencoder
315315
"""
316316
ip = Input(shape=input_shape)
317-
if n_hidden is None:
318-
n_hidden = input_shape[0] // 8
317+
if encoding_dim is None:
318+
encoding_dim = (input_shape[0] * input_shape[1]) // 8
319319

320320
fc = Flatten()(ip)
321-
fc = Dense(input_shape, activation="relu")(fc)
322-
fc = Dense(input_shape // 2, activation="relu")(fc)
323-
fc = Dense(input_shape // 4, activation="relu")(fc)
324-
fc = Dense(n_hidden, activation="relu")(fc)
325-
fc = Dense(input_shape // 4, activation="relu")(fc)
326-
fc = Dense(input_shape // 2, activation="relu")(fc)
327-
out = Dense(input_shape)(fc)
321+
fc = Dense(input_shape[0] * input_shape[1], activation="relu")(fc)
322+
fc = Dense((input_shape[0] * input_shape[1]) // 2, activation="relu")(fc)
323+
fc = Dense((input_shape[0] * input_shape[1]) // 4, activation="relu")(fc)
324+
fc = Dense(encoding_dim, activation="relu")(fc)
325+
fc = Dense((input_shape[0] * input_shape[1]) // 4, activation="relu")(fc)
326+
fc = Dense((input_shape[0] * input_shape[1]) // 2, activation="relu")(fc)
327+
pre_out = Dense((input_shape[0] * input_shape[1]))(fc)
328+
out = Reshape(input_shape)(pre_out)
328329

329330
model = Model([ip], [out])
330331
if verbose:

0 commit comments

Comments
 (0)