Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit a887998

Browse files
Michael Marlenfacebook-github-bot
authored andcommitted
Multi Label Cross Entropy Loss
Summary: Creating a BinaryCrossEntropy loss function with logit based loss Differential Revision: D25440017 fbshipit-source-id: bad4b8a3cae8d82fc7fdff298090c0efcbf15839
1 parent 80bcfec commit a887998

2 files changed

Lines changed: 25 additions & 0 deletions

File tree

pytext/loss/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .loss import (
55
AUCPRHingeLoss,
66
BinaryCrossEntropyLoss,
7+
BinaryCrossEntropyWithLogitsLoss,
78
CosineEmbeddingLoss,
89
CrossEntropyLoss,
910
KLDivergenceBCELoss,
@@ -26,6 +27,7 @@
2627
"CrossEntropyLoss",
2728
"CosineEmbeddingLoss",
2829
"BinaryCrossEntropyLoss",
30+
"BinaryCrossEntropyWithLogitsLoss",
2931
"MultiLabelSoftMarginLoss",
3032
"KLDivergenceBCELoss",
3133
"KLDivergenceCELoss",

pytext/loss/loss.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,29 @@ def __call__(self, log_probs, targets, reduce=True):
6969
)
7070

7171

72+
class BinaryCrossEntropyWithLogitsLoss(Loss):
73+
class Config(ConfigBase):
74+
reduce: bool = True
75+
76+
def __call__(self, logits, targets, reduce=True):
77+
"""
78+
Computes 1-vs-all binary cross entropy loss for multiclass classification. However, unlike BinaryCrossEntropyLoss, we require targets to be a one-hot vector.
79+
"""
80+
81+
target_labels = targets[0].float()
82+
83+
"""
84+
`F.binary_cross_entropy_with_logits` requires the
85+
output of the previous function be already a FloatTensor.
86+
"""
87+
88+
loss = F.binary_cross_entropy_with_logits(
89+
precision.maybe_float(logits), target_labels, reduction="none"
90+
)
91+
92+
return loss.sum(-1).mean() if reduce else loss.sum(-1)
93+
94+
7295
class BinaryCrossEntropyLoss(Loss):
7396
class Config(ConfigBase):
7497
reweight_negative: bool = True

0 commit comments

Comments
 (0)