diff --git a/src/torchmetrics/functional/image/psnrb.py b/src/torchmetrics/functional/image/psnrb.py index 88007d88635..bb91c63d03b 100644 --- a/src/torchmetrics/functional/image/psnrb.py +++ b/src/torchmetrics/functional/image/psnrb.py @@ -131,6 +131,9 @@ def peak_signal_noise_ratio_with_blocked_effect( tensor(7.8402) """ + if not isinstance(block_size, int) or block_size < 1: + raise ValueError("Argument ``block_size`` should be a positive integer") + if isinstance(data_range, tuple): preds = torch.clamp(preds, min=data_range[0], max=data_range[1]) target = torch.clamp(target, min=data_range[0], max=data_range[1]) diff --git a/src/torchmetrics/image/psnrb.py b/src/torchmetrics/image/psnrb.py index 25fa99bbd5c..168e204c5c0 100644 --- a/src/torchmetrics/image/psnrb.py +++ b/src/torchmetrics/image/psnrb.py @@ -79,7 +79,7 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - if not isinstance(block_size, int) and block_size < 1: + if not isinstance(block_size, int) or block_size < 1: raise ValueError("Argument ``block_size`` should be a positive integer") self.block_size = block_size diff --git a/tests/unittests/image/test_psnrb.py b/tests/unittests/image/test_psnrb.py index 6ed65bfd3c3..893a70bb2eb 100644 --- a/tests/unittests/image/test_psnrb.py +++ b/tests/unittests/image/test_psnrb.py @@ -135,5 +135,21 @@ def test_psnr_half_gpu(self, preds, target): def test_error_on_color_images(): """Test that appropriate error is raised when color images are passed to PSNRB metric.""" - with pytest.raises(ValueError, match="`psnrb` metric expects grayscale images.*"): + with pytest.raises(ValueError, match=r"`psnrb` metric expects grayscale images.*"): peak_signal_noise_ratio_with_blocked_effect(torch.rand(1, 3, 16, 16), torch.rand(1, 3, 16, 16), data_range=1.0) + + +@pytest.mark.parametrize("block_size", [0, -5, 1.5, "foo"]) +def test_error_on_invalid_block_size_class(block_size): + """Test that ValueError is raised for non-positive or non-integer block_size in the class API.""" + with pytest.raises(ValueError, match="Argument ``block_size`` should be a positive integer"): + PeakSignalNoiseRatioWithBlockedEffect(data_range=1.0, block_size=block_size) + + +@pytest.mark.parametrize("block_size", [0, -5, 1.5, "foo"]) +def test_error_on_invalid_block_size_functional(block_size): + """Test that ValueError is raised for non-positive or non-integer block_size in the functional API.""" + with pytest.raises(ValueError, match="Argument ``block_size`` should be a positive integer"): + peak_signal_noise_ratio_with_blocked_effect( + torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16), data_range=1.0, block_size=block_size + )