|
1 | | -from typing import Literal |
2 | | - |
3 | | -import numpy as np |
4 | 1 | import torch |
5 | | -from qpsolvers import solve_qp |
6 | 2 | from torch import Tensor |
7 | 3 |
|
8 | 4 |
|
9 | | -def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor: |
| 5 | +def project_weights(U: Tensor, G: Tensor, max_iter: int = 200, eps: float = 1e-08) -> Tensor: |
10 | 6 | """ |
11 | 7 | Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the |
12 | 8 | rows of a matrix whose Gramian is provided. |
13 | 9 |
|
14 | 10 | :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. |
15 | 11 | :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. |
16 | | - :param solver: The quadratic programming solver to use. |
17 | 12 | :return: A tensor of projection weights with the same shape as `U`. |
18 | 13 | """ |
19 | 14 |
|
20 | | - G_ = _to_array(G) |
21 | | - U_ = _to_array(U) |
22 | | - |
23 | | - W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) |
24 | | - |
25 | | - return torch.as_tensor(W, device=G.device, dtype=G.dtype) |
26 | | - |
27 | | - |
28 | | -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray: |
29 | | - r""" |
30 | | - Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, |
31 | | - given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies |
32 | | - `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. |
33 | | -
|
34 | | - By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: |
35 | | - minimize v^T G v |
36 | | - subject to u \preceq v |
37 | | -
|
38 | | - Reference: |
39 | | - [1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_. |
40 | | -
|
41 | | - :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to |
42 | | - project. |
43 | | - :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be |
44 | | - symmetric and positive definite. |
45 | | - :param solver: The quadratic programming solver to use. |
46 | | - """ |
47 | | - |
48 | | - m = G.shape[0] |
49 | | - w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) |
50 | | - |
51 | | - if w is None: # This may happen when G has large values. |
52 | | - raise ValueError("Failed to solve the quadratic programming problem.") |
53 | | - |
54 | | - return w |
55 | | - |
56 | | - |
57 | | -def _to_array(tensor: Tensor) -> np.ndarray: |
58 | | - """Transforms a tensor into a numpy array with float64 dtype.""" |
59 | | - |
60 | | - return tensor.cpu().detach().numpy().astype(np.float64) |
| 15 | + shape = U.shape |
| 16 | + m = shape[-1] |
| 17 | + U_matrix = U.reshape([-1, m]).T |
| 18 | + V = U_matrix.clone() |
| 19 | + |
| 20 | + # torch.linalg.eigvals synchronizes G on the CPU. |
| 21 | + lambda_max = torch.max(torch.linalg.eigvals(G).real).item() |
| 22 | + for t in range(1, max_iter + 1): |
| 23 | + step_size = 1.0 / (lambda_max + 1.0 / t) |
| 24 | + V_new = torch.maximum(V - step_size * (G @ V), U_matrix) |
| 25 | + gap = (V - V_new).norm() |
| 26 | + if gap < eps: |
| 27 | + print(t) |
| 28 | + break |
| 29 | + V = V_new |
| 30 | + return V.T.reshape(shape) |
0 commit comments