Skip to content

Commit b8f2ade

Browse files
authored
Generate CoreModel dynamically when using custom configs (#3083)
* Introduce generate_dual_core_model_with_config * Make schema_extra signature consistent * Use generate_dual_core_model for DecryptedString * Expand generate_dual_core_model note
1 parent 8a05d0a commit b8f2ade

File tree

11 files changed

+398
-292
lines changed

11 files changed

+398
-292
lines changed

src/dstack/_internal/core/models/common.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import re
22
from enum import Enum
3-
from typing import Any, Callable, Optional, Union
3+
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, Union
44

55
import orjson
66
from pydantic import Field
7-
from pydantic_duality import DualBaseModel
7+
from pydantic_duality import generate_dual_base_model
88
from typing_extensions import Annotated
99

1010
from dstack._internal.utils.json_utils import pydantic_orjson_dumps
@@ -17,46 +17,73 @@
1717
IncludeExcludeType = Union[IncludeExcludeSetType, IncludeExcludeDictType]
1818

1919

20+
class CoreConfig:
21+
json_loads = orjson.loads
22+
json_dumps = pydantic_orjson_dumps
23+
24+
25+
# All dstack models inherit from pydantic-duality's DualBaseModel.
2026
# DualBaseModel creates two classes for the model:
2127
# one with extra = "forbid" (CoreModel/CoreModel.__request__),
2228
# and another with extra = "ignore" (CoreModel.__response__).
23-
# This allows to use the same model both for a strict parsing of the user input and
24-
# for a permissive parsing of the server responses.
25-
class CoreModel(DualBaseModel):
26-
class Config:
27-
json_loads = orjson.loads
28-
json_dumps = pydantic_orjson_dumps
29-
30-
def json(
31-
self,
32-
*,
33-
include: Optional[IncludeExcludeType] = None,
34-
exclude: Optional[IncludeExcludeType] = None,
35-
by_alias: bool = False,
36-
skip_defaults: Optional[bool] = None, # ignore as it's deprecated
37-
exclude_unset: bool = False,
38-
exclude_defaults: bool = False,
39-
exclude_none: bool = False,
40-
encoder: Optional[Callable[[Any], Any]] = None,
41-
models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies
42-
**dumps_kwargs: Any,
43-
) -> str:
44-
"""
45-
Override `json()` method so that it calls `dict()`.
46-
Allows changing how models are serialized by overriding `dict()` only.
47-
By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place.
48-
"""
49-
data = self.dict(
50-
by_alias=by_alias,
51-
include=include,
52-
exclude=exclude,
53-
exclude_unset=exclude_unset,
54-
exclude_defaults=exclude_defaults,
55-
exclude_none=exclude_none,
56-
)
57-
if self.__custom_root_type__:
58-
data = data["__root__"]
59-
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)
29+
# This allows to use the same model both for strict parsing of the user input and
30+
# for permissive parsing of the server responses.
31+
#
32+
# We define a func to generate CoreModel dynamically that can be used
33+
# to define custom Config for both __request__ and __response__ models.
34+
# Note: Defining config in the model class directly overrides
35+
# pydantic-duality's base config, breaking __response__.
36+
def generate_dual_core_model(
37+
custom_config: Union[type, Mapping],
38+
) -> "type[CoreModel]":
39+
class CoreModel(generate_dual_base_model(custom_config)):
40+
def json(
41+
self,
42+
*,
43+
include: Optional[IncludeExcludeType] = None,
44+
exclude: Optional[IncludeExcludeType] = None,
45+
by_alias: bool = False,
46+
skip_defaults: Optional[bool] = None, # ignore as it's deprecated
47+
exclude_unset: bool = False,
48+
exclude_defaults: bool = False,
49+
exclude_none: bool = False,
50+
encoder: Optional[Callable[[Any], Any]] = None,
51+
models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies
52+
**dumps_kwargs: Any,
53+
) -> str:
54+
"""
55+
Override `json()` method so that it calls `dict()`.
56+
Allows changing how models are serialized by overriding `dict()` only.
57+
By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place.
58+
"""
59+
data = self.dict(
60+
by_alias=by_alias,
61+
include=include,
62+
exclude=exclude,
63+
exclude_unset=exclude_unset,
64+
exclude_defaults=exclude_defaults,
65+
exclude_none=exclude_none,
66+
)
67+
if self.__custom_root_type__:
68+
data = data["__root__"]
69+
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)
70+
71+
return CoreModel
72+
73+
74+
if TYPE_CHECKING:
75+
76+
class CoreModel(generate_dual_base_model(CoreConfig)):
77+
pass
78+
else:
79+
CoreModel = generate_dual_core_model(CoreConfig)
80+
81+
82+
class FrozenConfig(CoreConfig):
83+
frozen = True
84+
85+
86+
FrozenCoreModel = generate_dual_core_model(FrozenConfig)
6087

6188

6289
class Duration(int):
@@ -93,7 +120,7 @@ def parse(cls, v: Union[int, str]) -> "Duration":
93120
raise ValueError(f"Cannot parse the duration {v}")
94121

95122

