Skip to content

Commit d82731e

Browse files
committed
Support cutom pad value in AverageEmbeddingInputLayer
1 parent 877e7ac commit d82731e

1 file changed

Lines changed: 9 additions & 12 deletions

File tree

tensorlayer/layers.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ class AverageEmbeddingInputLayer(Layer):
662662
inputs : input placeholder or tensor; zeros are paddings
663663
vocabulary_size : an integer, the size of vocabulary
664664
embedding_size : an integer, the dimension of embedding vectors
665+
pad_value : an integer, the scalar pad value used in inputs
665666
name : a string, the name of the layer
666667
embeddings_initializer : the initializer of the embedding matrix
667668
embeddings_kwargs : kwargs to get embedding matrix variable
@@ -673,6 +674,7 @@ class AverageEmbeddingInputLayer(Layer):
673674
"""
674675
def __init__(
675676
self, inputs, vocabulary_size, embedding_size,
677+
pad_value=0,
676678
name='average_embedding_layer',
677679
embeddings_initializer=tf.random_uniform_initializer(-0.1, 0.1),
678680
embeddings_kwargs=None,
@@ -699,22 +701,17 @@ def __init__(
699701
self.embeddings, self.inputs,
700702
name='word_embeddings',
701703
)
702-
703-
# Masks used to ignore padding words
704-
masks = tf.expand_dims(
705-
tf.sign(self.inputs),
706-
axis=-1,
707-
name='masks',
708-
)
709-
sum_word_embeddings = tf.reduce_sum(
710-
word_embeddings * tf.cast(masks, tf.float32),
711-
axis=1,
704+
# Zero out embeddings of pad value
705+
masks = tf.not_equal(self.inputs, pad_value, name='masks')
706+
word_embeddings *= tf.cast(
707+
tf.expand_dims(masks, axis=-1),
708+
tf.float32,
712709
)
710+
sum_word_embeddings = tf.reduce_sum(word_embeddings, axis=1)
713711

714712
# Count number of non-padding words in each sentence
715-
# Used to commute average word embeddings in sentences
716713
sentence_lengths = tf.count_nonzero(
717-
self.inputs,
714+
masks,
718715
axis=1,
719716
keep_dims=True,
720717
dtype=tf.float32,

0 commit comments

Comments
 (0)