Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11515.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reject session requests whose image or caller declares a resource slot the target resource group does not provide, returning a clear 4xx instead of failing internally.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from collections.abc import AsyncIterator, Mapping, Sequence
from contextlib import asynccontextmanager as actxmgr
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -187,6 +188,21 @@ def _create_resource_slot_from_policy(
return ResourceSlot.from_policy(resource_policy_map, cast(Mapping[str, Any], known_slot_types))


@dataclass(frozen=True)
class _ScalingGroupWithSlotInventory:
"""Scaling group bundled with the slot inventory served by its agents.

``active_slot_types`` maps each slot name served by a non-terminated
agent in this scaling group to its registered :class:`SlotTypes`
unit. The validator chain consults this map both for membership
(reject requests for slots the RG does not provide) and for unit
metadata (humanize values during error formatting).
"""

sg_row: ScalingGroupRow
active_slot_types: Mapping[SlotName, SlotTypes]


class ScheduleDBSource:
"""
Database source for schedule-related operations.
Expand Down Expand Up @@ -289,6 +305,46 @@ async def get_scheduling_data(self, scaling_group: str, spec: SchedulingSpec) ->
spec=spec,
)

async def _fetch_scaling_group_with_slot_inventory(
self,
db_sess: SASession,
name: str,
) -> _ScalingGroupWithSlotInventory:
"""Load a scaling group together with its per-RG slot inventory.

Eager-loads ``agents`` -> ``agent_resource_rows`` -> ``slot_type_row``
via ``selectinload``, filters out TERMINATED agents, and projects
the remaining rows into ``{slot_name: SlotTypes}``. The ``AgentRow``
instances themselves are not exposed — callers only see the SG row
and the derived inventory.

Raises:
ScalingGroupNotFound: when the scaling group does not exist.
"""
sg_row = (
await db_sess.scalars(
sa.select(ScalingGroupRow)
.options(
selectinload(ScalingGroupRow.agents)
.selectinload(AgentRow.agent_resource_rows)
.selectinload(AgentResourceRow.slot_type_row)
)
.where(ScalingGroupRow.name == name)
)
).one_or_none()
if sg_row is None:
raise ScalingGroupNotFound(f"Resource group {name} not found")
active_slot_types: dict[SlotName, SlotTypes] = {
SlotName(ar.slot_name): SlotTypes(ar.slot_type_row.slot_type)
for agent in sg_row.agents
if agent.status != AgentStatus.TERMINATED
for ar in agent.agent_resource_rows
}
return _ScalingGroupWithSlotInventory(
sg_row=sg_row,
active_slot_types=active_slot_types,
)

async def _fetch_scaling_group(
self, db_sess: SASession, scaling_group: str
) -> ScalingGroupMeta:
Expand Down Expand Up @@ -1463,16 +1519,13 @@ async def fetch_session_spec_contexts(
network_info: ScalingGroupNetworkInfo | None = None
rg_defaults = None
resource_group_allow_fractional = False
known_slot_types: Mapping[SlotName, SlotTypes] = {}
if resource_group_name:
sg_row = (
await db_sess.scalars(
sa.select(ScalingGroupRow).where(
ScalingGroupRow.name == resource_group_name
)
)
).one_or_none()
if sg_row is None:
raise ScalingGroupNotFound(f"Resource group {resource_group_name} not found")
rg_bundle = await self._fetch_scaling_group_with_slot_inventory(
db_sess, resource_group_name
)
sg_row = rg_bundle.sg_row
known_slot_types = rg_bundle.active_slot_types
# Every production caller of ``enqueue_session_from_draft`` populates
# access_key/domain_name/project_id alongside resource_group_name; this
# branch flags the contract violation rather than letting the RG
Expand Down Expand Up @@ -1632,6 +1685,7 @@ async def fetch_session_spec_contexts(
dotfile_data=dotfile_bundle,
active_session_count=active_session_count,
keypair_resource_policy=keypair_policy,
known_slot_types=known_slot_types,
)

async def pick_default_resource_group(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Types for session creation and enqueueing."""

from collections.abc import Mapping
from dataclasses import dataclass, field
from decimal import Decimal
from typing import Any
Expand All @@ -10,6 +11,8 @@
from ai.backend.common.types import (
AccessKey,
SessionId,
SlotName,
SlotTypes,
VFolderMount,
)
from ai.backend.manager.data.dotfile.types import DotfileBundle
Expand Down Expand Up @@ -143,4 +146,5 @@ class SessionSpecContextFetch:
vfolder_mounts_by_role: dict[str, tuple[VFolderMount, ...]]
dotfile_data: DotfileBundle
keypair_resource_policy: Any | None # KeyPairResourcePolicyData
known_slot_types: Mapping[SlotName, SlotTypes] = field(default_factory=dict)
active_session_count: int = 0
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@
ConcurrentSessionLimitRule,
ContainerLimitRule,
DotfileVFolderConflictRule,
ImageSlotTypeRule,
InferenceModelFolderRule,
MountNameValidationRule,
RequestedSlotTypeRule,
ResourceLimitRule,
ServicePortRule,
SessionSpecValidationContext,
Expand Down Expand Up @@ -125,6 +127,8 @@ def __init__(self, args: SchedulingControllerArgs) -> None:
self._spec_validator = SessionSpecValidator([
ConcurrentSessionLimitRule(),
ContainerLimitRule(),
ImageSlotTypeRule(),
RequestedSlotTypeRule(),
ResourceLimitRule(),
ServicePortRule(),
MountNameValidationRule(),
Expand Down Expand Up @@ -161,9 +165,6 @@ async def enqueue_session_from_draft(
allowed_vfolder_types = list(
await self._config_provider.legacy_etcd_config_loader.get_vfolder_types()
)
known_slot_types = (
await self._config_provider.legacy_etcd_config_loader.get_resource_slots()
)

with self._metric_observer.measure_phase(
"scheduling_controller", rg_name, "spec_fetch_contexts"
Expand All @@ -186,7 +187,7 @@ async def enqueue_session_from_draft(
val_ctx = SessionSpecValidationContext(
keypair_resource_policy=fetched.keypair_resource_policy,
image_infos=fetched.image_infos,
known_slot_types=known_slot_types,
known_slot_types=fetched.known_slot_types,
dotfile_data=fetched.dotfile_data,
active_session_count=fetched.active_session_count,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from .concurrent_session_limit_rule import ConcurrentSessionLimitRule
from .container_limit_rule import ContainerLimitRule
from .dotfile_vfolder_conflict_rule import DotfileVFolderConflictRule
from .image_slot_type_rule import ImageSlotTypeRule
from .inference_model_folder_rule import InferenceModelFolderRule
from .mount_name_validation_rule import MountNameValidationRule
from .requested_slot_type_rule import RequestedSlotTypeRule
from .resource_limit_rule import ResourceLimitRule
from .service_port_rule import ServicePortRule
from .session_spec_base import (
Expand All @@ -17,8 +19,10 @@
"ConcurrentSessionLimitRule",
"ContainerLimitRule",
"DotfileVFolderConflictRule",
"ImageSlotTypeRule",
"InferenceModelFolderRule",
"MountNameValidationRule",
"RequestedSlotTypeRule",
"ResourceLimitRule",
"ServicePortRule",
"SessionSpecValidationContext",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Image-declared slot-type compatibility validator.

Every slot key declared in an image's ``resource_spec`` must be served
by some non-terminated agent in the requested resource group. The
context's ``known_slot_types`` is sourced from
``agent_resources`` joined with ``agents`` (status != TERMINATED) and
``resource_slot_types``, so it reflects the RG's hardware inventory and
the registered unit metadata in one mapping.

When the RG has no non-terminated agents the request is rejected
outright — an empty inventory cannot satisfy any image declaration and
would otherwise let the session reach the scheduler only to fail there.
"""

from __future__ import annotations

from ai.backend.manager.data.session.spec import SessionSpec
from ai.backend.manager.errors.api import InvalidAPIParameters
from ai.backend.manager.sokovan.scheduling_controller.validators.session_spec_base import (
SessionSpecValidationContext,
SessionSpecValidatorRule,
)


class ImageSlotTypeRule(SessionSpecValidatorRule):
"""Image-declared slot keys must be served by an agent in the target RG."""

def name(self) -> str:
return "image_slot_type"

def validate(
self,
spec: SessionSpec,
context: SessionSpecValidationContext,
) -> None:
rg_slot_types = context.known_slot_types
if not rg_slot_types:
raise InvalidAPIParameters(
extra_msg=(
f"resource group '{spec.scope.resource_group_name}' has no "
f"agents serving any resource slot."
),
)
for idx, kernel in enumerate(spec.kernel_specs):
image_info = context.image_infos.get(kernel.execution_spec.image_id)
if image_info is None:
continue
unknown = sorted(
slot_name
for slot_name in image_info.resource_spec
if slot_name not in rg_slot_types
)
if unknown:
raise InvalidAPIParameters(
extra_msg=(
f"kernel_specs[{idx}]: image '{image_info.canonical}' "
f"requires resource slot(s) {unknown} that resource "
f"group '{spec.scope.resource_group_name}' does not "
f"serve. Pick an image whose required slots are "
f"available here, or switch to a resource group that "
f"supports these slots."
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""User-requested slot-type compatibility validator.

Every ``resource_type`` in a kernel's requested resource list must be
served by some non-terminated agent in the requested resource group.
The context's ``known_slot_types`` is sourced from ``agent_resources``
joined with ``agents`` (status != TERMINATED) and
``resource_slot_types``, so it reflects the RG's hardware inventory and
the registered unit metadata in one mapping.

When the RG has no non-terminated agents the request is rejected
outright — an empty inventory cannot satisfy any caller-supplied
request and would otherwise let the session reach the scheduler only
to fail there.
"""

from __future__ import annotations

from ai.backend.manager.data.session.spec import SessionSpec
from ai.backend.manager.errors.api import InvalidAPIParameters
from ai.backend.manager.sokovan.scheduling_controller.validators.session_spec_base import (
SessionSpecValidationContext,
SessionSpecValidatorRule,
)


class RequestedSlotTypeRule(SessionSpecValidatorRule):
"""Requested slot keys must be served by an agent in the target RG."""

def name(self) -> str:
return "requested_slot_type"

def validate(
self,
spec: SessionSpec,
context: SessionSpecValidationContext,
) -> None:
rg_slot_types = context.known_slot_types
if not rg_slot_types:
raise InvalidAPIParameters(
extra_msg=(
f"resource group '{spec.scope.resource_group_name}' has no "
f"agents serving any resource slot."
),
)
for idx, kernel in enumerate(spec.kernel_specs):
unknown = sorted({
entry.resource_type
for entry in kernel.execution_spec.resources
if entry.resource_type not in rg_slot_types
})
if unknown:
raise InvalidAPIParameters(
extra_msg=(
f"kernel_specs[{idx}]: the request asks for resource "
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not including kernel row id instead of kernel_specs[{idx}]?

f"slot(s) {unknown} that resource group "
f"'{spec.scope.resource_group_name}' does not serve. "
f"Drop these slots from the request or switch to a "
f"resource group that supports them."
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
SessionSpecDraft,
)
from ai.backend.manager.errors.api import InvalidAPIParameters
from ai.backend.manager.models.agent import AgentRow
from ai.backend.manager.models.resource_slot import AgentResourceRow, ResourceSlotTypeRow
from ai.backend.manager.models.scaling_group import ScalingGroupOpts, ScalingGroupRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from ai.backend.manager.repositories.scheduler.db_source.db_source import ScheduleDBSource
Expand Down Expand Up @@ -110,7 +112,13 @@ async def db_with_rg(
short-circuit on ``ScalingGroupNotFound`` and never exercise the
invariant under test.
"""
async with with_tables(database_connection, [ScalingGroupRow]):
# Include the agent tables so the SG fetch's
# ``selectinload(agents).selectinload(agent_resource_rows)`` chain
# has tables to query, even though we seed no rows below.
async with with_tables(
database_connection,
[ScalingGroupRow, ResourceSlotTypeRow, AgentRow, AgentResourceRow],
):
async with database_connection.begin_session() as db_sess:
db_sess.add(
ScalingGroupRow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ def _fetch_bundle(image_id: ImageID) -> SessionSpecContextFetch:
vfolder_mounts_by_role={"main": (_vfolder_mount(),)},
dotfile_data=DotfileBundle(),
keypair_resource_policy=_keypair_policy(),
known_slot_types={
SlotName("cpu"): SlotTypes.COUNT,
SlotName("mem"): SlotTypes.BYTES,
},
)


Expand Down
Loading
Loading