Skip to content

Commit aa87f95

Browse files
refactor(aggregation) Explicit dependence on gramians (#280)
* Make dependence on gramian explicit in AlignedMTL * Make dependence on gramian explicit in CAGrad * Make dependence on gramian explicit in IMTL-G * Make dependence on gramian explicit in Krum * Make dependence on gramian explicit in MGDA * Make dependence on gramian explicit in PCGrad
1 parent d202cfb commit aa87f95

File tree

6 files changed

+56
-30
lines changed

6 files changed

+56
-30
lines changed

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
from torch import Tensor
3030

31+
from ._gramian_utils import compute_gramian
3132
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
3233
from .bases import _WeightedAggregator, _Weighting
3334
from .mean import _MeanWeighting
@@ -91,16 +92,14 @@ def __init__(self, weighting: _Weighting):
9192
def forward(self, matrix: Tensor) -> Tensor:
9293
w = self.weighting(matrix)
9394

94-
G = matrix.T
95-
B = self._compute_balance_transformation(G)
95+
M = compute_gramian(matrix)
96+
B = self._compute_balance_transformation(M)
9697
alpha = B @ w
9798

9899
return alpha
99100

100101
@staticmethod
101-
def _compute_balance_transformation(G: Tensor) -> Tensor:
102-
M = G.T @ G
103-
102+
def _compute_balance_transformation(M: Tensor) -> Tensor:
104103
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
105104
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
106105
rank = sum(lambda_ > tol)

src/torchjd/aggregation/cagrad.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,16 @@ def __init__(self, c: float, norm_eps: float):
7272
self.norm_eps = norm_eps
7373

7474
def forward(self, matrix: Tensor) -> Tensor:
75-
gramian = normalize(compute_gramian(matrix), self.norm_eps)
76-
U, S, _ = torch.svd(gramian)
75+
gramian = compute_gramian(matrix)
76+
return self._compute_from_gramian(gramian)
77+
78+
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
79+
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))
7780

7881
reduced_matrix = U @ S.sqrt().diag()
7982
reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64)
8083

81-
dimension = matrix.shape[0]
84+
dimension = gramian.shape[0]
8285
reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension
8386
sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2)
8487

@@ -97,6 +100,6 @@ def forward(self, matrix: Tensor) -> Tensor:
97100
# We are approximately on the pareto front
98101
weight_array = np.zeros(dimension)
99102

100-
weights = torch.from_numpy(weight_array).to(device=matrix.device, dtype=matrix.dtype)
103+
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)
101104

102105
return weights

src/torchjd/aggregation/imtl_g.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch import Tensor
33

4+
from ._gramian_utils import compute_gramian
45
from .bases import _WeightedAggregator, _Weighting
56

67

