11import torch
2- from numpy .ma .testutils import assert_allclose
32from pytest import raises
43from torch import Tensor
54from torch .testing import assert_close
@@ -126,8 +125,8 @@ def assert_stateful(aggregator: Aggregator, matrix: Tensor) -> None:
126125 aggregator .reset ()
127126 second_pair = (aggregator (matrix ), aggregator (matrix ))
128127
129- assert_allclose (first_pair [0 ], second_pair [0 ], atol = 0.0 , rtol = 0.0 )
130- assert_allclose (first_pair [1 ], second_pair [1 ], atol = 0.0 , rtol = 0.0 )
128+ assert_close (first_pair [0 ], second_pair [0 ], atol = 0.0 , rtol = 0.0 )
129+ assert_close (first_pair [1 ], second_pair [1 ], atol = 0.0 , rtol = 0.0 )
131130
132131
133132def assert_stateless (aggregator : Aggregator , matrix : Tensor ) -> None :
@@ -140,4 +139,4 @@ def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None:
140139 first = aggregator (matrix )
141140 second = aggregator (matrix )
142141
143- assert_allclose (first , second , atol = 0.0 , rtol = 0.0 )
142+ assert_close (first , second , atol = 0.0 , rtol = 0.0 )
0 commit comments