11from pytest import mark
22from torch import Tensor
33from torch .testing import assert_close
4- from utils .tensors import ones_ , randn_
4+ from utils .tensors import ones_ , randn_ , tensor_
55
66try :
77 from torchjd .aggregation import NashMTL
@@ -19,6 +19,10 @@ def _make_aggregator(matrix: Tensor) -> NashMTL:
1919
2020
2121standard_pairs = [(_make_aggregator (matrix ), matrix ) for matrix in nash_mtl_matrices ]
22+ edge_case_matrices = [
23+ tensor_ ([[0.0 , 0.0 ], [0.0 , 1.0 ]]) # This leads to a (caught) ValueError in _solve_optimization.
24+ ]
25+ edge_case_pairs = [(_make_aggregator (matrix ), matrix ) for matrix in edge_case_matrices ]
2226requires_grad_pairs = [(NashMTL (n_tasks = 3 ), ones_ (3 , 5 , requires_grad = True ))]
2327
2428
@@ -27,8 +31,13 @@ def _make_aggregator(matrix: Tensor) -> NashMTL:
2731@mark .filterwarnings (
2832 "ignore:Solution may be inaccurate." ,
2933 "ignore:You are solving a parameterized problem that is not DPP." ,
34+ "ignore:divide by zero encountered in divide" ,
35+ "ignore:divide by zero encountered in true_divide" ,
36+ "ignore:overflow encountered in divide" ,
37+ "ignore:overflow encountered in true_divide" ,
38+ "ignore:invalid value encountered in matmul" ,
3039)
31- @mark .parametrize (["aggregator" , "matrix" ], standard_pairs )
40+ @mark .parametrize (["aggregator" , "matrix" ], standard_pairs + edge_case_pairs )
3241def test_expected_structure (aggregator : NashMTL , matrix : Tensor ) -> None :
3342 assert_expected_structure (aggregator , matrix )
3443
0 commit comments