Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/torchmetrics/functional/image/psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def peak_signal_noise_ratio_with_blocked_effect(
tensor(7.8402)

"""
if not isinstance(block_size, int) or block_size < 1:
raise ValueError("Argument ``block_size`` should be a positive integer")

if isinstance(data_range, tuple):
preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
target = torch.clamp(target, min=data_range[0], max=data_range[1])
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not isinstance(block_size, int) and block_size < 1:
if not isinstance(block_size, int) or block_size < 1:
raise ValueError("Argument ``block_size`` should be a positive integer")
self.block_size = block_size

Expand Down
18 changes: 17 additions & 1 deletion tests/unittests/image/test_psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,21 @@ def test_psnr_half_gpu(self, preds, target):

def test_error_on_color_images():
"""Test that appropriate error is raised when color images are passed to PSNRB metric."""
with pytest.raises(ValueError, match="`psnrb` metric expects grayscale images.*"):
with pytest.raises(ValueError, match=r"`psnrb` metric expects grayscale images.*"):
peak_signal_noise_ratio_with_blocked_effect(torch.rand(1, 3, 16, 16), torch.rand(1, 3, 16, 16), data_range=1.0)


@pytest.mark.parametrize("block_size", [0, -5, 1.5, "foo"])
def test_error_on_invalid_block_size_class(block_size):
"""Test that ValueError is raised for non-positive or non-integer block_size in the class API."""
with pytest.raises(ValueError, match="Argument ``block_size`` should be a positive integer"):
PeakSignalNoiseRatioWithBlockedEffect(data_range=1.0, block_size=block_size)


@pytest.mark.parametrize("block_size", [0, -5, 1.5, "foo"])
def test_error_on_invalid_block_size_functional(block_size):
"""Test that ValueError is raised for non-positive or non-integer block_size in the functional API."""
with pytest.raises(ValueError, match="Argument ``block_size`` should be a positive integer"):
peak_signal_noise_ratio_with_blocked_effect(
torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16), data_range=1.0, block_size=block_size
)
Loading