Skip to content

Commit 6c5fb4b

Browse files
authored
Merge pull request #120 from ImogenBits/better_config
Better config structure
2 parents 0c8ad18 + 0db4bae commit 6c5fb4b

15 files changed

Lines changed: 1068 additions & 863 deletions

File tree

algobattle/battle.py

Lines changed: 117 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,38 @@
44
some basic battle types, and related classed.
55
"""
66
from dataclasses import dataclass
7+
from functools import wraps
78
from importlib.metadata import entry_points
89
from abc import abstractmethod
10+
from inspect import isclass
911
from typing import (
1012
Any,
13+
Awaitable,
14+
Callable,
1115
ClassVar,
12-
Generic,
16+
Concatenate,
17+
Hashable,
18+
Literal,
19+
ParamSpec,
1320
Protocol,
21+
Self,
1422
TypeAlias,
1523
TypeVar,
1624
)
1725

18-
from pydantic import Field
26+
from pydantic import Field, GetCoreSchemaHandler
27+
from pydantic.main import BaseModel as PydanticBase
28+
from pydantic_core import CoreSchema
29+
from pydantic_core.core_schema import tagged_union_schema
1930

2031
from algobattle.docker_util import (
2132
Generator,
2233
ProgramRunInfo,
23-
ProgramUiProxy,
34+
ProgramUi,
2435
Solver,
2536
)
26-
from algobattle.problem import InstanceT, Problem, SolutionT
27-
from algobattle.util import Encodable, Role, inherit_docs, BaseModel
37+
from algobattle.problem import AnyProblem
38+
from algobattle.util import Encodable, inherit_docs, BaseModel
2839

2940

3041
_BattleConfig: TypeAlias = Any
@@ -36,6 +47,9 @@
3647
the new battle type directly.
3748
"""
3849
T = TypeVar("T")
50+
P = ParamSpec("P")
51+
RunFight: TypeAlias = "Callable[Concatenate[FightHandler, P], Awaitable[Fight]]"
52+
Type = type
3953

4054

4155
class Fight(BaseModel):
@@ -57,41 +71,41 @@ class Fight(BaseModel):
5771
"""Data about the solver's execution."""
5872

5973

60-
class FightUiProxy(Protocol):
74+
class FightUi(ProgramUi, Protocol):
6175
"""Provides an interface for :class:`Fight` to update the ui."""
6276

63-
generator: ProgramUiProxy
64-
solver: ProgramUiProxy
65-
6677
@abstractmethod
67-
def start(self, max_size: int) -> None:
78+
def start_fight(self, max_size: int) -> None:
6879
"""Informs the ui that a new fight has been started."""
6980

7081
@abstractmethod
71-
def update(self, role: Role, data: ProgramRunInfo) -> None:
72-
"""Updates the ui's current fight section with new data about a program."""
73-
74-
@abstractmethod
75-
def end(self) -> None:
82+
def end_fight(self) -> None:
7683
"""Informs the ui that the fight has finished running and has been added to the battle's `.fight_results`."""
7784

7885

86+
def _save_result(func: "RunFight[P]") -> "RunFight[P]":
87+
@wraps(func)
88+
async def inner(self: "FightHandler", *args: P.args, **kwargs: P.kwargs) -> Fight:
89+
res = await func(self, *args, **kwargs)
90+
self.battle.fights.append(res)
91+
self.ui.end_fight()
92+
return res
93+
94+
return inner
95+
96+
7997
@dataclass
80-
class FightHandler(Generic[InstanceT, SolutionT]):
98+
class FightHandler:
8199
"""Helper class to run fights of a given battle."""
82100

83-
_problem: Problem[InstanceT, SolutionT]
84-
_generator: Generator[InstanceT, SolutionT]
85-
_solver: Solver[InstanceT, SolutionT]
86-
_battle: "Battle"
87-
_ui: FightUiProxy
88-
_set_cpus: str | None = None
89-
90-
def _saved(self, fight: Fight) -> Fight:
91-
self._battle.fight_results.append(fight)
92-
self._ui.end()
93-
return fight
101+
problem: AnyProblem
102+
generator: Generator
103+
solver: Solver
104+
battle: "Battle"
105+
ui: FightUi
106+
set_cpus: str | None
94107

