Skip to content

Commit 1034bbf

Browse files
committed
Merge branch 'main' into feature/gradvac (21f6b74)
2 parents 315c264 + 21f6b74 commit 1034bbf

5 files changed

Lines changed: 42 additions & 41 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11-
### Fixed
11+
### Added
1212

13+
- Added `GradVac` and `GradVacWeighting` from
14+
[Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).
1315
- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example
1416
on the matrix [[0., 0.], [0., 1.]]).
1517

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo
281281
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
282282
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
283283
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
284+
| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) |
284285
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
285286
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
286287
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |

src/torchjd/aggregation/_gradvac.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def reset(self) -> None:
151151
self._state_key = None
152152

153153
def forward(self, gramian: PSDMatrix, /) -> Tensor:
154+
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
154155
device = gramian.device
155156
dtype = gramian.dtype
156157
cpu = torch.device("cpu")

tests/unit/aggregation/test_gradvac.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616

1717
def test_representations() -> None:
18-
g = GradVac()
19-
assert repr(g) == "GradVac(beta=0.5, eps=1e-08)"
20-
assert str(g) == "GradVac"
18+
A = GradVac()
19+
assert repr(A) == "GradVac(beta=0.5, eps=1e-08)"
20+
assert str(A) == "GradVac"
2121

2222

2323
def test_beta_out_of_range() -> None:
@@ -28,17 +28,17 @@ def test_beta_out_of_range() -> None:
2828

2929

3030
def test_beta_setter_out_of_range() -> None:
31-
g = GradVac()
31+
A = GradVac()
3232
with raises(ValueError, match="beta"):
33-
g.beta = -0.1
33+
A.beta = -0.1
3434
with raises(ValueError, match="beta"):
35-
g.beta = 1.1
35+
A.beta = 1.1
3636

3737

3838
def test_beta_setter_updates_value() -> None:
39-
g = GradVac()
40-
g.beta = 0.25
41-
assert g.beta == 0.25
39+
A = GradVac()
40+
A.beta = 0.25
41+
assert A.beta == 0.25
4242

4343

4444
def test_eps_rejects_negative() -> None:
@@ -47,19 +47,19 @@ def test_eps_rejects_negative() -> None:
4747

4848

4949
def test_eps_setter_rejects_negative() -> None:
50-
g = GradVac()
50+
A = GradVac()
5151
with raises(ValueError, match="eps"):
52-
g.eps = -1e-9
52+
A.eps = -1e-9
5353

5454

5555
def test_eps_can_be_changed_between_steps() -> None:
56-
j = tensor_([[1.0, 0.0], [0.0, 1.0]])
57-
agg = GradVac()
58-
agg.eps = 1e-6
59-
assert agg(j).isfinite().all()
60-
agg.reset()
61-
agg.eps = 1e-10
62-
assert agg(j).isfinite().all()
56+
J = tensor_([[1.0, 0.0], [0.0, 1.0]])
57+
A = GradVac()
58+
A.eps = 1e-6
59+
assert A(J).isfinite().all()
60+
A.reset()
61+
A.eps = 1e-10
62+
assert A(J).isfinite().all()
6363

6464

6565
def test_zero_rows_returns_zero_vector() -> None:
@@ -73,25 +73,25 @@ def test_zero_columns_returns_zero_vector() -> None:
7373

7474

7575
def test_reproducible_with_manual_seed() -> None:
76-
j = randn_((3, 8))
76+
J = randn_((3, 8))
7777
torch.manual_seed(12345)
78-
a1 = GradVac(beta=0.3)
79-
out1 = a1(j)
78+
A1 = GradVac(beta=0.3)
79+
out1 = A1(J)
8080
torch.manual_seed(12345)
81-
a2 = GradVac(beta=0.3)
82-
out2 = a2(j)
81+
A2 = GradVac(beta=0.3)
82+
out2 = A2(J)
8383
assert_close(out1, out2)
8484

8585

8686
@mark.parametrize("matrix", typical_matrices_2_plus_rows)
8787
def test_reset_restores_first_step_behavior(matrix: Tensor) -> None:
8888
torch.manual_seed(7)
89-
agg = GradVac(beta=0.5)
90-
first = agg(matrix)
91-
agg(matrix)
92-
agg.reset()
89+
A = GradVac(beta=0.5)
90+
first = A(matrix)
91+
A(matrix)
92+
A.reset()
9393
torch.manual_seed(7)
94-
assert_close(first, agg(matrix))
94+
assert_close(first, A(matrix))
9595

9696

9797
@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
@@ -117,8 +117,8 @@ def test_weighting_eps_rejects_negative() -> None:
117117

118118

119119
def test_weighting_reset_restores_first_step_behavior() -> None:
120-
j = randn_((3, 8))
121-
G = j @ j.T
120+
J = randn_((3, 8))
121+
G = J @ J.T
122122
torch.manual_seed(7)
123123
w = GradVacWeighting(beta=0.5)
124124
first = w(G)
@@ -131,16 +131,16 @@ def test_weighting_reset_restores_first_step_behavior() -> None:
131131
def test_aggregator_and_weighting_agree() -> None:
132132
"""GradVac()(J) == GradVacWeighting()(J @ J.T) @ J for any matrix J."""
133133

134-
j = randn_((3, 8))
135-
G = j @ j.T
134+
J = randn_((3, 8))
135+
G = J @ J.T
136136

137137
torch.manual_seed(42)
138-
agg = GradVac(beta=0.3)
139-
expected = agg(j)
138+
A = GradVac(beta=0.3)
139+
expected = A(J)
140140

141141
torch.manual_seed(42)
142-
weighting = GradVacWeighting(beta=0.3)
143-
weights = weighting(G)
144-
result = weights @ j
142+
W = GradVacWeighting(beta=0.3)
143+
weights = W(G)
144+
result = weights @ J
145145

146146
assert_close(result, expected, rtol=1e-4, atol=1e-4)

tests/unit/aggregation/test_values.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
from pytest import mark, param
32
from torch import Tensor, tensor
43
from torch.testing import assert_close
@@ -118,8 +117,6 @@
118117
def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor) -> None:
119118
"""Test that the output values of an aggregator are fixed (on cpu)."""
120119

121-
if isinstance(A, GradVac):
122-
torch.manual_seed(0)
123120
assert_close(A(J), expected_output, rtol=0, atol=1e-4)
124121

125122

0 commit comments

Comments
 (0)