Skip to content

Commit 7d30162

Browse files
authored
typing(aggregation): Fix Literal attribute typing (#549)
* Add SUPPORTED_SOLVER type alias * Add SUPPORTED_SCALE_MODE type alias * Fix Literal automatic type widening
1 parent 1e03d34 commit 7d30162

File tree

4 files changed

+23
-23
lines changed

4 files changed

+23
-23
lines changed

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# SOFTWARE.
2626

2727

28-
from typing import Literal
28+
from typing import Literal, TypeAlias
2929

3030
import torch
3131
from torch import Tensor
@@ -37,6 +37,8 @@
3737
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
3838
from ._weighting_bases import Weighting
3939

40+
SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]
41+
4042

4143
class AlignedMTL(GramianWeightedAggregator):
4244
r"""
@@ -58,10 +60,10 @@ class AlignedMTL(GramianWeightedAggregator):
5860
def __init__(
5961
self,
6062
pref_vector: Tensor | None = None,
61-
scale_mode: Literal["min", "median", "rmse"] = "min",
63+
scale_mode: SUPPORTED_SCALE_MODE = "min",
6264
):
6365
self._pref_vector = pref_vector
64-
self._scale_mode = scale_mode
66+
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
6567
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
6668

6769
def __repr__(self) -> str:
@@ -89,11 +91,11 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):
8991
def __init__(
9092
self,
9193
pref_vector: Tensor | None = None,
92-
scale_mode: Literal["min", "median", "rmse"] = "min",
94+
scale_mode: SUPPORTED_SCALE_MODE = "min",
9395
):
9496
super().__init__()
9597
self._pref_vector = pref_vector
96-
self._scale_mode = scale_mode
98+
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
9799
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
98100

99101
def forward(self, gramian: PSDMatrix) -> Tensor:
@@ -105,7 +107,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor:
105107

106108
@staticmethod
107109
def _compute_balance_transformation(
108-
M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min"
110+
M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min"
109111
) -> Tensor:
110112
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
111113
tol = torch.max(lambda_) * len(M) * torch.finfo().eps

src/torchjd/aggregation/_dualproj.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from typing import Literal
2-
31
from torch import Tensor
42

53
from torchjd._linalg import PSDMatrix, normalize, regularize
64

75
from ._aggregator_bases import GramianWeightedAggregator
86
from ._mean import MeanWeighting
9-
from ._utils.dual_cone import project_weights
7+
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
108
from ._utils.non_differentiable import raise_non_differentiable_error
119
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
1210
from ._weighting_bases import Weighting
@@ -34,12 +32,12 @@ def __init__(
3432
pref_vector: Tensor | None = None,
3533
norm_eps: float = 0.0001,
3634
reg_eps: float = 0.0001,
37-
solver: Literal["quadprog"] = "quadprog",
35+
solver: SUPPORTED_SOLVER = "quadprog",
3836
):
3937
self._pref_vector = pref_vector
4038
self._norm_eps = norm_eps
4139
self._reg_eps = reg_eps
42-
self._solver = solver
40+
self._solver: SUPPORTED_SOLVER = solver
4341

4442
super().__init__(
4543
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
@@ -78,14 +76,14 @@ def __init__(
7876
pref_vector: Tensor | None = None,
7977
norm_eps: float = 0.0001,
8078
reg_eps: float = 0.0001,
81-
solver: Literal["quadprog"] = "quadprog",
79+
solver: SUPPORTED_SOLVER = "quadprog",
8280
):
8381
super().__init__()
8482
self._pref_vector = pref_vector
8583
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
8684
self.norm_eps = norm_eps
8785
self.reg_eps = reg_eps
88-
self.solver = solver
86+
self.solver: SUPPORTED_SOLVER = solver
8987

9088
def forward(self, gramian: PSDMatrix) -> Tensor:
9189
u = self.weighting(gramian)

src/torchjd/aggregation/_upgrad.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from typing import Literal
2-
31
import torch
42
from torch import Tensor
53

64
from torchjd._linalg import PSDMatrix, normalize, regularize
75

86
from ._aggregator_bases import GramianWeightedAggregator
97
from ._mean import MeanWeighting
10-
from ._utils.dual_cone import project_weights
8+
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
119
from ._utils.non_differentiable import raise_non_differentiable_error
1210
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
1311
from ._weighting_bases import Weighting
@@ -35,12 +33,12 @@ def __init__(
3533
pref_vector: Tensor | None = None,
3634
norm_eps: float = 0.0001,
3735
reg_eps: float = 0.0001,
38-
solver: Literal["quadprog"] = "quadprog",
36+
solver: SUPPORTED_SOLVER = "quadprog",
3937
):
4038
self._pref_vector = pref_vector
4139
self._norm_eps = norm_eps
4240
self._reg_eps = reg_eps
43-
self._solver = solver
41+
self._solver: SUPPORTED_SOLVER = solver
4442

4543
super().__init__(
4644
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
@@ -79,14 +77,14 @@ def __init__(
7977
pref_vector: Tensor | None = None,
8078
norm_eps: float = 0.0001,
8179
reg_eps: float = 0.0001,
82-
solver: Literal["quadprog"] = "quadprog",
80+
solver: SUPPORTED_SOLVER = "quadprog",
8381
):
8482
super().__init__()
8583
self._pref_vector = pref_vector
8684
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
8785
self.norm_eps = norm_eps
8886
self.reg_eps = reg_eps
89-
self.solver = solver
87+
self.solver: SUPPORTED_SOLVER = solver
9088

9189
def forward(self, gramian: PSDMatrix) -> Tensor:
9290
U = torch.diag(self.weighting(gramian))

src/torchjd/aggregation/_utils/dual_cone.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import Literal
1+
from typing import Literal, TypeAlias
22

33
import numpy as np
44
import torch
55
from qpsolvers import solve_qp
66
from torch import Tensor
77

8+
SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"]
89

9-
def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
10+
11+
def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor:
1012
"""
1113
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
1214
rows of a matrix whose Gramian is provided.
@@ -25,7 +27,7 @@ def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor
2527
return torch.as_tensor(W, device=G.device, dtype=G.dtype)
2628

2729

28-
def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray:
30+
def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray:
2931
r"""
3032
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,
3133
given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies

0 commit comments

Comments
 (0)