Skip to content

Commit fb3387e

Browse files
committed
add shape checking for MS-SSIM input and revise docstring as comment
1 parent 40072ed commit fb3387e

3 files changed

Lines changed: 34 additions & 28 deletions

File tree

mmeval/metrics/ms_ssim.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import numpy as np
3-
from typing import TYPE_CHECKING, Dict, List, Sequence, Tuple
3+
from scipy import signal
4+
from typing import Dict, List, Sequence, Tuple
45

56
from mmeval.core import BaseMetric
6-
from mmeval.utils import try_import
77
from .utils.image_transforms import reorder_image
88

9-
if TYPE_CHECKING:
10-
from scipy import signal
11-
else:
12-
signal = try_import('scipy.signal')
13-
149

1510
class MultiScaleStructureSimilarity(BaseMetric):
1611
"""MS-SSIM (Multi-Scale Structure Similarity) metric.
@@ -34,13 +29,13 @@ class MultiScaleStructureSimilarity(BaseMetric):
3429
between the maximum the and minimum allowed values).
3530
Defaults to 255.
3631
filter_size (int): Size of blur kernel to use (will be reduced for
37-
small images). Default to 11.
32+
small images). Defaults to 11.
3833
filter_sigma (float): Standard deviation for Gaussian blur kernel (will
39-
be reduced for small images). Default to 1.5.
34+
be reduced for small images). Defaults to 1.5.
4035
k1 (float): Constant used to maintain stability in the SSIM calculation
41-
(0.01 in the original paper). Default to 0.01.
36+
(0.01 in the original paper). Defaults to 0.01.
4237
k2 (float): Constant used to maintain stability in the SSIM calculation
43-
(0.03 in the original paper). Default to 0.03.
38+
(0.03 in the original paper). Defaults to 0.03.
4439
weights (List[float]): List of weights for each level. Defaults to
4540
[0.0448, 0.2856, 0.3001, 0.2363, 0.1333]. Noted that the default
4641
weights don't sum to 1.0 but do match the paper / matlab code.
@@ -84,13 +79,15 @@ def __init__(self,
8479
self.weights = np.array(weights)
8580

8681
def add(self, predictions: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501
87-
"""Add PSNR score of batch to ``self._results``
82+
"""Add a bunch of images to calculate metric result.
8883
8984
Args:
9085
predictions (Sequence[np.ndarray]): Predictions of the model. The
91-
number of elements in the Sequence must be divisible by 2.
92-
The channel order of each element should align with
93-
`self.input_order` and the range should be [0, 255].
86+
number of elements in the Sequence must be divisible by 2, and
87+
the width and height of each element must be divisible by 2 **
88+
num_scale (`self.weights.size`). The channel order of each
89+
element should align with `self.input_order` and the range
90+
should be [0, 255].
9491
"""
9592

9693
num_samples = len(predictions)
@@ -103,6 +100,16 @@ def add(self, predictions: Sequence[np.ndarray]) -> None: # type: ignore # yapf
103100
reorder_image(pred, self.input_order) for pred in predictions[1::2]
104101
]
105102

103+
least_size = 2**self.weights.size
104+
assert all([
105+
sample.shape[0] % least_size == 0 for sample in half1
106+
]), ('The height and width of each sample must be divisible by '
107+
f'{least_size} (2 ** len(self.weights.size)).')
108+
assert all([
109+
sample.shape[0] % least_size == 0 for sample in half2
110+
]), ('The height and width of each sample must be divisible by '
111+
f'{least_size} (2 ** self.weights.size).')
112+
106113
half1 = np.stack(half1, axis=0).astype(np.uint8)
107114
half2 = np.stack(half2, axis=0).astype(np.uint8)
108115

@@ -131,7 +138,7 @@ def compute_ms_ssim(self, img1: np.array, img2: np.array) -> List[float]:
131138
img2 (ndarray): Images with range [0, 255] and order "NHWC".
132139
133140
Returns:
134-
np.ndarray: MS-SSIM score between `img1` and `img2`.
141+
np.ndarray: MS-SSIM score between `img1` and `img2` of shape (N, ).
135142
"""
136143
if img1.shape != img2.shape:
137144
raise RuntimeError(
@@ -227,15 +234,15 @@ def _ssim_for_multi_scale(
227234
img2 (np.ndarray): Images with range [0, 255] and order "NHWC".
228235
max_val (int): the dynamic range of the images (i.e., the
229236
difference between the maximum the and minimum allowed
230-
values). Default to 255.
237+
values). Defaults to 255.
231238
filter_size (int): Size of blur kernel to use (will be reduced for
232-
small images). Default to 11.
239+
small images). Defaults to 11.
233240
filter_sigma (float): Standard deviation for Gaussian blur kernel (
234-
will be reduced for small images). Default to 1.5.
241+
will be reduced for small images). Defaults to 1.5.
235242
k1 (float): Constant used to maintain stability in the SSIM
236-
calculation (0.01 in the original paper). Default to 0.01.
243+
calculation (0.01 in the original paper). Defaults to 0.01.
237244
k2 (float): Constant used to maintain stability in the SSIM
238-
calculation (0.03 in the original paper). Default to 0.03.
245+
calculation (0.03 in the original paper). Defaults to 0.03.
239246
240247
Returns:
241248
tuple: Pair containing the mean SSIM and contrast sensitivity

mmeval/metrics/swd.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import numpy as np
3-
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
3+
from scipy import ndimage
4+
from typing import Any, Dict, List, Sequence
45

56
from mmeval.core import BaseMetric
6-
from mmeval.utils import try_import
7-
8-
if TYPE_CHECKING:
9-
from scipy import ndimage
10-
else:
11-
ndimage = try_import('scipy.ndimage')
127

138

149
class SlicedWassersteinDistance(BaseMetric):

tests/test_metrics/test_ms_ssim.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,7 @@ def test_raise_error():
5858
np.random.randint(0, 255, (64, 64, 3)),
5959
np.random.randint(0, 255, (64, 64, 3))
6060
)
61+
62+
with pytest.raises(AssertionError):
63+
inputs = [np.random.randint(0, 255, (3, 32, 32))] * 3
64+
ms_ssim(inputs)

0 commit comments

Comments
 (0)