diff --git a/src/torchjd/aggregation/aligned_mtl.py b/src/torchjd/aggregation/aligned_mtl.py index 6f54a5ad..e8df27fe 100644 --- a/src/torchjd/aggregation/aligned_mtl.py +++ b/src/torchjd/aggregation/aligned_mtl.py @@ -28,6 +28,7 @@ import torch from torch import Tensor +from ._gramian_utils import compute_gramian from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -91,16 +92,14 @@ def __init__(self, weighting: _Weighting): def forward(self, matrix: Tensor) -> Tensor: w = self.weighting(matrix) - G = matrix.T - B = self._compute_balance_transformation(G) + M = compute_gramian(matrix) + B = self._compute_balance_transformation(M) alpha = B @ w return alpha @staticmethod - def _compute_balance_transformation(G: Tensor) -> Tensor: - M = G.T @ G - + def _compute_balance_transformation(M: Tensor) -> Tensor: lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps rank = sum(lambda_ > tol) diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/cagrad.py index c6c66788..93075fc0 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/cagrad.py @@ -72,13 +72,16 @@ def __init__(self, c: float, norm_eps: float): self.norm_eps = norm_eps def forward(self, matrix: Tensor) -> Tensor: - gramian = normalize(compute_gramian(matrix), self.norm_eps) - U, S, _ = torch.svd(gramian) + gramian = compute_gramian(matrix) + return self._compute_from_gramian(gramian) + + def _compute_from_gramian(self, gramian: Tensor) -> Tensor: + U, S, _ = torch.svd(normalize(gramian, self.norm_eps)) reduced_matrix = U @ S.sqrt().diag() reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64) - dimension = matrix.shape[0] + dimension = gramian.shape[0] reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2) @@ -97,6 +100,6 @@ def forward(self, matrix: Tensor) -> Tensor: # We are approximately on the pareto front weight_array = np.zeros(dimension) - weights = torch.from_numpy(weight_array).to(device=matrix.device, dtype=matrix.dtype) + weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype) return weights diff --git a/src/torchjd/aggregation/imtl_g.py b/src/torchjd/aggregation/imtl_g.py index f4fb8232..b807df6d 100644 --- a/src/torchjd/aggregation/imtl_g.py +++ b/src/torchjd/aggregation/imtl_g.py @@ -1,6 +1,7 @@ import torch from torch import Tensor +from ._gramian_utils import compute_gramian from .bases import _WeightedAggregator, _Weighting @@ -38,8 +39,13 @@ class _IMTLGWeighting(_Weighting): """ def forward(self, matrix: Tensor) -> Tensor: - d = torch.linalg.norm(matrix, dim=1) - v = torch.linalg.pinv(matrix @ matrix.T) @ d + gramian = compute_gramian(matrix) + return self._compute_from_gramian(gramian) + + @staticmethod + def _compute_from_gramian(gramian: Tensor) -> Tensor: + d = torch.sqrt(torch.diagonal(gramian)) + v = torch.linalg.pinv(gramian) @ d v_sum = v.sum() if v_sum.abs() < 1e-12: diff --git a/src/torchjd/aggregation/krum.py b/src/torchjd/aggregation/krum.py index acbe5410..8e39045e 100644 --- a/src/torchjd/aggregation/krum.py +++ b/src/torchjd/aggregation/krum.py @@ -2,6 +2,7 @@ from torch import Tensor from torch.nn import functional as F +from ._gramian_utils import compute_gramian from .bases import _WeightedAggregator, _Weighting @@ -80,16 +81,24 @@ def __init__(self, n_byzantine: int, n_selected: int): def forward(self, matrix: Tensor) -> Tensor: self._check_matrix_shape(matrix) + gramian = compute_gramian(matrix) + return self._compute_from_gramian(gramian) - distances = torch.cdist(matrix, matrix, compute_mode="donot_use_mm_for_euclid_dist") - n_closest = matrix.shape[0] - self.n_byzantine - 2 + def _compute_from_gramian(self, gramian: Tensor) -> Tensor: + gradient_norms_squared = torch.diagonal(gramian) + distances_squared = ( + gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian + ) + distances = torch.sqrt(distances_squared) + + n_closest = gramian.shape[0] - self.n_byzantine - 2 smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False) smallest_distances_excluding_self = smallest_distances[:, 1:] scores = smallest_distances_excluding_self.sum(dim=1) _, selected_indices = torch.topk(scores, k=self.n_selected, largest=False) - one_hot_selected_indices = F.one_hot(selected_indices, num_classes=matrix.shape[0]) - weights = one_hot_selected_indices.sum(dim=0).to(dtype=matrix.dtype) / self.n_selected + one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0]) + weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected return weights diff --git a/src/torchjd/aggregation/mgda.py b/src/torchjd/aggregation/mgda.py index f2548ce2..2a2f05e1 100644 --- a/src/torchjd/aggregation/mgda.py +++ b/src/torchjd/aggregation/mgda.py @@ -56,15 +56,19 @@ def __init__(self, epsilon: float, max_iters: int): self.epsilon = epsilon self.max_iters = max_iters - def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor: - gramian = compute_gramian(matrix) - device = matrix.device - dtype = matrix.dtype - - alpha = torch.ones(matrix.shape[0], device=device, dtype=dtype) / matrix.shape[0] + def _compute_from_gramian(self, gramian: Tensor) -> Tensor: + """ + This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective + Optimization + `_. + """ + device = gramian.device + dtype = gramian.dtype + + alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0] for i in range(self.max_iters): t = torch.argmin(gramian @ alpha) - e_t = torch.zeros(matrix.shape[0], device=device, dtype=dtype) + e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype) e_t[t] = 1.0 a = alpha @ (gramian @ e_t) b = alpha @ (gramian @ alpha) @@ -81,5 +85,6 @@ def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor: return alpha def forward(self, matrix: Tensor) -> Tensor: - weights = self._frank_wolfe_solver(matrix) + gramian = compute_gramian(matrix) + weights = self._compute_from_gramian(gramian) return weights diff --git a/src/torchjd/aggregation/pcgrad.py b/src/torchjd/aggregation/pcgrad.py index c110aa98..1b448fe3 100644 --- a/src/torchjd/aggregation/pcgrad.py +++ b/src/torchjd/aggregation/pcgrad.py @@ -1,6 +1,7 @@ import torch from torch import Tensor +from ._gramian_utils import compute_gramian from .bases import _WeightedAggregator, _Weighting @@ -41,15 +42,18 @@ class _PCGradWeighting(_Weighting): def forward(self, matrix: Tensor) -> Tensor: # Pre-compute the inner products - inner_products = matrix @ matrix.T + gramian = compute_gramian(matrix) + return self._compute_from_gramian(gramian) + @staticmethod + def _compute_from_gramian(gramian: Tensor) -> Tensor: # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration - device = matrix.device - dtype = matrix.dtype + device = gramian.device + dtype = gramian.dtype cpu = torch.device("cpu") - inner_products = inner_products.to(device=cpu) + gramian = gramian.to(device=cpu) - dimension = inner_products.shape[0] + dimension = gramian.shape[0] weights = torch.zeros(dimension, device=cpu, dtype=dtype) for i in range(dimension): @@ -62,10 +66,10 @@ def forward(self, matrix: Tensor) -> Tensor: continue # Compute the inner product between g_i^{PC} and g_j - inner_product = inner_products[j] @ current_weights + inner_product = gramian[j] @ current_weights if inner_product < 0.0: - current_weights[j] -= inner_product / (inner_products[j, j]) + current_weights[j] -= inner_product / (gramian[j, j]) weights = weights + current_weights