44some basic battle types, and related classed.
55"""
66from dataclasses import dataclass
7+ from functools import wraps
78from importlib .metadata import entry_points
89from abc import abstractmethod
10+ from inspect import isclass
911from typing import (
1012 Any ,
13+ Awaitable ,
14+ Callable ,
1115 ClassVar ,
12- Generic ,
16+ Concatenate ,
17+ Hashable ,
18+ Literal ,
19+ ParamSpec ,
1320 Protocol ,
21+ Self ,
1422 TypeAlias ,
1523 TypeVar ,
1624)
1725
18- from pydantic import Field
26+ from pydantic import Field , GetCoreSchemaHandler
27+ from pydantic .main import BaseModel as PydanticBase
28+ from pydantic_core import CoreSchema
29+ from pydantic_core .core_schema import tagged_union_schema
1930
2031from algobattle .docker_util import (
2132 Generator ,
2233 ProgramRunInfo ,
23- ProgramUiProxy ,
34+ ProgramUi ,
2435 Solver ,
2536)
26- from algobattle .problem import InstanceT , Problem , SolutionT
27- from algobattle .util import Encodable , Role , inherit_docs , BaseModel
37+ from algobattle .problem import AnyProblem
38+ from algobattle .util import Encodable , inherit_docs , BaseModel
2839
2940
3041_BattleConfig : TypeAlias = Any
3647the new battle type directly.
3748"""
3849T = TypeVar ("T" )
50+ P = ParamSpec ("P" )
51+ RunFight : TypeAlias = "Callable[Concatenate[FightHandler, P], Awaitable[Fight]]"
52+ Type = type
3953
4054
4155class Fight (BaseModel ):
@@ -57,41 +71,41 @@ class Fight(BaseModel):
5771 """Data about the solver's execution."""
5872
5973
60- class FightUiProxy ( Protocol ):
74+ class FightUi ( ProgramUi , Protocol ):
6175 """Provides an interface for :class:`Fight` to update the ui."""
6276
63- generator : ProgramUiProxy
64- solver : ProgramUiProxy
65-
6677 @abstractmethod
67- def start (self , max_size : int ) -> None :
78+ def start_fight (self , max_size : int ) -> None :
6879 """Informs the ui that a new fight has been started."""
6980
7081 @abstractmethod
71- def update (self , role : Role , data : ProgramRunInfo ) -> None :
72- """Updates the ui's current fight section with new data about a program."""
73-
74- @abstractmethod
75- def end (self ) -> None :
82+ def end_fight (self ) -> None :
7683 """Informs the ui that the fight has finished running and has been added to the battle's `.fight_results`."""
7784
7885
86+ def _save_result (func : "RunFight[P]" ) -> "RunFight[P]" :
87+ @wraps (func )
88+ async def inner (self : "FightHandler" , * args : P .args , ** kwargs : P .kwargs ) -> Fight :
89+ res = await func (self , * args , ** kwargs )
90+ self .battle .fights .append (res )
91+ self .ui .end_fight ()
92+ return res
93+
94+ return inner
95+
96+
7997@dataclass
80- class FightHandler ( Generic [ InstanceT , SolutionT ]) :
98+ class FightHandler :
8199 """Helper class to run fights of a given battle."""
82100
83- _problem : Problem [InstanceT , SolutionT ]
84- _generator : Generator [InstanceT , SolutionT ]
85- _solver : Solver [InstanceT , SolutionT ]
86- _battle : "Battle"
87- _ui : FightUiProxy
88- _set_cpus : str | None = None
89-
90- def _saved (self , fight : Fight ) -> Fight :
91- self ._battle .fight_results .append (fight )
92- self ._ui .end ()
93- return fight
101+ problem : AnyProblem
102+ generator : Generator
103+ solver : Solver
104+ battle : "Battle"
105+ ui : FightUi
106+ set_cpus : str | None
94107
108+ @_save_result
95109 async def run (
96110 self ,
97111 max_size : int ,
@@ -134,62 +148,58 @@ async def run(
134148 Returns:
135149 The resulting info about the executed fight.
136150 """
137- min_size = self ._problem .min_size
151+ min_size = self .problem .min_size
138152 if max_size < min_size :
139153 raise ValueError (
140154 f"Cannot run battle at size { max_size } since it is smaller than the smallest "
141- "size the problem allows ({min_size})."
155+ f "size the problem allows ({ min_size } )."
142156 )
143- ui = self ._ui
144- ui .start (max_size )
145- gen_result = await self ._generator .run (
157+ ui = self .ui
158+ ui .start_fight (max_size )
159+ gen_result = await self .generator .run (
146160 max_size = max_size ,
147161 timeout = timeout_generator ,
148162 space = space_generator ,
149163 cpus = cpus_generator ,
150164 battle_input = generator_battle_input ,
151165 battle_output = generator_battle_output ,
152- set_cpus = self ._set_cpus ,
153- ui = ui . generator ,
166+ set_cpus = self .set_cpus ,
167+ ui = ui ,
154168 )
155- ui .update (Role .generator , gen_result .info )
156169 if gen_result .instance is None :
157- return self . _saved ( Fight (score = 1 , max_size = max_size , generator = gen_result .info , solver = None ) )
170+ return Fight (score = 1 , max_size = max_size , generator = gen_result .info , solver = None )
158171
159- sol_result = await self ._solver .run (
172+ sol_result = await self .solver .run (
160173 gen_result .instance ,
161174 max_size = max_size ,
162175 timeout = timeout_solver ,
163176 space = space_solver ,
164177 cpus = cpus_solver ,
165178 battle_input = solver_battle_input ,
166179 battle_output = solver_battle_output ,
167- set_cpus = self ._set_cpus ,
168- ui = ui . solver ,
180+ set_cpus = self .set_cpus ,
181+ ui = ui ,
169182 )
170- ui .update (Role .solver , sol_result .info )
171183 if sol_result .solution is None :
172- return self . _saved ( Fight (score = 0 , max_size = max_size , generator = gen_result .info , solver = sol_result .info ) )
184+ return Fight (score = 0 , max_size = max_size , generator = gen_result .info , solver = sol_result .info )
173185
174- if self ._problem .with_solution :
186+ if self .problem .with_solution :
175187 assert gen_result .solution is not None
176- score = self ._problem .score (
188+ score = self .problem .score (
177189 gen_result .instance , solver_solution = sol_result .solution , generator_solution = gen_result .solution
178190 )
179191 else :
180- score = self ._problem .score (gen_result .instance , solution = sol_result .solution )
192+ score = self .problem .score (gen_result .instance , solution = sol_result .solution )
181193 score = max (0 , min (1 , float (score )))
182- return self . _saved ( Fight (score = score , max_size = max_size , generator = gen_result .info , solver = sol_result .info ) )
194+ return Fight (score = score , max_size = max_size , generator = gen_result .info , solver = sol_result .info )
183195
184196
185197# We need this to be here to prevent an import cycle between match.py and battle.py
186- class BattleUiProxy (Protocol ):
198+ class BattleUi (Protocol ):
187199 """Provides an interface for :class:`Battle` to update the Ui."""
188200
189- fight_ui : FightUiProxy
190-
191201 @abstractmethod
192- def update_data (self , data : "Battle.UiData" ) -> None :
202+ def update_battle_data (self , data : "Battle.UiData" ) -> None :
193203 """Passes new custom display data to the Ui.
194204
195205 See :class:`Battle.UiData` for further details.
@@ -203,22 +213,49 @@ class Battle(BaseModel):
203213 they will ultimately be scored.
204214 """
205215
206- fight_results : list [Fight ] = Field (default_factory = list )
216+ fights : list [Fight ] = Field (default_factory = list )
207217 """The list of fights that have been fought in this battle."""
208218 run_exception : str | None = None
209219 """The description of an otherwise unhandeled exception that occured during the execution of :meth:`Battle.run`."""
210220
211- _battle_types : ClassVar [dict [str , type ["Battle" ]]] = {}
221+ _battle_types : ClassVar [dict [str , type [Self ]]] = {}
212222 """Dictionary mapping the names of all registered battle types to their python classes."""
213223
214- class BattleConfig (BaseModel ):
224+ class Config (BaseModel ):
215225 """Config object for each specific battle type.
216226
217227 A custom battle type can override this class to specify config options it uses. They will be parsed from a
218- dictionary located at `battle.NAME ` in the main config file, where NAME is the specific batle type's name.
219- The created object will then be passed to the :meth:`Battle.run` method with its fields set accordingly.
228+ dictionary located at `battle` in the main config file. The created object will then be passed to the
229+ :meth:`Battle.run` method with its fields set accordingly.
220230 """
221231
232+ type : str = "Iterated"
233+ """Type of battle that will be used."""
234+
235+ @classmethod
236+ def __get_pydantic_core_schema__ (cls , source : Type [PydanticBase ], handler : GetCoreSchemaHandler ) -> CoreSchema :
237+ # there's two bugs we need to catch:
238+ # 1. this function is called during the pydantic BaseModel metaclass's __new__, so the Battle class
239+ # won't be ready at that point and be missing in the namespace
240+ # 2. pydantic uses the core schema to build child classes core schema. for them we want to behave like a
241+ # normal model, only our own schema gets modified
242+ try :
243+ if cls != Battle .Config :
244+ return handler (source )
245+ except NameError :
246+ return handler (source )
247+ battle_classes = Battle .all ()
248+ match len (battle_classes ):
249+ case 0 :
250+ return handler (source )
251+ case 1 :
252+ return handler (next (iter (battle_classes .values ())).Config )
253+ case _:
254+ choices : dict [Hashable , CoreSchema ] = {
255+ name : handler (sublass .Config ) for name , sublass in battle_classes .items ()
256+ }
257+ return tagged_union_schema (choices = choices , discriminator = "type" )
258+
222259 class UiData (BaseModel ):
223260 """Object containing custom diplay data.
224261
@@ -233,15 +270,20 @@ def all() -> dict[str, type["Battle"]]:
233270 It includes all subclasses of :class:`Battle` that have been initialized so far, including ones exposed to the
234271 algobattle module via the `algobattle.battle` entrypoint hook.
235272 """
236- for entrypoint in entry_points (group = "algobattle.battle" ):
237- if entrypoint .name not in Battle ._battle_types :
238- battle : type [Battle ] = entrypoint .load ()
239- Battle ._battle_types [battle .name ()] = battle
240273 return Battle ._battle_types
241274
275+ @classmethod
276+ def load_entrypoints (cls ) -> None :
277+ """Loads all battle types presented via entrypoints."""
278+ for entrypoint in entry_points (group = "algobattle.battle" ):
279+ battle = entrypoint .load ()
280+ if not (isclass (battle ) and issubclass (battle , Battle )):
281+ raise ValueError (f"Entrypoint { entrypoint .name } targets something other than a Battle type" )
282+
242283 def __init_subclass__ (cls ) -> None :
243284 if cls .name () not in Battle ._battle_types :
244285 Battle ._battle_types [cls .name ()] = cls
286+ Battle .Config .model_rebuild (force = True )
245287 return super ().__init_subclass__ ()
246288
247289 @abstractmethod
@@ -270,9 +312,7 @@ def name(cls) -> str:
270312 return cls .__name__
271313
272314 @abstractmethod
273- async def run_battle (
274- self , fight : FightHandler [InstanceT , SolutionT ], config : _BattleConfig , min_size : int , ui : BattleUiProxy
275- ) -> None :
315+ async def run_battle (self , fight : FightHandler , config : _BattleConfig , min_size : int , ui : BattleUi ) -> None :
276316 """Executes one battle.
277317
278318 Args:
@@ -292,8 +332,11 @@ class Iterated(Battle):
292332
293333 results : list [int ] = Field (default_factory = list )
294334
295- @inherit_docs
296- class BattleConfig (Battle .BattleConfig ):
335+ class Config (Battle .Config ):
336+ """Config options for Iterated battles."""
337+
338+ type : Literal ["Iterated" ] = "Iterated"
339+
297340 rounds : int = 5
298341 """Number of times the instance size will be increased until the solver fails to produce correct solutions."""
299342 maximum_size : int = 50_000
@@ -308,9 +351,7 @@ class UiData(Battle.UiData):
308351 reached : list [int ]
309352 cap : int
310353
311- async def run_battle (
312- self , fight : FightHandler [InstanceT , SolutionT ], config : BattleConfig , min_size : int , ui : BattleUiProxy
313- ) -> None :
354+ async def run_battle (self , fight : FightHandler , config : Config , min_size : int , ui : BattleUi ) -> None :
314355 """Execute an iterated battle.
315356
316357 Incrementally tries to search for the highest n for which the solver is still able to solve instances.
@@ -329,7 +370,7 @@ async def run_battle(
329370 cap = config .maximum_size
330371 current = min_size
331372 while alive :
332- ui .update_data (self .UiData (reached = self .results + [reached ], cap = cap ))
373+ ui .update_battle_data (self .UiData (reached = self .results + [reached ], cap = cap ))
333374 result = await fight .run (current )
334375 score = result .score
335376 if score < config .minimum_score :
@@ -357,7 +398,6 @@ async def run_battle(
357398 base_increment = 1
358399 self .results .append (reached )
359400
360- @inherit_docs
361401 def score (self ) -> float :
362402 """Averages the highest instance size reached in each round."""
363403 return 0 if len (self .results ) == 0 else sum (self .results ) / len (self .results )
@@ -371,8 +411,11 @@ def format_score(score: float) -> str:
371411class Averaged (Battle ):
372412 """Class that executes an averaged battle."""
373413
374- @inherit_docs
375- class BattleConfig (Battle .BattleConfig ):
414+ class Config (Battle .Config ):
415+ """Config options for Averaged battles."""
416+
417+ type : Literal ["Averaged" ] = "Averaged"
418+
376419 instance_size : int = 10
377420 """Instance size that will be fought at."""
378421 num_fights : int = 10
@@ -382,26 +425,23 @@ class BattleConfig(Battle.BattleConfig):
382425 class UiData (Battle .UiData ):
383426 round : int
384427
385- async def run_battle (
386- self , fight : FightHandler [InstanceT , SolutionT ], config : BattleConfig , min_size : int , ui : BattleUiProxy
387- ) -> None :
428+ async def run_battle (self , fight : FightHandler , config : Config , min_size : int , ui : BattleUi ) -> None :
388429 """Execute an averaged battle.
389430
390431 This simple battle type just executes `iterations` many fights after each other at size `instance_size`.
391432 """
392433 if config .instance_size < min_size :
393434 raise ValueError (f"size { config .instance_size } is smaller than the smallest valid size, { min_size } ." )
394435 for i in range (config .num_fights ):
395- ui .update_data (self .UiData (round = i + 1 ))
436+ ui .update_battle_data (self .UiData (round = i + 1 ))
396437 await fight .run (config .instance_size )
397438
398- @inherit_docs
399439 def score (self ) -> float :
400440 """Averages the score of each fight."""
401- if len (self .fight_results ) == 0 :
441+ if len (self .fights ) == 0 :
402442 return 0
403443 else :
404- return sum (f .score for f in self .fight_results ) / len (self .fight_results )
444+ return sum (f .score for f in self .fights ) / len (self .fights )
405445
406446 @inherit_docs
407447 @staticmethod
0 commit comments