Skip to content

Commit b2b7ff2

Browse files
authored
refactor: use typed patch instance for session updates (#500)
See #485. The session blueprint is also updated to use the `validated_json()` method for API responses.
1 parent 7b44f99 commit b2b7ff2

6 files changed

Lines changed: 188 additions & 139 deletions

File tree

components/renku_data_services/session/apispec_base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Base models for API specifications."""
22

3+
from typing import Any
4+
35
from pydantic import BaseModel, field_validator
46
from ulid import ULID
57

@@ -12,8 +14,10 @@ class Config:
1214

1315
from_attributes = True
1416

15-
@field_validator("id", mode="before", check_fields=False)
17+
@field_validator("*", mode="before", check_fields=False)
1618
@classmethod
17-
def serialize_id(cls, id: str | ULID) -> str:
18-
"""Custom serializer that can handle ULIDs."""
19-
return str(id)
19+
def serialize_ulid(cls, value: Any) -> Any:
20+
"""Handle ULIDs."""
21+
if isinstance(value, ULID):
22+
return str(value)
23+
return value

components/renku_data_services/session/blueprints.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,22 @@
22

33
from dataclasses import dataclass
44

5-
from sanic import HTTPResponse, Request, json
5+
from sanic import HTTPResponse, Request
66
from sanic.response import JSONResponse
77
from sanic_ext import validate
88
from ulid import ULID
99

10-
import renku_data_services.base_models as base_models
10+
from renku_data_services import base_models
1111
from renku_data_services.base_api.auth import authenticate, validate_path_project_id
1212
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
13+
from renku_data_services.base_models.validation import validated_json
1314
from renku_data_services.session import apispec
15+
from renku_data_services.session.core import (
16+
validate_environment_patch,
17+
validate_session_launcher_patch,
18+
validate_unsaved_environment,
19+
validate_unsaved_session_launcher,
20+
)
1421
from renku_data_services.session.db import SessionRepository
1522

1623

@@ -26,9 +33,7 @@ def get_all(self) -> BlueprintFactoryResponse:
2633

2734
async def _get_all(_: Request) -> JSONResponse:
2835
environments = await self.session_repo.get_environments()
29-
return json(
30-
[apispec.Environment.model_validate(e).model_dump(exclude_none=True, mode="json") for e in environments]
31-
)
36+
return validated_json(apispec.EnvironmentList, environments)
3237

3338
return "/environments", ["GET"], _get_all
3439

@@ -37,7 +42,7 @@ def get_one(self) -> BlueprintFactoryResponse:
3742

3843
async def _get_one(_: Request, environment_id: ULID) -> JSONResponse:
3944
environment = await self.session_repo.get_environment(environment_id=environment_id)
40-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
45+
return validated_json(apispec.Environment, environment)
4146

4247
return "/environments/<environment_id:ulid>", ["GET"], _get_one
4348

@@ -47,8 +52,9 @@ def post(self) -> BlueprintFactoryResponse:
4752
@authenticate(self.authenticator)
4853
@validate(json=apispec.EnvironmentPost)
4954
async def _post(_: Request, user: base_models.APIUser, body: apispec.EnvironmentPost) -> JSONResponse:
50-
environment = await self.session_repo.insert_environment(user=user, new_environment=body)
51-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"), 201)
55+
new_environment = validate_unsaved_environment(body)
56+
environment = await self.session_repo.insert_environment(user=user, environment=new_environment)
57+
return validated_json(apispec.Environment, environment, status=201)
5258

5359
return "/environments", ["POST"], _post
5460

@@ -60,11 +66,11 @@ def patch(self) -> BlueprintFactoryResponse:
6066
async def _patch(
6167
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
6268
) -> JSONResponse:
63-
body_dict = body.model_dump(exclude_none=True)
69+
environment_patch = validate_environment_patch(body)
6470
environment = await self.session_repo.update_environment(
65-
user=user, environment_id=environment_id, **body_dict
71+
user=user, environment_id=environment_id, patch=environment_patch
6672
)
67-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
73+
return validated_json(apispec.Environment, environment)
6874

6975
return "/environments/<environment_id:ulid>", ["PATCH"], _patch
7076

@@ -92,12 +98,7 @@ def get_all(self) -> BlueprintFactoryResponse:
9298
@authenticate(self.authenticator)
9399
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
94100
launchers = await self.session_repo.get_launchers(user=user)
95-
return json(
96-
[
97-
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
98-
for item in launchers
99-
]
100-
)
101+
return validated_json(apispec.SessionLaunchersList, launchers)
101102

102103
return "/session_launchers", ["GET"], _get_all
103104

@@ -107,7 +108,7 @@ def get_one(self) -> BlueprintFactoryResponse:
107108
@authenticate(self.authenticator)
108109
async def _get_one(_: Request, user: base_models.APIUser, launcher_id: ULID) -> JSONResponse:
109110
launcher = await self.session_repo.get_launcher(user=user, launcher_id=launcher_id)
110-
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
111+
return validated_json(apispec.SessionLauncher, launcher)
111112

112113
return "/session_launchers/<launcher_id:ulid>", ["GET"], _get_one
113114

@@ -117,10 +118,9 @@ def post(self) -> BlueprintFactoryResponse:
117118
@authenticate(self.authenticator)
118119
@validate(json=apispec.SessionLauncherPost)
119120
async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLauncherPost) -> JSONResponse:
120-
launcher = await self.session_repo.insert_launcher(user=user, new_launcher=body)
121-
return json(
122-
apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"), 201
123-
)
121+
new_launcher = validate_unsaved_session_launcher(body)
122+
launcher = await self.session_repo.insert_launcher(user=user, launcher=new_launcher)
123+
return validated_json(apispec.SessionLauncher, launcher, status=201)
124124

125125
return "/session_launchers", ["POST"], _post
126126

@@ -132,9 +132,9 @@ def patch(self) -> BlueprintFactoryResponse:
132132
async def _patch(
133133
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
134134
) -> JSONResponse:
135-
body_dict = body.model_dump(exclude_none=True)
136-
launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, **body_dict)
137-
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
135+
launcher_patch = validate_session_launcher_patch(body)
136+
launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, patch=launcher_patch)
137+
return validated_json(apispec.SessionLauncher, launcher)
138138

139139
return "/session_launchers/<launcher_id:ulid>", ["PATCH"], _patch
140140

@@ -155,11 +155,6 @@ def get_project_launchers(self) -> BlueprintFactoryResponse:
155155
@validate_path_project_id
156156
async def _get_launcher(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse:
157157
launchers = await self.session_repo.get_project_launchers(user=user, project_id=project_id)
158-
return json(
159-
[
160-
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
161-
for item in launchers
162-
]
163-
)
158+
return validated_json(apispec.SessionLaunchersList, launchers)
164159

165160
return "/projects/<project_id>/session_launchers", ["GET"], _get_launcher
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Business logic for sessions."""
2+
3+
from ulid import ULID
4+
5+
from renku_data_services.session import apispec, models
6+
7+
8+
def validate_unsaved_environment(environment: apispec.EnvironmentPost) -> models.UnsavedEnvironment:
9+
"""Validate an unsaved session environment."""
10+
return models.UnsavedEnvironment(
11+
name=environment.name,
12+
description=environment.description,
13+
container_image=environment.container_image,
14+
default_url=environment.default_url,
15+
)
16+
17+
18+
def validate_environment_patch(patch: apispec.EnvironmentPatch) -> models.EnvironmentPatch:
19+
"""Validate the update to a session environment."""
20+
return models.EnvironmentPatch(
21+
name=patch.name,
22+
description=patch.description,
23+
container_image=patch.container_image,
24+
default_url=patch.default_url,
25+
)
26+
27+
28+
def validate_unsaved_session_launcher(launcher: apispec.SessionLauncherPost) -> models.UnsavedSessionLauncher:
29+
"""Validate an unsaved session launcher."""
30+
return models.UnsavedSessionLauncher(
31+
project_id=ULID.from_str(launcher.project_id),
32+
name=launcher.name,
33+
description=launcher.description,
34+
environment_kind=launcher.environment_kind,
35+
environment_id=launcher.environment_id,
36+
resource_class_id=launcher.resource_class_id,
37+
container_image=launcher.container_image,
38+
default_url=launcher.default_url,
39+
)
40+
41+
42+
def validate_session_launcher_patch(patch: apispec.SessionLauncherPatch) -> models.SessionLauncherPatch:
43+
"""Validate the update to a session launcher."""
44+
return models.SessionLauncherPatch(
45+
name=patch.name,
46+
description=patch.description,
47+
environment_kind=patch.environment_kind,
48+
environment_id=patch.environment_id,
49+
resource_class_id=patch.resource_class_id,
50+
container_image=patch.container_image,
51+
default_url=patch.default_url,
52+
)

0 commit comments

Comments
 (0)