Skip to content

Commit a8165cb

Browse files
committed
fix: allow session launcher parameters to be reset
Allows the API to accept None as input for args, command and the session launcher resource class ID so that they can be reset to their defaults in patch endpoints.
1 parent 55b3581 commit a8165cb

6 files changed

Lines changed: 221 additions & 111 deletions

File tree

components/renku_data_services/base_models/core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,17 @@ class Authenticator(Protocol[AnyAPIUser]):
212212
async def authenticate(self, access_token: str, request: Request) -> AnyAPIUser:
213213
"""Validates the user credentials (i.e. we can say that the user is a valid Renku user)."""
214214
...
215+
216+
217+
@dataclass(frozen=True, eq=True, kw_only=True)
218+
class Null:
219+
"""Parent class for distinguishing between None values."""
220+
221+
value: None = field(default=None, init=False, repr=False)
222+
223+
224+
@dataclass(frozen=True, eq=True, kw_only=True)
225+
class Reset(Null):
226+
"""Used to indicate a None value that has been deliberately set by the user or caller."""
227+
228+
...

components/renku_data_services/session/blueprints.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from renku_data_services.base_api.auth import authenticate, only_authenticated
1313
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
1414
from renku_data_services.base_models.validation import validated_json
15-
from renku_data_services.session import apispec, models
15+
from renku_data_services.session import apispec, converters, models
1616
from renku_data_services.session.db import SessionRepository
1717

1818

@@ -76,9 +76,11 @@ def patch(self) -> BlueprintFactoryResponse:
7676
async def _patch(
7777
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
7878
) -> JSONResponse:
79-
body_dict = body.model_dump(exclude_none=True)
79+
update = converters.environment_update_from_patch(body)
8080
environment = await self.session_repo.update_environment(
81-
user=user, environment_id=environment_id, **body_dict
81+
user=user,
82+
environment_id=environment_id,
83+
update=update,
8284
)
8385
return validated_json(apispec.Environment, environment)
8486

