Skip to content
This repository was archived by the owner on Feb 8, 2025. It is now read-only.

Commit 3ea64a1

Browse files
committed
More work on autoencoder.
Also, use groupyr 0.3.3
1 parent 9522c27 commit 3ea64a1

2 files changed

Lines changed: 21 additions & 11 deletions

File tree

afqinsight/nn/tf_models.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -336,21 +336,25 @@ def fc_autoencoder(input_shape, encoding_dim=None, verbose=False):
336336
return model
337337

338338

339-
def cnn_autoencoder(input_shape, verbose=False):
339+
def cnn_autoencoder(input_shape, encoding_dim=8, verbose=False):
340340
"""
341341
Convolutional autoencoder
342342
"""
343343
ip = Input(shape=input_shape)
344344
# Encoder
345-
x = Conv1D(32, (3), activation="relu", padding="same")(ip)
346-
x = MaxPooling1D((2), padding="same")(x)
347-
x = Conv1D(32, (3), activation="relu", padding="same")(x)
348-
x = MaxPooling1D((2), padding="same")(x)
349-
345+
x = Conv1D(32, 3, activation="relu", padding="same")(ip)
346+
x = MaxPooling1D(2, padding="same")(x)
347+
x = Conv1D(16, 3, activation="relu", padding="same")(x)
348+
x = MaxPooling1D(2, padding="same")(x)
349+
shape = x.shape
350+
# Latent
351+
x = Flatten()(x)
352+
x = Dense(encoding_dim, activation="relu")(x)
350353
# Decoder
351-
x = Conv1DTranspose(32, (3), strides=2, activation="relu", padding="same")(x)
352-
x = Conv1DTranspose(32, (3), strides=2, activation="relu", padding="same")(x)
353-
x = Conv1D(1, (3), activation="sigmoid", padding="same")(x)
354+
x = Reshape(shape)(x)
355+
x = Conv1DTranspose(32, 3, strides=2, activation="relu", padding="same")(x)
356+
x = Conv1DTranspose(16, 3, strides=2, activation="relu", padding="same")(x)
357+
x = Conv1DTranspose(1, 3, activation="sigmoid", padding="same")(x)
354358

355359
model = Model([ip], [x])
356360
if verbose:
@@ -401,7 +405,8 @@ def _fc_vae_decoder(input_shape, encoding_dim=None, verbose=False):
401405
fc = Dense((input_shape[0] * input_shape[1]) // 4, activation="relu")(fc)
402406
fc = Dense((input_shape[0] * input_shape[1]) // 2, activation="relu")(fc)
403407
pre_out = Dense((input_shape[0] * input_shape[1]))(fc)
404-
return Reshape(input_shape)(pre_out)
408+
out = Reshape(input_shape)(pre_out)
409+
return Model([ip], [out], name="decoder")
405410

406411

407412
class _VAE(Model):
@@ -427,6 +432,11 @@ def metrics(self):
427432
self.kl_loss_tracker,
428433
]
429434

435+
def call(self, inputs):
436+
z_mean, z_log_var, z = self.encoder(inputs)
437+
reconstructed = self.decoder(z)
438+
return reconstructed
439+
430440
def train_step(self, data):
431441
with tf.GradientTape() as tape:
432442
z_mean, z_log_var, z = self.encoder(data)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ setup_requires =
3333
python_requires = >=3.10
3434
install_requires =
3535
dipy>=1.0.0
36-
groupyr>=0.3.2
36+
groupyr>=0.3.3
3737
matplotlib
3838
numpy<2
3939
pandas==2.1.4

0 commit comments

Comments
 (0)