Skip to content

Commit 39bbe36

Browse files
committed
Replace validators with dict() for serialization tweaks
1 parent fcc8dc5 commit 39bbe36

8 files changed

Lines changed: 136 additions & 108 deletions

File tree

src/dstack/_internal/core/models/common.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import re
22
from enum import Enum
3-
from typing import Any, Union
3+
from typing import Any, Callable, Union
44

55
import orjson
6+
from git import Optional
67
from pydantic import Field
78
from pydantic_duality import DualBaseModel
89
from typing_extensions import Annotated
@@ -35,6 +36,37 @@ class Config:
3536
json_loads = orjson.loads
3637
json_dumps = _orjson_dumps
3738

39+
def json(
40+
self,
41+
*,
42+
include: Optional[IncludeExcludeType] = None,
43+
exclude: Optional[IncludeExcludeType] = None,
44+
by_alias: bool = False,
45+
skip_defaults: Optional[bool] = None, # ignore as it's deprecated
46+
exclude_unset: bool = False,
47+
exclude_defaults: bool = False,
48+
exclude_none: bool = False,
49+
encoder: Optional[Callable[[Any], Any]] = None,
50+
models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies
51+
**dumps_kwargs: Any,
52+
) -> str:
53+
"""
54+
Override `json()` method so that it calls `dict()`.
55+
Allows changing how models are serialized by overriding `dict()` only.
56+
By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place.
57+
"""
58+
data = self.dict(
59+
by_alias=by_alias,
60+
include=include,
61+
exclude=exclude,
62+
exclude_unset=exclude_unset,
63+
exclude_defaults=exclude_defaults,
64+
exclude_none=exclude_none,
65+
)
66+
if self.__custom_root_type__:
67+
data = data["__root__"]
68+
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)
69+
3870

