Skip to content

Commit c7426f5

Browse files
authored
Merge pull request #299 from SubstraFoundation/grad-fus-v2
Add gradient fusion method for multi-partner-learning
2 parents b81ff8c + cddc385 commit c7426f5

4 files changed

Lines changed: 99 additions & 7 deletions

File tree

mplc/doc/documentation.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,10 @@ There are several parameters influencing how the collaborative and distributed l
301301
- `'fedavg'`: stands for federated averaging
302302

303303
![Schema fedavg](../../img/collaborative_rounds_fedavg.png)
304-
304+
305+
- `'fedgrads'`: stands for gradient averaging, quite similar to federated averaging, but the partner-loss's gradients are averaged before the optimization step, instead of averaged the model's weights after the optimization step.
306+
Warning : This method needs a Keras model to work with. The `gradient_pass_per_update` is set to 1 in the current implementation.
307+
305308
- `'seq-...'`: stands for sequential and comes with 2 variations, `'seq-pure'` with no aggregation at all, and `'seq-with-final-agg'` where an aggregation is performed before evaluating on the validation set and test set (on last mini-batch of each epoch) for mitigating impact when the very last subset on which the model is trained is of low quality, or corrupted, or just detrimental to the model performance.
306309

307310
![Schema seq](../../img/collaborative_rounds_seq.png)
@@ -358,7 +361,7 @@ There are several parameters influencing how the collaborative and distributed l
358361
- "Federated SBS linear"
359362
- "Federated SBS quadratic"
360363
- "Federated SBS constant"
361-
- "LFlip"
364+
- "Smodel"
362365
- "PVRL"
363366
```
364367

mplc/mpl_utils.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import matplotlib.pyplot as plt
77
import numpy as np
88
import pandas as pd
9+
import tensorflow as tf
910

1011

1112
class History:
@@ -109,6 +110,23 @@ def save_data(self, binary=False):
109110
plt.close()
110111

111112

113+
#####################################
114+
#
115+
# tensorflow functions for aggregator
116+
#
117+
#####################################
118+
119+
@tf.function
120+
def _tf_aggregete_grads(grads, agg_w):
121+
global_grad = list()
122+
for grad_per_layer in zip(*grads):
123+
g = list()
124+
for g_p, w in zip(grad_per_layer, agg_w):
125+
g.append(g_p)
126+
global_grad.append(tf.reduce_mean(g, axis=0))
127+
return global_grad
128+
129+
112130
class Aggregator(ABC):
113131
name = 'abstract'
114132

@@ -117,7 +135,7 @@ def __init__(self, mpl):
117135
:type mpl: MultiPartnerLearning
118136
"""
119137
self.mpl = mpl
120-
self.aggregation_weights = np.zeros(self.mpl.partners_count)
138+
self.aggregation_weights = np.zeros(self.mpl.partners_count, dtype='float32')
121139

122140
def __str__(self):
123141
return f'{self.name} aggregator'
@@ -136,13 +154,17 @@ def aggregate_model_weights(self):
136154

137155
return new_weights
138156

157+
def aggregate_gradients(self):
158+
assert isinstance(self.aggregation_weights, list), 'Aggregation weights must be a list.'
159+
return _tf_aggregete_grads([p.grads for p in self.mpl.partners_list], self.aggregation_weights)
160+
139161

140162
class UniformAggregator(Aggregator):
141163
name = 'Uniform'
142164

143165
def __init__(self, mpl):
144166
super(UniformAggregator, self).__init__(mpl)
145-
self.aggregation_weights = [1 / self.mpl.partners_count] * self.mpl.partners_count
167+
self.aggregation_weights = list(np.ones(self.mpl.partners_count, dtype='float32') * self.mpl.partners_count)
146168

147169

148170
class DataVolumeAggregator(Aggregator):
@@ -151,7 +173,7 @@ class DataVolumeAggregator(Aggregator):
151173
def __init__(self, mpl):
152174
super(DataVolumeAggregator, self).__init__(mpl)
153175
partners_sizes = [partner.data_volume for partner in self.mpl.partners_list]
154-
self.aggregation_weights = partners_sizes / np.sum(partners_sizes)
176+
self.aggregation_weights = list((partners_sizes / np.sum(partners_sizes).astype('float32')))
155177

156178

157179
class ScoresAggregator(Aggregator):
@@ -162,12 +184,16 @@ def __init__(self, mpl):
162184

163185
def prepare_aggregation_weights(self):
164186
last_scores = [partner.last_round_score for partner in self.mpl.partners_list]
165-
self.aggregation_weights = last_scores / np.sum(last_scores)
187+
self.aggregation_weights = list((last_scores / np.sum(last_scores)).astype('float32'))
166188

167189
def aggregate_model_weights(self):
168190
self.prepare_aggregation_weights()
169191
super(ScoresAggregator, self).aggregate_model_weights()
170192

193+
def aggregate_gradients(self):
194+
self.prepare_aggregation_weights()
195+
super(ScoresAggregator, self).aggregate_gradients()
196+
171197

172198
# Supported _aggregation weights approaches
173199
AGGREGATORS = {

mplc/multi_partner_learning.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from timeit import default_timer as timer
1010

1111
import numpy as np
12+
import tensorflow as tf
1213
from loguru import logger
1314
from sklearn.metrics import confusion_matrix
1415
from tensorflow.keras import Input, Model
@@ -603,12 +604,73 @@ def fit_minibatch(self):
603604
logger.debug("End of S-Model collaborative round.")
604605

605606

607+
class FederatedGradients(MultiPartnerLearning):
608+
def __init__(self, scenario, **kwargs):
609+
super(FederatedGradients, self).__init__(scenario, **kwargs)
610+
if self.partners_count == 1:
611+
raise ValueError('Only one partner is provided. Please use the dedicated SinglePartnerLearning class')
612+
self.model = self.build_model()
613+
614+
def fit_epoch(self):
615+
# Split the train dataset in mini-batches
616+
self.split_in_minibatches()
617+
# Iterate over mini-batches and train
618+
for i in range(self.minibatch_count):
619+
self.minibatch_index = i
620+
self.fit_minibatch()
621+
622+
self.minibatch_index = 0
623+
624+
def fit_minibatch(self):
625+
"""Proceed to a collaborative round with a federated averaging approach"""
626+
627+
logger.debug("Start new gradients fusion collaborative round ...")
628+
629+
# Starting model for each partner is the aggregated model from the previous mini-batch iteration
630+
logger.info(f"(gradient fusion) Minibatch n°{self.minibatch_index} of epoch n°{self.epoch_index}, "
631+
f"init each partner's models with a copy of the global model")
632+
633+
for partner in self.partners_list:
634+
# Evaluate and store accuracy of mini-batch start model
635+
partner.model_weights = self.model_weights
636+
self.eval_and_log_model_val_perf()
637+
638+
# Iterate over partners for training each individual model
639+
for partner_index, partner in enumerate(self.partners_list):
640+
with tf.GradientTape() as tape:
641+
loss = self.model.loss(partner.minibatched_y_train[self.minibatch_index],
642+
self.model(partner.minibatched_x_train[self.minibatch_index]))
643+
partner.grads = tape.gradient(loss, self.model.trainable_weights)
644+
645+
global_grad = self.aggregator.aggregate_gradients()
646+
self.model.optimizer.apply_gradients(zip(global_grad, self.model.trainable_weights))
647+
self.model_weights = self.model.get_weights()
648+
649+
for partner_index, partner in enumerate(self.partners_list):
650+
val_history = self.model.evaluate(self.val_data[0], self.val_data[1], verbose=False)
651+
history = self.model.evaluate(partner.minibatched_x_train[self.minibatch_index],
652+
partner.minibatched_y_train[self.minibatch_index], verbose=False)
653+
history = {
654+
"loss": [history[0]],
655+
'accuracy': [history[1]],
656+
'val_loss': [val_history[0]],
657+
'val_accuracy': [val_history[1]]
658+
}
659+
660+
# Log results of the round
661+
self.log_partner_perf(partner.id, partner_index, history)
662+
663+
logger.debug("End of grads-fusion collaborative round.")
664+
665+
606666
# Supported multi-partner learning approaches
607667

608668
MULTI_PARTNER_LEARNING_APPROACHES = {
609669
"fedavg": FederatedAverageLearning,
670+
'fedgrads': FederatedGradients,
610671
"seq-pure": SequentialLearning,
611672
"seq-with-final-agg": SequentialWithFinalAggLearning,
612673
"seqavg": SequentialAverageLearning,
613-
"lflip": MplSModel
674+
"smodel": MplSModel,
675+
614676
}

mplc/partner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, partner_parent, mpl):
7676
:type partner_parent: Partner
7777
:type mpl: MultiPartnerLearning
7878
"""
79+
self.grads = None
7980
self.mpl = mpl
8081
self.id = partner_parent.id
8182
self.batch_size = partner_parent.batch_size

0 commit comments

Comments
 (0)