|
6 | 6 | from collections import defaultdict |
7 | 7 | from collections.abc import AsyncIterator, Mapping, Sequence |
8 | 8 | from contextlib import asynccontextmanager as actxmgr |
| 9 | +from dataclasses import dataclass |
9 | 10 | from datetime import datetime |
10 | 11 | from decimal import Decimal |
11 | 12 | from typing import TYPE_CHECKING, Any, cast |
@@ -187,6 +188,21 @@ def _create_resource_slot_from_policy( |
187 | 188 | return ResourceSlot.from_policy(resource_policy_map, cast(Mapping[str, Any], known_slot_types)) |
188 | 189 |
|
189 | 190 |
|
| 191 | +@dataclass(frozen=True) |
| 192 | +class _ScalingGroupWithSlotInventory: |
| 193 | + """Scaling group bundled with the slot inventory served by its agents. |
| 194 | +
|
| 195 | + ``active_slot_types`` maps each slot name served by a non-terminated |
| 196 | + agent in this scaling group to its registered :class:`SlotTypes` |
| 197 | + unit. The validator chain consults this map both for membership |
| 198 | + (reject requests for slots the RG does not provide) and for unit |
| 199 | + metadata (humanize values during error formatting). |
| 200 | + """ |
| 201 | + |
| 202 | + sg_row: ScalingGroupRow |
| 203 | + active_slot_types: Mapping[SlotName, SlotTypes] |
| 204 | + |
| 205 | + |
190 | 206 | class ScheduleDBSource: |
191 | 207 | """ |
192 | 208 | Database source for schedule-related operations. |
@@ -289,6 +305,46 @@ async def get_scheduling_data(self, scaling_group: str, spec: SchedulingSpec) -> |
289 | 305 | spec=spec, |
290 | 306 | ) |
291 | 307 |
|
| 308 | + async def _fetch_scaling_group_with_slot_inventory( |
| 309 | + self, |
| 310 | + db_sess: SASession, |
| 311 | + name: str, |
| 312 | + ) -> _ScalingGroupWithSlotInventory: |
| 313 | + """Load a scaling group together with its per-RG slot inventory. |
| 314 | +
|
| 315 | + Eager-loads ``agents`` -> ``agent_resource_rows`` -> ``slot_type_row`` |
| 316 | + via ``selectinload``, filters out TERMINATED agents, and projects |
| 317 | + the remaining rows into ``{slot_name: SlotTypes}``. The ``AgentRow`` |
| 318 | + instances themselves are not exposed — callers only see the SG row |
| 319 | + and the derived inventory. |
| 320 | +
|
| 321 | + Raises: |
| 322 | + ScalingGroupNotFound: when the scaling group does not exist. |
| 323 | + """ |
| 324 | + sg_row = ( |
| 325 | + await db_sess.scalars( |
| 326 | + sa.select(ScalingGroupRow) |
| 327 | + .options( |
| 328 | + selectinload(ScalingGroupRow.agents) |
| 329 | + .selectinload(AgentRow.agent_resource_rows) |
| 330 | + .selectinload(AgentResourceRow.slot_type_row) |
| 331 | + ) |
| 332 | + .where(ScalingGroupRow.name == name) |
| 333 | + ) |
| 334 | + ).one_or_none() |
| 335 | + if sg_row is None: |
| 336 | + raise ScalingGroupNotFound(f"Resource group {name} not found") |
| 337 | + active_slot_types: dict[SlotName, SlotTypes] = { |
| 338 | + SlotName(ar.slot_name): SlotTypes(ar.slot_type_row.slot_type) |
| 339 | + for agent in sg_row.agents |
| 340 | + if agent.status != AgentStatus.TERMINATED |
| 341 | + for ar in agent.agent_resource_rows |
| 342 | + } |
| 343 | + return _ScalingGroupWithSlotInventory( |
| 344 | + sg_row=sg_row, |
| 345 | + active_slot_types=active_slot_types, |
| 346 | + ) |
| 347 | + |
292 | 348 | async def _fetch_scaling_group( |
293 | 349 | self, db_sess: SASession, scaling_group: str |
294 | 350 | ) -> ScalingGroupMeta: |
@@ -1463,16 +1519,13 @@ async def fetch_session_spec_contexts( |
1463 | 1519 | network_info: ScalingGroupNetworkInfo | None = None |
1464 | 1520 | rg_defaults = None |
1465 | 1521 | resource_group_allow_fractional = False |
| 1522 | + known_slot_types: Mapping[SlotName, SlotTypes] = {} |
1466 | 1523 | if resource_group_name: |
1467 | | - sg_row = ( |
1468 | | - await db_sess.scalars( |
1469 | | - sa.select(ScalingGroupRow).where( |
1470 | | - ScalingGroupRow.name == resource_group_name |
1471 | | - ) |
1472 | | - ) |
1473 | | - ).one_or_none() |
1474 | | - if sg_row is None: |
1475 | | - raise ScalingGroupNotFound(f"Resource group {resource_group_name} not found") |
| 1524 | + rg_bundle = await self._fetch_scaling_group_with_slot_inventory( |
| 1525 | + db_sess, resource_group_name |
| 1526 | + ) |
| 1527 | + sg_row = rg_bundle.sg_row |
| 1528 | + known_slot_types = rg_bundle.active_slot_types |
1476 | 1529 | # Every production caller of ``enqueue_session_from_draft`` populates |
1477 | 1530 | # access_key/domain_name/project_id alongside resource_group_name; this |
1478 | 1531 | # branch flags the contract violation rather than letting the RG |
@@ -1632,6 +1685,7 @@ async def fetch_session_spec_contexts( |
1632 | 1685 | dotfile_data=dotfile_bundle, |
1633 | 1686 | active_session_count=active_session_count, |
1634 | 1687 | keypair_resource_policy=keypair_policy, |
| 1688 | + known_slot_types=known_slot_types, |
1635 | 1689 | ) |
1636 | 1690 |
|
1637 | 1691 | async def pick_default_resource_group( |
|
0 commit comments