99from abc import abstractmethod
1010from inspect import isclass
1111from typing import (
12+ TYPE_CHECKING ,
1213 Any ,
1314 Awaitable ,
1415 Callable ,
2223 TypeVar ,
2324)
2425
25- from pydantic import Field , GetCoreSchemaHandler
26+ from pydantic import (
27+ ConfigDict ,
28+ Field ,
29+ GetCoreSchemaHandler ,
30+ ValidationError ,
31+ ValidationInfo ,
32+ ValidatorFunctionWrapHandler ,
33+ )
2634from pydantic_core import CoreSchema
27- from pydantic_core .core_schema import tagged_union_schema
35+ from pydantic_core .core_schema import tagged_union_schema , general_wrap_validator_function
2836
2937from algobattle .program import (
3038 Generator ,
@@ -242,20 +250,63 @@ def __get_pydantic_core_schema__(cls, source: Type, handler: GetCoreSchemaHandle
242250 return handler (source )
243251 except NameError :
244252 return handler (source )
253+
245254 match len (Battle ._battle_types ):
246255 case 0 :
247- return handler (source )
256+ subclass_schema = handler (source )
248257 case 1 :
249- return handler (next (iter (Battle ._battle_types .values ())))
258+ subclass_schema = handler (next (iter (Battle ._battle_types .values ())))
250259 case _:
251- return tagged_union_schema (
260+ subclass_schema = tagged_union_schema (
252261 choices = {
253262 battle .Config .model_fields ["type" ].default : battle .Config .__pydantic_core_schema__
254263 for battle in Battle ._battle_types .values ()
255264 },
256265 discriminator = "type" ,
257266 )
258267
268+ # we want to validate into the actual battle type's config, so we need to treat them as a tagged union
269+ # but if we're initializing a project the type might not be installed yet, so we want to also parse
270+ # into an unspecified dummy object. This wrap validator will efficiently and transparently act as a tagged
271+ # union when ignore_uninstalled is not set. If it is set it catches only the error of a missing tag, other
272+ # errors are passed through
273+ def check_installed (val : object , handler : ValidatorFunctionWrapHandler , info : ValidationInfo ) -> object :
274+ try :
275+ return handler (val )
276+ except ValidationError as e :
277+ union_err = next (filter (lambda err : err ["type" ] == "union_tag_invalid" , e .errors ()), None )
278+ if union_err is None :
279+ raise
280+ if info .context is not None and info .context .get ("ignore_uninstalled" , False ):
281+ if info .config is not None :
282+ settings : dict [str , Any ] = {
283+ "strict" : info .config .get ("strict" , None ),
284+ "from_attributes" : info .config .get ("from_attributes" ),
285+ }
286+ else :
287+ settings = {}
288+ return Battle .FallbackConfig .model_validate (val , context = info .context , ** settings )
289+ else :
290+ passed = union_err ["input" ]["type" ]
291+ installed = ", " .join (b .name () for b in Battle ._battle_types .values ())
292+ raise ValueError (
293+ f"The specified battle type '{ passed } ' is not installed. Installed types are: { installed } "
294+ )
295+
296+ return general_wrap_validator_function (check_installed , subclass_schema )
297+
298+ class FallbackConfig (Config ):
299+ """Fallback config object to parse into if the proper battle typ isn't installed and we're ignoring installs."""
300+
301+ type : str
302+
303+ model_config = ConfigDict (extra = "allow" )
304+
305+ if TYPE_CHECKING :
306+ # to hint that we're gonna fill this with arbitrary data belonging to some supposed battle type
307+ def __getattr__ (self , __attr : str ) -> Any :
308+ ...
309+
259310 class UiData (BaseModel ):
260311 """Object containing custom diplay data.
261312
@@ -280,11 +331,12 @@ def load_entrypoints(cls) -> None:
280331 if not (isclass (battle ) and issubclass (battle , Battle )):
281332 raise ValueError (f"Entrypoint { entrypoint .name } targets something other than a Battle type" )
282333
283- def __init_subclass__ (cls ) -> None :
334+ @classmethod
335+ def __pydantic_init_subclass__ (cls , ** kwargs : Any ) -> None :
284336 if cls .name () not in Battle ._battle_types :
285337 Battle ._battle_types [cls .name ()] = cls
286338 Battle .Config .model_rebuild (force = True )
287- return super ().__init_subclass__ ( )
339+ return super ().__pydantic_init_subclass__ ( ** kwargs )
288340
289341 @abstractmethod
290342 def score (self ) -> float :
@@ -367,10 +419,11 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
367419 base_increment = 0
368420 alive = True
369421 reached = 0
422+ self .results .append (0 )
370423 cap = config .maximum_size
371424 current = min_size
372425 while alive :
373- ui .update_battle_data (self .UiData (reached = self .results + [ reached ] , cap = cap ))
426+ ui .update_battle_data (self .UiData (reached = self .results , cap = cap ))
374427 result = await fight .run (current )
375428 score = result .score
376429 if score < config .minimum_score :
@@ -384,7 +437,7 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
384437 alive = True
385438 elif current > reached and alive :
386439 # We solved an instance of bigger size than before
387- reached = current
440+ self . results [ - 1 ] = reached = current
388441
389442 if current + 1 > cap :
390443 alive = False
@@ -396,7 +449,7 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
396449 # We have failed at this value of n already, reset the step size!
397450 current -= base_increment ** config .exponent - 1
398451 base_increment = 1
399- self .results . append ( reached )
452+ self .results [ - 1 ] = reached
400453
401454 def score (self ) -> float :
402455 """Averages the highest instance size reached in each round."""
@@ -416,7 +469,7 @@ class Config(Battle.Config):
416469
417470 type : Literal ["Averaged" ] = "Averaged"
418471
419- instance_size : int = 10
472+ instance_size : int = 25
420473 """Instance size that will be fought at."""
421474 num_fights : int = 10
422475 """Number of iterations in each round."""
0 commit comments