Skip to content

Commit 8974877

Browse files
authored
fix(aggregation): Add fallback in NashMTL (#640)
1 parent e2f2b18 commit 8974877

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Fixed
12+
13+
- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example
14+
on the matrix [[0., 0.], [0., 1.]]).
15+
1116
## [0.9.0] - 2026-02-24
1217

1318
### Added

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,10 @@ def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray:
158158

159159
try:
160160
self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100)
161-
except SolverError:
162-
# On macOS, this can happen with: Solver 'ECOS' failed.
161+
except (SolverError, ValueError):
162+
# On macOS, SolverError can happen with: Solver 'ECOS' failed.
163163
# No idea why. The corresponding matrix is of shape [9, 11] with rank 5.
164+
# ValueError happens with for example matrix [[0., 0.], [0., 1.]].
164165
# Maybe other exceptions can happen in other cases.
165166
self.alpha_param.value = self.prvs_alpha_param.value
166167

tests/unit/aggregation/test_nash_mtl.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pytest import mark
22
from torch import Tensor
33
from torch.testing import assert_close
4-
from utils.tensors import ones_, randn_
4+
from utils.tensors import ones_, randn_, tensor_
55

66
try:
77
from torchjd.aggregation import NashMTL
@@ -19,6 +19,10 @@ def _make_aggregator(matrix: Tensor) -> NashMTL:
1919

2020

2121
standard_pairs = [(_make_aggregator(matrix), matrix) for matrix in nash_mtl_matrices]
22+
edge_case_matrices = [
23+
tensor_([[0.0, 0.0], [0.0, 1.0]]) # This leads to a (caught) ValueError in _solve_optimization.
24+
]
25+
edge_case_pairs = [(_make_aggregator(matrix), matrix) for matrix in edge_case_matrices]
2226
requires_grad_pairs = [(NashMTL(n_tasks=3), ones_(3, 5, requires_grad=True))]
2327

2428

@@ -27,8 +31,13 @@ def _make_aggregator(matrix: Tensor) -> NashMTL:
2731
@mark.filterwarnings(
2832
"ignore:Solution may be inaccurate.",
2933
"ignore:You are solving a parameterized problem that is not DPP.",
34+
"ignore:divide by zero encountered in divide",
35+
"ignore:divide by zero encountered in true_divide",
36+
"ignore:overflow encountered in divide",
37+
"ignore:overflow encountered in true_divide",
38+
"ignore:invalid value encountered in matmul",
3039
)
31-
@mark.parametrize(["aggregator", "matrix"], standard_pairs)
40+
@mark.parametrize(["aggregator", "matrix"], standard_pairs + edge_case_pairs)
3241
def test_expected_structure(aggregator: NashMTL, matrix: Tensor) -> None:
3342
assert_expected_structure(aggregator, matrix)
3443

0 commit comments

Comments
 (0)