Skip to content

Commit 21f6b74

Browse files
authored
Merge branch 'main' into feature/gradvac
2 parents 0ec471c + 8974877 commit 21f6b74

3 files changed

Lines changed: 18 additions & 5 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ changelog does not include internal changes that do not affect the user.
1010

1111
### Added
1212

13-
- Added `GradVac` and `GradVacWeighting` from [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).
13+
- Added `GradVac` and `GradVacWeighting` from
14+
[Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).
15+
- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example
16+
on the matrix [[0., 0.], [0., 1.]]).
1417

1518
## [0.9.0] - 2026-02-24
1619

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,10 @@ def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray:
158158

159159
try:
160160
self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100)
161-
except SolverError:
162-
# On macOS, this can happen with: Solver 'ECOS' failed.
161+
except (SolverError, ValueError):
162+
# On macOS, SolverError can happen with: Solver 'ECOS' failed.
163163
# No idea why. The corresponding matrix is of shape [9, 11] with rank 5.
164+
# ValueError happens with for example matrix [[0., 0.], [0., 1.]].
164165
# Maybe other exceptions can happen in other cases.
165166
self.alpha_param.value = self.prvs_alpha_param.value
166167

tests/unit/aggregation/test_nash_mtl.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pytest import mark
22
from torch import Tensor
33
from torch.testing import assert_close
4-
from utils.tensors import ones_, randn_
4+
from utils.tensors import ones_, randn_, tensor_
55

66
try:
77
from torchjd.aggregation import NashMTL
@@ -19,6 +19,10 @@ def _make_aggregator(matrix: Tensor) -> NashMTL:
1919

2020

2121
standard_pairs = [(_make_aggregator(matrix), matrix) for matrix in nash_mtl_matrices]
22+
edge_case_matrices = [
23+
tensor_([[0.0, 0.0], [0.0, 1.0]]) # This leads to a (caught) ValueError in _solve_optimization.
24+
]
25+
edge_case_pairs = [(_make_aggregator(matrix), matrix) for matrix in edge_case_matrices]
2226
requires_grad_pairs = [(NashMTL(n_tasks=3), ones_(3, 5, requires_grad=True))]
2327

2428

@@ -27,8 +31,13 @@ def _make_aggregator(matrix: Tensor) -> NashMTL:
2731
@mark.filterwarnings(
2832
"ignore:Solution may be inaccurate.",
2933
"ignore:You are solving a parameterized problem that is not DPP.",
34+
"ignore:divide by zero encountered in divide",
35+
"ignore:divide by zero encountered in true_divide",
36+
"ignore:overflow encountered in divide",
37+
"ignore:overflow encountered in true_divide",
38+
"ignore:invalid value encountered in matmul",
3039
)
31-
@mark.parametrize(["aggregator", "matrix"], standard_pairs)
40+
@mark.parametrize(["aggregator", "matrix"], standard_pairs + edge_case_pairs)
3241
def test_expected_structure(aggregator: NashMTL, matrix: Tensor) -> None:
3342
assert_expected_structure(aggregator, matrix)
3443

0 commit comments

Comments
 (0)