2525# SOFTWARE.
2626
2727
28+ from typing import Literal
29+
2830import torch
2931from torch import Tensor
3032
@@ -44,18 +46,29 @@ class AlignedMTL(GramianWeightedAggregator):
4446
4547 :param pref_vector: The preference vector to use. If not provided, defaults to
4648 :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
49+ :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
50+ the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
51+ uses the mean eigenvalue (as in the original implementation).
4752
4853 .. note::
4954 This implementation was adapted from the `official implementation
5055 <https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned>`_.
5156 """
5257
53- def __init__ (self , pref_vector : Tensor | None = None ):
58+ def __init__ (
59+ self ,
60+ pref_vector : Tensor | None = None ,
61+ scale_mode : Literal ["min" , "median" , "rmse" ] = "min" ,
62+ ):
5463 self ._pref_vector = pref_vector
55- super ().__init__ (AlignedMTLWeighting (pref_vector ))
64+ self ._scale_mode = scale_mode
65+ super ().__init__ (AlignedMTLWeighting (pref_vector , scale_mode = scale_mode ))
5666
5767 def __repr__ (self ) -> str :
58- return f"{ self .__class__ .__name__ } (pref_vector={ repr (self ._pref_vector )} )"
68+ return (
69+ f"{ self .__class__ .__name__ } (pref_vector={ repr (self ._pref_vector )} , "
70+ f"scale_mode={ repr (self ._scale_mode )} )"
71+ )
5972
6073 def __str__ (self ) -> str :
6174 return f"AlignedMTL{ pref_vector_to_str_suffix (self ._pref_vector )} "
@@ -68,22 +81,32 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):
6881
6982 :param pref_vector: The preference vector to use. If not provided, defaults to
7083 :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
84+ :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
85+ the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
86+ uses the mean eigenvalue (as in the original implementation).
7187 """
7288
73- def __init__ (self , pref_vector : Tensor | None = None ):
89+ def __init__ (
90+ self ,
91+ pref_vector : Tensor | None = None ,
92+ scale_mode : Literal ["min" , "median" , "rmse" ] = "min" ,
93+ ):
7494 super ().__init__ ()
7595 self ._pref_vector = pref_vector
96+ self ._scale_mode = scale_mode
7697 self .weighting = pref_vector_to_weighting (pref_vector , default = MeanWeighting ())
7798
7899 def forward (self , gramian : PSDMatrix ) -> Tensor :
79100 w = self .weighting (gramian )
80- B = self ._compute_balance_transformation (gramian )
101+ B = self ._compute_balance_transformation (gramian , self . _scale_mode )
81102 alpha = B @ w
82103
83104 return alpha
84105
85106 @staticmethod
86- def _compute_balance_transformation (M : Tensor ) -> Tensor :
107+ def _compute_balance_transformation (
108+ M : Tensor , scale_mode : Literal ["min" , "median" , "rmse" ] = "min"
109+ ) -> Tensor :
87110 lambda_ , V = torch .linalg .eigh (M , UPLO = "U" ) # More modern equivalent to torch.symeig
88111 tol = torch .max (lambda_ ) * len (M ) * torch .finfo ().eps
89112 rank = sum (lambda_ > tol )
@@ -96,6 +119,17 @@ def _compute_balance_transformation(M: Tensor) -> Tensor:
96119 lambda_ , V = lambda_ [order ][:rank ], V [:, order ][:, :rank ]
97120
98121 sigma_inv = torch .diag (1 / lambda_ .sqrt ())
99- lambda_R = lambda_ [- 1 ]
100- B = lambda_R .sqrt () * V @ sigma_inv @ V .T
122+
123+ if scale_mode == "min" :
124+ scale = lambda_ [- 1 ]
125+ elif scale_mode == "median" :
126+ scale = torch .median (lambda_ )
127+ elif scale_mode == "rmse" :
128+ scale = lambda_ .mean ()
129+ else :
130+ raise ValueError (
131+ f"Invalid scale_mode={ scale_mode !r} . Expected 'min', 'median', or 'rmse'."
132+ )
133+
134+ B = scale .sqrt () * V @ sigma_inv @ V .T
101135 return B
0 commit comments