Skip to content

Commit 82619dd

Browse files
committed
move config to match module
1 parent 2f15296 commit 82619dd

4 files changed

Lines changed: 444 additions & 480 deletions

File tree

algobattle/battle.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from abc import abstractmethod
1010
from inspect import isclass
1111
from typing import (
12-
TYPE_CHECKING,
1312
Any,
1413
Awaitable,
1514
Callable,
@@ -23,9 +22,9 @@
2322
TypeVar,
2423
)
2524

26-
from pydantic import Field
27-
28-
from algobattle.config import BattleConfig, IteratedConfig
25+
from pydantic import Field, GetCoreSchemaHandler
26+
from pydantic_core import CoreSchema
27+
from pydantic_core.core_schema import tagged_union_schema
2928

3029
from algobattle.program import (
3130
Generator,
@@ -220,10 +219,42 @@ class Battle(BaseModel):
220219
_battle_types: ClassVar[dict[str, type[Self]]] = {}
221220
"""Dictionary mapping the names of all registered battle types to their python classes."""
222221

223-
if TYPE_CHECKING:
224-
Config: TypeAlias = BattleConfig
225-
else:
226-
Config: ClassVar[TypeAlias] = BattleConfig
222+
class Config(BaseModel):
223+
"""Config object for each specific battle type.
224+
225+
A custom battle type can override this class to specify config options it uses. They will be parsed from a
226+
dictionary located at `battle` in the main config file. The created object will then be passed to the
227+
:meth:`Battle.run` method with its fields set accordingly.
228+
"""
229+
230+
type: str
231+
"""Type of battle that will be used."""
232+
233+
@classmethod
234+
def __get_pydantic_core_schema__(cls, source: Type, handler: GetCoreSchemaHandler) -> CoreSchema:
235+
# there's two bugs we need to catch:
236+
# 1. this function is called during the pydantic BaseModel metaclass's __new__, so the BattleConfig class
237+
# won't be ready at that point and be missing in the namespace
238+
# 2. pydantic uses the core schema to build child classes core schema. for them we want to behave like a
239+
# normal model, only our own schema gets modified
240+
try:
241+
if cls != Battle.Config:
242+
return handler(source)
243+
except NameError:
244+
return handler(source)
245+
match len(Battle._battle_types):
246+
case 0:
247+
return handler(source)
248+
case 1:
249+
return handler(next(iter(Battle._battle_types.values())))
250+
case _:
251+
return tagged_union_schema(
252+
choices={
253+
subclass.model_fields["type"].default: subclass.__pydantic_core_schema__
254+
for subclass in Battle._battle_types.values()
255+
},
256+
discriminator="type",
257+
)
227258

228259
class UiData(BaseModel):
229260
"""Object containing custom diplay data.
@@ -252,6 +283,7 @@ def load_entrypoints(cls) -> None:
252283
def __init_subclass__(cls) -> None:
253284
if cls.name() not in Battle._battle_types:
254285
Battle._battle_types[cls.name()] = cls
286+
Battle.Config.model_rebuild(force=True)
255287
return super().__init_subclass__()
256288

257289
@abstractmethod
@@ -300,10 +332,19 @@ class Iterated(Battle):
300332

301333
results: list[int] = Field(default_factory=list)
302334

303-
if TYPE_CHECKING:
304-
Config: TypeAlias = IteratedConfig
305-
else:
306-
Config: ClassVar[TypeAlias] = IteratedConfig
335+
class Config(Battle.Config):
336+
"""Config options for Iterated battles."""
337+
338+
type: Literal["Iterated"] = "Iterated"
339+
340+
rounds: int = 5
341+
"""Number of times the instance size will be increased until the solver fails to produce correct solutions."""
342+
maximum_size: int = 50_000
343+
"""Maximum instance size that will be tried."""
344+
exponent: int = 2
345+
"""Determines how quickly the instance size grows."""
346+
minimum_score: float = 1
347+
"""Minimum score that a solver needs to achieve in order to pass."""
307348

308349
@inherit_docs
309350
class UiData(Battle.UiData):

0 commit comments

Comments
 (0)