Skip to content

Commit 3930efa

Browse files
committed
fix: allow session launcher parameters to be reset
Allows the API to accept None as input for args, command and the session launcher resource class ID so that they can be reset to their defaults in patch endpoints.
1 parent bf53f3a commit 3930efa

7 files changed

Lines changed: 223 additions & 109 deletions

File tree

components/renku_data_services/base_models/core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,17 @@ class Authenticator(Protocol[AnyAPIUser]):
212212
async def authenticate(self, access_token: str, request: Request) -> AnyAPIUser:
213213
"""Validates the user credentials (i.e. we can say that the user is a valid Renku user)."""
214214
...
215+
216+
217+
@dataclass(frozen=True, eq=True, kw_only=True)
218+
class Null:
219+
"""Parent class for distinguishing between None values."""
220+
221+
value: None = field(default=None, init=False, repr=False)
222+
223+
224+
@dataclass(frozen=True, eq=True, kw_only=True)
225+
class Reset(Null):
226+
"""Used to indicate a None value that has been deliberately set by the user or caller."""
227+
228+
...

components/renku_data_services/session/blueprints.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import renku_data_services.base_models as base_models
1212
from renku_data_services.base_api.auth import authenticate, validate_path_project_id
1313
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
14-
from renku_data_services.session import apispec, models
14+
from renku_data_services.session import apispec, converters, models
1515
from renku_data_services.session.db import SessionRepository
1616

1717

@@ -75,9 +75,11 @@ def patch(self) -> BlueprintFactoryResponse:
7575
async def _patch(
7676
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
7777
) -> JSONResponse:
78-
body_dict = body.model_dump(exclude_none=True)
78+
update = converters.environment_update_from_patch(body)
7979
environment = await self.session_repo.update_environment(
80-
user=user, environment_id=environment_id, **body_dict
80+
user=user,
81+
environment_id=environment_id,
82+
update=update,
8183
)
8284
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
8385