@@ -38,8 +39,13 @@ class _IMTLGWeighting(_Weighting):
3839
"""
3940

4041
def forward(self, matrix: Tensor) -> Tensor:
41-
d = torch.linalg.norm(matrix, dim=1)
42-
v = torch.linalg.pinv(matrix @ matrix.T) @ d
42+
gramian = compute_gramian(matrix)
43+
return self._compute_from_gramian(gramian)
44+
45+
@staticmethod
46+
def _compute_from_gramian(gramian: Tensor) -> Tensor:
47+
d = torch.sqrt(torch.diagonal(gramian))
48+
v = torch.linalg.pinv(gramian) @ d
4349
v_sum = v.sum()
4450

4551
if v_sum.abs() < 1e-12:

src/torchjd/aggregation/krum.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch import Tensor
33
from torch.nn import functional as F
44

5+
from ._gramian_utils import compute_gramian
56
from .bases import _WeightedAggregator, _Weighting
67

78

@@ -80,16 +81,24 @@ def __init__(self, n_byzantine: int, n_selected: int):
8081

8182
def forward(self, matrix: Tensor) -> Tensor:
8283
self._check_matrix_shape(matrix)
84+
gramian = compute_gramian(matrix)
85+
return self._compute_from_gramian(gramian)
8386

84-
distances = torch.cdist(matrix, matrix, compute_mode="donot_use_mm_for_euclid_dist")
85-
n_closest = matrix.shape[0] - self.n_byzantine - 2
87+
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
88+
gradient_norms_squared = torch.diagonal(gramian)
89+
distances_squared = (
90+
gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian
91+
)
92+
distances = torch.sqrt(distances_squared)
93+
94+
n_closest = gramian.shape[0] - self.n_byzantine - 2
8695
smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False)
8796
smallest_distances_excluding_self = smallest_distances[:, 1:]
8897
scores = smallest_distances_excluding_self.sum(dim=1)
8998

9099
_, selected_indices = torch.topk(scores, k=self.n_selected, largest=False)
91-
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=matrix.shape[0])
92-
weights = one_hot_selected_indices.sum(dim=0).to(dtype=matrix.dtype) / self.n_selected
100+
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0])
101+
weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected
93102

94103
return weights
95104

src/torchjd/aggregation/mgda.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,19 @@ def __init__(self, epsilon: float, max_iters: int):
5656
self.epsilon = epsilon
5757
self.max_iters = max_iters
5858

59-
def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
60-
gramian = compute_gramian(matrix)
61-
device = matrix.device
62-
dtype = matrix.dtype
63-
64-
alpha = torch.ones(matrix.shape[0], device=device, dtype=dtype) / matrix.shape[0]
59+
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
60+
"""
61+
This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective
62+
Optimization
63+
<https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf>`_.
64+
"""
65+
device = gramian.device
66+
dtype = gramian.dtype
67+
68+
alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0]
6569
for i in range(self.max_iters):
6670
t = torch.argmin(gramian @ alpha)
67-
e_t = torch.zeros(matrix.shape[0], device=device, dtype=dtype)
71+
e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype)
6872
e_t[t] = 1.0
6973
a = alpha @ (gramian @ e_t)
7074
b = alpha @ (gramian @ alpha)
@@ -81,5 +85,6 @@ def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
8185
return alpha
8286

8387
def forward(self, matrix: Tensor) -> Tensor:
84-
weights = self._frank_wolfe_solver(matrix)
88+
gramian = compute_gramian(matrix)
89+
weights = self._compute_from_gramian(gramian)
8590
return weights

src/torchjd/aggregation/pcgrad.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch import Tensor
33

4+
from ._gramian_utils import compute_gramian
45
from .bases import _WeightedAggregator, _Weighting
56

67

@@ -41,15 +42,18 @@ class _PCGradWeighting(_Weighting):
4142

4243
def forward(self, matrix: Tensor) -> Tensor:
4344
# Pre-compute the inner products
44-
inner_products = matrix @ matrix.T
45+
gramian = compute_gramian(matrix)
46+
return self._compute_from_gramian(gramian)
4547

48+
@staticmethod
49+
def _compute_from_gramian(gramian: Tensor) -> Tensor:
4650
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
47-
device = matrix.device
48-
dtype = matrix.dtype
51+
device = gramian.device
52+
dtype = gramian.dtype
4953
cpu = torch.device("cpu")
50-
inner_products = inner_products.to(device=cpu)
54+
gramian = gramian.to(device=cpu)
5155

52-
dimension = inner_products.shape[0]
56+
dimension = gramian.shape[0]
5357
weights = torch.zeros(dimension, device=cpu, dtype=dtype)
5458

5559
for i in range(dimension):
@@ -62,10 +66,10 @@ def forward(self, matrix: Tensor) -> Tensor:
6266
continue
6367

6468
# Compute the inner product between g_i^{PC} and g_j
65-
inner_product = inner_products[j] @ current_weights
69+
inner_product = gramian[j] @ current_weights
6670

6771
if inner_product < 0.0:
68-
current_weights[j] -= inner_product / (inner_products[j, j])
72+
current_weights[j] -= inner_product / (gramian[j, j])
6973

7074
weights = weights + current_weights
7175

0 commit comments

Comments
 (0)