Skip to content

Commit 9f12418

Browse files
authored
fix: use protocol for SessionTasks (#1165)
Closes #1163. This should improve type safety with the SessionTasks Python class.
1 parent 4c9e7a1 commit 9f12418

3 files changed

Lines changed: 41 additions & 11 deletions

File tree

bases/renku_data_services/data_tasks/dependencies.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,9 @@ def from_env(cls, cfg: Config | None = None) -> "DependencyManager":
5959
metrics=metrics,
6060
authz=authz,
6161
)
62-
session_repo = SessionRepository(
62+
session_environment_repo = SessionRepository.make_session_environment_repo(
6363
session_maker=cfg.db.async_session_maker,
6464
project_authz=authz,
65-
resource_pools=None, # type: ignore
66-
shipwright_client=None,
67-
builds_config=None, # type: ignore
6865
)
6966
syncer = UsersSync(
7067
cfg.db.async_session_maker,
@@ -73,7 +70,7 @@ def from_env(cls, cfg: Config | None = None) -> "DependencyManager":
7370
metrics=metrics,
7471
authz=authz,
7572
)
76-
session_tasks = SessionTasks(session_repo=session_repo)
73+
session_tasks = SessionTasks(session_environment_repo=session_environment_repo)
7774
kc_api: IKeycloakAPI
7875
if cfg.dummy_stores:
7976
dummy_users = [

components/renku_data_services/session/db.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Callable
66
from contextlib import AbstractAsyncContextManager, nullcontext
77
from datetime import UTC, datetime
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Protocol
99

1010
from sqlalchemy import select
1111
from sqlalchemy.ext.asyncio import AsyncSession
@@ -28,7 +28,21 @@
2828
from renku_data_services.session.config import BuildsConfig
2929

3030

31-
class SessionRepository:
31+
class SessionEnvironmentRepositoryProtocol(Protocol):
32+
"""Protocol for operations on session environments."""
33+
34+
async def get_environments(self, include_archived: bool = False) -> list[models.Environment]:
35+
"""Get all global session environments from the database."""
36+
...
37+
38+
async def insert_environment(
39+
self, user: base_models.APIUser, environment: models.UnsavedEnvironment
40+
) -> models.Environment:
41+
"""Insert a new session environment."""
42+
...
43+
44+
45+
class SessionRepository(SessionEnvironmentRepositoryProtocol):
3246
"""Repository for sessions."""
3347

3448
def __init__(
@@ -1160,3 +1174,22 @@ async def _get_environment_authorization(
11601174
if launcher:
11611175
authorized = await self.project_authz.has_permission(user, ResourceType.project, launcher.project_id, scope)
11621176
return authorized
1177+
1178+
@classmethod
1179+
def make_session_environment_repo(
1180+
cls,
1181+
session_maker: Callable[..., AsyncSession],
1182+
project_authz: Authz,
1183+
) -> SessionEnvironmentRepositoryProtocol:
1184+
"""Create an instance of SessionEnvironmentRepositoryProtocol."""
1185+
# NOTE: resource_pools, shipwright_client and builds_config are set to None
1186+
# because the SessionEnvironmentRepositoryProtocol only exposes database
1187+
# operations for session environments.
1188+
instance = cls(
1189+
session_maker,
1190+
project_authz=project_authz,
1191+
resource_pools=None, # type: ignore
1192+
shipwright_client=None,
1193+
builds_config=None, # type: ignore
1194+
)
1195+
return instance

components/renku_data_services/session/tasks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,29 @@
66

77
from renku_data_services.app_config import logging
88
from renku_data_services.base_models.core import InternalServiceAdmin
9-
from renku_data_services.session.db import SessionRepository
9+
from renku_data_services.session.db import SessionEnvironmentRepositoryProtocol
1010
from renku_data_services.session.models import EnvironmentImageSource, EnvironmentKind, UnsavedEnvironment
1111

1212

1313
@dataclass(kw_only=True)
1414
class SessionTasks:
1515
"""Task definitions for sessions."""
1616

17-
session_repo: SessionRepository
17+
session_environment_repo: SessionEnvironmentRepositoryProtocol
1818

1919
async def initialize_session_environments_task(self, requested_by: InternalServiceAdmin) -> None:
2020
"""Initialize session environments."""
2121
logger = logging.getLogger(self.__class__.__name__)
2222

2323
# Skip this task if global session environments already exist
24-
existing_envs = await self.session_repo.get_environments()
24+
existing_envs = await self.session_environment_repo.get_environments()
2525
if existing_envs:
2626
logger.debug("Global session environments are already initialized.")
2727
return None
2828

2929
for env in self._get_default_session_environments():
3030
try:
31-
await self.session_repo.insert_environment(user=requested_by, environment=env)
31+
await self.session_environment_repo.insert_environment(user=requested_by, environment=env)
3232
logger.info(f"Added global environment with image {env.container_image}")
3333
except Exception as err:
3434
logger.error(f"Failed to create global environment with image {env.container_image} because {err}")

0 commit comments

Comments
 (0)