@@ -172,34 +174,14 @@ def patch(self) -> BlueprintFactoryResponse:
172174
async def _patch(
173175
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
174176
) -> JSONResponse:
175-
body_dict = body.model_dump(exclude_none=True, mode="json")
176177
async with self.session_repo.session_maker() as session, session.begin():
177178
current_launcher = await self.session_repo.get_launcher(user, launcher_id)
178-
new_env: models.UnsavedEnvironment | None = None
179-
if (
180-
isinstance(body.environment, apispec.EnvironmentPatchInLauncher)
181-
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
182-
and body.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
183-
):
184-
# This means that the global environment is being swapped for a custom one,
185-
# so we have to create a brand new environment, but we have to validate here.
186-
validated_env = apispec.EnvironmentPostInLauncher.model_validate(body_dict.pop("environment"))
187-
new_env = models.UnsavedEnvironment(
188-
name=validated_env.name,
189-
description=validated_env.description,
190-
container_image=validated_env.container_image,
191-
default_url=validated_env.default_url,
192-
port=validated_env.port,
193-
working_directory=PurePosixPath(validated_env.working_directory),
194-
mount_directory=PurePosixPath(validated_env.mount_directory),
195-
uid=validated_env.uid,
196-
gid=validated_env.gid,
197-
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
198-
args=validated_env.args,
199-
command=validated_env.command,
200-
)
179+
update = converters.launcher_update_from_patch(body, current_launcher)
201180
launcher = await self.session_repo.update_launcher(
202-
user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict
181+
user=user,
182+
launcher_id=launcher_id,
183+
session=session,
184+
update=update,
203185
)
204186
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
205187

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""Code used to convert from/to apispec and models."""
2+
3+
from pathlib import PurePosixPath
4+
5+
from renku_data_services.base_models.core import Reset
6+
from renku_data_services.session import apispec, models
7+
8+
9+
def environment_update_from_patch(data: apispec.EnvironmentPatch) -> models.EnvironmentUpdate:
10+
"""Create an update object from an apispec or any other pydantic model."""
11+
data_dict = data.model_dump(exclude_unset=True, mode="json")
12+
working_directory: PurePosixPath | None = None
13+
if data.working_directory is not None:
14+
working_directory = PurePosixPath(data.working_directory)
15+
mount_directory: PurePosixPath | None = None
16+
if data.mount_directory is not None:
17+
mount_directory = PurePosixPath(data.mount_directory)
18+
# NOTE: If the args or command are present in the data_dict and they are None they were passed in by the user.
19+
# The None specifically passed by the user indicates that the value should be removed from the DB.
20+
args = Reset() if "args" in data_dict and data_dict["args"] is None else data.args
21+
command = Reset() if "command" in data_dict and data_dict["command"] is None else data.command
22+
return models.EnvironmentUpdate(
23+
name=data.name,
24+
description=data.description,
25+
container_image=data.container_image,
26+
default_url=data.default_url,
27+
port=data.port,
28+
working_directory=working_directory,
29+
mount_directory=mount_directory,
30+
uid=data.uid,
31+
gid=data.gid,
32+
args=args,
33+
command=command,
34+
)
35+
36+
37+
def launcher_update_from_patch(
38+
data: apispec.SessionLauncherPatch,
39+
current_launcher: models.SessionLauncher | None = None,
40+
) -> models.SessionLauncherUpdate:
41+
"""Create an update object from an apispec or any other pydantic model."""
42+
data_dict = data.model_dump(exclude_unset=True, mode="json")
43+
environment: str | models.EnvironmentUpdate | models.UnsavedEnvironment | None = None
44+
if (
45+
isinstance(data.environment, apispec.EnvironmentPatchInLauncher)
46+
and current_launcher is not None
47+
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
48+
and data.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
49+
):
50+
# This means that the global environment is being swapped for a custom one,
51+
# so we have to create a brand new environment, but we have to validate here.
52+
validated_env = apispec.EnvironmentPostInLauncher.model_validate(data_dict["environment"])
53+
environment = models.UnsavedEnvironment(
54+
name=validated_env.name,
55+
description=validated_env.description,
56+
container_image=validated_env.container_image,
57+
default_url=validated_env.default_url,
58+
port=validated_env.port,
59+
working_directory=PurePosixPath(validated_env.working_directory),
60+
mount_directory=PurePosixPath(validated_env.mount_directory),
61+
uid=validated_env.uid,
62+
gid=validated_env.gid,
63+
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
64+
args=validated_env.args,
65+
command=validated_env.command,
66+
)
67+
elif isinstance(data.environment, apispec.EnvironmentPatchInLauncher):
68+
environment = environment_update_from_patch(data.environment)
69+
elif isinstance(data.environment, apispec.EnvironmentIdOnlyPatch):
70+
environment = data.environment.id
71+
resource_class_id: int | None | Reset = None
72+
if "resource_class_id" in data_dict and data_dict["resource_class_id"] is None:
73+
# NOTE: This means that the resource class set in the DB should be removed so that the
74+
# default resource class currently set in the CRC will be used.
75+
resource_class_id = Reset()
76+
else:
77+
resource_class_id = data_dict.get("resource_class_id")
78+
return models.SessionLauncherUpdate(
79+
name=data_dict.get("name"),
80+
description=data_dict.get("description"),
81+
environment=environment,
82+
resource_class_id=resource_class_id,
83+
)

components/renku_data_services/session/db.py

Lines changed: 68 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections.abc import Callable
66
from contextlib import AbstractAsyncContextManager, nullcontext
77
from datetime import UTC, datetime
8-
from typing import Any
98

109
from sqlalchemy import select
1110
from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,6 +14,7 @@
1514
from renku_data_services import errors
1615
from renku_data_services.authz.authz import Authz, ResourceType
1716
from renku_data_services.authz.models import Scope
17+
from renku_data_services.base_models.core import Reset
1818
from renku_data_services.crc.db import ResourcePoolRepository
1919
from renku_data_services.session import models
2020
from renku_data_services.session import orm as schemas
@@ -101,53 +101,59 @@ async def insert_environment(
101101
env = await self.__insert_environment(user, session, new_environment)
102102
return env.dump()
103103

104-
async def __update_environment(
104+
def __update_environment(
105105
self,
106-
user: base_models.APIUser,
107-
session: AsyncSession,
108-
environment_id: ULID,
109-
kind: models.EnvironmentKind,
110-
**kwargs: dict,
111-
) -> models.Environment:
112-
res = await session.scalars(
113-
select(schemas.EnvironmentORM)
114-
.where(schemas.EnvironmentORM.id == str(environment_id))
115-
.where(schemas.EnvironmentORM.environment_kind == kind.value)
116-
)
117-
environment = res.one_or_none()
118-
if environment is None:
119-
raise errors.MissingResourceError(message=f"Session environment with id '{environment_id}' does not exist.")
120-
121-
for key, value in kwargs.items():
122-
# NOTE: Only some fields can be edited
123-
if key in [
124-
"name",
125-
"description",
126-
"container_image",
127-
"default_url",
128-
"port",
129-
"working_directory",
130-
"mount_directory",
131-
"uid",
132-
"gid",
133-
"args",
134-
"command",
135-
]:
136-
setattr(environment, key, value)
137-
138-
return environment.dump()
106+
environment: schemas.EnvironmentORM,
107+
update: models.EnvironmentUpdate,
108+
) -> None:
109+
# NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks
110+
if update.name is not None:
111+
environment.name = update.name
112+
if update.description is not None:
113+
environment.description = update.description
114+
if update.container_image is not None:
115+
environment.container_image = update.container_image
116+
if update.default_url is not None:
117+
environment.default_url = update.default_url
118+
if update.port is not None:
119+
environment.port = update.port
120+
if update.working_directory is not None:
121+
environment.working_directory = update.working_directory
122+
if update.mount_directory is not None:
123+
environment.mount_directory = update.mount_directory
124+
if update.uid is not None:
125+
environment.uid = update.uid
126+
if update.gid is not None:
127+
environment.gid = update.gid
128+
if isinstance(update.args, Reset):
129+
environment.args = None
130+
elif isinstance(update.args, list):
131+
environment.args = update.args
132+
if isinstance(update.command, Reset):
133+
environment.command = None
134+
elif isinstance(update.command, list):
135+
environment.command = update.command
139136

140137
async def update_environment(
141-
self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict
138+
self, user: base_models.APIUser, environment_id: ULID, update: models.EnvironmentUpdate
142139
) -> models.Environment:
143140
"""Update a global session environment entry."""
144141
if not user.is_admin:
145142
raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")
146143

147144
async with self.session_maker() as session, session.begin():
148-
return await self.__update_environment(
149-
user, session, environment_id, models.EnvironmentKind.GLOBAL, **kwargs
145+
res = await session.scalars(
146+
select(schemas.EnvironmentORM)
147+
.where(schemas.EnvironmentORM.id == str(environment_id))
148+
.where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL)
150149
)
150+
environment = res.one_or_none()
151+
if environment is None:
152+
raise errors.MissingResourceError(
153+
message=f"Session environment with id '{environment_id}' does not exist."
154+
)
155+
self.__update_environment(environment, update)
156+
return environment.dump()
151157

152158
async def delete_environment(self, user: base_models.APIUser, environment_id: ULID) -> None:
153159
"""Delete a global session environment entry."""
@@ -300,9 +306,8 @@ async def update_launcher(
300306
self,
301307
user: base_models.APIUser,
302308
launcher_id: ULID,
303-
new_custom_environment: models.UnsavedEnvironment | None,
309+
update: models.SessionLauncherUpdate,
304310
session: AsyncSession | None = None,
305-
**kwargs: Any,
306311
) -> models.SessionLauncher:
307312
"""Update a session launcher entry."""
308313
if not user.is_authenticated or user.id is None:
@@ -336,8 +341,8 @@ async def update_launcher(
336341
if not authorized:
337342
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
338343

339-
resource_class_id = kwargs.get("resource_class_id")
340-
if resource_class_id is not None:
344+
resource_class_id = update.resource_class_id
345+
if isinstance(resource_class_id, int):
341346
res = await session.scalars(
342347
select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id)
343348
)
@@ -354,30 +359,32 @@ async def update_launcher(
354359
message=f"You do not have access to resource class with id '{resource_class_id}'."
355360
)
356361

357-
for key, value in kwargs.items():
358-
# NOTE: Only some fields can be updated.
359-
if key in [
360-
"name",
361-
"description",
362-
"resource_class_id",
363-
]:
364-
setattr(launcher, key, value)
365-
366-
env_payload = kwargs.get("environment", {})
367-
await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload)
362+
# NOTE: Only some fields can be updated.
363+
if update.name is not None:
364+
launcher.name = update.name
365+
if update.description is not None:
366+
launcher.description = update.description
367+
if isinstance(update.resource_class_id, int):
368+
launcher.resource_class_id = update.resource_class_id
369+
elif isinstance(update.resource_class_id, Reset):
370+
launcher.resource_class_id = None
371+
372+
if update.environment is None:
373+
return launcher.dump()
374+
375+
await self.__update_launcher_environment(user, launcher, session, update.environment)
368376
return launcher.dump()
369377

370378
async def __update_launcher_environment(
371379
self,
372380
user: base_models.APIUser,
373381
launcher: schemas.SessionLauncherORM,
374382
session: AsyncSession,
375-
new_custom_environment: models.UnsavedEnvironment | None,
376-
**kwargs: Any,
383+
update: models.EnvironmentUpdate | models.UnsavedEnvironment | str,
377384
) -> None:
378385
current_env_kind = launcher.environment.environment_kind
379-
match new_custom_environment, current_env_kind, kwargs:
380-
case None, _, {"id": env_id, **nothing_else} if len(nothing_else) == 0:
386+
match update, current_env_kind:
387+
case str() as env_id, _:
381388
# The environment in the launcher is set via ID, the new ID has to refer
382389
# to an environment that is GLOBAL.
383390
old_environment = launcher.environment
@@ -404,29 +411,11 @@ async def __update_launcher_environment(
404411
# We remove the custom environment to avoid accumulating custom environments that are not associated
405412
# with any launchers.
406413
await session.delete(old_environment)
407-
case None, models.EnvironmentKind.CUSTOM, {**rest} if (
408-
rest.get("environment_kind") is None
409-
or rest.get("environment_kind") == models.EnvironmentKind.CUSTOM.value
410-
):
414+
case models.EnvironmentUpdate(), models.EnvironmentKind.CUSTOM:
411415
# Custom environment being updated
412-
for key, val in rest.items():
413-
# NOTE: Only some fields can be updated.
414-
if key in [
415-
"name",
416-
"description",
417-
"container_image",
418-
"default_url",
419-
"port",
420-
"working_directory",
421-
"mount_directory",
422-
"uid",
423-
"gid",
424-
"args",
425-
"command",
426-
]:
427-
setattr(launcher.environment, key, val)
428-
case models.UnsavedEnvironment(), models.EnvironmentKind.GLOBAL, {**nothing_else} if (
429-
len(nothing_else) == 0 and new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
416+
self.__update_environment(launcher.environment, update)
417+
case models.UnsavedEnvironment() as new_custom_environment, models.EnvironmentKind.GLOBAL if (
418+
new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
430419
):
431420
# Global environment replaced by a custom one
432421
new_env = await self.__insert_environment(user, session, new_custom_environment)

0 commit comments

Comments
 (0)