11# Copyright (c) OpenMMLab. All rights reserved.
22import numpy as np
3- from typing import TYPE_CHECKING , Dict , List , Sequence , Tuple
3+ from scipy import signal
4+ from typing import Dict , List , Sequence , Tuple
45
56from mmeval .core import BaseMetric
6- from mmeval .utils import try_import
77from .utils .image_transforms import reorder_image
88
9- if TYPE_CHECKING :
10- from scipy import signal
11- else :
12- signal = try_import ('scipy.signal' )
13-
149
1510class MultiScaleStructureSimilarity (BaseMetric ):
1611 """MS-SSIM (Multi-Scale Structure Similarity) metric.
@@ -34,13 +29,13 @@ class MultiScaleStructureSimilarity(BaseMetric):
3429 between the maximum the and minimum allowed values).
3530 Defaults to 255.
3631 filter_size (int): Size of blur kernel to use (will be reduced for
37- small images). Default to 11.
32+ small images). Defaults to 11.
3833 filter_sigma (float): Standard deviation for Gaussian blur kernel (will
39- be reduced for small images). Default to 1.5.
34+ be reduced for small images). Defaults to 1.5.
4035 k1 (float): Constant used to maintain stability in the SSIM calculation
41- (0.01 in the original paper). Default to 0.01.
36+ (0.01 in the original paper). Defaults to 0.01.
4237 k2 (float): Constant used to maintain stability in the SSIM calculation
43- (0.03 in the original paper). Default to 0.03.
38+ (0.03 in the original paper). Defaults to 0.03.
4439 weights (List[float]): List of weights for each level. Defaults to
4540 [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]. Noted that the default
4641 weights don't sum to 1.0 but do match the paper / matlab code.
@@ -84,13 +79,15 @@ def __init__(self,
8479 self .weights = np .array (weights )
8580
8681 def add (self , predictions : Sequence [np .ndarray ]) -> None : # type: ignore # yapf: disable # noqa: E501
87- """Add PSNR score of batch to ``self._results``
82+ """Add a bunch of images to calculate metric result.
8883
8984 Args:
9085 predictions (Sequence[np.ndarray]): Predictions of the model. The
91- number of elements in the Sequence must be divisible by 2.
92- The channel order of each element should align with
93- `self.input_order` and the range should be [0, 255].
86+ number of elements in the Sequence must be divisible by 2, and
87+ the width and height of each element must be divisible by 2 **
88+ num_scale (`self.weights.size`). The channel order of each
89+ element should align with `self.input_order` and the range
90+ should be [0, 255].
9491 """
9592
9693 num_samples = len (predictions )
@@ -103,6 +100,16 @@ def add(self, predictions: Sequence[np.ndarray]) -> None: # type: ignore # yapf
103100 reorder_image (pred , self .input_order ) for pred in predictions [1 ::2 ]
104101 ]
105102
103+ least_size = 2 ** self .weights .size
104+ assert all ([
105+ sample .shape [0 ] % least_size == 0 for sample in half1
106+ ]), ('The height and width of each sample must be divisible by '
107+ f'{ least_size } (2 ** len(self.weights.size)).' )
108+ assert all ([
109+ sample .shape [0 ] % least_size == 0 for sample in half2
110+ ]), ('The height and width of each sample must be divisible by '
111+ f'{ least_size } (2 ** self.weights.size).' )
112+
106113 half1 = np .stack (half1 , axis = 0 ).astype (np .uint8 )
107114 half2 = np .stack (half2 , axis = 0 ).astype (np .uint8 )
108115
@@ -131,7 +138,7 @@ def compute_ms_ssim(self, img1: np.array, img2: np.array) -> List[float]:
131138 img2 (ndarray): Images with range [0, 255] and order "NHWC".
132139
133140 Returns:
134- np.ndarray: MS-SSIM score between `img1` and `img2`.
141+ np.ndarray: MS-SSIM score between `img1` and `img2` of shape (N, ) .
135142 """
136143 if img1 .shape != img2 .shape :
137144 raise RuntimeError (
@@ -227,15 +234,15 @@ def _ssim_for_multi_scale(
227234 img2 (np.ndarray): Images with range [0, 255] and order "NHWC".
228235 max_val (int): the dynamic range of the images (i.e., the
229236 difference between the maximum the and minimum allowed
230- values). Default to 255.
237+ values). Defaults to 255.
231238 filter_size (int): Size of blur kernel to use (will be reduced for
232- small images). Default to 11.
239+ small images). Defaults to 11.
233240 filter_sigma (float): Standard deviation for Gaussian blur kernel (
234- will be reduced for small images). Default to 1.5.
241+ will be reduced for small images). Defaults to 1.5.
235242 k1 (float): Constant used to maintain stability in the SSIM
236- calculation (0.01 in the original paper). Default to 0.01.
243+ calculation (0.01 in the original paper). Defaults to 0.01.
237244 k2 (float): Constant used to maintain stability in the SSIM
238- calculation (0.03 in the original paper). Default to 0.03.
245+ calculation (0.03 in the original paper). Defaults to 0.03.
239246
240247 Returns:
241248 tuple: Pair containing the mean SSIM and contrast sensitivity
0 commit comments