Skip to content

Commit 3f520c5

Browse files
jopemachineclaude
andcommitted
feat(BA-5978): make BackendAIModel.build_validation_error overridable
Add a public ``build_validation_error`` classmethod on ``BackendAIModel`` that returns the ``BackendAIError`` instance to raise when ``model_validate*`` fails. Default surfaces the generic ``BackendAIModelValidationFailed``; subclasses override the method to inject a domain-specific 400 directly, without any caller-side try/except re-wrap. Apply the override on the two models that previously needed wrapping: * ``ModelDefinition`` raises ``ModelDefinitionValidationError``. Moved that exception class from ``ai.backend.agent.errors.agent`` to ``ai.backend.common.exception`` (and dropped the agent re-export) so the model — which lives in ``common.config`` — can construct it without an upward-layer import. The agent-specific error_type URL segment is dropped in the move. * ``SessionSpec`` raises ``IncompleteSessionSpec`` with the existing ``extra_data["missing"]`` shape, using a module-local ``_format_loc`` helper. The caller-side try/except wrappers around ``ModelDefinition.model_validate`` in ``agent/agent.py`` and ``manager/services/model_card/service.py``, and the wrapper in ``sokovan/scheduling_controller/preparers/session_spec_preparer.py``, are all removed — the models now raise the right domain error directly. Tests stay unchanged: they still expect ``IncompleteSessionSpec`` / ``ModelDefinitionValidationError`` because the override raises the same types. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 202f333 commit 3f520c5

11 files changed

Lines changed: 162 additions & 167 deletions

File tree

src/ai/backend/agent/agent.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@
158158
AbstractEvent,
159159
)
160160
from ai.backend.common.exception import (
161-
BackendAIModelValidationFailed,
162161
ConfigurationError,
163162
VolumeMountFailed,
164163
)
@@ -236,7 +235,6 @@
236235
ImagePullTimeoutError,
237236
ModelDefinitionEmptyError,
238237
ModelDefinitionNotFoundError,
239-
ModelDefinitionValidationError,
240238
ModelFolderNotSpecifiedError,
241239
PortConflictError,
242240
ReservedPortError,
@@ -3295,13 +3293,7 @@ async def _load_model_definition(
32953293
f" vFolder {model_folder.name} (ID {model_folder.vfid})",
32963294
)
32973295

3298-
try:
3299-
parsed = ModelDefinition.model_validate(inlined)
3300-
except BackendAIModelValidationFailed as e:
3301-
raise ModelDefinitionValidationError(
3302-
"Failed to validate model definition for vFolder"
3303-
f" {model_folder.name} (ID {model_folder.vfid})",
3304-
) from e
3296+
parsed = ModelDefinition.model_validate(inlined)
33053297
if not parsed.models:
33063298
raise ModelDefinitionEmptyError
33073299
model_definition = parsed.model_dump(mode="json")

src/ai/backend/agent/errors/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
ModelDefinitionEmptyError,
2121
ModelDefinitionInvalidYAMLError,
2222
ModelDefinitionNotFoundError,
23-
ModelDefinitionValidationError,
2423
ModelFolderNotSpecifiedError,
2524
PortConflictError,
2625
ReservedPortError,
@@ -64,7 +63,6 @@
6463
"ModelDefinitionEmptyError",
6564
"ModelDefinitionInvalidYAMLError",
6665
"ModelDefinitionNotFoundError",
67-
"ModelDefinitionValidationError",
6866
"ModelFolderNotSpecifiedError",
6967
"PortConflictError",
7068
"ReservedPortError",

src/ai/backend/agent/errors/agent.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,6 @@ def error_code(self) -> ErrorCode:
169169
)
170170

171171

172-
class ModelDefinitionValidationError(BackendAIError, web.HTTPBadRequest):
173-
"""Raised when model definition validation fails."""
174-
175-
error_type = "https://api.backend.ai/probs/agent/model-definition-validation-failed"
176-
error_title = "Model definition validation failed."
177-
178-
def error_code(self) -> ErrorCode:
179-
return ErrorCode(
180-
domain=ErrorDomain.MODEL_SERVICE,
181-
operation=ErrorOperation.ACCESS,
182-
error_detail=ErrorDetail.INVALID_PARAMETERS,
183-
)
184-
185-
186172
class ModelFolderNotSpecifiedError(BackendAIError, web.HTTPBadRequest):
187173
"""Raised when no model virtual folder is specified."""
188174

