Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion components/renku_data_services/base_models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum, StrEnum
from typing import ClassVar, Optional, Protocol, Self, TypeVar
from typing import ClassVar, NewType, Optional, Protocol, Self, TypeVar

from sanic import Request

Expand Down Expand Up @@ -212,3 +212,12 @@ class Authenticator(Protocol[AnyAPIUser]):
async def authenticate(self, access_token: str, request: Request) -> AnyAPIUser:
"""Validates the user credentials (i.e. we can say that the user is a valid Renku user)."""
...


ResetType = NewType("ResetType", object)
"""This type represents that a value that may be None should be reset back to None or null.
This type should have only one instance, defined in the same file as this type.
"""

RESET: ResetType = ResetType(object())
"""The single instance of the ResetType, can be compared to similar to None, i.e. `if value is RESET`"""
38 changes: 10 additions & 28 deletions components/renku_data_services/session/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from renku_data_services.base_api.auth import authenticate, only_authenticated
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_models.validation import validated_json
from renku_data_services.session import apispec, models
from renku_data_services.session import apispec, converters, models
from renku_data_services.session.db import SessionRepository


Expand Down Expand Up @@ -76,9 +76,11 @@ def patch(self) -> BlueprintFactoryResponse:
async def _patch(
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True)
update = converters.environment_update_from_patch(body)
environment = await self.session_repo.update_environment(
user=user, environment_id=environment_id, **body_dict
user=user,
environment_id=environment_id,
update=update,
)
return validated_json(apispec.Environment, environment)

Expand Down Expand Up @@ -169,34 +171,14 @@ def patch(self) -> BlueprintFactoryResponse:
async def _patch(
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True, mode="json")
async with self.session_repo.session_maker() as session, session.begin():
current_launcher = await self.session_repo.get_launcher(user, launcher_id)
new_env: models.UnsavedEnvironment | None = None
if (
isinstance(body.environment, apispec.EnvironmentPatchInLauncher)
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
and body.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
):
# This means that the global environment is being swapped for a custom one,
# so we have to create a brand new environment, but we have to validate here.
validated_env = apispec.EnvironmentPostInLauncher.model_validate(body_dict.pop("environment"))
new_env = models.UnsavedEnvironment(
name=validated_env.name,
description=validated_env.description,
container_image=validated_env.container_image,
default_url=validated_env.default_url,
port=validated_env.port,
working_directory=PurePosixPath(validated_env.working_directory),
mount_directory=PurePosixPath(validated_env.mount_directory),
uid=validated_env.uid,
gid=validated_env.gid,
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
args=validated_env.args,
command=validated_env.command,
)
update = converters.launcher_update_from_patch(body, current_launcher)
launcher = await self.session_repo.update_launcher(
user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict
user=user,
launcher_id=launcher_id,
session=session,
update=update,
)
return validated_json(apispec.SessionLauncher, launcher)

