66import keras
77import numpy as np
88import pandas as pd
9+ import tensorflow as tf
910from keras import backend as K
10- from keras .layers import Input
11- from keras .layers .merge import _Merge
11+ from keras .layers import Input , Layer
1212from keras .models import Model
1313from scipy import integrate , stats
1414
1818LOGGER = logging .getLogger (__name__ )
1919
2020
21- class RandomWeightedAverage (_Merge ):
22- def _merge_function (self , inputs ):
23- alpha = K .random_uniform ((64 , 1 , 1 ))
21+ class RandomWeightedAverage (Layer ):
22+ def __init__ (self , batch_size ):
23+ super ().__init__ ()
24+ self .batch_size = batch_size
25+
26+ def call (self , inputs , ** kwargs ):
27+ alpha = tf .random_uniform ((self .batch_size , 1 , 1 , 1 ))
2428 return (alpha * inputs [0 ]) + ((1 - alpha ) * inputs [1 ])
2529
30+ def compute_output_shape (self , input_shape ):
31+ return input_shape [0 ]
32+
2633
2734class CycleGAN ():
2835 """CycleGAN class"""
@@ -130,7 +137,7 @@ def _build_cyclegan(self, **kwargs):
130137 z_ = self .encoder (x )
131138 fake_x = self .critic_x (x_ )
132139 valid_x = self .critic_x (x )
133- interpolated_x = RandomWeightedAverage ()([x , x_ ])
140+ interpolated_x = RandomWeightedAverage (self . batch_size )([x , x_ ])
134141
135142 validity_interpolated_x = self .critic_x (interpolated_x )
136143 partial_gp_loss_x = partial (self ._gradient_penalty_loss , averaged_samples = interpolated_x )
@@ -143,7 +150,7 @@ def _build_cyclegan(self, **kwargs):
143150
144151 fake_z = self .critic_z (z_ )
145152 valid_z = self .critic_z (z )
146- interpolated_z = RandomWeightedAverage ()([z , z_ ])
153+ interpolated_z = RandomWeightedAverage (self . batch_size )([z , z_ ])
147154 validity_interpolated_z = self .critic_z (interpolated_z )
148155 partial_gp_loss_z = partial (self ._gradient_penalty_loss , averaged_samples = interpolated_z )
149156 partial_gp_loss_z .__name__ = 'gradient_penalty'
@@ -210,10 +217,11 @@ def predict(self, X):
210217 N-dimensional array containing the input sequences for the model.
211218
212219 Returns:
213- ndarray:
214- N-dimensional array containing the reconstructions for each input sequence.
215- ndarray:
216- N-dimensional array containing the critic scores for each input sequence.
220+ typle:
221+ ndarray:
222+ N-dimensional array containing the reconstructions for each input sequence.
223+ ndarray:
224+ N-dimensional array containing the critic scores for each input sequence.
217225 """
218226 X = X .reshape ((- 1 , self .shape [0 ], 1 ))
219227 z_ = self .encoder .predict (X )
0 commit comments