3971
class Duration(int):
4072
"""

src/dstack/_internal/core/models/resources.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -382,14 +382,6 @@ def schema_extra(schema: Dict[str, Any]):
382382
gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None
383383
disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK
384384

385-
# TODO: Remove in 0.20. Added for backward compatibility.
386-
@root_validator
387-
def _post_validate(cls, values):
388-
cpu = values.get("cpu")
389-
if isinstance(cpu, CPUSpec) and cpu.arch in [None, gpuhunt.CPUArchitecture.X86]:
390-
values["cpu"] = cpu.count
391-
return values
392-
393385
def pretty_format(self) -> str:
394386
# TODO: Remove in 0.20. Use self.cpu directly
395387
cpu = parse_obj_as(CPUSpec, self.cpu)
@@ -407,3 +399,18 @@ def pretty_format(self) -> str:
407399
resources.update(disk_size=self.disk.size)
408400
res = pretty_resources(**resources)
409401
return res
402+
403+
def dict(self, *args, **kwargs) -> Dict:
404+
# super() does not work with pydantic-duality
405+
res = CoreModel.dict(self, *args, **kwargs)
406+
self._update_serialized_cpu(res)
407+
return res
408+
409+
# TODO: Remove in 0.20. Added for backward compatibility.
410+
def _update_serialized_cpu(self, values: Dict):
411+
cpu = values["cpu"]
412+
if cpu:
413+
arch = cpu.get("arch")
414+
count = cpu.get("count")
415+
if count and arch in [None, gpuhunt.CPUArchitecture.X86.value]:
416+
values["cpu"] = count

src/dstack/_internal/core/models/runs.py

Lines changed: 69 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -325,56 +325,45 @@ def duration(self) -> timedelta:
325325
end_time = self.finished_at
326326
return end_time - self.submitted_at
327327

328-
@root_validator
329-
def _status_message(cls, values) -> Dict:
330-
try:
331-
status = values["status"]
332-
termination_reason = values["termination_reason"]
333-
exit_code = values["exit_status"]
334-
except KeyError:
335-
return values
336-
values["status_message"] = JobSubmission._get_status_message(
337-
status=status,
338-
termination_reason=termination_reason,
339-
exit_status=exit_code,
340-
)
341-
return values
328+
def dict(self, *args, **kwargs) -> Dict:
329+
status_message = self._get_status_message()
330+
error = self._get_error()
331+
# super() does not work with pydantic-duality
332+
res = CoreModel.dict(self, *args, **kwargs)
333+
res["status_message"] = status_message
334+
res["error"] = error
335+
return res
342336

343-
@staticmethod
344-
def _get_status_message(
345-
status: JobStatus,
346-
termination_reason: Optional[JobTerminationReason],
347-
exit_status: Optional[int],
348-
) -> str:
349-
if status == JobStatus.DONE:
337+
def _get_status_message(self) -> Optional[str]:
338+
if self.status == JobStatus.DONE:
350339
return "exited (0)"
351-
elif status == JobStatus.FAILED:
352-
if termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
353-
return f"exited ({exit_status})"
354-
elif termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY:
340+
elif self.status == JobStatus.FAILED:
341+
if self.termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
342+
return f"exited ({self.exit_status})"
343+
elif (
344+
self.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
345+
):
355346
return "no offers"
356-
elif termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
347+
elif self.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
357348
return "interrupted"
358349
else:
359350
return "error"
360-
elif status == JobStatus.TERMINATED:
361-
if termination_reason == JobTerminationReason.TERMINATED_BY_USER:
351+
elif self.status == JobStatus.TERMINATED:
352+
if self.termination_reason == JobTerminationReason.TERMINATED_BY_USER:
362353
return "stopped"
363-
elif termination_reason == JobTerminationReason.ABORTED_BY_USER:
354+
elif self.termination_reason == JobTerminationReason.ABORTED_BY_USER:
364355
return "aborted"
365-
return status.value
356+
return self.status.value
366357

367-
@root_validator
368-
def _error(cls, values) -> Dict:
369-
try:
370-
termination_reason = values["termination_reason"]
371-
except KeyError:
372-
return values
373-
values["error"] = JobSubmission._get_error(termination_reason=termination_reason)
374-
return values
358+
def _get_error(self) -> Optional[str]:
359+
return JobSubmission._termination_reason_to_error(
360+
termination_reason=self.termination_reason
361+
)
375362

376363
@staticmethod
377-
def _get_error(termination_reason: Optional[JobTerminationReason]) -> Optional[str]:
364+
def _termination_reason_to_error(
365+
termination_reason: Optional[JobTerminationReason],
366+
) -> Optional[str]:
378367
error_mapping = {
379368
JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
380369
JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
@@ -395,6 +384,12 @@ class Job(CoreModel):
395384
job_spec: JobSpec
396385
job_submissions: List[JobSubmission]
397386

387+
def get_last_termination_reason(self) -> Optional[JobTerminationReason]:
388+
for submission in reversed(self.job_submissions):
389+
if submission.termination_reason is not None:
390+
return submission.termination_reason
391+
return None
392+
398393

399394
class RunSpec(CoreModel):
400395
# TODO: run_name, working_dir are redundant here since they already passed in configuration
@@ -525,83 +520,66 @@ class Run(CoreModel):
525520
last_processed_at: datetime
526521
status: RunStatus
527522
status_message: Optional[str] = None
528-
termination_reason: Optional[RunTerminationReason]
523+
termination_reason: Optional[RunTerminationReason] = None
529524
run_spec: RunSpec
530525
jobs: List[Job]
531-
latest_job_submission: Optional[JobSubmission]
526+
latest_job_submission: Optional[JobSubmission] = None
532527
cost: float = 0
533528
service: Optional[ServiceSpec] = None
534529
deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers
535530
# TODO: make error a computed field after migrating to pydanticV2
536531
error: Optional[str] = None
537532
deleted: Optional[bool] = None
538533

539-
@root_validator
540-
def _error(cls, values) -> Dict:
541-
try:
542-
termination_reason = values["termination_reason"]
543-
except KeyError:
544-
return values
545-
values["error"] = Run._get_error(termination_reason=termination_reason)
546-
return values
534+
def dict(self, *args, **kwargs) -> Dict:
535+
status_message = self._get_status_message()
536+
error = self._get_error()
537+
# super() does not work with pydantic-duality
538+
res = CoreModel.dict(self, *args, **kwargs)
539+
res["status_message"] = status_message
540+
res["error"] = error
541+
return res
542+
543+
def _get_error(self) -> Optional[str]:
544+
return Run._termination_reason_to_error(termination_reason=self.termination_reason)
547545

548546
@staticmethod
549-
def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[str]:
547+
def _termination_reason_to_error(
548+
termination_reason: Optional[RunTerminationReason],
549+
) -> Optional[str]:
550550
if termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED:
551551
return "retry limit exceeded"
552552
elif termination_reason == RunTerminationReason.SERVER_ERROR:
553553
return "server error"
554554
else:
555555
return None
556556

557-
@root_validator
558-
def _status_message(cls, values) -> Dict:
559-
try:
560-
status = values["status"]
561-
jobs: List[Job] = values["jobs"]
562-
retry_on_events = (
563-
jobs[0].job_spec.retry.on_events if jobs and jobs[0].job_spec.retry else []
564-
)
565-
job_status = (
566-
jobs[0].job_submissions[-1].status
567-
if len(jobs) == 1 and jobs[0].job_submissions
568-
else None
569-
)
570-
termination_reason = Run.get_last_termination_reason(jobs[0]) if jobs else None
571-
except KeyError:
572-
return values
573-
values["status_message"] = Run._get_status_message(
574-
status=status,
575-
job_status=job_status,
576-
retry_on_events=retry_on_events,
577-
termination_reason=termination_reason,
578-
)
579-
return values
557+
def _get_status_message(self) -> Optional[str]:
558+
if len(self.jobs) == 0:
559+
return self.status.value
580560

581-
@staticmethod
582-
def get_last_termination_reason(job: "Job") -> Optional[JobTerminationReason]:
583-
for submission in reversed(job.job_submissions):
584-
if submission.termination_reason is not None:
585-
return submission.termination_reason
586-
return None
561+
last_job = self.jobs[0]
562+
last_job_termination_reason = last_job.get_last_termination_reason()
587563

588-
@staticmethod
589-
def _get_status_message(
590-
status: RunStatus,
591-
job_status: Optional[JobStatus],
592-
retry_on_events: List[RetryEvent],
593-
termination_reason: Optional[JobTerminationReason],
594-
) -> str:
595-
if job_status == JobStatus.PULLING:
596-
return "pulling"
564+
if len(self.jobs) == 1:
565+
# FIXME: Clarify why show "pulling" only in case of one job
566+
if (
567+
last_job.job_submissions
568+
and last_job.job_submissions[-1].status == JobStatus.PULLING
569+
):
570+
return "pulling"
571+
572+
retry_on_events = last_job.job_spec.retry.on_events if last_job.job_spec.retry else []
597573
# Currently, `retrying` is shown only for `no-capacity` events
598574
if (
599-
status in [RunStatus.SUBMITTED, RunStatus.PENDING]
600-
and termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
575+
self.status in [RunStatus.SUBMITTED, RunStatus.PENDING]
576+
and last_job_termination_reason
577+
== JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
601578
and RetryEvent.NO_CAPACITY in retry_on_events
602579
):
603580
return "retrying"
604-
return status.value
581+
582+
return self.status.value
605583

606584
def is_deployment_in_progress(self) -> bool:
607585
return any(

src/dstack/_internal/server/app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def create_app() -> FastAPI:
9494
app = FastAPI(
9595
docs_url="/api/docs",
9696
lifespan=lifespan,
97-
default_response_class=CustomORJSONResponse,
9897
)
9998
app.state.proxy_dependency_injector = ServerProxyDependencyInjector()
10099
return app

src/dstack/_internal/server/utils/routers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@ class CustomORJSONResponse(Response):
1313
"""
1414
Custom JSONResponse that uses orjson for serialization.
1515
16-
It's recommended to return this class from routers directly
17-
to avoid the FastAPI's jsonable_encoder overhead.
16+
It's recommended to return this class from routers directly instead of
17+
returning pydantic models to avoid the FastAPI's jsonable_encoder overhead.
1818
See https://fastapi.tiangolo.com/advanced/custom-response/#use-orjsonresponse.
19+
20+
Beware that FastAPI skips model validation when responses are returned directly.
21+
If serialization needs to be modified, override `dict()` instead of adding validators.
1922
"""
2023

2124
media_type = "application/json"

src/dstack/_internal/utils/json_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import orjson
22
from pydantic import BaseModel
33

4+
FREEZEGUN = True
5+
try:
6+
from freezegun.api import FakeDatetime
7+
except ImportError:
8+
FREEZEGUN = False
9+
410

511
def orjson_default(obj):
612
if isinstance(obj, float):
@@ -10,6 +16,9 @@ def orjson_default(obj):
1016
# Allows calling orjson.dumps() on pydantic models
1117
# (e.g. to return from the API)
1218
return obj.dict()
19+
if FREEZEGUN:
20+
if isinstance(obj, FakeDatetime):
21+
return obj.isoformat()
1322
raise TypeError
1423

1524

src/tests/_internal/core/models/test_runs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_job_termination_reason_to_retry_event_works_with_all_enum_variants():
3333
assert retry_event is None or isinstance(retry_event, RetryEvent)
3434

3535

36-
# Will fail if JobTerminationReason value is added without updaing JobSubmission._get_error
36+
# Will fail if JobTerminationReason value is added without updating JobSubmission._get_error
3737
def test_get_error_returns_expected_messages():
3838
no_error_reasons = [
3939
JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY,
@@ -47,7 +47,7 @@ def test_get_error_returns_expected_messages():
4747
]
4848

4949
for reason in JobTerminationReason:
50-
if JobSubmission._get_error(reason) is None:
50+
if JobSubmission._termination_reason_to_error(reason) is None:
5151
# Fail no-error reason is not in the list
5252
assert reason in no_error_reasons
5353

@@ -62,6 +62,6 @@ def test_run_get_error_returns_none_for_specific_reasons():
6262
]
6363

6464
for reason in RunTerminationReason:
65-
if Run._get_error(reason) is None:
65+
if Run._termination_reason_to_error(reason) is None:
6666
# Fail no-error reason is not in the list
6767
assert reason in no_error_reasons

src/tests/_internal/server/routers/test_fleets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,8 @@ async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncC
855855
assert response.json() == {
856856
"project_name": project.name,
857857
"user": user.name,
858-
"spec": spec.dict(),
859-
"effective_spec": spec.dict(),
858+
"spec": json.loads(spec.json()),
859+
"effective_spec": json.loads(spec.json()),
860860
"current_resource": None,
861861
"offers": [json.loads(o.json()) for o in offers],
862862
"total_offers": len(offers),

0 commit comments

Comments
 (0)