Skip to content

Commit 73e5ce6

Browse files
committed
Adding RDT config and reformatted according to ruff.
1 parent b10f199 commit 73e5ce6

20 files changed

Lines changed: 282 additions & 146 deletions

.readthedocs.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
version: 2
2+
3+
build:
4+
os: ubuntu-22.04
5+
tools:
6+
python: "3.11"
7+
8+
sphinx:
9+
configuration: docs/source/conf.py
10+
11+
python:
12+
install:
13+
- method: pip
14+
path: .
15+
extra_requirements:
16+
- docs

distributed_resource_optimization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ async def main():
167167
OptimizationFinishedMessage,
168168
StartCoordinatedDistributedOptimization,
169169
)
170+
170171
_MANGO_AVAILABLE = True
171172
except ImportError: # pragma: no cover
172173
pass

distributed_resource_optimization/algorithm/admm/consensus_admm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def primal_residual(self, x: list[np.ndarray], z: list[np.ndarray]) -> float:
8282
# Factories
8383
# ---------------------------------------------------------------------------
8484

85+
8586
def create_consensus_target_reach_admm_coordinator() -> ADMMGenericCoordinator:
8687
"""Create an :class:`ADMMGenericCoordinator` for the consensus variant."""
8788
return ADMMGenericCoordinator(global_actor=ADMMConsensusGlobalActor())

distributed_resource_optimization/algorithm/admm/core.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# Message types
3131
# ---------------------------------------------------------------------------
3232

33+
3334
@dataclass
3435
class ADMMStart:
3536
"""Sent to the coordinator to begin a new ADMM run.
@@ -69,6 +70,7 @@ class ADMMAnswer:
6970
# Abstract global-actor interface
7071
# ---------------------------------------------------------------------------
7172

73+
7274
class ADMMGlobalActor(ABC):
7375
"""Interface for the coordinator-side global update in ADMM variants."""
7476

@@ -136,6 +138,7 @@ def objective(
136138
# Helper: max-norm over list-of-arrays or single array
137139
# ---------------------------------------------------------------------------
138140

141+
139142
def _max_norm(v: Any) -> float:
140143
"""Return ``max ||v_i||`` if *v* is a list, else ``max |v_j|`` for a vector."""
141144
if isinstance(v, list):
@@ -160,6 +163,7 @@ def _deepcopy_z(z: Any) -> Any:
160163
# Generic ADMM coordinator
161164
# ---------------------------------------------------------------------------
162165

166+
163167
class ADMMGenericCoordinator(Coordinator):
164168
"""Standard ADMM iteration loop.
165169
@@ -198,9 +202,7 @@ async def start_optimization(
198202
message_data: ADMMStart,
199203
meta: Any,
200204
) -> list[np.ndarray]:
201-
x, _z, _u = await self._run(
202-
carrier, message_data.data, message_data.solution_length
203-
)
205+
x, _z, _u = await self._run(carrier, message_data.data, message_data.solution_length)
204206
return x
205207

