diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index e98ac0b14ab..1bb0ad3d64e 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence from typing import Optional, Union import torch @@ -25,7 +26,7 @@ def _psnr_compute( num_obs: Tensor, data_range: Tensor, base: float = 10.0, - reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean", ) -> Tensor: """Compute peak signal-to-noise ratio. @@ -58,7 +59,7 @@ def _psnr_compute( def _psnr_update( preds: Tensor, target: Tensor, - dim: Optional[Union[int, tuple[int, ...]]] = None, + dim: Optional[Union[int, Sequence[int]]] = None, ) -> tuple[Tensor, Tensor]: """Update and return variables required to compute peak signal-to-noise ratio. @@ -95,10 +96,10 @@ def _psnr_update( def peak_signal_noise_ratio( preds: Tensor, target: Tensor, - data_range: Union[float, tuple[float, float]], + data_range: Union[float, Sequence[float]], base: float = 10.0, - reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, tuple[int, ...]]] = None, + reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean", + dim: Optional[Union[int, Sequence[int]]] = None, ) -> Tensor: """Compute the peak signal-to-noise ratio. @@ -106,7 +107,7 @@ def peak_signal_noise_ratio( preds: estimated signal target: groun truth signal data_range: - the range of the data. If a tuple is provided then the range is calculated as the difference and + the range of the data. If a Sequence is provided then the range is calculated as the difference and input is clamped between the values. base: a base of a logarithm to use reduction: a method to reduce metric score over labels. @@ -136,7 +137,7 @@ def peak_signal_noise_ratio( if dim is None and reduction != "elementwise_mean": rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.") - if isinstance(data_range, tuple): + if isinstance(data_range, Sequence): preds = torch.clamp(preds, min=data_range[0], max=data_range[1]) target = torch.clamp(target, min=data_range[0], max=data_range[1]) data_range_val = tensor(data_range[1] - data_range[0]) diff --git a/src/torchmetrics/functional/image/psnrb.py b/src/torchmetrics/functional/image/psnrb.py index 88007d88635..c6e24ac5114 100644 --- a/src/torchmetrics/functional/image/psnrb.py +++ b/src/torchmetrics/functional/image/psnrb.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from collections.abc import Sequence from typing import Union import torch @@ -102,7 +103,7 @@ def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> tuple[T def peak_signal_noise_ratio_with_blocked_effect( preds: Tensor, target: Tensor, - data_range: Union[float, tuple[float, float]], + data_range: Union[float, Sequence[float]], block_size: int = 8, ) -> Tensor: r"""Computes `Peak Signal to Noise Ratio With Blocked Effect` (PSNRB) metrics. @@ -115,7 +116,7 @@ def peak_signal_noise_ratio_with_blocked_effect( Args: preds: estimated signal target: ground truth signal - data_range: the range of the data. If a tuple is provided then the range is calculated as the difference and + data_range: the range of the data. If a Sequence is provided then the range is calculated as the difference and input is clamped between the values. block_size: integer indication the block size @@ -131,7 +132,7 @@ def peak_signal_noise_ratio_with_blocked_effect( tensor(7.8402) """ - if isinstance(data_range, tuple): + if isinstance(data_range, Sequence): preds = torch.clamp(preds, min=data_range[0], max=data_range[1]) target = torch.clamp(target, min=data_range[0], max=data_range[1]) data_range_val = tensor(data_range[1] - data_range[0]) diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index c3c33fcb5c1..ecf27ce1e2a 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -48,7 +48,7 @@ class PeakSignalNoiseRatio(Metric): Args: data_range: - the range of the data. If a tuple is provided, then the range is calculated as the difference and + the range of the data. If a Sequence is provided, then the range is calculated as the difference and input is clamped between the values. base: a base of a logarithm to use. reduction: a method to reduce metric score over labels. @@ -80,10 +80,10 @@ class PeakSignalNoiseRatio(Metric): def __init__( self, - data_range: Union[float, tuple[float, float]], + data_range: Union[float, Sequence[float]], base: float = 10.0, - reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", - dim: Optional[Union[int, tuple[int, ...]]] = None, + reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean", + dim: Optional[Union[int, Sequence[int]]] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -99,7 +99,7 @@ def __init__( self.add_state("total", default=[], dist_reduce_fx="cat") self.clamping_fn = None - if isinstance(data_range, tuple): + if isinstance(data_range, Sequence): self.add_state("data_range", default=tensor(data_range[1] - data_range[0]), dist_reduce_fx="mean") self.clamping_fn = partial(torch.clamp, min=data_range[0], max=data_range[1]) else: @@ -107,7 +107,7 @@ def __init__( self.base = base self.reduction = reduction - self.dim = tuple(dim) if isinstance(dim, Sequence) else dim + self.dim = dim def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" diff --git a/src/torchmetrics/image/psnrb.py b/src/torchmetrics/image/psnrb.py index 25fa99bbd5c..0f3f68a0840 100644 --- a/src/torchmetrics/image/psnrb.py +++ b/src/torchmetrics/image/psnrb.py @@ -48,7 +48,7 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric): - ``psnrb`` (:class:`~torch.Tensor`): float scalar tensor with aggregated PSNRB value Args: - data_range: the range of the data. If a tuple is provided then the range is calculated as the difference and + data_range: the range of the data. If a Sequence is provided then the range is calculated as the difference and input is clamped between the values. block_size: integer indication the block size kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -74,7 +74,7 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric): def __init__( self, - data_range: Union[float, tuple[float, float]], + data_range: Union[float, Sequence[float]], block_size: int = 8, **kwargs: Any, ) -> None: @@ -87,7 +87,7 @@ def __init__( self.add_state("total", default=tensor(0), dist_reduce_fx="sum") self.add_state("bef", default=tensor(0.0), dist_reduce_fx="sum") - if isinstance(data_range, tuple): + if isinstance(data_range, Sequence): self.add_state("data_range", default=tensor(data_range[1] - data_range[0]), dist_reduce_fx="mean") self.clamping_fn = lambda x: torch.clamp(x, min=data_range[0], max=data_range[1]) else: