Skip to content

Commit b5a2cf7

Browse files
committed
change_for_mu_values_comment
1 parent 722ed1c commit b5a2cf7

1 file changed

Lines changed: 70 additions & 8 deletions

File tree

openfl/utilities/optimizers/torch/fedprox.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ class FedProxOptimizer(Optimizer):
2020
It introduces a proximal term to the federated averaging algorithm to
2121
reduce the impact of devices with outlying updates.
2222
23+
IMPORTANT: This optimizer requires a reference to the original (global) model parameters
24+
to calculate the proximal term. These must be set explicitly using the set_old_weights()
25+
method before training begins. The old weights (w_old) must match the order and structure
26+
of the model's parameters. Typically, w_old should be set to the initial global model
27+
parameters received from the aggregator at the beginning of each round.
28+
29+
If mu > 0 and w_old is not set, the optimizer will raise a ValueError.
30+
2331
Paper: https://arxiv.org/pdf/1812.06127.pdf
2432
2533
Attributes:
@@ -67,7 +75,12 @@ def __init__(
6775
if weight_decay < 0.0:
6876
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
6977
if mu < 0.0:
70-
raise ValueError(f"Invalid mu value: {mu}")
78+
import warnings
79+
warnings.warn(
80+
f"Negative mu value ({mu}) will cause the proximal term to reward "
81+
f"deviations from global weights, which may be counterintuitive.",
82+
UserWarning,
83+
)
7184
defaults = {
7285
"dampening": dampening,
7386
"lr": lr,
@@ -117,8 +130,15 @@ def step(self, closure=None):
117130
mu = group["mu"]
118131
w_old = group["w_old"]
119132

120-
# Skip FedProx regularization if w_old is not set or mu is 0
121-
apply_proximal = w_old is not None and mu > 0
133+
# Check if FedProx regularization should be applied (mu > 0)
134+
if mu > 0 and w_old is None:
135+
raise ValueError(
136+
"FedProx requires old weights to be set when mu > 0. "
137+
"Please call set_old_weights() before optimization step."
138+
)
139+
140+
# Apply proximal term when mu != 0
141+
apply_proximal = w_old is not None and mu != 0
122142

123143
for i, p in enumerate(group["params"]):
124144
if p.grad is None:
@@ -147,9 +167,20 @@ def step(self, closure=None):
147167

148168
def set_old_weights(self, old_weights):
149169
"""Set the global weights parameter to `old_weights` value.
170+
171+
This method must be called before training begins to set the reference point for
172+
calculating the proximal term in FedProx. Typically, this should be set to the
173+
initial global model parameters received from the aggregator at the beginning
174+
of each federated learning round.
175+
176+
If mu > 0 and this method is not called, the optimizer will raise a ValueError
177+
during the optimization step.
150178
151179
Args:
152-
old_weights: The old weights to be set.
180+
old_weights: List of parameter tensors representing the global model weights.
181+
Must match the order and structure of the model's parameters
182+
being optimized (typically obtained by calling
183+
[p.clone().detach() for p in model.parameters()]).
153184
"""
154185
for param_group in self.param_groups:
155186
param_group["w_old"] = old_weights
@@ -160,6 +191,14 @@ class FedProxAdam(Optimizer):
160191
161192
Implements the FedProx optimization algorithm with Adam optimizer.
162193
194+
IMPORTANT: This optimizer requires a reference to the original (global) model parameters
195+
to calculate the proximal term. These must be set explicitly using the set_old_weights()
196+
method before training begins. The old weights (w_old) must match the order and structure
197+
of the model's parameters. Typically, w_old should be set to the initial global model
198+
parameters received from the aggregator at the beginning of each round.
199+
200+
If mu > 0 and w_old is not set, the optimizer will raise a ValueError.
201+
163202
Attributes:
164203
params: Parameters to be stored for optimization.
165204
mu: Proximal term coefficient.
@@ -211,7 +250,12 @@ def __init__(
211250
if not 0.0 <= weight_decay:
212251
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
213252
if mu < 0.0:
214-
raise ValueError(f"Invalid mu value: {mu}")
253+
import warnings
254+
warnings.warn(
255+
f"Negative mu value ({mu}) will cause the proximal term to reward "
256+
f"deviations from global weights, which may be counterintuitive.",
257+
UserWarning,
258+
)
215259
defaults = {
216260
"lr": lr,
217261
"betas": betas,
@@ -231,9 +275,20 @@ def __setstate__(self, state):
231275

232276
def set_old_weights(self, old_weights):
233277
"""Set the global weights parameter to `old_weights` value.
278+
279+
This method must be called before training begins to set the reference point for
280+
calculating the proximal term in FedProx. Typically, this should be set to the
281+
initial global model parameters received from the aggregator at the beginning
282+
of each federated learning round.
283+
284+
If mu > 0 and this method is not called, the optimizer will raise a ValueError
285+
during the optimization step.
234286
235287
Args:
236-
old_weights: The old weights to be set.
288+
old_weights: List of parameter tensors representing the global model weights.
289+
Must match the order and structure of the model's parameters
290+
being optimized (typically obtained by calling
291+
[p.clone().detach() for p in model.parameters()]).
237292
"""
238293
for param_group in self.param_groups:
239294
param_group["w_old"] = old_weights
@@ -356,8 +411,15 @@ def adam(
356411
mu (float): Proximal term coefficient.
357412
w_old: The old weights.
358413
"""
359-
# Skip FedProx regularization if w_old is not set or mu is 0
360-
apply_proximal = w_old is not None and mu > 0
414+
# Check if FedProx regularization should be applied (mu > 0)
415+
if mu > 0 and w_old is None:
416+
raise ValueError(
417+
"FedProx requires old weights to be set when mu > 0. "
418+
"Please call set_old_weights() before optimization step."
419+
)
420+
421+
# Apply proximal term when mu != 0
422+
apply_proximal = w_old is not None and mu != 0
361423

362424
for i, param in enumerate(params):
363425
grad = grads[i]

0 commit comments

Comments
 (0)