Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/torchjd/aggregation/aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
from torch import Tensor

from ._gramian_utils import compute_gramian
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting
Expand Down Expand Up @@ -91,16 +92,14 @@ def __init__(self, weighting: _Weighting):
def forward(self, matrix: Tensor) -> Tensor:
w = self.weighting(matrix)

G = matrix.T
B = self._compute_balance_transformation(G)
M = compute_gramian(matrix)
B = self._compute_balance_transformation(M)
alpha = B @ w

return alpha

@staticmethod
def _compute_balance_transformation(G: Tensor) -> Tensor:
M = G.T @ G

def _compute_balance_transformation(M: Tensor) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)
Expand Down
11 changes: 7 additions & 4 deletions src/torchjd/aggregation/cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,16 @@ def __init__(self, c: float, norm_eps: float):
self.norm_eps = norm_eps

def forward(self, matrix: Tensor) -> Tensor:
gramian = normalize(compute_gramian(matrix), self.norm_eps)
U, S, _ = torch.svd(gramian)
gramian = compute_gramian(matrix)
return self._compute_from_gramian(gramian)

def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))

reduced_matrix = U @ S.sqrt().diag()
reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64)

dimension = matrix.shape[0]
dimension = gramian.shape[0]
reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension
sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2)

Expand All @@ -97,6 +100,6 @@ def forward(self, matrix: Tensor) -> Tensor:
# We are approximately on the pareto front
weight_array = np.zeros(dimension)

weights = torch.from_numpy(weight_array).to(device=matrix.device, dtype=matrix.dtype)
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)

return weights
10 changes: 8 additions & 2 deletions src/torchjd/aggregation/imtl_g.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import Tensor

from ._gramian_utils import compute_gramian
from .bases import _WeightedAggregator, _Weighting


Expand Down Expand Up @@ -38,8 +39,13 @@ class _IMTLGWeighting(_Weighting):
"""

def forward(self, matrix: Tensor) -> Tensor:
d = torch.linalg.norm(matrix, dim=1)
v = torch.linalg.pinv(matrix @ matrix.T) @ d
gramian = compute_gramian(matrix)
return self._compute_from_gramian(gramian)

@staticmethod
def _compute_from_gramian(gramian: Tensor) -> Tensor:
d = torch.sqrt(torch.diagonal(gramian))
v = torch.linalg.pinv(gramian) @ d
v_sum = v.sum()

if v_sum.abs() < 1e-12:
Expand Down
17 changes: 13 additions & 4 deletions src/torchjd/aggregation/krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import Tensor
from torch.nn import functional as F

from ._gramian_utils import compute_gramian
from .bases import _WeightedAggregator, _Weighting


Expand Down Expand Up @@ -80,16 +81,24 @@ def __init__(self, n_byzantine: int, n_selected: int):

def forward(self, matrix: Tensor) -> Tensor:
self._check_matrix_shape(matrix)
gramian = compute_gramian(matrix)
return self._compute_from_gramian(gramian)

distances = torch.cdist(matrix, matrix, compute_mode="donot_use_mm_for_euclid_dist")
n_closest = matrix.shape[0] - self.n_byzantine - 2
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
gradient_norms_squared = torch.diagonal(gramian)
distances_squared = (
gradient_norms_squared.unsqueeze(0) + gradient_norms_squared.unsqueeze(1) - 2 * gramian
)
distances = torch.sqrt(distances_squared)

n_closest = gramian.shape[0] - self.n_byzantine - 2
smallest_distances, _ = torch.topk(distances, k=n_closest + 1, largest=False)
smallest_distances_excluding_self = smallest_distances[:, 1:]
scores = smallest_distances_excluding_self.sum(dim=1)

_, selected_indices = torch.topk(scores, k=self.n_selected, largest=False)
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=matrix.shape[0])
weights = one_hot_selected_indices.sum(dim=0).to(dtype=matrix.dtype) / self.n_selected
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0])
weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected

return weights

Expand Down
21 changes: 13 additions & 8 deletions src/torchjd/aggregation/mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,19 @@ def __init__(self, epsilon: float, max_iters: int):
self.epsilon = epsilon
self.max_iters = max_iters

def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
gramian = compute_gramian(matrix)
device = matrix.device
dtype = matrix.dtype

alpha = torch.ones(matrix.shape[0], device=device, dtype=dtype) / matrix.shape[0]
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
"""
This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective
Optimization
<https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf>`_.
"""
device = gramian.device
dtype = gramian.dtype

alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0]
for i in range(self.max_iters):
t = torch.argmin(gramian @ alpha)
e_t = torch.zeros(matrix.shape[0], device=device, dtype=dtype)
e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype)
e_t[t] = 1.0
a = alpha @ (gramian @ e_t)
b = alpha @ (gramian @ alpha)
Expand All @@ -81,5 +85,6 @@ def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
return alpha

def forward(self, matrix: Tensor) -> Tensor:
weights = self._frank_wolfe_solver(matrix)
gramian = compute_gramian(matrix)
weights = self._compute_from_gramian(gramian)
return weights
18 changes: 11 additions & 7 deletions src/torchjd/aggregation/pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import Tensor

from ._gramian_utils import compute_gramian
from .bases import _WeightedAggregator, _Weighting


Expand Down Expand Up @@ -41,15 +42,18 @@ class _PCGradWeighting(_Weighting):

def forward(self, matrix: Tensor) -> Tensor:
# Pre-compute the inner products
inner_products = matrix @ matrix.T
gramian = compute_gramian(matrix)
return self._compute_from_gramian(gramian)

@staticmethod
def _compute_from_gramian(gramian: Tensor) -> Tensor:
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = matrix.device
dtype = matrix.dtype
device = gramian.device
dtype = gramian.dtype
cpu = torch.device("cpu")
inner_products = inner_products.to(device=cpu)
gramian = gramian.to(device=cpu)

dimension = inner_products.shape[0]
dimension = gramian.shape[0]
weights = torch.zeros(dimension, device=cpu, dtype=dtype)

for i in range(dimension):
Expand All @@ -62,10 +66,10 @@ def forward(self, matrix: Tensor) -> Tensor:
continue

# Compute the inner product between g_i^{PC} and g_j
inner_product = inner_products[j] @ current_weights
inner_product = gramian[j] @ current_weights

if inner_product < 0.0:
current_weights[j] -= inner_product / (inner_products[j, j])
current_weights[j] -= inner_product / (gramian[j, j])

weights = weights + current_weights

Expand Down