@@ -336,21 +336,25 @@ def fc_autoencoder(input_shape, encoding_dim=None, verbose=False):
336336 return model
337337
338338
339- def cnn_autoencoder (input_shape , verbose = False ):
339+ def cnn_autoencoder (input_shape , encoding_dim = 8 , verbose = False ):
340340 """
341341 Convolutional autoencoder
342342 """
343343 ip = Input (shape = input_shape )
344344 # Encoder
345- x = Conv1D (32 , (3 ), activation = "relu" , padding = "same" )(ip )
346- x = MaxPooling1D ((2 ), padding = "same" )(x )
347- x = Conv1D (32 , (3 ), activation = "relu" , padding = "same" )(x )
348- x = MaxPooling1D ((2 ), padding = "same" )(x )
349-
345+ x = Conv1D (32 , 3 , activation = "relu" , padding = "same" )(ip )
346+ x = MaxPooling1D (2 , padding = "same" )(x )
347+ x = Conv1D (16 , 3 , activation = "relu" , padding = "same" )(x )
348+ x = MaxPooling1D (2 , padding = "same" )(x )
349+ shape = x .shape
350+ # Latent
351+ x = Flatten ()(x )
352+ x = Dense (encoding_dim , activation = "relu" )(x )
350353 # Decoder
351- x = Conv1DTranspose (32 , (3 ), strides = 2 , activation = "relu" , padding = "same" )(x )
352- x = Conv1DTranspose (32 , (3 ), strides = 2 , activation = "relu" , padding = "same" )(x )
353- x = Conv1D (1 , (3 ), activation = "sigmoid" , padding = "same" )(x )
354+ x = Reshape (shape )(x )
355+ x = Conv1DTranspose (32 , 3 , strides = 2 , activation = "relu" , padding = "same" )(x )
356+ x = Conv1DTranspose (16 , 3 , strides = 2 , activation = "relu" , padding = "same" )(x )
357+ x = Conv1DTranspose (1 , 3 , activation = "sigmoid" , padding = "same" )(x )
354358
355359 model = Model ([ip ], [x ])
356360 if verbose :
@@ -401,7 +405,8 @@ def _fc_vae_decoder(input_shape, encoding_dim=None, verbose=False):
401405 fc = Dense ((input_shape [0 ] * input_shape [1 ]) // 4 , activation = "relu" )(fc )
402406 fc = Dense ((input_shape [0 ] * input_shape [1 ]) // 2 , activation = "relu" )(fc )
403407 pre_out = Dense ((input_shape [0 ] * input_shape [1 ]))(fc )
404- return Reshape (input_shape )(pre_out )
408+ out = Reshape (input_shape )(pre_out )
409+ return Model ([ip ], [out ], name = "decoder" )
405410
406411
407412class _VAE (Model ):
@@ -427,6 +432,11 @@ def metrics(self):
427432 self .kl_loss_tracker ,
428433 ]
429434
435+ def call (self , inputs ):
436+ z_mean , z_log_var , z = self .encoder (inputs )
437+ reconstructed = self .decoder (z )
438+ return reconstructed
439+
430440 def train_step (self , data ):
431441 with tf .GradientTape () as tape :
432442 z_mean , z_log_var , z = self .encoder (data )
0 commit comments