Skip to content

Commit 12463ea

Browse files
committed
Fix replicas typing
1 parent 40ab38d commit 12463ea

1 file changed

Lines changed: 15 additions & 13 deletions

File tree

src/dstack/_internal/core/models/configurations.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from dstack._internal.core.models.unix import UnixUser
2121
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
2222
from dstack._internal.utils.common import has_duplicates
23+
from dstack._internal.utils.json_schema import add_extra_schema_types
2324
from dstack._internal.utils.json_utils import (
2425
pydantic_orjson_dumps_with_indent,
2526
)
@@ -561,7 +562,7 @@ class ServiceConfigurationParams(CoreModel):
561562
)
562563
auth: Annotated[bool, Field(description="Enable the authorization")] = True
563564
replicas: Annotated[
564-
Union[conint(ge=1), constr(regex=r"^[0-9]+..[1-9][0-9]*$"), Range[int]],
565+
Range[int],
565566
Field(
566567
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
567568
"If it's a range, the `scaling` property is required"
@@ -592,20 +593,13 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]:
592593
return v
593594

594595
@validator("replicas")
595-
def convert_replicas(cls, v: Any) -> Range[int]:
596-
if isinstance(v, str) and ".." in v:
597-
min, max = v.replace(" ", "").split("..")
598-
v = Range(min=min or 0, max=max or None)
599-
elif isinstance(v, (int, float)):
600-
v = Range(min=v, max=v)
596+
def convert_replicas(cls, v: Range[int]) -> Range[int]:
601597
if v.max is None:
602598
raise ValueError("The maximum number of replicas is required")
599+
if v.min is None:
600+
v.min = 0
603601
if v.min < 0:
604602
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
605-
if v.max < v.min:
606-
raise ValueError(
607-
"The maximum number of replicas must be greater than or equal to the minimum number of replicas"
608-
)
609603
return v
610604

611605
@validator("gateway")
@@ -622,9 +616,9 @@ def validate_gateway(
622616
def validate_scaling(cls, values):
623617
scaling = values.get("scaling")
624618
replicas = values.get("replicas")
625-
if replicas.min != replicas.max and not scaling:
619+
if replicas and replicas.min != replicas.max and not scaling:
626620
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
627-
if replicas.min == replicas.max and scaling:
621+
if replicas and replicas.min == replicas.max and scaling:
628622
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
629623
return values
630624

@@ -655,6 +649,14 @@ class ServiceConfiguration(
655649
):
656650
type: Literal["service"] = "service"
657651

652+
class Config:
653+
@staticmethod
654+
def schema_extra(schema: Dict[str, Any]):
655+
add_extra_schema_types(
656+
schema["properties"]["replicas"],
657+
extra_types=[{"type": "integer"}, {"type": "string"}],
658+
)
659+
658660

659661
AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]
660662

0 commit comments

Comments
 (0)