@@ -662,6 +662,7 @@ class AverageEmbeddingInputLayer(Layer):
662662 inputs : input placeholder or tensor; zeros are paddings
663663 vocabulary_size : an integer, the size of vocabulary
664664 embedding_size : an integer, the dimension of embedding vectors
665+ pad_value : an integer, the scalar pad value used in inputs
665666 name : a string, the name of the layer
666667 embeddings_initializer : the initializer of the embedding matrix
667668 embeddings_kwargs : kwargs to get embedding matrix variable
@@ -673,6 +674,7 @@ class AverageEmbeddingInputLayer(Layer):
673674 """
674675 def __init__ (
675676 self , inputs , vocabulary_size , embedding_size ,
677+ pad_value = 0 ,
676678 name = 'average_embedding_layer' ,
677679 embeddings_initializer = tf .random_uniform_initializer (- 0.1 , 0.1 ),
678680 embeddings_kwargs = None ,
@@ -699,22 +701,17 @@ def __init__(
699701 self .embeddings , self .inputs ,
700702 name = 'word_embeddings' ,
701703 )
702-
703- # Masks used to ignore padding words
704- masks = tf .expand_dims (
705- tf .sign (self .inputs ),
706- axis = - 1 ,
707- name = 'masks' ,
708- )
709- sum_word_embeddings = tf .reduce_sum (
710- word_embeddings * tf .cast (masks , tf .float32 ),
711- axis = 1 ,
704+ # Zero out embeddings of pad value
705+ masks = tf .not_equal (self .inputs , pad_value , name = 'masks' )
706+ word_embeddings *= tf .cast (
707+ tf .expand_dims (masks , axis = - 1 ),
708+ tf .float32 ,
712709 )
710+ sum_word_embeddings = tf .reduce_sum (word_embeddings , axis = 1 )
713711
714712 # Count number of non-padding words in each sentence
715- # Used to commute average word embeddings in sentences
716713 sentence_lengths = tf .count_nonzero (
717- self . inputs ,
714+ masks ,
718715 axis = 1 ,
719716 keep_dims = True ,
720717 dtype = tf .float32 ,
0 commit comments