|
9 | 9 | from abc import abstractmethod |
10 | 10 | from inspect import isclass |
11 | 11 | from typing import ( |
12 | | - TYPE_CHECKING, |
13 | 12 | Any, |
14 | 13 | Awaitable, |
15 | 14 | Callable, |
|
23 | 22 | TypeVar, |
24 | 23 | ) |
25 | 24 |
|
26 | | -from pydantic import Field |
27 | | - |
28 | | -from algobattle.config import BattleConfig, IteratedConfig |
| 25 | +from pydantic import Field, GetCoreSchemaHandler |
| 26 | +from pydantic_core import CoreSchema |
| 27 | +from pydantic_core.core_schema import tagged_union_schema |
29 | 28 |
|
30 | 29 | from algobattle.program import ( |
31 | 30 | Generator, |
@@ -220,10 +219,42 @@ class Battle(BaseModel): |
220 | 219 | _battle_types: ClassVar[dict[str, type[Self]]] = {} |
221 | 220 | """Dictionary mapping the names of all registered battle types to their python classes.""" |
222 | 221 |
|
223 | | - if TYPE_CHECKING: |
224 | | - Config: TypeAlias = BattleConfig |
225 | | - else: |
226 | | - Config: ClassVar[TypeAlias] = BattleConfig |
| 222 | + class Config(BaseModel): |
| 223 | + """Config object for each specific battle type. |
| 224 | +
|
| 225 | + A custom battle type can override this class to specify config options it uses. They will be parsed from a |
| 226 | + dictionary located at `battle` in the main config file. The created object will then be passed to the |
| 227 | + :meth:`Battle.run` method with its fields set accordingly. |
| 228 | + """ |
| 229 | + |
| 230 | + type: str |
| 231 | + """Type of battle that will be used.""" |
| 232 | + |
| 233 | + @classmethod |
| 234 | + def __get_pydantic_core_schema__(cls, source: Type, handler: GetCoreSchemaHandler) -> CoreSchema: |
| 235 | + # there's two bugs we need to catch: |
| 236 | + # 1. this function is called during the pydantic BaseModel metaclass's __new__, so the BattleConfig class |
| 237 | + # won't be ready at that point and be missing in the namespace |
| 238 | + # 2. pydantic uses the core schema to build child classes core schema. for them we want to behave like a |
| 239 | + # normal model, only our own schema gets modified |
| 240 | + try: |
| 241 | + if cls != Battle.Config: |
| 242 | + return handler(source) |
| 243 | + except NameError: |
| 244 | + return handler(source) |
| 245 | + match len(Battle._battle_types): |
| 246 | + case 0: |
| 247 | + return handler(source) |
| 248 | + case 1: |
| 249 | + return handler(next(iter(Battle._battle_types.values()))) |
| 250 | + case _: |
| 251 | + return tagged_union_schema( |
| 252 | + choices={ |
| 253 | + subclass.model_fields["type"].default: subclass.__pydantic_core_schema__ |
| 254 | + for subclass in Battle._battle_types.values() |
| 255 | + }, |
| 256 | + discriminator="type", |
| 257 | + ) |
227 | 258 |
|
228 | 259 | class UiData(BaseModel): |
229 | 260 | """Object containing custom diplay data. |
@@ -252,6 +283,7 @@ def load_entrypoints(cls) -> None: |
252 | 283 | def __init_subclass__(cls) -> None: |
253 | 284 | if cls.name() not in Battle._battle_types: |
254 | 285 | Battle._battle_types[cls.name()] = cls |
| 286 | + Battle.Config.model_rebuild(force=True) |
255 | 287 | return super().__init_subclass__() |
256 | 288 |
|
257 | 289 | @abstractmethod |
@@ -300,10 +332,19 @@ class Iterated(Battle): |
300 | 332 |
|
301 | 333 | results: list[int] = Field(default_factory=list) |
302 | 334 |
|
303 | | - if TYPE_CHECKING: |
304 | | - Config: TypeAlias = IteratedConfig |
305 | | - else: |
306 | | - Config: ClassVar[TypeAlias] = IteratedConfig |
| 335 | + class Config(Battle.Config): |
| 336 | + """Config options for Iterated battles.""" |
| 337 | + |
| 338 | + type: Literal["Iterated"] = "Iterated" |
| 339 | + |
| 340 | + rounds: int = 5 |
| 341 | + """Number of times the instance size will be increased until the solver fails to produce correct solutions.""" |
| 342 | + maximum_size: int = 50_000 |
| 343 | + """Maximum instance size that will be tried.""" |
| 344 | + exponent: int = 2 |
| 345 | + """Determines how quickly the instance size grows.""" |
| 346 | + minimum_score: float = 1 |
| 347 | + """Minimum score that a solver needs to achieve in order to pass.""" |
307 | 348 |
|
308 | 349 | @inherit_docs |
309 | 350 | class UiData(Battle.UiData): |
|
0 commit comments