@@ -34,18 +34,50 @@ def __init__(
3434 solver : SUPPORTED_SOLVER = "quadprog" ,
3535 ) -> None :
3636 super ().__init__ ()
37- self ._pref_vector = pref_vector
38- self .weighting = pref_vector_to_weighting (pref_vector , default = MeanWeighting ())
37+ self .pref_vector = pref_vector
3938 self .norm_eps = norm_eps
4039 self .reg_eps = reg_eps
4140 self .solver : SUPPORTED_SOLVER = solver
4241
4342 def forward (self , gramian : PSDMatrix , / ) -> Tensor :
4443 U = torch .diag (self .weighting (gramian ))
45- G = regularize (normalize (gramian , self .norm_eps ), self .reg_eps )
44+ G = regularize (normalize (gramian , self ._norm_eps ), self ._reg_eps )
4645 W = project_weights (U , G , self .solver )
4746 return torch .sum (W , dim = 0 )
4847
48+ @property
49+ def pref_vector (self ) -> Tensor | None :
50+ return self ._pref_vector
51+
52+ @pref_vector .setter
53+ def pref_vector (self , value : Tensor | None ) -> None :
54+ self ._pref_vector = value
55+ self .weighting = pref_vector_to_weighting (value , default = MeanWeighting ())
56+
57+ @property
58+ def norm_eps (self ) -> float :
59+ return self ._norm_eps
60+
61+ @norm_eps .setter
62+ def norm_eps (self , value : float ) -> None :
63+
64+ if value < 0 :
65+ raise ValueError (f"norm_eps must be non-negative, but got { value } ." )
66+
67+ self ._norm_eps = value
68+
69+ @property
70+ def reg_eps (self ) -> float :
71+ return self ._reg_eps
72+
73+ @reg_eps .setter
74+ def reg_eps (self , value : float ) -> None :
75+
76+ if value < 0 :
77+ raise ValueError (f"reg_eps must be non-negative, but got { value } ." )
78+
79+ self ._reg_eps = value
80+
4981
5082class UPGrad (GramianWeightedAggregator ):
5183 r"""
@@ -73,9 +105,6 @@ def __init__(
73105 reg_eps : float = 0.0001 ,
74106 solver : SUPPORTED_SOLVER = "quadprog" ,
75107 ) -> None :
76- self ._pref_vector = pref_vector
77- self ._norm_eps = norm_eps
78- self ._reg_eps = reg_eps
79108 self ._solver : SUPPORTED_SOLVER = solver
80109
81110 super ().__init__ (
@@ -85,11 +114,35 @@ def __init__(
85114 # This prevents considering the computed weights as constant w.r.t. the matrix.
86115 self .register_full_backward_pre_hook (raise_non_differentiable_error )
87116
117+ @property
118+ def pref_vector (self ) -> Tensor | None :
119+ return self .gramian_weighting .pref_vector
120+
121+ @pref_vector .setter
122+ def pref_vector (self , value : Tensor | None ) -> None :
123+ self .gramian_weighting .pref_vector = value
124+
125+ @property
126+ def norm_eps (self ) -> float :
127+ return self .gramian_weighting .norm_eps
128+
129+ @norm_eps .setter
130+ def norm_eps (self , value : float ) -> None :
131+ self .gramian_weighting .norm_eps = value
132+
133+ @property
134+ def reg_eps (self ) -> float :
135+ return self .gramian_weighting .reg_eps
136+
137+ @reg_eps .setter
138+ def reg_eps (self , value : float ) -> None :
139+ self .gramian_weighting .reg_eps = value
140+
88141 def __repr__ (self ) -> str :
89142 return (
90- f"{ self .__class__ .__name__ } (pref_vector={ repr (self ._pref_vector )} , norm_eps="
91- f"{ self ._norm_eps } , reg_eps={ self ._reg_eps } , solver={ repr (self ._solver )} )"
143+ f"{ self .__class__ .__name__ } (pref_vector={ repr (self .pref_vector )} , norm_eps="
144+ f"{ self .norm_eps } , reg_eps={ self .reg_eps } , solver={ repr (self ._solver )} )"
92145 )
93146
94147 def __str__ (self ) -> str :
95- return f"UPGrad{ pref_vector_to_str_suffix (self ._pref_vector )} "
148+ return f"UPGrad{ pref_vector_to_str_suffix (self .pref_vector )} "
0 commit comments