@@ -169,34 +171,14 @@ def patch(self) -> BlueprintFactoryResponse:
169171
async def _patch(
170172
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
171173
) -> JSONResponse:
172-
body_dict = body.model_dump(exclude_none=True, mode="json")
173174
async with self.session_repo.session_maker() as session, session.begin():
174175
current_launcher = await self.session_repo.get_launcher(user, launcher_id)
175-
new_env: models.UnsavedEnvironment | None = None
176-
if (
177-
isinstance(body.environment, apispec.EnvironmentPatchInLauncher)
178-
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
179-
and body.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
180-
):
181-
# This means that the global environment is being swapped for a custom one,
182-
# so we have to create a brand new environment, but we have to validate here.
183-
validated_env = apispec.EnvironmentPostInLauncher.model_validate(body_dict.pop("environment"))
184-
new_env = models.UnsavedEnvironment(
185-
name=validated_env.name,
186-
description=validated_env.description,
187-
container_image=validated_env.container_image,
188-
default_url=validated_env.default_url,
189-
port=validated_env.port,
190-
working_directory=PurePosixPath(validated_env.working_directory),
191-
mount_directory=PurePosixPath(validated_env.mount_directory),
192-
uid=validated_env.uid,
193-
gid=validated_env.gid,
194-
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
195-
args=validated_env.args,
196-
command=validated_env.command,
197-
)
176+
update = converters.launcher_update_from_patch(body, current_launcher)
198177
launcher = await self.session_repo.update_launcher(
199-
user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict
178+
user=user,
179+
launcher_id=launcher_id,
180+
session=session,
181+
update=update,
200182
)
201183
return validated_json(apispec.SessionLauncher, launcher)
202184

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""Code used to convert from/to apispec and models."""
2+
3+
from pathlib import PurePosixPath
4+
5+
from renku_data_services.base_models.core import Reset
6+
from renku_data_services.session import apispec, models
7+
8+
9+
def environment_update_from_patch(data: apispec.EnvironmentPatch) -> models.EnvironmentUpdate:
10+
"""Create an update object from an apispec or any other pydantic model."""
11+
data_dict = data.model_dump(exclude_unset=True, mode="json")
12+
working_directory: PurePosixPath | None = None
13+
if data.working_directory is not None:
14+
working_directory = PurePosixPath(data.working_directory)
15+
mount_directory: PurePosixPath | None = None
16+
if data.mount_directory is not None:
17+
mount_directory = PurePosixPath(data.mount_directory)
18+
# NOTE: If the args or command are present in the data_dict and they are None they were passed in by the user.
19+
# The None specifically passed by the user indicates that the value should be removed from the DB.
20+
args = Reset() if "args" in data_dict and data_dict["args"] is None else data.args
21+
command = Reset() if "command" in data_dict and data_dict["command"] is None else data.command
22+
return models.EnvironmentUpdate(
23+
name=data.name,
24+
description=data.description,
25+
container_image=data.container_image,
26+
default_url=data.default_url,
27+
port=data.port,
28+
working_directory=working_directory,
29+
mount_directory=mount_directory,
30+
uid=data.uid,
31+
gid=data.gid,
32+
args=args,
33+
command=command,
34+
)
35+
36+
37+
def launcher_update_from_patch(
38+
data: apispec.SessionLauncherPatch,
39+
current_launcher: models.SessionLauncher | None = None,
40+
) -> models.SessionLauncherUpdate:
41+
"""Create an update object from an apispec or any other pydantic model."""
42+
data_dict = data.model_dump(exclude_unset=True, mode="json")
43+
environment: str | models.EnvironmentUpdate | models.UnsavedEnvironment | None = None
44+
if (
45+
isinstance(data.environment, apispec.EnvironmentPatchInLauncher)
46+
and current_launcher is not None
47+
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
48+
and data.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
49+
):
50+
# This means that the global environment is being swapped for a custom one,
51+
# so we have to create a brand new environment, but we have to validate here.
52+
validated_env = apispec.EnvironmentPostInLauncher.model_validate(data_dict["environment"])
53+
environment = models.UnsavedEnvironment(
54+
name=validated_env.name,
55+
description=validated_env.description,
56+
container_image=validated_env.container_image,
57+
default_url=validated_env.default_url,
58+
port=validated_env.port,
59+
working_directory=PurePosixPath(validated_env.working_directory),
60+
mount_directory=PurePosixPath(validated_env.mount_directory),
61+
uid=validated_env.uid,
62+
gid=validated_env.gid,
63+
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
64+
args=validated_env.args,
65+
command=validated_env.command,
66+
)
67+
elif isinstance(data.environment, apispec.EnvironmentPatchInLauncher):
68+
environment = environment_update_from_patch(data.environment)
69+
elif isinstance(data.environment, apispec.EnvironmentIdOnlyPatch):
70+
environment = data.environment.id
71+
resource_class_id: int | None | Reset = None
72+
if "resource_class_id" in data_dict and data_dict["resource_class_id"] is None:
73+
# NOTE: This means that the resource class set in the DB should be removed so that the
74+
# default resource class currently set in the CRC will be used.
75+
resource_class_id = Reset()
76+
else:
77+
resource_class_id = data_dict.get("resource_class_id")
78+
return models.SessionLauncherUpdate(
79+
name=data_dict.get("name"),
80+
description=data_dict.get("description"),
81+
environment=environment,
82+
resource_class_id=resource_class_id,
83+
)

components/renku_data_services/session/db.py

Lines changed: 68 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from collections.abc import Callable
66
from contextlib import AbstractAsyncContextManager, nullcontext
7-
from typing import Any
87

98
from sqlalchemy import select
109
from sqlalchemy.ext.asyncio import AsyncSession
@@ -14,6 +13,7 @@
1413
from renku_data_services import errors
1514
from renku_data_services.authz.authz import Authz, ResourceType
1615
from renku_data_services.authz.models import Scope
16+
from renku_data_services.base_models.core import Reset
1717
from renku_data_services.crc.db import ResourcePoolRepository
1818
from renku_data_services.session import models
1919
from renku_data_services.session import orm as schemas
@@ -101,53 +101,59 @@ async def insert_environment(
101101
await session.refresh(env)
102102
return env.dump()
103103

104-
async def __update_environment(
104+
def __update_environment(
105105
self,
106-
user: base_models.APIUser,
107-
session: AsyncSession,
108-
environment_id: ULID,
109-
kind: models.EnvironmentKind,
110-
**kwargs: dict,
111-
) -> models.Environment:
112-
res = await session.scalars(
113-
select(schemas.EnvironmentORM)
114-
.where(schemas.EnvironmentORM.id == str(environment_id))
115-
.where(schemas.EnvironmentORM.environment_kind == kind.value)
116-
)
117-
environment = res.one_or_none()
118-
if environment is None:
119-
raise errors.MissingResourceError(message=f"Session environment with id '{environment_id}' does not exist.")
120-
121-
for key, value in kwargs.items():
122-
# NOTE: Only some fields can be edited
123-
if key in [
124-
"name",
125-
"description",
126-
"container_image",
127-
"default_url",
128-
"port",
129-
"working_directory",
130-
"mount_directory",
131-
"uid",
132-
"gid",
133-
"args",
134-
"command",
135-
]:
136-
setattr(environment, key, value)
137-
138-
return environment.dump()
106+
environment: schemas.EnvironmentORM,
107+
update: models.EnvironmentUpdate,
108+
) -> None:
109+
# NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks
110+
if update.name is not None:
111+
environment.name = update.name
112+
if update.description is not None:
113+
environment.description = update.description
114+
if update.container_image is not None:
115+
environment.container_image = update.container_image
116+
if update.default_url is not None:
117+
environment.default_url = update.default_url
118+
if update.port is not None:
119+
environment.port = update.port
120+
if update.working_directory is not None:
121+
environment.working_directory = update.working_directory
122+
if update.mount_directory is not None:
123+
environment.mount_directory = update.mount_directory
124+
if update.uid is not None:
125+
environment.uid = update.uid
126+
if update.gid is not None:
127+
environment.gid = update.gid
128+
if isinstance(update.args, Reset):
129+
environment.args = None
130+
elif isinstance(update.args, list):
131+
environment.args = update.args
132+
if isinstance(update.command, Reset):
133+
environment.command = None
134+
elif isinstance(update.command, list):
135+
environment.command = update.command
139136

140137
async def update_environment(
141-
self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict
138+
self, user: base_models.APIUser, environment_id: ULID, update: models.EnvironmentUpdate
142139
) -> models.Environment:
143140
"""Update a global session environment entry."""
144141
if not user.is_admin:
145142
raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")
146143

147144
async with self.session_maker() as session, session.begin():
148-
return await self.__update_environment(
149-
user, session, environment_id, models.EnvironmentKind.GLOBAL, **kwargs
145+
res = await session.scalars(
146+
select(schemas.EnvironmentORM)
147+
.where(schemas.EnvironmentORM.id == str(environment_id))
148+
.where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL)
150149
)
150+
environment = res.one_or_none()
151+
if environment is None:
152+
raise errors.MissingResourceError(
153+
message=f"Session environment with id '{environment_id}' does not exist."
154+
)
155+
self.__update_environment(environment, update)
156+
return environment.dump()
151157

152158
async def delete_environment(self, user: base_models.APIUser, environment_id: ULID) -> None:
153159
"""Delete a global session environment entry."""
@@ -297,9 +303,8 @@ async def update_launcher(
297303
self,
298304
user: base_models.APIUser,
299305
launcher_id: ULID,
300-
new_custom_environment: models.UnsavedEnvironment | None,
306+
update: models.SessionLauncherUpdate,
301307
session: AsyncSession | None = None,
302-
**kwargs: Any,
303308
) -> models.SessionLauncher:
304309
"""Update a session launcher entry."""
305310
if not user.is_authenticated or user.id is None:
@@ -333,8 +338,8 @@ async def update_launcher(
333338
if not authorized:
334339
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
335340

336-
resource_class_id = kwargs.get("resource_class_id")
337-
if resource_class_id is not None:
341+
resource_class_id = update.resource_class_id
342+
if isinstance(resource_class_id, int):
338343
res = await session.scalars(
339344
select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id)
340345
)
@@ -351,32 +356,32 @@ async def update_launcher(
351356
message=f"You do not have access to resource class with id '{resource_class_id}'."
352357
)
353358

354-
for key, value in kwargs.items():
355-
# NOTE: Only some fields can be updated.
356-
if key in [
357-
"name",
358-
"description",
359-
"resource_class_id",
360-
]:
361-
setattr(launcher, key, value)
362-
363-
env_payload = kwargs.get("environment", {})
364-
await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload)
365-
await session.flush()
366-
await session.refresh(launcher)
359+
# NOTE: Only some fields can be updated.
360+
if update.name is not None:
361+
launcher.name = update.name
362+
if update.description is not None:
363+
launcher.description = update.description
364+
if isinstance(update.resource_class_id, int):
365+
launcher.resource_class_id = update.resource_class_id
366+
elif isinstance(update.resource_class_id, Reset):
367+
launcher.resource_class_id = None
368+
369+
if update.environment is None:
370+
return launcher.dump()
371+
372+
await self.__update_launcher_environment(user, launcher, session, update.environment)
367373
return launcher.dump()
368374

369375
async def __update_launcher_environment(
370376
self,
371377
user: base_models.APIUser,
372378
launcher: schemas.SessionLauncherORM,
373379
session: AsyncSession,
374-
new_custom_environment: models.UnsavedEnvironment | None,
375-
**kwargs: Any,
380+
update: models.EnvironmentUpdate | models.UnsavedEnvironment | str,
376381
) -> None:
377382
current_env_kind = launcher.environment.environment_kind
378-
match new_custom_environment, current_env_kind, kwargs:
379-
case None, _, {"id": env_id, **nothing_else} if len(nothing_else) == 0:
383+
match update, current_env_kind:
384+
case str() as env_id, _:
380385
# The environment in the launcher is set via ID, the new ID has to refer
381386
# to an environment that is GLOBAL.
382387
old_environment = launcher.environment
@@ -403,29 +408,11 @@ async def __update_launcher_environment(
403408
# We remove the custom environment to avoid accumulating custom environments that are not associated
404409
# with any launchers.
405410
await session.delete(old_environment)
406-
case None, models.EnvironmentKind.CUSTOM, {**rest} if (
407-
rest.get("environment_kind") is None
408-
or rest.get("environment_kind") == models.EnvironmentKind.CUSTOM.value
409-
):
411+
case models.EnvironmentUpdate(), models.EnvironmentKind.CUSTOM:
410412
# Custom environment being updated
411-
for key, val in rest.items():
412-
# NOTE: Only some fields can be updated.
413-
if key in [
414-
"name",
415-
"description",
416-
"container_image",
417-
"default_url",
418-
"port",
419-
"working_directory",
420-
"mount_directory",
421-
"uid",
422-
"gid",
423-
"args",
424-
"command",
425-
]:
426-
setattr(launcher.environment, key, val)
427-
case models.UnsavedEnvironment(), models.EnvironmentKind.GLOBAL, {**nothing_else} if (
428-
len(nothing_else) == 0 and new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
413+
self.__update_environment(launcher.environment, update)
414+
case models.UnsavedEnvironment() as new_custom_environment, models.EnvironmentKind.GLOBAL if (
415+
new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
429416
):
430417
# Global environment replaced by a custom one
431418
new_env = await self.__insert_environment(user, session, new_custom_environment)

0 commit comments

Comments
 (0)