96-
class RegistryAuth(CoreModel):
123+
class RegistryAuth(FrozenCoreModel):
97124
"""
98125
Credentials for pulling a private Docker image.
99126
@@ -105,9 +132,6 @@ class RegistryAuth(CoreModel):
105132
username: Annotated[str, Field(description="The username")]
106133
password: Annotated[str, Field(description="The password or access token")]
107134

108-
class Config(CoreModel.Config):
109-
frozen = True
110-
111135

112136
class ApplyAction(str, Enum):
113137
CREATE = "create" # resource is to be created or overridden

src/dstack/_internal/core/models/configurations.py

Lines changed: 88 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,23 @@
1010
from typing_extensions import Self
1111

1212
from dstack._internal.core.errors import ConfigurationError
13-
from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth
13+
from dstack._internal.core.models.common import (
14+
CoreConfig,
15+
CoreModel,
16+
Duration,
17+
RegistryAuth,
18+
generate_dual_core_model,
19+
)
1420
from dstack._internal.core.models.envs import Env
1521
from dstack._internal.core.models.files import FilePathMapping
1622
from dstack._internal.core.models.fleets import FleetConfiguration
1723
from dstack._internal.core.models.gateways import GatewayConfiguration
18-
from dstack._internal.core.models.profiles import ProfileParams, parse_duration, parse_off_duration
24+
from dstack._internal.core.models.profiles import (
25+
ProfileParams,
26+
ProfileParamsConfig,
27+
parse_duration,
28+
parse_off_duration,
29+
)
1930
from dstack._internal.core.models.resources import Range, ResourcesSpec
2031
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
2132
from dstack._internal.core.models.unix import UnixUser
@@ -276,7 +287,20 @@ class HTTPHeaderSpec(CoreModel):
276287
]
277288

278289

279-
class ProbeConfig(CoreModel):
290+
class ProbeConfigConfig(CoreConfig):
291+
@staticmethod
292+
def schema_extra(schema: Dict[str, Any]):
293+
add_extra_schema_types(
294+
schema["properties"]["timeout"],
295+
extra_types=[{"type": "string"}],
296+
)
297+
add_extra_schema_types(
298+
schema["properties"]["interval"],
299+
extra_types=[{"type": "string"}],
300+
)
301+
302+
303+
class ProbeConfig(generate_dual_core_model(ProbeConfigConfig)):
280304
type: Literal["http"] # expect other probe types in the future, namely `exec`
281305
url: Annotated[
282306
Optional[str], Field(description=f"The URL to request. Defaults to `{DEFAULT_PROBE_URL}`")
@@ -331,18 +355,6 @@ class ProbeConfig(CoreModel):
331355
),
332356
] = None
333357

334-
class Config(CoreModel.Config):
335-
@staticmethod
336-
def schema_extra(schema: Dict[str, Any]):
337-
add_extra_schema_types(
338-
schema["properties"]["timeout"],
339-
extra_types=[{"type": "string"}],
340-
)
341-
add_extra_schema_types(
342-
schema["properties"]["interval"],
343-
extra_types=[{"type": "string"}],
344-
)
345-
346358
@validator("timeout", pre=True)
347359
def parse_timeout(cls, v: Optional[Union[int, str]]) -> Optional[int]:
348360
if v is None:
@@ -381,6 +393,19 @@ def validate_body_matches_method(cls, values):
381393
return values
382394

383395

396+
class BaseRunConfigurationConfig(CoreConfig):
397+
@staticmethod
398+
def schema_extra(schema: Dict[str, Any]):
399+
add_extra_schema_types(
400+
schema["properties"]["volumes"]["items"],
401+
extra_types=[{"type": "string"}],
402+
)
403+
add_extra_schema_types(
404+
schema["properties"]["files"]["items"],
405+
extra_types=[{"type": "string"}],
406+
)
407+
408+
384409
class BaseRunConfiguration(CoreModel):
385410
type: Literal["none"]
386411
name: Annotated[
@@ -484,18 +509,6 @@ class BaseRunConfiguration(CoreModel):
484509
# deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init`
485510
setup: CommandsList = []
486511

487-
class Config(CoreModel.Config):
488-
@staticmethod
489-
def schema_extra(schema: Dict[str, Any]):
490-
add_extra_schema_types(
491-
schema["properties"]["volumes"]["items"],
492-
extra_types=[{"type": "string"}],
493-
)
494-
add_extra_schema_types(
495-
schema["properties"]["files"]["items"],
496-
extra_types=[{"type": "string"}],
497-
)
498-
499512
@validator("python", pre=True, always=True)
500513
def convert_python(cls, v, values) -> Optional[PythonVersion]:
501514
if v is not None and values.get("image"):
@@ -621,20 +634,25 @@ def parse_inactivity_duration(
621634
return None
622635

623636

637+
class DevEnvironmentConfigurationConfig(
638+
ProfileParamsConfig,
639+
BaseRunConfigurationConfig,
640+
):
641+
@staticmethod
642+
def schema_extra(schema: Dict[str, Any]):
643+
ProfileParamsConfig.schema_extra(schema)
644+
BaseRunConfigurationConfig.schema_extra(schema)
645+
646+
624647
class DevEnvironmentConfiguration(
625648
ProfileParams,
626649
BaseRunConfiguration,
627650
ConfigurationWithPortsParams,
628651
DevEnvironmentConfigurationParams,
652+
generate_dual_core_model(DevEnvironmentConfigurationConfig),
629653
):
630654
type: Literal["dev-environment"] = "dev-environment"
631655

632-
class Config(ProfileParams.Config, BaseRunConfiguration.Config):
633-
@staticmethod
634-
def schema_extra(schema: Dict[str, Any]):
635-
ProfileParams.Config.schema_extra(schema)
636-
BaseRunConfiguration.Config.schema_extra(schema)
637-
638656
@validator("entrypoint")
639657
def validate_entrypoint(cls, v: Optional[str]) -> Optional[str]:
640658
if v is not None:
@@ -646,20 +664,38 @@ class TaskConfigurationParams(CoreModel):
646664
nodes: Annotated[int, Field(description="Number of nodes", ge=1)] = 1
647665

648666

667+
class TaskConfigurationConfig(
668+
ProfileParamsConfig,
669+
BaseRunConfigurationConfig,
670+
):
671+
@staticmethod
672+
def schema_extra(schema: Dict[str, Any]):
673+
ProfileParamsConfig.schema_extra(schema)
674+
BaseRunConfigurationConfig.schema_extra(schema)
675+
676+
649677
class TaskConfiguration(
650678
ProfileParams,
651679
BaseRunConfiguration,
652680
ConfigurationWithCommandsParams,
653681
ConfigurationWithPortsParams,
654682
TaskConfigurationParams,
683+
generate_dual_core_model(TaskConfigurationConfig),
655684
):
656685
type: Literal["task"] = "task"
657686

658-
class Config(ProfileParams.Config, BaseRunConfiguration.Config):
659-
@staticmethod
660-
def schema_extra(schema: Dict[str, Any]):
661-
ProfileParams.Config.schema_extra(schema)
662-
BaseRunConfiguration.Config.schema_extra(schema)
687+
688+
class ServiceConfigurationParamsConfig(CoreConfig):
689+
@staticmethod
690+
def schema_extra(schema: Dict[str, Any]):
691+
add_extra_schema_types(
692+
schema["properties"]["replicas"],
693+
extra_types=[{"type": "integer"}, {"type": "string"}],
694+
)
695+
add_extra_schema_types(
696+
schema["properties"]["model"],
697+
extra_types=[{"type": "string"}],
698+
)
663699

664700

665701
class ServiceConfigurationParams(CoreModel):
@@ -719,18 +755,6 @@ class ServiceConfigurationParams(CoreModel):
719755
Field(description="List of probes used to determine job health"),
720756
] = []
721757

722-
class Config(CoreModel.Config):
723-
@staticmethod
724-
def schema_extra(schema: Dict[str, Any]):
725-
add_extra_schema_types(
726-
schema["properties"]["replicas"],
727-
extra_types=[{"type": "integer"}, {"type": "string"}],
728-
)
729-
add_extra_schema_types(
730-
schema["properties"]["model"],
731-
extra_types=[{"type": "string"}],
732-
)
733-
734758
@validator("port")
735759
def convert_port(cls, v) -> PortMapping:
736760
if isinstance(v, int):
@@ -797,25 +821,27 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
797821
return v
798822

799823

824+
class ServiceConfigurationConfig(
825+
ProfileParamsConfig,
826+
BaseRunConfigurationConfig,
827+
ServiceConfigurationParamsConfig,
828+
):
829+
@staticmethod
830+
def schema_extra(schema: Dict[str, Any]):
831+
ProfileParamsConfig.schema_extra(schema)
832+
BaseRunConfigurationConfig.schema_extra(schema)
833+
ServiceConfigurationParamsConfig.schema_extra(schema)
834+
835+
800836
class ServiceConfiguration(
801837
ProfileParams,
802838
BaseRunConfiguration,
803839
ConfigurationWithCommandsParams,
804840
ServiceConfigurationParams,
841+
generate_dual_core_model(ServiceConfigurationConfig),
805842
):
806843
type: Literal["service"] = "service"
807844

808-
class Config(
809-
ProfileParams.Config,
810-
BaseRunConfiguration.Config,
811-
ServiceConfigurationParams.Config,
812-
):
813-
@staticmethod
814-
def schema_extra(schema: Dict[str, Any]):
815-
ProfileParams.Config.schema_extra(schema)
816-
BaseRunConfiguration.Config.schema_extra(schema)
817-
ServiceConfigurationParams.Config.schema_extra(schema)
818-
819845

820846
AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]
821847

@@ -876,7 +902,7 @@ class DstackConfiguration(CoreModel):
876902
Field(discriminator="type"),
877903
]
878904

879-
class Config(CoreModel.Config):
905+
class Config(CoreConfig):
880906
json_loads = orjson.loads
881907
json_dumps = pydantic_orjson_dumps_with_indent
882908

0 commit comments

Comments
 (0)