55from collections .abc import Callable
66from contextlib import AbstractAsyncContextManager , nullcontext
77from datetime import UTC , datetime
8- from typing import Any
98
109from sqlalchemy import select
1110from sqlalchemy .ext .asyncio import AsyncSession
1514from renku_data_services import errors
1615from renku_data_services .authz .authz import Authz , ResourceType
1716from renku_data_services .authz .models import Scope
17+ from renku_data_services .base_models .core import Reset
1818from renku_data_services .crc .db import ResourcePoolRepository
1919from renku_data_services .session import models
2020from renku_data_services .session import orm as schemas
@@ -101,53 +101,59 @@ async def insert_environment(
101101 env = await self .__insert_environment (user , session , new_environment )
102102 return env .dump ()
103103
104- async def __update_environment (
104+ def __update_environment (
105105 self ,
106- user : base_models .APIUser ,
107- session : AsyncSession ,
108- environment_id : ULID ,
109- kind : models .EnvironmentKind ,
110- ** kwargs : dict ,
111- ) -> models .Environment :
112- res = await session .scalars (
113- select (schemas .EnvironmentORM )
114- .where (schemas .EnvironmentORM .id == str (environment_id ))
115- .where (schemas .EnvironmentORM .environment_kind == kind .value )
116- )
117- environment = res .one_or_none ()
118- if environment is None :
119- raise errors .MissingResourceError (message = f"Session environment with id '{ environment_id } ' does not exist." )
120-
121- for key , value in kwargs .items ():
122- # NOTE: Only some fields can be edited
123- if key in [
124- "name" ,
125- "description" ,
126- "container_image" ,
127- "default_url" ,
128- "port" ,
129- "working_directory" ,
130- "mount_directory" ,
131- "uid" ,
132- "gid" ,
133- "args" ,
134- "command" ,
135- ]:
136- setattr (environment , key , value )
137-
138- return environment .dump ()
106+ environment : schemas .EnvironmentORM ,
107+ update : models .EnvironmentUpdate ,
108+ ) -> None :
109+ # NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks
110+ if update .name is not None :
111+ environment .name = update .name
112+ if update .description is not None :
113+ environment .description = update .description
114+ if update .container_image is not None :
115+ environment .container_image = update .container_image
116+ if update .default_url is not None :
117+ environment .default_url = update .default_url
118+ if update .port is not None :
119+ environment .port = update .port
120+ if update .working_directory is not None :
121+ environment .working_directory = update .working_directory
122+ if update .mount_directory is not None :
123+ environment .mount_directory = update .mount_directory
124+ if update .uid is not None :
125+ environment .uid = update .uid
126+ if update .gid is not None :
127+ environment .gid = update .gid
128+ if isinstance (update .args , Reset ):
129+ environment .args = None
130+ elif isinstance (update .args , list ):
131+ environment .args = update .args
132+ if isinstance (update .command , Reset ):
133+ environment .command = None
134+ elif isinstance (update .command , list ):
135+ environment .command = update .command
139136
140137 async def update_environment (
141- self , user : base_models .APIUser , environment_id : ULID , ** kwargs : dict
138+ self , user : base_models .APIUser , environment_id : ULID , update : models . EnvironmentUpdate
142139 ) -> models .Environment :
143140 """Update a global session environment entry."""
144141 if not user .is_admin :
145142 raise errors .UnauthorizedError (message = "You do not have the required permissions for this operation." )
146143
147144 async with self .session_maker () as session , session .begin ():
148- return await self .__update_environment (
149- user , session , environment_id , models .EnvironmentKind .GLOBAL , ** kwargs
145+ res = await session .scalars (
146+ select (schemas .EnvironmentORM )
147+ .where (schemas .EnvironmentORM .id == str (environment_id ))
148+ .where (schemas .EnvironmentORM .environment_kind == models .EnvironmentKind .GLOBAL )
150149 )
150+ environment = res .one_or_none ()
151+ if environment is None :
152+ raise errors .MissingResourceError (
153+ message = f"Session environment with id '{ environment_id } ' does not exist."
154+ )
155+ self .__update_environment (environment , update )
156+ return environment .dump ()
151157
152158 async def delete_environment (self , user : base_models .APIUser , environment_id : ULID ) -> None :
153159 """Delete a global session environment entry."""
@@ -300,9 +306,8 @@ async def update_launcher(
300306 self ,
301307 user : base_models .APIUser ,
302308 launcher_id : ULID ,
303- new_custom_environment : models .UnsavedEnvironment | None ,
309+ update : models .SessionLauncherUpdate ,
304310 session : AsyncSession | None = None ,
305- ** kwargs : Any ,
306311 ) -> models .SessionLauncher :
307312 """Update a session launcher entry."""
308313 if not user .is_authenticated or user .id is None :
@@ -336,8 +341,8 @@ async def update_launcher(
336341 if not authorized :
337342 raise errors .ForbiddenError (message = "You do not have the required permissions for this operation." )
338343
339- resource_class_id = kwargs . get ( " resource_class_id" )
340- if resource_class_id is not None :
344+ resource_class_id = update . resource_class_id
345+ if isinstance ( resource_class_id , int ) :
341346 res = await session .scalars (
342347 select (schemas .ResourceClassORM ).where (schemas .ResourceClassORM .id == resource_class_id )
343348 )
@@ -354,30 +359,32 @@ async def update_launcher(
354359 message = f"You do not have access to resource class with id '{ resource_class_id } '."
355360 )
356361
357- for key , value in kwargs .items ():
358- # NOTE: Only some fields can be updated.
359- if key in [
360- "name" ,
361- "description" ,
362- "resource_class_id" ,
363- ]:
364- setattr (launcher , key , value )
365-
366- env_payload = kwargs .get ("environment" , {})
367- await self .__update_launcher_environment (user , launcher , session , new_custom_environment , ** env_payload )
362+ # NOTE: Only some fields can be updated.
363+ if update .name is not None :
364+ launcher .name = update .name
365+ if update .description is not None :
366+ launcher .description = update .description
367+ if isinstance (update .resource_class_id , int ):
368+ launcher .resource_class_id = update .resource_class_id
369+ elif isinstance (update .resource_class_id , Reset ):
370+ launcher .resource_class_id = None
371+
372+ if update .environment is None :
373+ return launcher .dump ()
374+
375+ await self .__update_launcher_environment (user , launcher , session , update .environment )
368376 return launcher .dump ()
369377
370378 async def __update_launcher_environment (
371379 self ,
372380 user : base_models .APIUser ,
373381 launcher : schemas .SessionLauncherORM ,
374382 session : AsyncSession ,
375- new_custom_environment : models .UnsavedEnvironment | None ,
376- ** kwargs : Any ,
383+ update : models .EnvironmentUpdate | models .UnsavedEnvironment | str ,
377384 ) -> None :
378385 current_env_kind = launcher .environment .environment_kind
379- match new_custom_environment , current_env_kind , kwargs :
380- case None , _, { "id" : env_id , ** nothing_else } if len ( nothing_else ) == 0 :
386+ match update , current_env_kind :
387+ case str () as env_id , _ :
381388 # The environment in the launcher is set via ID, the new ID has to refer
382389 # to an environment that is GLOBAL.
383390 old_environment = launcher .environment
@@ -404,29 +411,11 @@ async def __update_launcher_environment(
404411 # We remove the custom environment to avoid accumulating custom environments that are not associated
405412 # with any launchers.
406413 await session .delete (old_environment )
407- case None , models .EnvironmentKind .CUSTOM , {** rest } if (
408- rest .get ("environment_kind" ) is None
409- or rest .get ("environment_kind" ) == models .EnvironmentKind .CUSTOM .value
410- ):
414+ case models .EnvironmentUpdate (), models .EnvironmentKind .CUSTOM :
411415 # Custom environment being updated
412- for key , val in rest .items ():
413- # NOTE: Only some fields can be updated.
414- if key in [
415- "name" ,
416- "description" ,
417- "container_image" ,
418- "default_url" ,
419- "port" ,
420- "working_directory" ,
421- "mount_directory" ,
422- "uid" ,
423- "gid" ,
424- "args" ,
425- "command" ,
426- ]:
427- setattr (launcher .environment , key , val )
428- case models .UnsavedEnvironment (), models .EnvironmentKind .GLOBAL , {** nothing_else } if (
429- len (nothing_else ) == 0 and new_custom_environment .environment_kind == models .EnvironmentKind .CUSTOM
416+ self .__update_environment (launcher .environment , update )
417+ case models .UnsavedEnvironment () as new_custom_environment , models .EnvironmentKind .GLOBAL if (
418+ new_custom_environment .environment_kind == models .EnvironmentKind .CUSTOM
430419 ):
431420 # Global environment replaced by a custom one
432421 new_env = await self .__insert_environment (user , session , new_custom_environment )
0 commit comments