|
| 1 | +import logging |
| 2 | +from nebula.core.utils.locker import Locker |
| 3 | +from nebula.core.eventmanager import EventManager |
| 4 | +from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent |
| 5 | +from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy |
| 6 | +from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer |
| 7 | +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command |
| 8 | + |
| 9 | +VANILLA_LEARNING_RATE = 1e-3 |
| 10 | +FR_LEARNING_RATE = 1e-3 |
| 11 | +MAX_ROUNDS = 20 |
| 12 | +DEFAULT_WEIGHT_MODIFIER = 3 |
| 13 | + |
| 14 | + |
| 15 | +class FastReboot(TrainingPolicy): |
| 16 | + def __init__( |
| 17 | + self, |
| 18 | + config |
| 19 | + ): |
| 20 | + logging.info("🌐 Initializing FastReboot") |
| 21 | + self._max_rounds = MAX_ROUNDS # Max rounds to be applied FastReboot |
| 22 | + self._weight_mod_value = DEFAULT_WEIGHT_MODIFIER |
| 23 | + self._default_lr = VANILLA_LEARNING_RATE # Stable value for learning rate |
| 24 | + self._upgrade_lr = FR_LEARNING_RATE # Increased value for learning rate |
| 25 | + self._current_lr = VANILLA_LEARNING_RATE |
| 26 | + self._verbose = config["verbose"] |
| 27 | + |
| 28 | + self._learning_rate_lock = Locker(name="learning_rate_lock", async_lock=True) |
| 29 | + self._weight_modifier = {} |
| 30 | + self._weight_modifier_lock = Locker(name="weight_modifier_lock", async_lock=True) |
| 31 | + |
| 32 | + self._fr_in_progress = False |
| 33 | + |
| 34 | + async def init(self, config): |
| 35 | + #await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent) |
| 36 | + #await EventManager.get_instance().subscribe_node_event(AggregationEvent) |
| 37 | + pass |
| 38 | + |
| 39 | + async def get_evaluation_results(self): |
| 40 | + pass |
| 41 | + |
| 42 | + def __str__(self): |
| 43 | + return "FRTS" |
| 44 | + |
| 45 | + async def _get_current_learning_rate(self): |
| 46 | + await self._learning_rate_lock.acquire_async() |
| 47 | + lr = self._current_lr |
| 48 | + await self._learning_rate_lock.release_async() |
| 49 | + return lr |
| 50 | + |
| 51 | + async def discard_fastreboot_for(self, addr): |
| 52 | + await self._weight_modifier_lock.acquire_async() |
| 53 | + try: |
| 54 | + del self._weight_modifier[addr] |
| 55 | + except KeyError: |
| 56 | + pass |
| 57 | + await self._weight_modifier_lock.release_async() |
| 58 | + |
| 59 | + async def _set_learning_rate(self, lr): |
| 60 | + await self._learning_rate_lock.acquire_async() |
| 61 | + self._current_lr = lr |
| 62 | + await self._learning_rate_lock.release_async() |
| 63 | + |
| 64 | + async def add_fastReboot_addr(self, addr): |
| 65 | + await self._weight_modifier_lock.acquire_async() |
| 66 | + if addr not in self._weight_modifier: |
| 67 | + self._fr_in_progress = True |
| 68 | + wm = self._weight_mod_value |
| 69 | + logging.info( |
| 70 | + f"📝 Registering | FastReboot registered for source {addr} | round application: {self._max_rounds} | multiplier value: {wm}" |
| 71 | + ) |
| 72 | + self._weight_modifier[addr] = (wm, 1) |
| 73 | + await self._set_learning_rate(self._upgrade_lr) |
| 74 | + current_lr = await self._get_current_learning_rate() |
| 75 | + #TODO modify learning rate suggestion await self.nm.update_learning_rate(current_lr) |
| 76 | + await self._weight_modifier_lock.release_async() |
| 77 | + |
| 78 | + async def _remove_weight_modifier(self, addr): |
| 79 | + logging.info(f"📝 Removing | FastReboot removed for source {addr}") |
| 80 | + del self._weight_modifier[addr] |
| 81 | + |
| 82 | + async def _weight_modifiers_empty(self): |
| 83 | + await self._weight_modifier_lock.acquire_async() |
| 84 | + empty = False if self._weight_modifier else True |
| 85 | + await self._weight_modifier_lock.release_async() |
| 86 | + return empty |
| 87 | + |
| 88 | + async def apply_weight_strategy(self, updates: dict): |
| 89 | + if await self._weight_modifiers_empty(): |
| 90 | + if self._fr_in_progress: |
| 91 | + await self._end_fastreboot() |
| 92 | + return |
| 93 | + logging.info("🔄 Applying FastReboot Strategy...") |
| 94 | + for addr, update in updates.items(): |
| 95 | + weightmodifier, rounds = await self._get_weight_modifier(addr) |
| 96 | + if weightmodifier != 1: |
| 97 | + logging.info( |
| 98 | + f"📝 Appliying FastReboot strategy | addr: {addr} | multiplier value: {weightmodifier}, rounds applied: {rounds}" |
| 99 | + ) |
| 100 | + model, weight = update |
| 101 | + updates.update({addr: (model, weight * weightmodifier)}) |
| 102 | + await self._update_weight_modifiers() |
| 103 | + |
| 104 | + async def _update_weight_modifiers(self): |
| 105 | + await self._weight_modifier_lock.acquire_async() |
| 106 | + if self._weight_modifier: |
| 107 | + logging.info("🔄 Update | weights being updated") |
| 108 | + remove_addrs = [] |
| 109 | + for addr, (weight, rounds) in self._weight_modifier.items(): |
| 110 | + new_weight = weight - 1 / (rounds**2) |
| 111 | + rounds = rounds + 1 |
| 112 | + if new_weight > 1 and rounds <= self._max_rounds: |
| 113 | + self._weight_modifier[addr] = (new_weight, rounds) |
| 114 | + else: |
| 115 | + remove_addrs.append(addr) |
| 116 | + for a in remove_addrs: |
| 117 | + await self._remove_weight_modifier(a) |
| 118 | + await self._weight_modifier_lock.release_async() |
| 119 | + |
| 120 | + async def _end_fastreboot(self): |
| 121 | + await self._weight_modifier_lock.acquire_async() |
| 122 | + if not self._weight_modifier and await self._is_lr_modified(): |
| 123 | + logging.info("🔄 Finishing | FastReboot is completed") |
| 124 | + self._fr_in_progress = False |
| 125 | + await self._set_learning_rate(self._default_lr) |
| 126 | + #TODO modify learning rate suggestion await self.nm.update_learning_rate(self._default_lr) |
| 127 | + await self._weight_modifier_lock.release_async() |
| 128 | + |
| 129 | + async def _get_weight_modifier(self, addr): |
| 130 | + await self._weight_modifier_lock.acquire_async() |
| 131 | + wm = self._weight_modifier.get(addr, (1, 0)) |
| 132 | + await self._weight_modifier_lock.release_async() |
| 133 | + return wm |
| 134 | + |
| 135 | + async def _is_lr_modified(self): |
| 136 | + await self._learning_rate_lock.acquire_async() |
| 137 | + mod = self._current_lr == self._upgrade_lr |
| 138 | + await self._learning_rate_lock.release_async() |
| 139 | + return mod |
0 commit comments