src/ai/backend/common/config.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
from collections.abc import Mapping, MutableMapping
77
from pathlib import Path
8-
from typing import Any
8+
from typing import Any, override
99

1010
import humps
1111
import tomli
@@ -19,8 +19,8 @@
1919

2020
from . import validators as tx
2121
from .etcd import AsyncEtcd, ConfigScopes
22-
from .exception import ConfigurationError
23-
from .types import BackendAIModel, RedisHelperConfig
22+
from .exception import BackendAIError, ConfigurationError, ModelDefinitionValidationError
23+
from .types import BackendAIModel, ModelValidationFailureInfo, RedisHelperConfig
2424

2525
__all__ = (
2626
"ConfigurationError",
@@ -477,6 +477,14 @@ class ModelDefinition(BaseConfigModel):
477477
description="List of models in the model definition.",
478478
)
479479

480+
@override
481+
@classmethod
482+
def build_validation_error(cls, info: ModelValidationFailureInfo) -> BackendAIError:
483+
return ModelDefinitionValidationError(
484+
extra_msg=info.summary,
485+
extra_data={"errors": info.errors},
486+
)
487+
480488
def merge(self, override: ModelDefinition) -> ModelDefinition:
481489
"""Merge the given override into this definition, returning a new instance."""
482490
return _merge_definition(self, override)
@@ -664,6 +672,14 @@ class ModelDefinitionDraft(BaseConfigModel):
664672

665673
models: list[ModelConfigDraft] | None = None
666674

675+
@override
676+
@classmethod
677+
def build_validation_error(cls, info: ModelValidationFailureInfo) -> BackendAIError:
678+
return ModelDefinitionValidationError(
679+
extra_msg=info.summary,
680+
extra_data={"errors": info.errors},
681+
)
682+
667683
def merge(self, override: ModelDefinitionDraft) -> ModelDefinitionDraft:
668684
"""Merge ``override`` over ``self`` and return a new draft.
669685

src/ai/backend/common/exception.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
from __future__ import annotations
22

33
import enum
4-
import logging
54
from abc import ABC, abstractmethod
65
from collections.abc import Mapping
76
from dataclasses import dataclass
87
from typing import Any, Self
98

109
from aiohttp import web
10+
from pydantic_core import ErrorDetails
1111

1212
from .json import dump_json
1313

14-
log = logging.getLogger(__spec__.name)
15-
1614

1715
class ConfigurationError(Exception):
1816
invalid_data: Mapping[str, Any]
@@ -447,13 +445,10 @@ def error_code(self) -> ErrorCode:
447445

448446

449447
class BackendAIModelValidationFailed(BackendAIError, web.HTTPBadRequest):
450-
"""Generic 400 raised when a :class:`BackendAIModel` fails validation.
448+
"""Default 400 raised by :class:`BackendAIModel.build_validation_error`.
451449
452-
Distinct from :class:`InvalidAPIParameters` so callers can tell
453-
"this came from a Pydantic model validator" apart from "this came
454-
from an explicit API parameter check," and so domain-specific
455-
handlers can choose to catch and re-wrap one without affecting the
456-
other.
450+
Kept distinct from :class:`InvalidAPIParameters` so handlers can
451+
catch one without picking up the other.
457452
"""
458453

459454
error_type = "https://api.backend.ai/probs/model-validation-failed"
@@ -466,6 +461,33 @@ def error_code(self) -> ErrorCode:
466461
error_detail=ErrorDetail.INVALID_PARAMETERS,
467462
)
468463

