Binary class usage #541
-
|
In which format does the library expect Binary logit predictions to be for metric modules like Accuracy? I am a bit confused with the description here, My current binary segmentation model outputs are of the form Nx1xHxW (sigmoids of predicted logits) and targets are NxHxW. Currently, the following code snippet uses the Moreover, due to my output being single channel sigmoid probabilities, the implied_classes are also assigned wrongly. Previously, with pytorch lightning 1.1, this wasn't an issue as the Is there a specific reason to move away from the previous approach and how are we supposed to use the Metrics for the above case now? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
|
I just realized, this seems to be a bug because as per the docstring below, the prediction and target should both be flatttened arrays. The intended classes would then end up being wrong in any case as it only assigns the very first value of the array, in turn failing the class consistency check. |
Beta Was this translation helpful? Give feedback.
-
|
Is there an answer for this? |
Beta Was this translation helpful? Give feedback.
-
|
@tridivb — the shape handling issue ( For binary segmentation with sigmoid output (Nx1xHxW): from torchmetrics.classification import BinaryAccuracy, BinaryF1Score, BinaryJaccardIndex
# Squeeze the channel dim
preds = model(x).squeeze(1) # (N, H, W) — sigmoid probabilities
target = target # (N, H, W) — 0 or 1
acc = BinaryAccuracy(multidim_average="global")
f1 = BinaryF1Score(multidim_average="global")
iou = BinaryJaccardIndex()
acc(preds, target)
f1(preds, target)
iou(preds, target)Key differences from the old API:
The old Docs: Accuracy |
Beta Was this translation helpful? Give feedback.
@tridivb — the shape handling issue (
Nx1xHxWvsNxHxW) was a real pain point with the old unified API. Completely resolved with task-specific metrics in v1.9.0.For binary segmentation with sigmoid output (Nx1xHxW):
Key differences from the old API:
num_classesfor binary tasks