Skip to content

Commit c9c5790

Browse files
fix: problem of type with BCEwithLogitsLoss
1 parent d1987ef commit c9c5790

2 files changed

Lines changed: 5 additions & 1 deletion

File tree

torchTextClassifiers/dataset/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def collate_fn(self, batch):
100100
labels_tensor[rows, cols] = 1
101101

102102
else:
103-
labels_tensor = torch.tensor(labels, dtype=torch.long)
103+
labels_tensor = torch.tensor(labels)
104104
else:
105105
labels_tensor = None
106106

torchTextClassifiers/model/lightning.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def training_step(self, batch, batch_idx: int) -> torch.Tensor:
7676
targets = batch["labels"]
7777

7878
outputs = self.forward(batch)
79+
80+
if isinstance(self.loss, torch.nn.BCEWithLogitsLoss):
81+
targets = targets.float()
82+
7983
loss = self.loss(outputs, targets)
8084
self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
8185
accuracy = self.accuracy_fn(outputs, targets)

0 commit comments

Comments
 (0)