Skip to content

Commit 5d5699a

Browse files
committed
feat(aggregation): Add getters and setters to UPGrad parameters
Expose pref_vector, norm_eps, and reg_eps as properties on UPGrad and UPGradWeighting so users can read and update them after instantiation. The norm_eps and reg_eps setters validate that the new value is non-negative.
1 parent 4597af8 commit 5d5699a

2 files changed

Lines changed: 68 additions & 9 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Added
12+
13+
- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and
14+
`UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is
15+
non-negative.
16+
1117
## [0.10.0] - 2026-04-16
1218

1319
### Added

src/torchjd/aggregation/_upgrad.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,50 @@ def __init__(
3434
solver: SUPPORTED_SOLVER = "quadprog",
3535
) -> None:
3636
super().__init__()
37-
self._pref_vector = pref_vector
38-
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
37+
self.pref_vector = pref_vector
3938
self.norm_eps = norm_eps
4039
self.reg_eps = reg_eps
4140
self.solver: SUPPORTED_SOLVER = solver
4241

4342
def forward(self, gramian: PSDMatrix, /) -> Tensor:
4443
U = torch.diag(self.weighting(gramian))
45-
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
44+
G = regularize(normalize(gramian, self._norm_eps), self._reg_eps)
4645
W = project_weights(U, G, self.solver)
4746
return torch.sum(W, dim=0)
4847

48+
@property
49+
def pref_vector(self) -> Tensor | None:
50+
return self._pref_vector
51+
52+
@pref_vector.setter
53+
def pref_vector(self, value: Tensor | None) -> None:
54+
self._pref_vector = value
55+
self.weighting = pref_vector_to_weighting(value, default=MeanWeighting())
56+
57+
@property
58+
def norm_eps(self) -> float:
59+
return self._norm_eps
60+
61+
@norm_eps.setter
62+
def norm_eps(self, value: float) -> None:
63+
64+
if value < 0:
65+
raise ValueError(f"norm_eps must be non-negative, but got {value}.")
66+
67+
self._norm_eps = value
68+
69+
@property
70+
def reg_eps(self) -> float:
71+
return self._reg_eps
72+
73+
@reg_eps.setter
74+
def reg_eps(self, value: float) -> None:
75+
76+
if value < 0:
77+
raise ValueError(f"reg_eps must be non-negative, but got {value}.")
78+
79+
self._reg_eps = value
80+
4981

5082
class UPGrad(GramianWeightedAggregator):
5183
r"""
@@ -73,9 +105,6 @@ def __init__(
73105
reg_eps: float = 0.0001,
74106
solver: SUPPORTED_SOLVER = "quadprog",
75107
) -> None:
76-
self._pref_vector = pref_vector
77-
self._norm_eps = norm_eps
78-
self._reg_eps = reg_eps
79108
self._solver: SUPPORTED_SOLVER = solver
80109

81110
super().__init__(
@@ -85,11 +114,35 @@ def __init__(
85114
# This prevents considering the computed weights as constant w.r.t. the matrix.
86115
self.register_full_backward_pre_hook(raise_non_differentiable_error)
87116

117+
@property
118+
def pref_vector(self) -> Tensor | None:
119+
return self.gramian_weighting.pref_vector
120+
121+
@pref_vector.setter
122+
def pref_vector(self, value: Tensor | None) -> None:
123+
self.gramian_weighting.pref_vector = value
124+
125+
@property
126+
def norm_eps(self) -> float:
127+
return self.gramian_weighting.norm_eps
128+
129+
@norm_eps.setter
130+
def norm_eps(self, value: float) -> None:
131+
self.gramian_weighting.norm_eps = value
132+
133+
@property
134+
def reg_eps(self) -> float:
135+
return self.gramian_weighting.reg_eps
136+
137+
@reg_eps.setter
138+
def reg_eps(self, value: float) -> None:
139+
self.gramian_weighting.reg_eps = value
140+
88141
def __repr__(self) -> str:
89142
return (
90-
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
91-
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
143+
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
144+
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
92145
)
93146

94147
def __str__(self) -> str:
95-
return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}"
148+
return f"UPGrad{pref_vector_to_str_suffix(self.pref_vector)}"

0 commit comments

Comments
 (0)