206208
async def _run(
@@ -248,23 +250,16 @@ async def _run(
248250
# 4. Convergence check
249251
r_norm = actor.primal_residual(x, z)
250252
s_norm = rho * _max_diff_norm(z, z_old)
251-
eps_pri = (
252-
np.sqrt(m * n) * self.abs_tol
253-
+ self.rel_tol * max(_max_norm(x), _max_norm(z))
254-
)
255-
eps_dual = (
256-
np.sqrt(m * n) * self.abs_tol
257-
+ self.rel_tol * _max_norm(u)
258-
)
253+
eps_pri = np.sqrt(m * n) * self.abs_tol + self.rel_tol * max(_max_norm(x), _max_norm(z))
254+
eps_dual = np.sqrt(m * n) * self.abs_tol + self.rel_tol * _max_norm(u)
259255

260256
if r_norm < eps_pri and s_norm < eps_dual:
261257
logger.debug("ADMM converged in %d iterations.", k)
262258
break
263259

264260
if k == self.max_iters:
265261
logger.warning(
266-
"ADMM reached max iterations (%d) without full convergence "
267-
"(r=%.4g, s=%.4g).",
262+
"ADMM reached max iterations (%d) without full convergence (r=%.4g, s=%.4g).",
268263
self.max_iters,
269264
r_norm,
270265
s_norm,
@@ -277,6 +272,7 @@ async def _run(
277272
# Factory
278273
# ---------------------------------------------------------------------------
279274

275+
280276
def create_admm_start(data: Any, length: int | None = None) -> ADMMStart:
281277
"""Create an :class:`ADMMStart` message.
282278
@@ -290,6 +286,4 @@ def create_admm_start(data: Any, length: int | None = None) -> ADMMStart:
290286
return ADMMStart(data=data, solution_length=data.solution_length)
291287
if hasattr(data, "target"):
292288
return ADMMStart(data=data, solution_length=len(data.target))
293-
raise ValueError(
294-
"Cannot infer solution_length; pass it explicitly as the second argument."
295-
)
289+
raise ValueError("Cannot infer solution_length; pass it explicitly as the second argument.")

distributed_resource_optimization/algorithm/admm/flex_actor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ def _local_update(actor: ADMMFlexActor, v: np.ndarray, rho: float) -> np.ndarray
7474
x_var = cp.Variable(m)
7575

7676
h = rho * np.asarray(v, dtype=float) + np.asarray(actor.S, dtype=float)
77-
objective = cp.Minimize(
78-
(rho / 2) * cp.sum_squares(x_var) + h @ x_var
79-
)
77+
objective = cp.Minimize((rho / 2) * cp.sum_squares(x_var) + h @ x_var)
8078

8179
constraints = [
8280
x_var >= actor.lb,

distributed_resource_optimization/algorithm/admm/sharing_admm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
# Global objective (currently informational only)
3232
# ---------------------------------------------------------------------------
3333

34+
3435
class ADMMTargetDistanceObjective(ADMMGlobalObjective):
3536
"""Quadratic target-distance objective (informational)."""
3637

@@ -48,6 +49,7 @@ def objective(
4849
# Sharing data
4950
# ---------------------------------------------------------------------------
5051

52+
5153
@dataclass
5254
class ADMMSharingData:
5355
"""Input data for the sharing ADMM variant.
@@ -86,6 +88,7 @@ def create_admm_start(data: ADMMSharingData) -> ADMMStart:
8688
# Sharing global actor
8789
# ---------------------------------------------------------------------------
8890

91+
8992
class ADMMSharingGlobalActor(ADMMGlobalActor):
9093
"""Global actor for the sharing ADMM variant.
9194
@@ -119,16 +122,12 @@ def z_update(
119122
constraints.append(d_var[i] >= lhs)
120123
constraints.append(d_var[i] >= -lhs)
121124

122-
objective = cp.Minimize(
123-
(n * rho / 2) * cp.sum_squares(z_var - u - x_avg) + cp.sum(d_var)
124-
)
125+
objective = cp.Minimize((n * rho / 2) * cp.sum_squares(z_var - u - x_avg) + cp.sum(d_var))
125126
prob = cp.Problem(objective, constraints)
126127
prob.solve(solver=cp.OSQP, verbose=False)
127128

128129
if z_var.value is None:
129-
raise RuntimeError(
130-
f"Sharing ADMM z-update QP did not converge (status={prob.status})."
131-
)
130+
raise RuntimeError(f"Sharing ADMM z-update QP did not converge (status={prob.status}).")
132131
return np.asarray(z_var.value, dtype=float)
133132

134133
def u_update(
@@ -167,6 +166,7 @@ def primal_residual(self, x: list[np.ndarray], z: np.ndarray) -> float:
167166
# Factories
168167
# ---------------------------------------------------------------------------
169168

169+
170170
def create_sharing_target_distance_admm_coordinator() -> ADMMGenericCoordinator:
171171
"""Create an :class:`~.core.ADMMGenericCoordinator` for target-distance sharing."""
172172
return ADMMGenericCoordinator(

distributed_resource_optimization/algorithm/consensus/averaging.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
# ConsensusActor hierarchy
3535
# ---------------------------------------------------------------------------
3636

37+
3738
class ConsensusActor:
3839
"""Optional plug-in that adds a gradient term to the averaging update.
3940
@@ -59,6 +60,7 @@ class NoConsensusActor(ConsensusActor):
5960
# Message types
6061
# ---------------------------------------------------------------------------
6162

63+
6264
@dataclass
6365
class AveragingConsensusMessage(OptimizationMessage):
6466
"""Message exchanged between averaging-consensus participants.
@@ -94,6 +96,7 @@ class ConsensusFinishedMessage:
9496
# AveragingConsensusAlgorithm
9597
# ---------------------------------------------------------------------------
9698

99+
97100
class AveragingConsensusAlgorithm(DistributedAlgorithm):
98101
"""Distributed averaging consensus with an optional gradient correction.
99102
@@ -193,6 +196,7 @@ async def on_exchange_message(
193196
# Factories
194197
# ---------------------------------------------------------------------------
195198

199+
196200
def create_averaging_consensus_participant(
197201
finish_callback: Callable,
198202
consensus_actor: ConsensusActor | None = None,

distributed_resource_optimization/algorithm/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ async def start_optimization(
5757
# Module-level shim functions — preserved for compatibility with carrier code
5858
# ---------------------------------------------------------------------------
5959

60+
6061
async def on_exchange_message(
6162
algorithm: DistributedAlgorithm,
6263
carrier: "Carrier",
@@ -81,6 +82,7 @@ async def start_optimization(
8182
# CoordinatedDistributedAlgorithm
8283
# ---------------------------------------------------------------------------
8384

85+
8486
class CoordinatedDistributedAlgorithm:
8587
"""Bundle of a coordinator and its worker algorithms (informational only)."""
8688

distributed_resource_optimization/algorithm/heuristic/cohda/core.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# Data structures
3434
# ---------------------------------------------------------------------------
3535

36+
3637
@dataclass
3738
class ScheduleSelection:
3839
"""A participant's chosen schedule together with its version counter."""
@@ -43,10 +44,7 @@ class ScheduleSelection:
4344
def __eq__(self, other: object) -> bool:
4445
if not isinstance(other, ScheduleSelection):
4546
return NotImplemented
46-
return (
47-
np.array_equal(self.schedule, other.schedule)
48-
and self.counter == other.counter
49-
)
47+
return np.array_equal(self.schedule, other.schedule) and self.counter == other.counter
5048

5149
def __hash__(self) -> int:
5250
return hash((tuple(float(v) for v in self.schedule), self.counter))
@@ -67,8 +65,7 @@ def __eq__(self, other: object) -> bool:
6765
if set(self.schedule_choices) != set(other.schedule_choices):
6866
return False
6967
return all(
70-
self.schedule_choices[k] == other.schedule_choices[k]
71-
for k in self.schedule_choices
68+
self.schedule_choices[k] == other.schedule_choices[k] for k in self.schedule_choices
7269
)
7370

7471
def __hash__(self) -> int:
@@ -106,9 +103,7 @@ def __eq__(self, other: object) -> bool:
106103
)
107104

108105
def __hash__(self) -> int:
109-
return hash(
110-
(self.participant_id, self.schedules.tobytes(), self.perf, self.present)
111-
)
106+
return hash((self.participant_id, self.schedules.tobytes(), self.perf, self.present))
112107

113108
def __repr__(self) -> str:
114109
return (
@@ -128,9 +123,8 @@ class TargetParams:
128123
def __eq__(self, other: object) -> bool:
129124
if not isinstance(other, TargetParams):
130125
return NotImplemented
131-
return (
132-
np.array_equal(self.schedule, other.schedule)
133-
and np.array_equal(self.weights, other.weights)
126+
return np.array_equal(self.schedule, other.schedule) and np.array_equal(
127+
self.weights, other.weights
134128
)
135129

136130
def __hash__(self) -> int:
@@ -142,10 +136,7 @@ def __hash__(self) -> int:
142136
)
143137

144138
def __repr__(self) -> str:
145-
return (
146-
f"TargetParams(schedule={self.schedule.tolist()}, "
147-
f"weights={self.weights.tolist()})"
148-
)
139+
return f"TargetParams(schedule={self.schedule.tolist()}, weights={self.weights.tolist()})"
149140

150141

151142
@dataclass
@@ -168,6 +159,7 @@ class WorkingMemory(OptimizationMessage):
168159
# Default performance function
169160
# ---------------------------------------------------------------------------
170161

162+
171163
def cohda_default_performance(
172164
cluster_schedule: np.ndarray,
173165
target_params: TargetParams,
@@ -178,7 +170,7 @@ def cohda_default_performance(
178170
:param target_params: Target schedule and weights.
179171
:returns: ``-sum(weights * abs(target - column_sums))``.
180172
"""
181-
sum_cs = cluster_schedule.sum(axis=0) # (n_intervals,)
173+
sum_cs = cluster_schedule.sum(axis=0) # (n_intervals,)
182174
diff = np.abs(target_params.schedule - sum_cs)
183175
return -float(np.sum(diff * target_params.weights))
184176

@@ -187,6 +179,7 @@ def cohda_default_performance(
187179
# Local decider hierarchy
188180
# ---------------------------------------------------------------------------
189181

182+
190183
class LocalDecider:
191184
"""Abstract strategy for selecting a local schedule in the decide step."""
192185

@@ -218,6 +211,7 @@ def initial_schedule(self, memory: WorkingMemory) -> np.ndarray:
218211
# COHDAAlgorithmData
219212
# ---------------------------------------------------------------------------
220213

214+
221215
class COHDAAlgorithmData(DistributedAlgorithm):
222216
"""Per-participant COHDA state machine.
223217
@@ -255,6 +249,7 @@ async def on_exchange_message(
255249
# Core algorithmic functions
256250
# ---------------------------------------------------------------------------
257251

252+
258253
def merge_sysconfigs(
259254
sysconfig_i: SystemConfig,
260255
sysconfig_j: SystemConfig,
@@ -271,8 +266,7 @@ def merge_sysconfigs(
271266
modified = False
272267
for aid in all_ids:
273268
if aid in choices_i and (
274-
aid not in choices_j
275-
or choices_i[aid].counter >= choices_j[aid].counter
269+
aid not in choices_j or choices_i[aid].counter >= choices_j[aid].counter
276270
):
277271
new_choices[aid] = choices_i[aid]
278272
else:
@@ -378,9 +372,7 @@ def perceive(
378372
counter=cohda_data.counter + 1,
379373
)
380374
cohda_data.counter += 1
381-
current_sysconfig = SystemConfig(
382-
dict(own_memory.system_config.schedule_choices)
383-
)
375+
current_sysconfig = SystemConfig(dict(own_memory.system_config.schedule_choices))
384376
else:
385377
current_sysconfig = own_memory.system_config
386378

@@ -456,9 +448,7 @@ def _decide_default(
456448
candidate: SolutionCandidate,
457449
) -> tuple[SystemConfig, SolutionCandidate]:
458450
"""Evaluate all feasible schedules; keep the best-performing candidate."""
459-
possible = [
460-
np.array(s, dtype=float) for s in decider.schedule_provider(cohda_data.memory)
461-
]
451+
possible = [np.array(s, dtype=float) for s in decider.schedule_provider(cohda_data.memory)]
462452
current_best = candidate
463453
if current_best.perf is None:
464454
current_best = _evaluated(
@@ -469,9 +459,7 @@ def _decide_default(
469459

470460
for schedule in possible:
471461
if decider.is_local_acceptable(schedule):
472-
new_cand = create_from_updated_sysconf(
473-
cohda_data.participant_id, sysconfig, schedule
474-
)
462+
new_cand = create_from_updated_sysconf(cohda_data.participant_id, sysconfig, schedule)
475463
new_perf = cohda_data.performance_function(
476464
new_cand.schedules, cohda_data.memory.target_params
477465
)
@@ -521,9 +509,7 @@ async def process_exchange_message(
521509
sysconf, candidate = perceive(algorithm_data, messages)
522510

523511
if sysconf != old_sysconf or candidate != old_candidate:
524-
sysconf, candidate = decide(
525-
algorithm_data, algorithm_data.decider, sysconf, candidate
526-
)
512+
sysconf, candidate = decide(algorithm_data, algorithm_data.decider, sysconf, candidate)
527513
wm = act(algorithm_data, sysconf, candidate)
528514
for other in carrier.others(str(algorithm_data.participant_id)):
529515
carrier.send_to_other(wm, other)
@@ -533,6 +519,7 @@ async def process_exchange_message(
533519
# Factory helpers
534520
# ---------------------------------------------------------------------------
535521

522+
536523
def create_cohda_start_message(
537524
target_schedule: list[float] | np.ndarray,
538525
weights: list[float] | np.ndarray | None = None,

0 commit comments

Comments
 (0)