Expand Down
83 changes: 83 additions & 0 deletions components/renku_data_services/session/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Code used to convert from/to apispec and models."""

from pathlib import PurePosixPath

from renku_data_services.base_models.core import RESET, ResetType
from renku_data_services.session import apispec, models


def environment_update_from_patch(data: apispec.EnvironmentPatch) -> models.EnvironmentUpdate:
"""Create an update object from an apispec or any other pydantic model."""
data_dict = data.model_dump(exclude_unset=True, mode="json")
working_directory: PurePosixPath | None = None
if data.working_directory is not None:
working_directory = PurePosixPath(data.working_directory)
mount_directory: PurePosixPath | None = None
if data.mount_directory is not None:
mount_directory = PurePosixPath(data.mount_directory)
# NOTE: If the args or command are present in the data_dict and they are None they were passed in by the user.
# The None specifically passed by the user indicates that the value should be removed from the DB.
args = RESET if "args" in data_dict and data_dict["args"] is None else data.args
command = RESET if "command" in data_dict and data_dict["command"] is None else data.command
return models.EnvironmentUpdate(
name=data.name,
description=data.description,
container_image=data.container_image,
default_url=data.default_url,
port=data.port,
working_directory=working_directory,
mount_directory=mount_directory,
uid=data.uid,
gid=data.gid,
args=args,
command=command,
)


def launcher_update_from_patch(
data: apispec.SessionLauncherPatch,
current_launcher: models.SessionLauncher | None = None,
) -> models.SessionLauncherUpdate:
"""Create an update object from an apispec or any other pydantic model."""
data_dict = data.model_dump(exclude_unset=True, mode="json")
environment: str | models.EnvironmentUpdate | models.UnsavedEnvironment | None = None
if (
isinstance(data.environment, apispec.EnvironmentPatchInLauncher)
and current_launcher is not None
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
and data.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
):
# This means that the global environment is being swapped for a custom one,
# so we have to create a brand new environment, but we have to validate here.
validated_env = apispec.EnvironmentPostInLauncher.model_validate(data_dict["environment"])
environment = models.UnsavedEnvironment(
name=validated_env.name,
description=validated_env.description,
container_image=validated_env.container_image,
default_url=validated_env.default_url,
port=validated_env.port,
working_directory=PurePosixPath(validated_env.working_directory),
mount_directory=PurePosixPath(validated_env.mount_directory),
uid=validated_env.uid,
gid=validated_env.gid,
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
args=validated_env.args,
command=validated_env.command,
)
elif isinstance(data.environment, apispec.EnvironmentPatchInLauncher):
environment = environment_update_from_patch(data.environment)
elif isinstance(data.environment, apispec.EnvironmentIdOnlyPatch):
environment = data.environment.id
resource_class_id: int | None | ResetType = None
if "resource_class_id" in data_dict and data_dict["resource_class_id"] is None:
# NOTE: This means that the resource class set in the DB should be removed so that the
# default resource class currently set in the CRC will be used.
resource_class_id = RESET
else:
resource_class_id = data_dict.get("resource_class_id")
return models.SessionLauncherUpdate(
name=data_dict.get("name"),
description=data_dict.get("description"),
environment=environment,
resource_class_id=resource_class_id,
)
150 changes: 69 additions & 81 deletions components/renku_data_services/session/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from collections.abc import Callable
from contextlib import AbstractAsyncContextManager, nullcontext
from typing import Any

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -14,6 +13,7 @@
from renku_data_services import errors
from renku_data_services.authz.authz import Authz, ResourceType
from renku_data_services.authz.models import Scope
from renku_data_services.base_models.core import RESET
from renku_data_services.crc.db import ResourcePoolRepository
from renku_data_services.session import models
from renku_data_services.session import orm as schemas
Expand Down Expand Up @@ -101,53 +101,59 @@ async def insert_environment(
await session.refresh(env)
return env.dump()

async def __update_environment(
def __update_environment(
self,
user: base_models.APIUser,
session: AsyncSession,
environment_id: ULID,
kind: models.EnvironmentKind,
**kwargs: dict,
) -> models.Environment:
res = await session.scalars(
select(schemas.EnvironmentORM)
.where(schemas.EnvironmentORM.id == str(environment_id))
.where(schemas.EnvironmentORM.environment_kind == kind.value)
)
environment = res.one_or_none()
if environment is None:
raise errors.MissingResourceError(message=f"Session environment with id '{environment_id}' does not exist.")

for key, value in kwargs.items():
# NOTE: Only some fields can be edited
if key in [
"name",
"description",
"container_image",
"default_url",
"port",
"working_directory",
"mount_directory",
"uid",
"gid",
"args",
"command",
]:
setattr(environment, key, value)

return environment.dump()
environment: schemas.EnvironmentORM,
update: models.EnvironmentUpdate,
) -> None:
# NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks
if update.name is not None:
environment.name = update.name
if update.description is not None:
environment.description = update.description
if update.container_image is not None:
environment.container_image = update.container_image
if update.default_url is not None:
environment.default_url = update.default_url
if update.port is not None:
environment.port = update.port
if update.working_directory is not None:
environment.working_directory = update.working_directory
if update.mount_directory is not None:
environment.mount_directory = update.mount_directory
if update.uid is not None:
environment.uid = update.uid
if update.gid is not None:
environment.gid = update.gid
if update.args is RESET:
environment.args = None
elif isinstance(update.args, list):
environment.args = update.args
if update.command is RESET:
environment.command = None
elif isinstance(update.command, list):
environment.command = update.command

async def update_environment(
self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict
self, user: base_models.APIUser, environment_id: ULID, update: models.EnvironmentUpdate
) -> models.Environment:
"""Update a global session environment entry."""
if not user.is_admin:
raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")

async with self.session_maker() as session, session.begin():
return await self.__update_environment(
user, session, environment_id, models.EnvironmentKind.GLOBAL, **kwargs
res = await session.scalars(
select(schemas.EnvironmentORM)
.where(schemas.EnvironmentORM.id == str(environment_id))
.where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL)
)
environment = res.one_or_none()
if environment is None:
raise errors.MissingResourceError(
message=f"Session environment with id '{environment_id}' does not exist."
)
self.__update_environment(environment, update)
return environment.dump()

async def delete_environment(self, user: base_models.APIUser, environment_id: ULID) -> None:
"""Delete a global session environment entry."""
Expand Down Expand Up @@ -297,9 +303,8 @@ async def update_launcher(
self,
user: base_models.APIUser,
launcher_id: ULID,
new_custom_environment: models.UnsavedEnvironment | None,
update: models.SessionLauncherUpdate,
session: AsyncSession | None = None,
**kwargs: Any,
) -> models.SessionLauncher:
"""Update a session launcher entry."""
if not user.is_authenticated or user.id is None:
Expand Down Expand Up @@ -333,8 +338,8 @@ async def update_launcher(
if not authorized:
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")

resource_class_id = kwargs.get("resource_class_id")
if resource_class_id is not None:
resource_class_id = update.resource_class_id
if isinstance(resource_class_id, int):
res = await session.scalars(
select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id)
)
Expand All @@ -351,32 +356,32 @@ async def update_launcher(
message=f"You do not have access to resource class with id '{resource_class_id}'."
)

for key, value in kwargs.items():
# NOTE: Only some fields can be updated.
if key in [
"name",
"description",
"resource_class_id",
]:
setattr(launcher, key, value)

env_payload = kwargs.get("environment", {})
await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload)
await session.flush()
await session.refresh(launcher)
# NOTE: Only some fields can be updated.
if update.name is not None:
launcher.name = update.name
if update.description is not None:
launcher.description = update.description
if isinstance(update.resource_class_id, int):
launcher.resource_class_id = update.resource_class_id
elif update.resource_class_id is RESET:
launcher.resource_class_id = None

if update.environment is None:
return launcher.dump()

await self.__update_launcher_environment(user, launcher, session, update.environment)
return launcher.dump()

async def __update_launcher_environment(
self,
user: base_models.APIUser,
launcher: schemas.SessionLauncherORM,
session: AsyncSession,
new_custom_environment: models.UnsavedEnvironment | None,
**kwargs: Any,
update: models.EnvironmentUpdate | models.UnsavedEnvironment | str,
) -> None:
current_env_kind = launcher.environment.environment_kind
match new_custom_environment, current_env_kind, kwargs:
case None, _, {"id": env_id, **nothing_else} if len(nothing_else) == 0:
match update, current_env_kind:
case str() as env_id, _:
# The environment in the launcher is set via ID, the new ID has to refer
# to an environment that is GLOBAL.
old_environment = launcher.environment
Expand All @@ -403,33 +408,16 @@ async def __update_launcher_environment(
# We remove the custom environment to avoid accumulating custom environments that are not associated
# with any launchers.
await session.delete(old_environment)
case None, models.EnvironmentKind.CUSTOM, {**rest} if (
rest.get("environment_kind") is None
or rest.get("environment_kind") == models.EnvironmentKind.CUSTOM.value
):
case models.EnvironmentUpdate(), models.EnvironmentKind.CUSTOM:
# Custom environment being updated
for key, val in rest.items():
# NOTE: Only some fields can be updated.
if key in [
"name",
"description",
"container_image",
"default_url",
"port",
"working_directory",
"mount_directory",
"uid",
"gid",
"args",
"command",
]:
setattr(launcher.environment, key, val)
case models.UnsavedEnvironment(), models.EnvironmentKind.GLOBAL, {**nothing_else} if (
len(nothing_else) == 0 and new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
self.__update_environment(launcher.environment, update)
case models.UnsavedEnvironment() as new_custom_environment, models.EnvironmentKind.GLOBAL if (
new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
):
# Global environment replaced by a custom one
new_env = await self.__insert_environment(user, session, new_custom_environment)
launcher.environment = new_env
await session.flush()
case _:
raise errors.ValidationError(
message="Encountered an invalid payload for updating a launcher environment", quiet=True
Expand Down
Loading