Skip to content

Commit 60711e3

Browse files
feature integrate sa training and policies
1 parent 353a480 commit 60711e3

7 files changed

Lines changed: 574 additions & 5 deletions

File tree

nebula/core/situationalawareness/awareness/sareasoner.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,15 @@ async def _initialize_sa_components(self):
315315
await sacomp.init()
316316

317317
def _load_minimal_requirement_config(self):
318-
self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["addr"] = self._addr
319-
self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["sar"] = self
320-
self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["strict_topology"] = self._config[
321-
"situational_awareness"
322-
]["strict_topology"]
318+
#self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["addr"] = self._addr
319+
#self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["sar"] = self
320+
self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["strict_topology"] = self._config["situational_awareness"]["strict_topology"]
321+
322+
# SA Reasoner instance for all SA Reasoner Components
323+
sar_components: dict = self._config["situational_awareness"]["sa_reasoner"]["sar_components"]
324+
for sar_comp in sar_components.keys():
325+
self._config["situational_awareness"]["sa_reasoner"][sar_comp]["sar"] = self
326+
self._config["situational_awareness"]["sa_reasoner"][sar_comp]["addr"] = self._addr
323327

324328
async def _set_minimal_requirements(self):
325329
if self._sa_components:
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import asyncio
2+
import logging
3+
from nebula.core.utils.locker import Locker
4+
from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import factory_training_policy
5+
from nebula.core.situationalawareness.awareness.sareasoner import SAMComponent
6+
from nebula.addons.functions import print_msg_box
7+
from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner, SAMComponent
8+
from nebula.core.eventmanager import EventManager
9+
10+
RESTRUCTURE_COOLDOWN = 5
11+
12+
class SATraining(SAMComponent):
13+
"""
14+
SATraining is a Situational Awareness (SA) component responsible for enhancing
15+
the training process in Distributed Federated Learning (DFL) environments
16+
by leveraging context-awareness and environmental knowledge.
17+
18+
This component dynamically instantiates a training policy based on the configuration,
19+
allowing the system to adapt training strategies depending on the local topology,
20+
node behavior, or environmental constraints.
21+
22+
Attributes:
23+
_config (dict): Configuration dictionary containing parameters and references.
24+
_sar (SAReasoner): Reference to the shared situational reasoner.
25+
_trainning_policy: Instantiated training policy strategy.
26+
"""
27+
28+
def __init__(self, config):
29+
"""
30+
Initialize the SATraining component with a given configuration.
31+
32+
Args:
33+
config (dict): Configuration dictionary containing:
34+
- 'addr': Node address.
35+
- 'verbose': Verbosity flag.
36+
- 'sar': Reference to the SAReasoner instance.
37+
- 'training_policy': Training policy name to be used.
38+
"""
39+
print_msg_box(
40+
msg=f"Starting Training SA\nTraining policy: {training_policy}",
41+
indent=2,
42+
title="Training SA module",
43+
)
44+
self._config = config
45+
self._sar: SAReasoner = self._config["sar"]
46+
tp_config = {}
47+
tp_config["addr"] = self._config["addr"]
48+
tp_config["verbose"] = self._config["verbose"]
49+
training_policy = self._config["training_policy"]
50+
self._trainning_policy = factory_training_policy(training_policy, tp_config)
51+
52+
@property
53+
def sar(self):
54+
"""
55+
Returns the current instance of the SAReasoner.
56+
"""
57+
return self._sar
58+
59+
@property
60+
def tp(self):
61+
"""
62+
Returns the currently active training policy instance.
63+
"""
64+
return self._trainning_policy
65+
66+
async def init(self):
67+
"""
68+
Initialize the training policy with the current known neighbors from the SAReasoner.
69+
This setup enables the policy to make informed decisions based on local topology.
70+
"""
71+
config = {}
72+
config["nodes"] = set(await self.sar.get_nodes_known(neighbors_only=True))
73+
await self.tp.init(config)
74+
75+
async def sa_component_actions(self):
76+
"""
77+
Periodically called action of the SA component to evaluate the current scenario.
78+
This invokes the evaluation logic defined in the training policy to adapt behavior.
79+
"""
80+
logging.info("SA Trainng evaluating current scenario")
81+
asyncio.create_task(self.tp.get_evaluation_results())
82+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy
2+
from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer
3+
from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, factory_sa_command, SACommandAction, SACommandPRIO
4+
from nebula.core.nebulaevents import RoundEndEvent
5+
6+
class BPSTrainingPolicy(TrainingPolicy):
7+
8+
def __init__(self, config=None):
9+
pass
10+
11+
async def init(self, config):
12+
await self.register_sa_agent()
13+
14+
async def get_evaluation_results(self):
15+
sac = factory_sa_command(
16+
"connectivity",
17+
SACommandAction.MAINTAIN_CONNECTIONS,
18+
self,
19+
"",
20+
SACommandPRIO.LOW,
21+
False,
22+
None,
23+
None
24+
)
25+
await self.suggest_action(sac)
26+
await self.notify_all_suggestions_done(RoundEndEvent)
27+
28+
async def get_agent(self) -> str:
29+
return "SATraining_BPSTP"
30+
31+
async def register_sa_agent(self):
32+
await SuggestionBuffer.get_instance().register_event_agents(RoundEndEvent, self)
33+
34+
async def suggest_action(self, sac : SACommand):
35+
await SuggestionBuffer.get_instance().register_suggestion(RoundEndEvent, self, sac)
36+
37+
async def notify_all_suggestions_done(self, event_type):
38+
await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import logging
2+
from nebula.core.utils.locker import Locker
3+
from nebula.core.eventmanager import EventManager
4+
from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent
5+
from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy
6+
from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer
7+
from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command
8+
9+
VANILLA_LEARNING_RATE = 1e-3
10+
FR_LEARNING_RATE = 1e-3
11+
MAX_ROUNDS = 20
12+
DEFAULT_WEIGHT_MODIFIER = 3
13+
14+
15+
class FastReboot(TrainingPolicy):
16+
def __init__(
17+
self,
18+
config
19+
):
20+
logging.info("🌐 Initializing FastReboot")
21+
self._max_rounds = MAX_ROUNDS # Max rounds to be applied FastReboot
22+
self._weight_mod_value = DEFAULT_WEIGHT_MODIFIER
23+
self._default_lr = VANILLA_LEARNING_RATE # Stable value for learning rate
24+
self._upgrade_lr = FR_LEARNING_RATE # Increased value for learning rate
25+
self._current_lr = VANILLA_LEARNING_RATE
26+
self._verbose = config["verbose"]
27+
28+
self._learning_rate_lock = Locker(name="learning_rate_lock", async_lock=True)
29+
self._weight_modifier = {}
30+
self._weight_modifier_lock = Locker(name="weight_modifier_lock", async_lock=True)
31+
32+
self._fr_in_progress = False
33+
34+
async def init(self, config):
35+
#await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent)
36+
#await EventManager.get_instance().subscribe_node_event(AggregationEvent)
37+
pass
38+
39+
async def get_evaluation_results(self):
40+
pass
41+
42+
def __str__(self):
43+
return "FRTS"
44+
45+
async def _get_current_learning_rate(self):
46+
await self._learning_rate_lock.acquire_async()
47+
lr = self._current_lr
48+
await self._learning_rate_lock.release_async()
49+
return lr
50+
51+
async def discard_fastreboot_for(self, addr):
52+
await self._weight_modifier_lock.acquire_async()
53+
try:
54+
del self._weight_modifier[addr]
55+
except KeyError:
56+
pass
57+
await self._weight_modifier_lock.release_async()
58+
59+
async def _set_learning_rate(self, lr):
60+
await self._learning_rate_lock.acquire_async()
61+
self._current_lr = lr
62+
await self._learning_rate_lock.release_async()
63+
64+
async def add_fastReboot_addr(self, addr):
65+
await self._weight_modifier_lock.acquire_async()
66+
if addr not in self._weight_modifier:
67+
self._fr_in_progress = True
68+
wm = self._weight_mod_value
69+
logging.info(
70+
f"📝 Registering | FastReboot registered for source {addr} | round application: {self._max_rounds} | multiplier value: {wm}"
71+
)
72+
self._weight_modifier[addr] = (wm, 1)
73+
await self._set_learning_rate(self._upgrade_lr)
74+
current_lr = await self._get_current_learning_rate()
75+
#TODO modify learning rate suggestion await self.nm.update_learning_rate(current_lr)
76+
await self._weight_modifier_lock.release_async()
77+
78+
async def _remove_weight_modifier(self, addr):
79+
logging.info(f"📝 Removing | FastReboot removed for source {addr}")
80+
del self._weight_modifier[addr]
81+
82+
async def _weight_modifiers_empty(self):
83+
await self._weight_modifier_lock.acquire_async()
84+
empty = False if self._weight_modifier else True
85+
await self._weight_modifier_lock.release_async()
86+
return empty
87+
88+
async def apply_weight_strategy(self, updates: dict):
89+
if await self._weight_modifiers_empty():
90+
if self._fr_in_progress:
91+
await self._end_fastreboot()
92+
return
93+
logging.info("🔄 Applying FastReboot Strategy...")
94+
for addr, update in updates.items():
95+
weightmodifier, rounds = await self._get_weight_modifier(addr)
96+
if weightmodifier != 1:
97+
logging.info(
98+
f"📝 Appliying FastReboot strategy | addr: {addr} | multiplier value: {weightmodifier}, rounds applied: {rounds}"
99+
)
100+
model, weight = update
101+
updates.update({addr: (model, weight * weightmodifier)})
102+
await self._update_weight_modifiers()
103+
104+
async def _update_weight_modifiers(self):
105+
await self._weight_modifier_lock.acquire_async()
106+
if self._weight_modifier:
107+
logging.info("🔄 Update | weights being updated")
108+
remove_addrs = []
109+
for addr, (weight, rounds) in self._weight_modifier.items():
110+
new_weight = weight - 1 / (rounds**2)
111+
rounds = rounds + 1
112+
if new_weight > 1 and rounds <= self._max_rounds:
113+
self._weight_modifier[addr] = (new_weight, rounds)
114+
else:
115+
remove_addrs.append(addr)
116+
for a in remove_addrs:
117+
await self._remove_weight_modifier(a)
118+
await self._weight_modifier_lock.release_async()
119+
120+
async def _end_fastreboot(self):
121+
await self._weight_modifier_lock.acquire_async()
122+
if not self._weight_modifier and await self._is_lr_modified():
123+
logging.info("🔄 Finishing | FastReboot is completed")
124+
self._fr_in_progress = False
125+
await self._set_learning_rate(self._default_lr)
126+
#TODO modify learning rate suggestion await self.nm.update_learning_rate(self._default_lr)
127+
await self._weight_modifier_lock.release_async()
128+
129+
async def _get_weight_modifier(self, addr):
130+
await self._weight_modifier_lock.acquire_async()
131+
wm = self._weight_modifier.get(addr, (1, 0))
132+
await self._weight_modifier_lock.release_async()
133+
return wm
134+
135+
async def _is_lr_modified(self):
136+
await self._learning_rate_lock.acquire_async()
137+
mod = self._current_lr == self._upgrade_lr
138+
await self._learning_rate_lock.release_async()
139+
return mod
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy
2+
from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import factory_training_policy
3+
from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy
4+
import logging
5+
6+
# "Hybrid Training Strategy" (HTS)
7+
class HTSTrainingPolicy(TrainingPolicy):
8+
"""
9+
Implements a Hybrid Training Strategy (HTS) that combines multiple training policies
10+
(e.g., QDS, FRTS) to collaboratively decide on the evaluation and potential pruning
11+
of neighbors in a decentralized federated learning scenario.
12+
13+
Attributes:
14+
TRAINING_POLICY (set): Names of training policy classes to instantiate and manage.
15+
"""
16+
17+
TRAINING_POLICY = {
18+
"qds",
19+
"frts",
20+
}
21+
22+
def __init__(self, config):
23+
"""
24+
Initializes the HTS policy with the node's address and verbosity level.
25+
It creates instances of each sub-policy listed in TRAINING_POLICY.
26+
27+
Args:
28+
config (dict): Configuration dictionary with keys:
29+
- 'addr': Node's address
30+
- 'verbose': Enable verbose logging
31+
"""
32+
self._addr = config["addr"]
33+
self._verbose = config["verbose"]
34+
self._training_policies : set[TrainingPolicy] = set()
35+
self._training_policies.add([factory_training_policy(x, config) for x in self.TRAINING_POLICY])
36+
37+
def __str__(self):
38+
return "HTS"
39+
40+
@property
41+
def tps(self):
42+
return self._training_policies
43+
44+
async def init(self, config):
45+
for tp in self.tps:
46+
await tp.init(config)
47+
48+
async def update_neighbors(self, node, remove=False):
49+
pass
50+
51+
async def get_evaluation_results(self):
52+
"""
53+
Asynchronously calls the `get_evaluation_results` of each policy,
54+
and logs the nodes each policy would remove.
55+
56+
Returns:
57+
None (future version may merge all evaluations).
58+
"""
59+
nodes_to_remove = dict()
60+
for tp in self.tps:
61+
nodes_to_remove[tp] = await tp.get_evaluation_results()
62+
63+
for tp, nodes in nodes_to_remove.items():
64+
logging.info(f"Training Policy: {tp}, nodes to remove: {nodes}")
65+
66+
return None

0 commit comments

Comments
 (0)