Skip to content

Commit 84fb0bd

Browse files
committed
Implements a first version of projected gradient descent with decaying regularization.
1 parent a0f9173 commit 84fb0bd

4 files changed

Lines changed: 33 additions & 64 deletions

File tree

Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,30 @@
1-
from typing import Literal
2-
3-
import numpy as np
41
import torch
5-
from qpsolvers import solve_qp
62
from torch import Tensor
73

84

9-
def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
5+
def project_weights(U: Tensor, G: Tensor, max_iter: int = 200, eps: float = 1e-08) -> Tensor:
106
"""
117
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
128
rows of a matrix whose Gramian is provided.
139
1410
:param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`.
1511
:param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite.
16-
:param solver: The quadratic programming solver to use.
1712
:return: A tensor of projection weights with the same shape as `U`.
1813
"""
1914

20-
G_ = _to_array(G)
21-
U_ = _to_array(U)
22-
23-
W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_)
24-
25-
return torch.as_tensor(W, device=G.device, dtype=G.dtype)
26-
27-
28-
def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray:
29-
r"""
30-
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,
31-
given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies
32-
`\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1].
33-
34-
By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program:
35-
minimize v^T G v
36-
subject to u \preceq v
37-
38-
Reference:
39-
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
40-
41-
:param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to
42-
project.
43-
:param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be
44-
symmetric and positive definite.
45-
:param solver: The quadratic programming solver to use.
46-
"""
47-
48-
m = G.shape[0]
49-
w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver)
50-
51-
if w is None: # This may happen when G has large values.
52-
raise ValueError("Failed to solve the quadratic programming problem.")
53-
54-
return w
55-
56-
57-
def _to_array(tensor: Tensor) -> np.ndarray:
58-
"""Transforms a tensor into a numpy array with float64 dtype."""
59-
60-
return tensor.cpu().detach().numpy().astype(np.float64)
15+
shape = U.shape
16+
m = shape[-1]
17+
U_matrix = U.reshape([-1, m]).T
18+
V = U_matrix.clone()
19+
20+
# torch.linalg.eigvals synchronizes G on the CPU.
21+
lambda_max = torch.max(torch.linalg.eigvals(G).real).item()
22+
for t in range(1, max_iter + 1):
23+
step_size = 1.0 / (lambda_max + 1.0 / t)
24+
V_new = torch.maximum(V - step_size * (G @ V), U_matrix)
25+
gap = (V - V_new).norm()
26+
if gap < eps:
27+
print(t)
28+
break
29+
V = V_new
30+
return V.T.reshape(shape)

src/torchjd/aggregation/dualproj.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import Tensor
44

55
from ._dual_cone_utils import project_weights
6-
from ._gramian_utils import compute_gramian, normalize, regularize
6+
from ._gramian_utils import compute_gramian, normalize
77
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
88
from .bases import _WeightedAggregator, _Weighting
99
from .mean import _MeanWeighting
@@ -100,6 +100,6 @@ def __init__(
100100

101101
def forward(self, matrix: Tensor) -> Tensor:
102102
u = self.weighting(matrix)
103-
G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps)
104-
w = project_weights(u, G, self.solver)
103+
G = normalize(compute_gramian(matrix), self.norm_eps)
104+
w = project_weights(u, G)
105105
return w

src/torchjd/aggregation/upgrad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66
from ._dual_cone_utils import project_weights
7-
from ._gramian_utils import compute_gramian, normalize, regularize
7+
from ._gramian_utils import compute_gramian, normalize
88
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
99
from .bases import _WeightedAggregator, _Weighting
1010
from .mean import _MeanWeighting
@@ -96,6 +96,6 @@ def __init__(
9696

9797
def forward(self, matrix: Tensor) -> Tensor:
9898
U = torch.diag(self.weighting(matrix))
99-
G = regularize(normalize(compute_gramian(matrix), self.norm_eps), self.reg_eps)
100-
W = project_weights(U, G, self.solver)
99+
G = normalize(compute_gramian(matrix), self.norm_eps)
100+
W = project_weights(U, G)
101101
return torch.sum(W, dim=0)

tests/unit/aggregation/test_dual_cone_utils.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import numpy as np
21
import torch
32
from pytest import mark, raises
43
from torch.testing import assert_close
54

6-
from torchjd.aggregation._dual_cone_utils import _project_weight_vector, project_weights
5+
from torchjd.aggregation._dual_cone_utils import project_weights
76

87

98
@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
@@ -33,7 +32,7 @@ def test_solution_weights(shape: tuple[int, int]):
3332
G = J @ J.T
3433
u = torch.rand(shape[0])
3534

36-
w = project_weights(u, G, "quadprog")
35+
w = project_weights(u, G)
3736
dual_gap = w - u
3837

3938
# Dual feasibility
@@ -62,8 +61,8 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float):
6261
G = J @ J.T
6362
u = torch.rand(shape[0])
6463

65-
w = project_weights(u, G, "quadprog")
66-
w_scaled = project_weights(u, scaling * G, "quadprog")
64+
w = project_weights(u, G)
65+
w_scaled = project_weights(u, scaling * G)
6766

6867
assert_close(w_scaled, w)
6968

@@ -81,16 +80,16 @@ def test_tensorization_shape(shape: tuple[int, ...]):
8180

8281
G = matrix @ matrix.T
8382

84-
W_tensor = project_weights(U_tensor, G, "quadprog")
85-
W_matrix = project_weights(U_matrix, G, "quadprog")
83+
W_tensor = project_weights(U_tensor, G)
84+
W_matrix = project_weights(U_matrix, G)
8685

8786
assert_close(W_matrix.reshape(shape), W_tensor)
8887

8988

90-
def test_project_weight_vector_failure():
91-
"""Tests that `_project_weight_vector` raises an error when the input G has too large values."""
89+
def test_project_weight_failure():
90+
"""Tests that `project_weight` raises an error when the input G has too large values."""
9291

93-
large_J = np.random.randn(10, 100) * 1e5
92+
large_J = torch.randn(10, 100) * 1e5
9493
large_G = large_J @ large_J.T
9594
with raises(ValueError):
96-
_project_weight_vector(np.ones(10), large_G, "quadprog")
95+
project_weights(torch.ones(10), large_G)

0 commit comments

Comments
 (0)