Skip to content

Question on the code of masked cross entropy loss #11

@LuckyJinging

Description

@LuckyJinging

When I use the Player2Vec algorithm, I am confused by the masked cross entropy loss. mask/tf.reduce_sum(mask) has taken the average of items which are equal to 1. Why does it need to do another global average (tf.reduce_mean(loss)) instead of summing (tf.reduce_sum(loss))?

def masked_softmax_cross_entropy(preds: tf.Tensor, labels: tf.Tensor,
                                 mask: tf.Tensor) -> tf.Tensor:
    """
    Softmax cross-entropy loss with masking.
    :param preds: the last layer logits of the input data
    :param labels: the labels of the input data
    :param mask: the mask for train/val/test data
    """
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.maximum(tf.reduce_sum(mask), tf.constant([1.]))
    loss *= mask
    return tf.reduce_mean(loss)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions