Skip to content

Commit af361cc

Browse files
Follow-up for Fixing SSIM Test Method Names (#8750)
This fixes some other issues which were described in #8746 . ### Description This is a follow-up for #8746, thanks and credits to @ericspod! 👍 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Benedikt Johannes <benedikt.johannes.hofer@gmail.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent d3d0209 commit af361cc

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/metrics/test_ssim_metric.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class TestSSIMMetric(unittest.TestCase):
2323

24-
def test2d_gaussian(self):
24+
def test_2d_gaussian(self):
2525
set_determinism(0)
2626
preds = torch.abs(torch.randn(2, 3, 16, 16))
2727
target = torch.abs(torch.randn(2, 3, 16, 16))
@@ -32,9 +32,9 @@ def test2d_gaussian(self):
3232
metric(preds, target)
3333
result = metric.aggregate()
3434
expected_value = 0.045415
35-
self.assertTrue(expected_value - result.item() < 0.000001)
35+
self.assertTrue(abs(expected_value - result.item()) < 0.000001)
3636

37-
def test2d_uniform(self):
37+
def test_2d_uniform(self):
3838
set_determinism(0)
3939
preds = torch.abs(torch.randn(2, 3, 16, 16))
4040
target = torch.abs(torch.randn(2, 3, 16, 16))
@@ -45,9 +45,9 @@ def test2d_uniform(self):
4545
metric(preds, target)
4646
result = metric.aggregate()
4747
expected_value = 0.050103
48-
self.assertTrue(expected_value - result.item() < 0.000001)
48+
self.assertTrue(abs(expected_value - result.item()) < 0.000001)
4949

50-
def test3d_gaussian(self):
50+
def test_3d_gaussian(self):
5151
set_determinism(0)
5252
preds = torch.abs(torch.randn(2, 3, 16, 16, 16))
5353
target = torch.abs(torch.randn(2, 3, 16, 16, 16))

0 commit comments

Comments
 (0)