diff --git a/DIN/modules.py b/DIN/modules.py index 60a84b3..6c75542 100644 --- a/DIN/modules.py +++ b/DIN/modules.py @@ -17,7 +17,11 @@ def __init__(self, att_hidden_units, activation='prelu'): """ """ super(Attention_Layer, self).__init__() - self.att_dense = [Dense(unit, activation=activation) for unit in att_hidden_units] + if activation=="prelu": + self.att_dense = [Dense(unit, activation=tf.keras.layers.PReLU()) for unit in att_hidden_units] + else: + self.att_dense = [Dense(unit, activation=activation) for unit in att_hidden_units] + self.att_final_dense = Dense(1) def call(self, inputs):