@@ -20,6 +20,14 @@ class FedProxOptimizer(Optimizer):
2020 It introduces a proximal term to the federated averaging algorithm to
2121 reduce the impact of devices with outlying updates.
2222
23+ IMPORTANT: This optimizer requires a reference to the original (global) model parameters
24+ to calculate the proximal term. These must be set explicitly using the set_old_weights()
25+ method before training begins. The old weights (w_old) must match the order and structure
26+ of the model's parameters. Typically, w_old should be set to the initial global model
27+ parameters received from the aggregator at the beginning of each round.
28+
29+ If mu > 0 and w_old is not set, the optimizer will raise a ValueError.
30+
2331 Paper: https://arxiv.org/pdf/1812.06127.pdf
2432
2533 Attributes:
@@ -67,7 +75,12 @@ def __init__(
6775 if weight_decay < 0.0 :
6876 raise ValueError (f"Invalid weight_decay value: { weight_decay } " )
6977 if mu < 0.0 :
70- raise ValueError (f"Invalid mu value: { mu } " )
78+ import warnings
79+ warnings .warn (
80+ f"Negative mu value ({ mu } ) will cause the proximal term to reward "
81+ f"deviations from global weights, which may be counterintuitive." ,
82+ UserWarning ,
83+ )
7184 defaults = {
7285 "dampening" : dampening ,
7386 "lr" : lr ,
@@ -117,8 +130,15 @@ def step(self, closure=None):
117130 mu = group ["mu" ]
118131 w_old = group ["w_old" ]
119132
120- # Skip FedProx regularization if w_old is not set or mu is 0
121- apply_proximal = w_old is not None and mu > 0
133+ # Check if FedProx regularization should be applied (mu > 0)
134+ if mu > 0 and w_old is None :
135+ raise ValueError (
136+ "FedProx requires old weights to be set when mu > 0. "
137+ "Please call set_old_weights() before optimization step."
138+ )
139+
140+ # Apply proximal term when mu != 0
141+ apply_proximal = w_old is not None and mu != 0
122142
123143 for i , p in enumerate (group ["params" ]):
124144 if p .grad is None :
@@ -147,9 +167,20 @@ def step(self, closure=None):
147167
148168 def set_old_weights (self , old_weights ):
149169 """Set the global weights parameter to `old_weights` value.
170+
171+ This method must be called before training begins to set the reference point for
172+ calculating the proximal term in FedProx. Typically, this should be set to the
173+ initial global model parameters received from the aggregator at the beginning
174+ of each federated learning round.
175+
176+ If mu > 0 and this method is not called, the optimizer will raise a ValueError
177+ during the optimization step.
150178
151179 Args:
152- old_weights: The old weights to be set.
180+ old_weights: List of parameter tensors representing the global model weights.
181+ Must match the order and structure of the model's parameters
182+ being optimized (typically obtained by calling
183+ [p.clone().detach() for p in model.parameters()]).
153184 """
154185 for param_group in self .param_groups :
155186 param_group ["w_old" ] = old_weights
@@ -160,6 +191,14 @@ class FedProxAdam(Optimizer):
160191
161192 Implements the FedProx optimization algorithm with Adam optimizer.
162193
194+ IMPORTANT: This optimizer requires a reference to the original (global) model parameters
195+ to calculate the proximal term. These must be set explicitly using the set_old_weights()
196+ method before training begins. The old weights (w_old) must match the order and structure
197+ of the model's parameters. Typically, w_old should be set to the initial global model
198+ parameters received from the aggregator at the beginning of each round.
199+
200+ If mu > 0 and w_old is not set, the optimizer will raise a ValueError.
201+
163202 Attributes:
164203 params: Parameters to be stored for optimization.
165204 mu: Proximal term coefficient.
@@ -211,7 +250,12 @@ def __init__(
211250 if not 0.0 <= weight_decay :
212251 raise ValueError (f"Invalid weight_decay value: { weight_decay } " )
213252 if mu < 0.0 :
214- raise ValueError (f"Invalid mu value: { mu } " )
253+ import warnings
254+ warnings .warn (
255+ f"Negative mu value ({ mu } ) will cause the proximal term to reward "
256+ f"deviations from global weights, which may be counterintuitive." ,
257+ UserWarning ,
258+ )
215259 defaults = {
216260 "lr" : lr ,
217261 "betas" : betas ,
@@ -231,9 +275,20 @@ def __setstate__(self, state):
231275
232276 def set_old_weights (self , old_weights ):
233277 """Set the global weights parameter to `old_weights` value.
278+
279+ This method must be called before training begins to set the reference point for
280+ calculating the proximal term in FedProx. Typically, this should be set to the
281+ initial global model parameters received from the aggregator at the beginning
282+ of each federated learning round.
283+
284+ If mu > 0 and this method is not called, the optimizer will raise a ValueError
285+ during the optimization step.
234286
235287 Args:
236- old_weights: The old weights to be set.
288+ old_weights: List of parameter tensors representing the global model weights.
289+ Must match the order and structure of the model's parameters
290+ being optimized (typically obtained by calling
291+ [p.clone().detach() for p in model.parameters()]).
237292 """
238293 for param_group in self .param_groups :
239294 param_group ["w_old" ] = old_weights
@@ -356,8 +411,15 @@ def adam(
356411 mu (float): Proximal term coefficient.
357412 w_old: The old weights.
358413 """
359- # Skip FedProx regularization if w_old is not set or mu is 0
360- apply_proximal = w_old is not None and mu > 0
414+ # Check if FedProx regularization should be applied (mu > 0)
415+ if mu > 0 and w_old is None :
416+ raise ValueError (
417+ "FedProx requires old weights to be set when mu > 0. "
418+ "Please call set_old_weights() before optimization step."
419+ )
420+
421+ # Apply proximal term when mu != 0
422+ apply_proximal = w_old is not None and mu != 0
361423
362424 for i , param in enumerate (params ):
363425 grad = grads [i ]
0 commit comments