1515class AggregatorException (Exception ):
1616 pass
1717
18-
19- def create_target_aggregator (config , engine ):
20- from nebula .core .aggregation .fedavg import FedAvg
21- from nebula .core .aggregation .krum import Krum
22- from nebula .core .aggregation .median import Median
23- from nebula .core .aggregation .trimmedmean import TrimmedMean
24-
25- ALGORITHM_MAP = {
26- "FedAvg" : FedAvg ,
27- "Krum" : Krum ,
28- "Median" : Median ,
29- "TrimmedMean" : TrimmedMean ,
30- }
31- algorithm = config .participant ["defense_args" ]["target_aggregation" ]
32- aggregator = ALGORITHM_MAP .get (algorithm )
33- if aggregator :
34- return aggregator (config = config , engine = engine )
35- else :
36- raise AggregatorException (f"Aggregation algorithm { algorithm } not found." )
37-
38-
3918class Aggregator (ABC ):
4019 def __init__ (self , config = None , engine = None ):
4120 self .config = config
@@ -59,6 +38,7 @@ def __repr__(self):
5938
6039 @property
6140 def us (self ):
41+ """Federation type UpdateHandler (e.g. DFL-UpdateHandler, CFL-UpdateHandler...)"""
6242 return self ._update_storage
6343
6444 @abstractmethod
@@ -71,6 +51,20 @@ async def init(self):
7151 await self .us .init (self .config )
7252
7353 async def update_federation_nodes (self , federation_nodes : set ):
54+ """
55+ Updates the current set of nodes expected to participate in the upcoming aggregation round.
56+
57+ This method informs the update handler (`us`) about the new set of federation nodes,
58+ clears any pending models, and attempts to acquire the aggregation lock to prepare
59+ for model aggregation. If the aggregation process is already running, it raises an exception.
60+
61+ Args:
62+ federation_nodes (set): A set of addresses representing the nodes expected to contribute
63+ updates for the next aggregation round.
64+
65+ Raises:
66+ Exception: If the aggregation process is already running and the lock is currently held.
67+ """
7468 await self .us .round_expected_updates (federation_nodes = federation_nodes )
7569
7670 if not self ._aggregation_done_lock .locked ():
@@ -86,6 +80,23 @@ def get_nodes_pending_models_to_aggregate(self):
8680 return self ._federation_nodes
8781
8882 async def get_aggregation (self ):
83+ """
84+ Handles the aggregation process for a training round.
85+
86+ This method waits for all expected model updates from federation nodes or until a timeout occurs.
87+ It uses an asynchronous lock to coordinate access and includes an early exit mechanism if all
88+ updates are received before the timeout. Once the condition is satisfied, it releases the lock,
89+ collects the updates, identifies any missing nodes, and publishes an `AggregationEvent`.
90+ Finally, it runs the aggregation algorithm and returns the result.
91+
92+ Returns:
93+ Any: The result of the aggregation process, as returned by `run_aggregation`.
94+
95+ Raises:
96+ TimeoutError: If the aggregation lock is not acquired within the defined timeout.
97+ asyncio.CancelledError: If the aggregation lock acquisition is cancelled.
98+ Exception: For any other unexpected errors during the aggregation process.
99+ """
89100 try :
90101 timeout = self .config .participant ["aggregator_args" ]["aggregation_timeout" ]
91102 logging .info (f"Aggregation timeout: { timeout } starts..." )
0 commit comments