464+
def errors(self) -> list[ErrorDetails]:
465+
"""Per-field errors in the same shape as
466+
``pydantic.ValidationError.errors()``. Empty when no
467+
``extra_data["errors"]`` is attached."""
468+
if not self.extra_data:
469+
return []
470+
return list(self.extra_data.get("errors") or [])
471+
472+
473+
class ModelDefinitionValidationError(BackendAIError, web.HTTPBadRequest):
474+
"""400 raised by ``ModelDefinition.model_validate`` (via its
475+
:meth:`BackendAIModel.build_validation_error` override).
476+
477+
Lives in ``common`` so ``ModelDefinition`` (also in ``common``) can
478+
construct it without an upward-layer import.
479+
"""
480+
481+
error_type = "https://api.backend.ai/probs/model-definition-validation-failed"
482+
error_title = "Model definition validation failed."
483+
484+
def error_code(self) -> ErrorCode:
485+
return ErrorCode(
486+
domain=ErrorDomain.MODEL_SERVICE,
487+
operation=ErrorOperation.ACCESS,
488+
error_detail=ErrorDetail.INVALID_PARAMETERS,
489+
)
490+
469491

470492
class DeprecatedAPI(BackendAIError, web.HTTPBadRequest):
471493
error_type = "https://api.backend.ai/probs/deprecated"

src/ai/backend/common/types.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@
5151
TypeAdapter,
5252
ValidationError,
5353
)
54+
from pydantic_core import ErrorDetails
5455
from redis.asyncio import Redis
5556

5657
from .defs import UNKNOWN_CONTAINER_ID, RedisRole
5758
from .exception import (
59+
BackendAIError,
5860
BackendAIModelValidationFailed,
5961
GenericNotImplementedError,
6062
InvalidIpAddressValue,
@@ -116,6 +118,7 @@
116118
"MetricValue",
117119
"ModelServiceProfile",
118120
"ModelServiceStatus",
121+
"ModelValidationFailureInfo",
119122
"MountExpression",
120123
"MountInfoEntry",
121124
"MountPermission",
@@ -181,56 +184,70 @@
181184
)
182185

183186

