Skip to content

Commit 477f19f

Browse files
committed
Fix asserts
1 parent 62a269f commit 477f19f

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

tests/unit/aggregation/_asserts.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
from numpy.ma.testutils import assert_allclose
32
from pytest import raises
43
from torch import Tensor
54
from torch.testing import assert_close
@@ -126,8 +125,8 @@ def assert_stateful(aggregator: Aggregator, matrix: Tensor) -> None:
126125
aggregator.reset()
127126
second_pair = (aggregator(matrix), aggregator(matrix))
128127

129-
assert_allclose(first_pair[0], second_pair[0], atol=0.0, rtol=0.0)
130-
assert_allclose(first_pair[1], second_pair[1], atol=0.0, rtol=0.0)
128+
assert_close(first_pair[0], second_pair[0], atol=0.0, rtol=0.0)
129+
assert_close(first_pair[1], second_pair[1], atol=0.0, rtol=0.0)
131130

132131

133132
def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None:
@@ -140,4 +139,4 @@ def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None:
140139
first = aggregator(matrix)
141140
second = aggregator(matrix)
142141

143-
assert_allclose(first, second, atol=0.0, rtol=0.0)
142+
assert_close(first, second, atol=0.0, rtol=0.0)

0 commit comments

Comments
 (0)