Skip to content

Commit c02ebfb

Browse files
committed
feat: Add scale_mode options to AlignedMTL
Expose median and rmse scaling modes for the balance transformation to match original behavior while keeping min as default.
1 parent 2d7bf7f commit c02ebfb

File tree

2 files changed

+44
-10
lines changed

2 files changed

+44
-10
lines changed

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
# SOFTWARE.
2626

2727

28+
from typing import Literal
29+
2830
import torch
2931
from torch import Tensor
3032

@@ -44,18 +46,29 @@ class AlignedMTL(GramianWeightedAggregator):
4446
4547
:param pref_vector: The preference vector to use. If not provided, defaults to
4648
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
49+
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
50+
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
51+
uses the mean eigenvalue (as in the original implementation).
4752
4853
.. note::
4954
This implementation was adapted from the `official implementation
5055
<https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned>`_.
5156
"""
5257

53-
def __init__(self, pref_vector: Tensor | None = None):
58+
def __init__(
59+
self,
60+
pref_vector: Tensor | None = None,
61+
scale_mode: Literal["min", "median", "rmse"] = "min",
62+
):
5463
self._pref_vector = pref_vector
55-
super().__init__(AlignedMTLWeighting(pref_vector))
64+
self._scale_mode = scale_mode
65+
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
5666

5767
def __repr__(self) -> str:
58-
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
68+
return (
69+
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
70+
f"scale_mode={repr(self._scale_mode)})"
71+
)
5972

6073
def __str__(self) -> str:
6174
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
@@ -68,22 +81,32 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):
6881
6982
:param pref_vector: The preference vector to use. If not provided, defaults to
7083
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
84+
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
85+
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
86+
uses the mean eigenvalue (as in the original implementation).
7187
"""
7288

73-
def __init__(self, pref_vector: Tensor | None = None):
89+
def __init__(
90+
self,
91+
pref_vector: Tensor | None = None,
92+
scale_mode: Literal["min", "median", "rmse"] = "min",
93+
):
7494
super().__init__()
7595
self._pref_vector = pref_vector
96+
self._scale_mode = scale_mode
7697
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
7798

7899
def forward(self, gramian: PSDMatrix) -> Tensor:
79100
w = self.weighting(gramian)
80-
B = self._compute_balance_transformation(gramian)
101+
B = self._compute_balance_transformation(gramian, self._scale_mode)
81102
alpha = B @ w
82103

83104
return alpha
84105

85106
@staticmethod
86-
def _compute_balance_transformation(M: Tensor) -> Tensor:
107+
def _compute_balance_transformation(
108+
M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min"
109+
) -> Tensor:
87110
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
88111
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
89112
rank = sum(lambda_ > tol)
@@ -96,6 +119,17 @@ def _compute_balance_transformation(M: Tensor) -> Tensor:
96119
lambda_, V = lambda_[order][:rank], V[:, order][:, :rank]
97120

98121
sigma_inv = torch.diag(1 / lambda_.sqrt())
99-
lambda_R = lambda_[-1]
100-
B = lambda_R.sqrt() * V @ sigma_inv @ V.T
122+
123+
if scale_mode == "min":
124+
scale = lambda_[-1]
125+
elif scale_mode == "median":
126+
scale = torch.median(lambda_)
127+
elif scale_mode == "rmse":
128+
scale = lambda_.mean()
129+
else:
130+
raise ValueError(
131+
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
132+
)
133+
134+
B = scale.sqrt() * V @ sigma_inv @ V.T
101135
return B

tests/unit/aggregation/test_aligned_mtl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor):
2323

2424
def test_representations():
2525
A = AlignedMTL(pref_vector=None)
26-
assert repr(A) == "AlignedMTL(pref_vector=None)"
26+
assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')"
2727
assert str(A) == "AlignedMTL"
2828

2929
A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"))
30-
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))"
30+
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]), scale_mode='min')"
3131
assert str(A) == "AlignedMTL([1., 2., 3.])"

0 commit comments

Comments
 (0)