Skip to content

Commit 8653929

Browse files
feat(aggregation): Add scale_mode parameter to AlignedMTL (#527)
* Add scale_mode parameter to AlignedMTL and AlignedMTLWeighting * Add test coverage for this * Add changelog entry
1 parent 2d7bf7f commit 8653929

File tree

3 files changed

+65
-12
lines changed

3 files changed

+65
-12
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Added
12+
13+
- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
14+
between `"min"`, `"median"`, and `"rmse"` scaling.
15+
1116
### Changed
1217

1318
- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
# SOFTWARE.
2626

2727

28+
from typing import Literal
29+
2830
import torch
2931
from torch import Tensor
3032

@@ -44,18 +46,29 @@ class AlignedMTL(GramianWeightedAggregator):
4446
4547
:param pref_vector: The preference vector to use. If not provided, defaults to
4648
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
49+
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
50+
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
51+
uses the mean eigenvalue (as in the original implementation).
4752
4853
.. note::
4954
This implementation was adapted from the `official implementation
5055
<https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned>`_.
5156
"""
5257

53-
def __init__(self, pref_vector: Tensor | None = None):
58+
def __init__(
59+
self,
60+
pref_vector: Tensor | None = None,
61+
scale_mode: Literal["min", "median", "rmse"] = "min",
62+
):
5463
self._pref_vector = pref_vector
55-
super().__init__(AlignedMTLWeighting(pref_vector))
64+
self._scale_mode = scale_mode
65+
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
5666

5767
def __repr__(self) -> str:
58-
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
68+
return (
69+
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
70+
f"scale_mode={repr(self._scale_mode)})"
71+
)
5972

6073
def __str__(self) -> str:
6174
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
@@ -68,22 +81,32 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):
6881
6982
:param pref_vector: The preference vector to use. If not provided, defaults to
7083
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
84+
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
85+
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
86+
uses the mean eigenvalue (as in the original implementation).
7187
"""
7288

73-
def __init__(self, pref_vector: Tensor | None = None):
89+
def __init__(
90+
self,
91+
pref_vector: Tensor | None = None,
92+
scale_mode: Literal["min", "median", "rmse"] = "min",
93+
):
7494
super().__init__()
7595
self._pref_vector = pref_vector
96+
self._scale_mode = scale_mode
7697
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
7798

7899
def forward(self, gramian: PSDMatrix) -> Tensor:
79100
w = self.weighting(gramian)
80-
B = self._compute_balance_transformation(gramian)
101+
B = self._compute_balance_transformation(gramian, self._scale_mode)
81102
alpha = B @ w
82103

83104
return alpha
84105

85106
@staticmethod
86-
def _compute_balance_transformation(M: Tensor) -> Tensor:
107+
def _compute_balance_transformation(
108+
M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min"
109+
) -> Tensor:
87110
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
88111
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
89112
rank = sum(lambda_ > tol)
@@ -96,6 +119,17 @@ def _compute_balance_transformation(M: Tensor) -> Tensor:
96119
lambda_, V = lambda_[order][:rank], V[:, order][:, :rank]
97120

98121
sigma_inv = torch.diag(1 / lambda_.sqrt())
99-
lambda_R = lambda_[-1]
100-
B = lambda_R.sqrt() * V @ sigma_inv @ V.T
122+
123+
if scale_mode == "min":
124+
scale = lambda_[-1]
125+
elif scale_mode == "median":
126+
scale = torch.median(lambda_)
127+
elif scale_mode == "rmse":
128+
scale = lambda_.mean()
129+
else:
130+
raise ValueError(
131+
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
132+
)
133+
134+
B = scale.sqrt() * V @ sigma_inv @ V.T
101135
return B

tests/unit/aggregation/test_aligned_mtl.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
import torch
2-
from pytest import mark
2+
from pytest import mark, raises
33
from torch import Tensor
4+
from utils.tensors import ones_
45

56
from torchjd.aggregation import AlignedMTL
67

78
from ._asserts import assert_expected_structure, assert_permutation_invariant
89
from ._inputs import scaled_matrices, typical_matrices
910

10-
scaled_pairs = [(AlignedMTL(), matrix) for matrix in scaled_matrices]
11+
aggregators = [
12+
AlignedMTL(),
13+
AlignedMTL(scale_mode="median"),
14+
AlignedMTL(scale_mode="rmse"),
15+
]
16+
scaled_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in scaled_matrices]
17+
# test_permutation_invariant seems to fail on gpu for scale_mode="median" or scale_mode="rmse".
1118
typical_pairs = [(AlignedMTL(), matrix) for matrix in typical_matrices]
1219

1320

@@ -23,9 +30,16 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor):
2330

2431
def test_representations():
2532
A = AlignedMTL(pref_vector=None)
26-
assert repr(A) == "AlignedMTL(pref_vector=None)"
33+
assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')"
2734
assert str(A) == "AlignedMTL"
2835

2936
A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"))
30-
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))"
37+
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]), scale_mode='min')"
3138
assert str(A) == "AlignedMTL([1., 2., 3.])"
39+
40+
41+
def test_invalid_scale_mode():
42+
aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type]
43+
matrix = ones_(3, 4)
44+
with raises(ValueError, match=r"Invalid scale_mode=.*Expected"):
45+
aggregator(matrix)

0 commit comments

Comments
 (0)