108+
@_save_result
95109
async def run(
96110
self,
97111
max_size: int,
@@ -134,62 +148,58 @@ async def run(
134148
Returns:
135149
The resulting info about the executed fight.
136150
"""
137-
min_size = self._problem.min_size
151+
min_size = self.problem.min_size
138152
if max_size < min_size:
139153
raise ValueError(
140154
f"Cannot run battle at size {max_size} since it is smaller than the smallest "
141-
"size the problem allows ({min_size})."
155+
f"size the problem allows ({min_size})."
142156
)
143-
ui = self._ui
144-
ui.start(max_size)
145-
gen_result = await self._generator.run(
157+
ui = self.ui
158+
ui.start_fight(max_size)
159+
gen_result = await self.generator.run(
146160
max_size=max_size,
147161
timeout=timeout_generator,
148162
space=space_generator,
149163
cpus=cpus_generator,
150164
battle_input=generator_battle_input,
151165
battle_output=generator_battle_output,
152-
set_cpus=self._set_cpus,
153-
ui=ui.generator,
166+
set_cpus=self.set_cpus,
167+
ui=ui,
154168
)
155-
ui.update(Role.generator, gen_result.info)
156169
if gen_result.instance is None:
157-
return self._saved(Fight(score=1, max_size=max_size, generator=gen_result.info, solver=None))
170+
return Fight(score=1, max_size=max_size, generator=gen_result.info, solver=None)
158171

159-
sol_result = await self._solver.run(
172+
sol_result = await self.solver.run(
160173
gen_result.instance,
161174
max_size=max_size,
162175
timeout=timeout_solver,
163176
space=space_solver,
164177
cpus=cpus_solver,
165178
battle_input=solver_battle_input,
166179
battle_output=solver_battle_output,
167-
set_cpus=self._set_cpus,
168-
ui=ui.solver,
180+
set_cpus=self.set_cpus,
181+
ui=ui,
169182
)
170-
ui.update(Role.solver, sol_result.info)
171183
if sol_result.solution is None:
172-
return self._saved(Fight(score=0, max_size=max_size, generator=gen_result.info, solver=sol_result.info))
184+
return Fight(score=0, max_size=max_size, generator=gen_result.info, solver=sol_result.info)
173185

174-
if self._problem.with_solution:
186+
if self.problem.with_solution:
175187
assert gen_result.solution is not None
176-
score = self._problem.score(
188+
score = self.problem.score(
177189
gen_result.instance, solver_solution=sol_result.solution, generator_solution=gen_result.solution
178190
)
179191
else:
180-
score = self._problem.score(gen_result.instance, solution=sol_result.solution)
192+
score = self.problem.score(gen_result.instance, solution=sol_result.solution)
181193
score = max(0, min(1, float(score)))
182-
return self._saved(Fight(score=score, max_size=max_size, generator=gen_result.info, solver=sol_result.info))
194+
return Fight(score=score, max_size=max_size, generator=gen_result.info, solver=sol_result.info)
183195

184196

185197
# We need this to be here to prevent an import cycle between match.py and battle.py
186-
class BattleUiProxy(Protocol):
198+
class BattleUi(Protocol):
187199
"""Provides an interface for :class:`Battle` to update the Ui."""
188200

189-
fight_ui: FightUiProxy
190-
191201
@abstractmethod
192-
def update_data(self, data: "Battle.UiData") -> None:
202+
def update_battle_data(self, data: "Battle.UiData") -> None:
193203
"""Passes new custom display data to the Ui.
194204
195205
See :class:`Battle.UiData` for further details.
@@ -203,22 +213,49 @@ class Battle(BaseModel):
203213
they will ultimately be scored.
204214
"""
205215

206-
fight_results: list[Fight] = Field(default_factory=list)
216+
fights: list[Fight] = Field(default_factory=list)
207217
"""The list of fights that have been fought in this battle."""
208218
run_exception: str | None = None
209219
"""The description of an otherwise unhandeled exception that occured during the execution of :meth:`Battle.run`."""
210220

211-
_battle_types: ClassVar[dict[str, type["Battle"]]] = {}
221+
_battle_types: ClassVar[dict[str, type[Self]]] = {}
212222
"""Dictionary mapping the names of all registered battle types to their python classes."""
213223

214-
class BattleConfig(BaseModel):
224+
class Config(BaseModel):
215225
"""Config object for each specific battle type.
216226
217227
A custom battle type can override this class to specify config options it uses. They will be parsed from a
218-
dictionary located at `battle.NAME` in the main config file, where NAME is the specific batle type's name.
219-
The created object will then be passed to the :meth:`Battle.run` method with its fields set accordingly.
228+
dictionary located at `battle` in the main config file. The created object will then be passed to the
229+
:meth:`Battle.run` method with its fields set accordingly.
220230
"""
221231

232+
type: str = "Iterated"
233+
"""Type of battle that will be used."""
234+
235+
@classmethod
236+
def __get_pydantic_core_schema__(cls, source: Type[PydanticBase], handler: GetCoreSchemaHandler) -> CoreSchema:
237+
# there's two bugs we need to catch:
238+
# 1. this function is called during the pydantic BaseModel metaclass's __new__, so the Battle class
239+
# won't be ready at that point and be missing in the namespace
240+
# 2. pydantic uses the core schema to build child classes core schema. for them we want to behave like a
241+
# normal model, only our own schema gets modified
242+
try:
243+
if cls != Battle.Config:
244+
return handler(source)
245+
except NameError:
246+
return handler(source)
247+
battle_classes = Battle.all()
248+
match len(battle_classes):
249+
case 0:
250+
return handler(source)
251+
case 1:
252+
return handler(next(iter(battle_classes.values())).Config)
253+
case _:
254+
choices: dict[Hashable, CoreSchema] = {
255+
name: handler(sublass.Config) for name, sublass in battle_classes.items()
256+
}
257+
return tagged_union_schema(choices=choices, discriminator="type")
258+
222259
class UiData(BaseModel):
223260
"""Object containing custom diplay data.
224261
@@ -233,15 +270,20 @@ def all() -> dict[str, type["Battle"]]:
233270
It includes all subclasses of :class:`Battle` that have been initialized so far, including ones exposed to the
234271
algobattle module via the `algobattle.battle` entrypoint hook.
235272
"""
236-
for entrypoint in entry_points(group="algobattle.battle"):
237-
if entrypoint.name not in Battle._battle_types:
238-
battle: type[Battle] = entrypoint.load()
239-
Battle._battle_types[battle.name()] = battle
240273
return Battle._battle_types
241274

275+
@classmethod
276+
def load_entrypoints(cls) -> None:
277+
"""Loads all battle types presented via entrypoints."""
278+
for entrypoint in entry_points(group="algobattle.battle"):
279+
battle = entrypoint.load()
280+
if not (isclass(battle) and issubclass(battle, Battle)):
281+
raise ValueError(f"Entrypoint {entrypoint.name} targets something other than a Battle type")
282+
242283
def __init_subclass__(cls) -> None:
243284
if cls.name() not in Battle._battle_types:
244285
Battle._battle_types[cls.name()] = cls
286+
Battle.Config.model_rebuild(force=True)
245287
return super().__init_subclass__()
246288

247289
@abstractmethod
@@ -270,9 +312,7 @@ def name(cls) -> str:
270312
return cls.__name__
271313

272314
@abstractmethod
273-
async def run_battle(
274-
self, fight: FightHandler[InstanceT, SolutionT], config: _BattleConfig, min_size: int, ui: BattleUiProxy
275-
) -> None:
315+
async def run_battle(self, fight: FightHandler, config: _BattleConfig, min_size: int, ui: BattleUi) -> None:
276316
"""Executes one battle.
277317
278318
Args:
@@ -292,8 +332,11 @@ class Iterated(Battle):
292332

293333
results: list[int] = Field(default_factory=list)
294334

295-
@inherit_docs
296-
class BattleConfig(Battle.BattleConfig):
335+
class Config(Battle.Config):
336+
"""Config options for Iterated battles."""
337+
338+
type: Literal["Iterated"] = "Iterated"
339+
297340
rounds: int = 5
298341
"""Number of times the instance size will be increased until the solver fails to produce correct solutions."""
299342
maximum_size: int = 50_000
@@ -308,9 +351,7 @@ class UiData(Battle.UiData):
308351
reached: list[int]
309352
cap: int
310353

311-
async def run_battle(
312-
self, fight: FightHandler[InstanceT, SolutionT], config: BattleConfig, min_size: int, ui: BattleUiProxy
313-
) -> None:
354+
async def run_battle(self, fight: FightHandler, config: Config, min_size: int, ui: BattleUi) -> None:
314355
"""Execute an iterated battle.
315356
316357
Incrementally tries to search for the highest n for which the solver is still able to solve instances.
@@ -329,7 +370,7 @@ async def run_battle(
329370
cap = config.maximum_size
330371
current = min_size
331372
while alive:
332-
ui.update_data(self.UiData(reached=self.results + [reached], cap=cap))
373+
ui.update_battle_data(self.UiData(reached=self.results + [reached], cap=cap))
333374
result = await fight.run(current)
334375
score = result.score
335376
if score < config.minimum_score:
@@ -357,7 +398,6 @@ async def run_battle(
357398
base_increment = 1
358399
self.results.append(reached)
359400

360-
@inherit_docs
361401
def score(self) -> float:
362402
"""Averages the highest instance size reached in each round."""
363403
return 0 if len(self.results) == 0 else sum(self.results) / len(self.results)
@@ -371,8 +411,11 @@ def format_score(score: float) -> str:
371411
class Averaged(Battle):
372412
"""Class that executes an averaged battle."""
373413

374-
@inherit_docs
375-
class BattleConfig(Battle.BattleConfig):
414+
class Config(Battle.Config):
415+
"""Config options for Averaged battles."""
416+
417+
type: Literal["Averaged"] = "Averaged"
418+
376419
instance_size: int = 10
377420
"""Instance size that will be fought at."""
378421
num_fights: int = 10
@@ -382,26 +425,23 @@ class BattleConfig(Battle.BattleConfig):
382425
class UiData(Battle.UiData):
383426
round: int
384427

385-
async def run_battle(
386-
self, fight: FightHandler[InstanceT, SolutionT], config: BattleConfig, min_size: int, ui: BattleUiProxy
387-
) -> None:
428+
async def run_battle(self, fight: FightHandler, config: Config, min_size: int, ui: BattleUi) -> None:
388429
"""Execute an averaged battle.
389430
390431
This simple battle type just executes `iterations` many fights after each other at size `instance_size`.
391432
"""
392433
if config.instance_size < min_size:
393434
raise ValueError(f"size {config.instance_size} is smaller than the smallest valid size, {min_size}.")
394435
for i in range(config.num_fights):
395-
ui.update_data(self.UiData(round=i + 1))
436+
ui.update_battle_data(self.UiData(round=i + 1))
396437
await fight.run(config.instance_size)
397438

398-
@inherit_docs
399439
def score(self) -> float:
400440
"""Averages the score of each fight."""
401-
if len(self.fight_results) == 0:
441+
if len(self.fights) == 0:
402442
return 0
403443
else:
404-
return sum(f.score for f in self.fight_results) / len(self.fight_results)
444+
return sum(f.score for f in self.fights) / len(self.fights)
405445

406446
@inherit_docs
407447
@staticmethod

0 commit comments

Comments
 (0)