|
9 | 9 | from timeit import default_timer as timer |
10 | 10 |
|
11 | 11 | import numpy as np |
| 12 | +import tensorflow as tf |
12 | 13 | from loguru import logger |
13 | 14 | from sklearn.metrics import confusion_matrix |
14 | 15 | from tensorflow.keras import Input, Model |
@@ -603,12 +604,73 @@ def fit_minibatch(self): |
603 | 604 | logger.debug("End of S-Model collaborative round.") |
604 | 605 |
|
605 | 606 |
|
| 607 | +class FederatedGradients(MultiPartnerLearning): |
| 608 | + def __init__(self, scenario, **kwargs): |
| 609 | + super(FederatedGradients, self).__init__(scenario, **kwargs) |
| 610 | + if self.partners_count == 1: |
| 611 | + raise ValueError('Only one partner is provided. Please use the dedicated SinglePartnerLearning class') |
| 612 | + self.model = self.build_model() |
| 613 | + |
| 614 | + def fit_epoch(self): |
| 615 | + # Split the train dataset in mini-batches |
| 616 | + self.split_in_minibatches() |
| 617 | + # Iterate over mini-batches and train |
| 618 | + for i in range(self.minibatch_count): |
| 619 | + self.minibatch_index = i |
| 620 | + self.fit_minibatch() |
| 621 | + |
| 622 | + self.minibatch_index = 0 |
| 623 | + |
| 624 | + def fit_minibatch(self): |
| 625 | + """Proceed to a collaborative round with a federated averaging approach""" |
| 626 | + |
| 627 | + logger.debug("Start new gradients fusion collaborative round ...") |
| 628 | + |
| 629 | + # Starting model for each partner is the aggregated model from the previous mini-batch iteration |
| 630 | + logger.info(f"(gradient fusion) Minibatch n°{self.minibatch_index} of epoch n°{self.epoch_index}, " |
| 631 | + f"init each partner's models with a copy of the global model") |
| 632 | + |
| 633 | + for partner in self.partners_list: |
| 634 | + # Evaluate and store accuracy of mini-batch start model |
| 635 | + partner.model_weights = self.model_weights |
| 636 | + self.eval_and_log_model_val_perf() |
| 637 | + |
| 638 | + # Iterate over partners for training each individual model |
| 639 | + for partner_index, partner in enumerate(self.partners_list): |
| 640 | + with tf.GradientTape() as tape: |
| 641 | + loss = self.model.loss(partner.minibatched_y_train[self.minibatch_index], |
| 642 | + self.model(partner.minibatched_x_train[self.minibatch_index])) |
| 643 | + partner.grads = tape.gradient(loss, self.model.trainable_weights) |
| 644 | + |
| 645 | + global_grad = self.aggregator.aggregate_gradients() |
| 646 | + self.model.optimizer.apply_gradients(zip(global_grad, self.model.trainable_weights)) |
| 647 | + self.model_weights = self.model.get_weights() |
| 648 | + |
| 649 | + for partner_index, partner in enumerate(self.partners_list): |
| 650 | + val_history = self.model.evaluate(self.val_data[0], self.val_data[1], verbose=False) |
| 651 | + history = self.model.evaluate(partner.minibatched_x_train[self.minibatch_index], |
| 652 | + partner.minibatched_y_train[self.minibatch_index], verbose=False) |
| 653 | + history = { |
| 654 | + "loss": [history[0]], |
| 655 | + 'accuracy': [history[1]], |
| 656 | + 'val_loss': [val_history[0]], |
| 657 | + 'val_accuracy': [val_history[1]] |
| 658 | + } |
| 659 | + |
| 660 | + # Log results of the round |
| 661 | + self.log_partner_perf(partner.id, partner_index, history) |
| 662 | + |
| 663 | + logger.debug("End of grads-fusion collaborative round.") |
| 664 | + |
| 665 | + |
606 | 666 | # Supported multi-partner learning approaches |
607 | 667 |
|
608 | 668 | MULTI_PARTNER_LEARNING_APPROACHES = { |
609 | 669 | "fedavg": FederatedAverageLearning, |
| 670 | + 'fedgrads': FederatedGradients, |
610 | 671 | "seq-pure": SequentialLearning, |
611 | 672 | "seq-with-final-agg": SequentialWithFinalAggLearning, |
612 | 673 | "seqavg": SequentialAverageLearning, |
613 | | - "lflip": MplSModel |
| 674 | + "smodel": MplSModel, |
| 675 | + |
614 | 676 | } |
0 commit comments