Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11556.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add RequiredResourceSlotRule to the SessionSpec validator chain so session creation fails with InvalidAPIParameters when a kernel omits a globally required resource slot
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ async def _fetch_scaling_group_with_slot_inventory(
active_slot_types=active_slot_types,
)

async def _fetch_required_slot_names(self, db_sess: SASession) -> frozenset[SlotName]:
stmt = sa.select(ResourceSlotTypeRow.slot_name).where(
ResourceSlotTypeRow.required.is_(True)
)
rows = (await db_sess.execute(stmt)).scalars().all()
return frozenset(SlotName(slot_name) for slot_name in rows)
Comment on lines +348 to +353

async def _fetch_scaling_group(
self, db_sess: SASession, scaling_group: str
) -> ScalingGroupMeta:
Expand Down Expand Up @@ -1520,6 +1527,7 @@ async def fetch_session_spec_contexts(
rg_defaults = None
resource_group_allow_fractional = False
known_slot_types: Mapping[SlotName, SlotTypes] = {}
required_slot_names = await self._fetch_required_slot_names(db_sess)
if resource_group_name:
rg_bundle = await self._fetch_scaling_group_with_slot_inventory(
db_sess, resource_group_name
Expand Down Expand Up @@ -1686,6 +1694,7 @@ async def fetch_session_spec_contexts(
active_session_count=active_session_count,
keypair_resource_policy=keypair_policy,
known_slot_types=known_slot_types,
required_slot_names=required_slot_names,
)

async def pick_default_resource_group(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,5 @@ class SessionSpecContextFetch:
dotfile_data: DotfileBundle
keypair_resource_policy: Any | None # KeyPairResourcePolicyData
known_slot_types: Mapping[SlotName, SlotTypes] = field(default_factory=dict)
required_slot_names: frozenset[SlotName] = field(default_factory=frozenset)
active_session_count: int = 0
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
InferenceModelFolderRule,
MountNameValidationRule,
RequestedSlotTypeRule,
RequiredResourceSlotRule,
ResourceLimitRule,
ServicePortRule,
SessionSpecValidationContext,
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(self, args: SchedulingControllerArgs) -> None:
ContainerLimitRule(),
ImageSlotTypeRule(),
RequestedSlotTypeRule(),
RequiredResourceSlotRule(),
ResourceLimitRule(),
ServicePortRule(),
MountNameValidationRule(),
Expand Down Expand Up @@ -188,6 +190,7 @@ async def enqueue_session_from_draft(
keypair_resource_policy=fetched.keypair_resource_policy,
image_infos=fetched.image_infos,
known_slot_types=fetched.known_slot_types,
required_slot_names=fetched.required_slot_names,
dotfile_data=fetched.dotfile_data,
active_session_count=fetched.active_session_count,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .inference_model_folder_rule import InferenceModelFolderRule
from .mount_name_validation_rule import MountNameValidationRule
from .requested_slot_type_rule import RequestedSlotTypeRule
from .required_resource_slot_rule import RequiredResourceSlotRule
from .resource_limit_rule import ResourceLimitRule
from .service_port_rule import ServicePortRule
from .session_spec_base import (
Expand All @@ -23,6 +24,7 @@
"InferenceModelFolderRule",
"MountNameValidationRule",
"RequestedSlotTypeRule",
"RequiredResourceSlotRule",
"ResourceLimitRule",
"ServicePortRule",
"SessionSpecValidationContext",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Required resource-slot validator.

Required slot names come from ``resource_slot_types.required``. The
preparer chain runs first, so image-minimum fallback for intrinsic slots
has already had a chance to fill ``cpu`` / ``mem`` before this rule checks
the finalized ``SessionSpec``.
"""

from __future__ import annotations

from decimal import Decimal

from ai.backend.manager.data.session.spec import SessionSpec
from ai.backend.manager.errors.api import InvalidAPIParameters
from ai.backend.manager.sokovan.scheduling_controller.resource_parse import parse_quantity
from ai.backend.manager.sokovan.scheduling_controller.validators.session_spec_base import (
SessionSpecValidationContext,
SessionSpecValidatorRule,
)


class RequiredResourceSlotRule(SessionSpecValidatorRule):
"""Every kernel request must include all globally required slots."""

def name(self) -> str:
return "required_resource_slot"

def validate(
self,
spec: SessionSpec,
context: SessionSpecValidationContext,
) -> None:
required = context.required_slot_names
if not required:
return

for idx, kernel in enumerate(spec.kernel_specs):
requested = {
entry.resource_type: parse_quantity(entry.quantity)
for entry in kernel.execution_spec.resources
}
missing = sorted(
str(slot_name)
for slot_name in required
if requested.get(str(slot_name), Decimal(0)) <= Decimal(0)
)
if missing:
raise InvalidAPIParameters(
extra_msg=(
f"kernel_specs[{idx}].execution_spec.resources is missing "
f"required resource slot(s): {missing}."
Comment on lines +42 to +51
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class SessionSpecValidationContext:
keypair_resource_policy: KeyPairResourcePolicyData | None = None
image_infos: Mapping[ImageID, ImageInfo] = field(default_factory=dict)
known_slot_types: Mapping[SlotName, SlotTypes] = field(default_factory=dict)
required_slot_names: frozenset[SlotName] = field(default_factory=frozenset)
dotfile_data: DotfileBundle = field(default_factory=DotfileBundle)
active_session_count: int = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
from ai.backend.manager.sokovan.scheduling_controller.validators.requested_slot_type_rule import (
RequestedSlotTypeRule,
)
from ai.backend.manager.sokovan.scheduling_controller.validators.required_resource_slot_rule import (
RequiredResourceSlotRule,
)
from ai.backend.manager.sokovan.scheduling_controller.validators.resource_limit_rule import (
ResourceLimitRule,
)
Expand Down Expand Up @@ -167,12 +170,14 @@ def _ctx(
keypair_policy: KeyPairResourcePolicyData | None = None,
image_infos: dict[ImageID, ImageInfo] | None = None,
known_slot_types: dict[SlotName, SlotTypes] | None = None,
required_slot_names: frozenset[SlotName] | None = None,
dotfile_data: DotfileBundle | None = None,
) -> SessionSpecValidationContext:
return SessionSpecValidationContext(
keypair_resource_policy=keypair_policy,
image_infos=image_infos or {},
known_slot_types=known_slot_types or {},
required_slot_names=required_slot_names or frozenset(),
dotfile_data=dotfile_data or DotfileBundle(),
)

Expand Down Expand Up @@ -506,3 +511,41 @@ def test_rejects_when_rg_has_no_active_agents(self) -> None:
spec = _spec((_kernel_with_resources(img, resources=(("cpu", "1"),)),))
with pytest.raises(InvalidAPIParameters):
RequestedSlotTypeRule().validate(spec, _ctx())


class TestRequiredResourceSlotRule:
@pytest.fixture
def image_id(self) -> ImageID:
return ImageID(uuid.uuid4())

@pytest.fixture
def required_slot_ctx(self) -> SessionSpecValidationContext:
return _ctx(required_slot_names=frozenset({SlotName("cpu"), SlotName("mem")}))

def test_passes_when_required_slots_present(
self,
image_id: ImageID,
required_slot_ctx: SessionSpecValidationContext,
) -> None:
spec = _spec((
_kernel_with_resources(image_id, resources=(("cpu", "1"), ("mem", "1073741824"))),
))
RequiredResourceSlotRule().validate(spec, required_slot_ctx)

@pytest.mark.parametrize(
("resources", "expected_missing_slot"),
[
(((("cpu", "1"),)), "mem"),
(((("cpu", "0"), ("mem", "1073741824"))), "cpu"),
],
)
def test_rejects_missing_or_zero_required_slot(
self,
image_id: ImageID,
required_slot_ctx: SessionSpecValidationContext,
resources: tuple[tuple[str, str], ...],
expected_missing_slot: str,
) -> None:
spec = _spec((_kernel_with_resources(image_id, resources=resources),))
with pytest.raises(InvalidAPIParameters, match=expected_missing_slot):
RequiredResourceSlotRule().validate(spec, required_slot_ctx)
Loading