187+
@dataclass(frozen=True)
188+
class ModelValidationFailureInfo:
189+
"""Pydantic-decoupled view of a failed ``model_validate*`` call,
190+
passed to :meth:`BackendAIModel.build_validation_error`.
191+
192+
``summary`` is ``str(pydantic.ValidationError)``; ``errors`` is
193+
``exc.errors()`` as-is.
194+
"""
195+
196+
summary: str
197+
errors: list[ErrorDetails]
198+
199+
184200
class BackendAIModel(BaseModel):
185-
"""Project-wide Pydantic base for Backend.AI models.
186-
187-
Overrides ``model_validate`` / ``model_validate_json`` /
188-
``model_validate_strings`` so a ``ValidationError`` is auto-mapped
189-
to :class:`BackendAIModelValidationFailed` (HTTP 400) carrying the structured
190-
per-field error list. Call sites get a clean 4xx without repeating
191-
``try / except ValidationError`` at every site.
192-
193-
Notes:
194-
195-
* Pydantic v2 routes nested validation through
196-
``__pydantic_validator__`` directly, not the classmethod, so this
197-
override only affects explicit ``Model.model_validate(...)``
198-
calls — nested models are unaffected.
199-
* The ``__init__`` constructor and the compiled validator stay
200-
untouched, so internal default-value construction
201-
(``Model()`` / ``Model(field=...)``) still works exactly like
202-
stock Pydantic.
201+
"""Pydantic base whose ``model_validate`` / ``model_validate_json``
202+
auto-map ``ValidationError`` to a :class:`BackendAIError` (HTTP 4xx)
203+
via :meth:`build_validation_error`. Subclasses override the
204+
classmethod to inject a domain-specific 400::
205+
206+
class MyConfig(BackendAIModel):
207+
@override
208+
@classmethod
209+
def build_validation_error(
210+
cls, info: ModelValidationFailureInfo
211+
) -> BackendAIError:
212+
return MyConfigParseError(
213+
extra_msg=info.summary,
214+
extra_data={"errors": info.errors},
215+
)
216+
217+
``__init__`` is left alone: pydantic v2 invokes nested models'
218+
``__init__`` from inside the outer validator, so converting there
219+
would break ``loc``-path aggregation. Direct ``Model(field=...)``
220+
construction therefore still raises stock ``pydantic.ValidationError``;
221+
switch the call site to ``Model.model_validate({...})`` to opt into
222+
the override path.
203223
"""
204224

225+
@classmethod
226+
def build_validation_error(cls, info: ModelValidationFailureInfo) -> BackendAIError:
227+
"""Default override raising the generic
228+
:class:`BackendAIModelValidationFailed`."""
229+
return BackendAIModelValidationFailed(
230+
extra_msg=info.summary,
231+
extra_data={"errors": info.errors},
232+
)
233+
234+
@classmethod
235+
def _validation_failure_info(cls, exc: ValidationError) -> ModelValidationFailureInfo:
236+
return ModelValidationFailureInfo(summary=str(exc), errors=exc.errors())
237+
205238
@classmethod
206239
def model_validate(cls, *args: Any, **kwargs: Any) -> Self:
207240
try:
208241
return super().model_validate(*args, **kwargs)
209242
except ValidationError as e:
210-
raise BackendAIModelValidationFailed(
211-
extra_msg=str(e),
212-
extra_data={"errors": e.errors()},
213-
) from e
243+
raise cls.build_validation_error(cls._validation_failure_info(e)) from e
214244

215245
@classmethod
216246
def model_validate_json(cls, *args: Any, **kwargs: Any) -> Self:
217247
try:
218248
return super().model_validate_json(*args, **kwargs)
219249
except ValidationError as e:
220-
raise BackendAIModelValidationFailed(
221-
extra_msg=str(e),
222-
extra_data={"errors": e.errors()},
223-
) from e
224-
225-
@classmethod
226-
def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Self:
227-
try:
228-
return super().model_validate_strings(*args, **kwargs)
229-
except ValidationError as e:
230-
raise BackendAIModelValidationFailed(
231-
extra_msg=str(e),
232-
extra_data={"errors": e.errors()},
233-
) from e
250+
raise cls.build_validation_error(cls._validation_failure_info(e)) from e
234251

235252

236253
class aobject:

src/ai/backend/manager/data/session/spec.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,30 @@
1717
from __future__ import annotations
1818

1919
from collections.abc import Mapping
20-
from typing import Any
20+
from typing import Any, override
2121
from uuid import UUID
2222

2323
import yarl
2424
from pydantic import ConfigDict, Field
2525

26+
from ai.backend.common.exception import BackendAIError
2627
from ai.backend.common.identifier.domain import DomainName
2728
from ai.backend.common.identifier.project import ProjectID
2829
from ai.backend.common.identifier.resource_group import ResourceGroupName
2930
from ai.backend.common.identifier.session import SessionID
30-
from ai.backend.common.types import AccessKey, BackendAIModel, SessionTypes, VFolderMount
31+
from ai.backend.common.types import (
32+
AccessKey,
33+
BackendAIModel,
34+
ModelValidationFailureInfo,
35+
SessionTypes,
36+
VFolderMount,
37+
)
3138
from ai.backend.manager.data.session.options import (
3239
InternalDataExtras,
3340
KernelExecutionSpec,
3441
SessionOptions,
3542
)
43+
from ai.backend.manager.errors.kernel import IncompleteSessionSpec
3644
from ai.backend.manager.models.network import NetworkType
3745

3846

@@ -138,3 +146,22 @@ class SessionSpec(_SpecBaseModel):
138146
options: SessionOptions
139147
kernel_specs: tuple[KernelSpec, ...]
140148
internal_data_extras: InternalDataExtras = Field(default_factory=InternalDataExtras)
149+
150+
@override
151+
@classmethod
152+
def build_validation_error(cls, info: ModelValidationFailureInfo) -> BackendAIError:
153+
missing_paths = [cls._format_loc(tuple(err["loc"])) for err in info.errors]
154+
return IncompleteSessionSpec(
155+
extra_msg="SessionSpec fields not resolved: " + ", ".join(missing_paths),
156+
extra_data={"missing": missing_paths},
157+
)
158+
159+
@staticmethod
160+
def _format_loc(loc: tuple[object, ...]) -> str:
161+
parts: list[str] = []
162+
for item in loc:
163+
if isinstance(item, int):
164+
parts.append(f"[{item}]")
165+
else:
166+
parts.append(f".{item}" if parts else str(item))
167+
return "".join(parts)

0 commit comments

Comments
 (0)