Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions components/renku_data_services/k8s/watcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,29 +196,27 @@ async def collect_metrics(
await metrics.session_stopped(user=user, metadata={"session_id": new_obj.meta.name})
return
previous_state = previous_obj.manifest.get("status", {}).get("state", None) if previous_obj else None

resource_class_id = int(new_obj.obj.metadata.annotations.get("renku.io/resource_class_id"))
resource_pool = await rp_repo.get_resource_pool_from_class(user, resource_class_id)
resource_class = await rp_repo.get_resource_class(user, resource_class_id)
metadata = {
"cpu": int(resource_class.cpu * 1000),
"memory": resource_class.memory,
"gpu": resource_class.gpu,
"storage": new_obj.obj.spec.session.storage.size,
"resource_class_id": resource_class_id,
"resource_pool_id": resource_pool.id or "",
"resource_class_name": f"{resource_pool.name}.{resource_class.name}",
"session_id": new_obj.meta.name,
}
match new_obj.obj.raw.get("status", {}).get("state"):
case State.Running.value if previous_state is None or previous_state == State.NotReady.value:
# session starting
resource_class_id = int(new_obj.obj.metadata.annotations.get("renku.io/resource_class_id"))
resource_pool = await rp_repo.get_resource_pool_from_class(k8s_watcher_admin_user, resource_class_id)
resource_class = await rp_repo.get_resource_class(k8s_watcher_admin_user, resource_class_id)

await metrics.session_started(
user=user,
metadata={
"cpu": int(resource_class.cpu * 1000),
"memory": resource_class.memory,
"gpu": resource_class.gpu,
"storage": new_obj.obj.spec.session.storage.size,
"resource_class_id": resource_class_id,
"resource_pool_id": resource_pool.id or "",
"resource_class_name": f"{resource_pool.name}.{resource_class.name}",
"session_id": new_obj.meta.name,
},
)
await metrics.session_started(user=user, metadata=metadata)
case State.Running.value | State.NotReady.value if previous_state == State.Hibernated.value:
# session resumed
await metrics.session_resumed(user, metadata={"session_id": new_obj.meta.name})
await metrics.session_resumed(user=user, metadata=metadata)
case State.Hibernated.value if previous_state != State.Hibernated.value:
# session hibernated
await metrics.session_hibernated(user=user, metadata={"session_id": new_obj.meta.name})
Expand Down
28 changes: 28 additions & 0 deletions components/renku_data_services/metrics/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Repository for the metrics staging table."""

import asyncio
from collections.abc import AsyncGenerator, Callable
from typing import Any

Expand Down Expand Up @@ -31,9 +32,36 @@ async def get_unprocessed_metrics(self) -> AsyncGenerator[MetricsORM, None]:
async for metrics in result:
yield metrics

async def delete_all_metrics(self) -> None:
"""Delete all metrics from the staging table."""
async with self.session_maker() as session, session.begin():
await session.execute(delete(MetricsORM))

async def delete_processed_metrics(self, metrics_ids: list[ULID]) -> None:
"""Delete metrics events from the staging table."""
if not metrics_ids:
return
async with self.session_maker() as session, session.begin():
await session.execute(delete(MetricsORM).where(MetricsORM.id.in_(metrics_ids)))

async def wait_for_metrics(self, timeout: float = 5.0, poll_interval: float = 0.1) -> bool:
"""Wait for metrics to be processed.

Polls for metrics events and returns when at least one event is found or timeout is reached.

Args:
timeout: Maximum time to wait in seconds
poll_interval: Time between polls in seconds

Returns:
True if metrics were found, False if timeout reached
"""
import time

start_time = time.monotonic()
while time.monotonic() - start_time < timeout:
metrics = [m async for m in self.get_unprocessed_metrics()]
if metrics:
return True
await asyncio.sleep(poll_interval)
return False
43 changes: 32 additions & 11 deletions components/renku_data_services/notebooks/core_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,10 +1138,9 @@ async def patch_session(
):
# Session is being resumed
patch.spec.hibernated = False
await metrics.user_requested_session_resume(user, metadata={"session_id": session_id})

rp: ResourcePool | None = None
# Resource class
# Resource class is being changed
if body.resource_class_id is not None:
new_cluster = await nb_config.k8s_v2_client.cluster_by_class_id(body.resource_class_id, user)
if new_cluster.id != cluster.id:
Expand All @@ -1153,10 +1152,25 @@ async def patch_session(
)
rp = await rp_repo.get_resource_pool_from_class(user, body.resource_class_id)
rc = rp.get_resource_class(body.resource_class_id)
if not rc:
raise errors.MissingResourceError(
message=f"The resource class you requested with ID {body.resource_class_id} does not exist"
)
# Resource class is not being changed but we still need to get the resource pool and class for patching
# in case they changed since the session was created
else:
rp = await rp_repo.get_resource_pool_from_class(user, session.resource_class_id())
rc = rp.get_resource_class(session.resource_class_id())

if not rc:
raise errors.MissingResourceError(
message=f"The resource class you requested with ID {body.resource_class_id} does not exist"
)
# If the session is being hibernated we do not need to patch anything else that is
# not specifically called for in the request body, we can refresh things when the user resumes.
if is_getting_hibernated:
return await nb_config.k8s_v2_client.patch_session(session_id, user.id, patch.to_rfc7386())

# If the session is being resumed, we need to patch the resource requests/limits to match the current
# values of the resource class since they might have changed since the session was created.
# We also patch the annotations for the resource pool and class to make sure they are up to date.
else:
if not patch.metadata:
patch.metadata = AmaltheaSessionV1Alpha1MetadataPatch()
# Patch the resource pool and class ID in the annotations
Expand Down Expand Up @@ -1186,14 +1200,21 @@ async def patch_session(
patch.spec.service_account_name = (
rp.cluster.service_account_name if rp.cluster.service_account_name is not None else RESET
)
await metrics.user_requested_session_resume(
user,
metadata={
"cpu": int(rc.cpu * 1000),
"memory": rc.memory,
"gpu": rc.gpu,
"resource_class_id": str(rc.id),
"resource_pool_id": str(rp.id) or "",
"resource_class_name": f"{rp.name}.{rc.name}",
"session_id": session_id,
},
)

patch.spec.culling = get_culling_patch(user, rp, nb_config, body.lastInteraction)

# If the session is being hibernated we do not need to patch anything else that is
# not specifically called for in the request body, we can refresh things when the user resumes.
if is_getting_hibernated:
return await nb_config.k8s_v2_client.patch_session(session_id, user.id, patch.to_rfc7386())

server_name = session.metadata.name
launcher = await session_repo.get_launcher(user, session.launcher_id)
project = await project_repo.get_project(user=user, project_id=session.project_id)
Expand Down
Loading
Loading