diff --git a/changes/11556.feature.md b/changes/11556.feature.md new file mode 100644 index 00000000000..b458fa3d01e --- /dev/null +++ b/changes/11556.feature.md @@ -0,0 +1 @@ +Add RequiredResourceSlotRule to the SessionSpec validator chain so session creation fails with InvalidAPIParameters when a kernel omits a globally required resource slot diff --git a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py index 70fdac4f802..dddab057169 100644 --- a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py @@ -345,6 +345,13 @@ async def _fetch_scaling_group_with_slot_inventory( active_slot_types=active_slot_types, ) + async def _fetch_required_slot_names(self, db_sess: SASession) -> frozenset[SlotName]: + stmt = sa.select(ResourceSlotTypeRow.slot_name).where( + ResourceSlotTypeRow.required.is_(True) + ) + rows = (await db_sess.execute(stmt)).scalars().all() + return frozenset(SlotName(slot_name) for slot_name in rows) + async def _fetch_scaling_group( self, db_sess: SASession, scaling_group: str ) -> ScalingGroupMeta: @@ -1520,6 +1527,7 @@ async def fetch_session_spec_contexts( rg_defaults = None resource_group_allow_fractional = False known_slot_types: Mapping[SlotName, SlotTypes] = {} + required_slot_names = await self._fetch_required_slot_names(db_sess) if resource_group_name: rg_bundle = await self._fetch_scaling_group_with_slot_inventory( db_sess, resource_group_name @@ -1686,6 +1694,7 @@ async def fetch_session_spec_contexts( active_session_count=active_session_count, keypair_resource_policy=keypair_policy, known_slot_types=known_slot_types, + required_slot_names=required_slot_names, ) async def pick_default_resource_group( diff --git a/src/ai/backend/manager/repositories/scheduler/types/session_creation.py b/src/ai/backend/manager/repositories/scheduler/types/session_creation.py index 63c5ef45e19..fb8f823332d 100644 --- a/src/ai/backend/manager/repositories/scheduler/types/session_creation.py +++ b/src/ai/backend/manager/repositories/scheduler/types/session_creation.py @@ -147,4 +147,5 @@ class SessionSpecContextFetch: dotfile_data: DotfileBundle keypair_resource_policy: Any | None # KeyPairResourcePolicyData known_slot_types: Mapping[SlotName, SlotTypes] = field(default_factory=dict) + required_slot_names: frozenset[SlotName] = field(default_factory=frozenset) active_session_count: int = 0 diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py b/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py index d0514f9aded..d9e7575ca37 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py @@ -52,6 +52,7 @@ InferenceModelFolderRule, MountNameValidationRule, RequestedSlotTypeRule, + RequiredResourceSlotRule, ResourceLimitRule, ServicePortRule, SessionSpecValidationContext, @@ -129,6 +130,7 @@ def __init__(self, args: SchedulingControllerArgs) -> None: ContainerLimitRule(), ImageSlotTypeRule(), RequestedSlotTypeRule(), + RequiredResourceSlotRule(), ResourceLimitRule(), ServicePortRule(), MountNameValidationRule(), @@ -188,6 +190,7 @@ async def enqueue_session_from_draft( keypair_resource_policy=fetched.keypair_resource_policy, image_infos=fetched.image_infos, known_slot_types=fetched.known_slot_types, + required_slot_names=fetched.required_slot_names, dotfile_data=fetched.dotfile_data, active_session_count=fetched.active_session_count, ) diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/validators/__init__.py b/src/ai/backend/manager/sokovan/scheduling_controller/validators/__init__.py index 217294f04ea..77e14a171da 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/validators/__init__.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/validators/__init__.py @@ -7,6 +7,7 @@ from .inference_model_folder_rule import InferenceModelFolderRule from .mount_name_validation_rule import MountNameValidationRule from .requested_slot_type_rule import RequestedSlotTypeRule +from .required_resource_slot_rule import RequiredResourceSlotRule from .resource_limit_rule import ResourceLimitRule from .service_port_rule import ServicePortRule from .session_spec_base import ( @@ -23,6 +24,7 @@ "InferenceModelFolderRule", "MountNameValidationRule", "RequestedSlotTypeRule", + "RequiredResourceSlotRule", "ResourceLimitRule", "ServicePortRule", "SessionSpecValidationContext", diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/validators/required_resource_slot_rule.py b/src/ai/backend/manager/sokovan/scheduling_controller/validators/required_resource_slot_rule.py new file mode 100644 index 00000000000..47917b0c0cb --- /dev/null +++ b/src/ai/backend/manager/sokovan/scheduling_controller/validators/required_resource_slot_rule.py @@ -0,0 +1,53 @@ +"""Required resource-slot validator. + +Required slot names come from ``resource_slot_types.required``. The +preparer chain runs first, so image-minimum fallback for intrinsic slots +has already had a chance to fill ``cpu`` / ``mem`` before this rule checks +the finalized ``SessionSpec``. +""" + +from __future__ import annotations + +from decimal import Decimal + +from ai.backend.manager.data.session.spec import SessionSpec +from ai.backend.manager.errors.api import InvalidAPIParameters +from ai.backend.manager.sokovan.scheduling_controller.resource_parse import parse_quantity +from ai.backend.manager.sokovan.scheduling_controller.validators.session_spec_base import ( + SessionSpecValidationContext, + SessionSpecValidatorRule, +) + + +class RequiredResourceSlotRule(SessionSpecValidatorRule): + """Every kernel request must include all globally required slots.""" + + def name(self) -> str: + return "required_resource_slot" + + def validate( + self, + spec: SessionSpec, + context: SessionSpecValidationContext, + ) -> None: + required = context.required_slot_names + if not required: + return + + for idx, kernel in enumerate(spec.kernel_specs): + requested = { + entry.resource_type: parse_quantity(entry.quantity) + for entry in kernel.execution_spec.resources + } + missing = sorted( + str(slot_name) + for slot_name in required + if requested.get(str(slot_name), Decimal(0)) <= Decimal(0) + ) + if missing: + raise InvalidAPIParameters( + extra_msg=( + f"kernel_specs[{idx}].execution_spec.resources is missing " + f"required resource slot(s): {missing}." + ) + ) diff --git a/src/ai/backend/manager/sokovan/scheduling_controller/validators/session_spec_base.py b/src/ai/backend/manager/sokovan/scheduling_controller/validators/session_spec_base.py index aa075087472..501fbbf63bc 100644 --- a/src/ai/backend/manager/sokovan/scheduling_controller/validators/session_spec_base.py +++ b/src/ai/backend/manager/sokovan/scheduling_controller/validators/session_spec_base.py @@ -57,6 +57,7 @@ class SessionSpecValidationContext: keypair_resource_policy: KeyPairResourcePolicyData | None = None image_infos: Mapping[ImageID, ImageInfo] = field(default_factory=dict) known_slot_types: Mapping[SlotName, SlotTypes] = field(default_factory=dict) + required_slot_names: frozenset[SlotName] = field(default_factory=frozenset) dotfile_data: DotfileBundle = field(default_factory=DotfileBundle) active_session_count: int = 0 diff --git a/tests/unit/manager/sokovan/scheduling_controller/validators/test_session_spec_rules.py b/tests/unit/manager/sokovan/scheduling_controller/validators/test_session_spec_rules.py index 01b0c379e20..0f486ba00b3 100644 --- a/tests/unit/manager/sokovan/scheduling_controller/validators/test_session_spec_rules.py +++ b/tests/unit/manager/sokovan/scheduling_controller/validators/test_session_spec_rules.py @@ -70,6 +70,9 @@ from ai.backend.manager.sokovan.scheduling_controller.validators.requested_slot_type_rule import ( RequestedSlotTypeRule, ) +from ai.backend.manager.sokovan.scheduling_controller.validators.required_resource_slot_rule import ( + RequiredResourceSlotRule, +) from ai.backend.manager.sokovan.scheduling_controller.validators.resource_limit_rule import ( ResourceLimitRule, ) @@ -167,12 +170,14 @@ def _ctx( keypair_policy: KeyPairResourcePolicyData | None = None, image_infos: dict[ImageID, ImageInfo] | None = None, known_slot_types: dict[SlotName, SlotTypes] | None = None, + required_slot_names: frozenset[SlotName] | None = None, dotfile_data: DotfileBundle | None = None, ) -> SessionSpecValidationContext: return SessionSpecValidationContext( keypair_resource_policy=keypair_policy, image_infos=image_infos or {}, known_slot_types=known_slot_types or {}, + required_slot_names=required_slot_names or frozenset(), dotfile_data=dotfile_data or DotfileBundle(), ) @@ -506,3 +511,41 @@ def test_rejects_when_rg_has_no_active_agents(self) -> None: spec = _spec((_kernel_with_resources(img, resources=(("cpu", "1"),)),)) with pytest.raises(InvalidAPIParameters): RequestedSlotTypeRule().validate(spec, _ctx()) + + +class TestRequiredResourceSlotRule: + @pytest.fixture + def image_id(self) -> ImageID: + return ImageID(uuid.uuid4()) + + @pytest.fixture + def required_slot_ctx(self) -> SessionSpecValidationContext: + return _ctx(required_slot_names=frozenset({SlotName("cpu"), SlotName("mem")})) + + def test_passes_when_required_slots_present( + self, + image_id: ImageID, + required_slot_ctx: SessionSpecValidationContext, + ) -> None: + spec = _spec(( + _kernel_with_resources(image_id, resources=(("cpu", "1"), ("mem", "1073741824"))), + )) + RequiredResourceSlotRule().validate(spec, required_slot_ctx) + + @pytest.mark.parametrize( + ("resources", "expected_missing_slot"), + [ + (((("cpu", "1"),)), "mem"), + (((("cpu", "0"), ("mem", "1073741824"))), "cpu"), + ], + ) + def test_rejects_missing_or_zero_required_slot( + self, + image_id: ImageID, + required_slot_ctx: SessionSpecValidationContext, + resources: tuple[tuple[str, str], ...], + expected_missing_slot: str, + ) -> None: + spec = _spec((_kernel_with_resources(image_id, resources=resources),)) + with pytest.raises(InvalidAPIParameters, match=expected_missing_slot): + RequiredResourceSlotRule().validate(spec, required_slot_ctx)