Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11554.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Migrate every remaining pydantic `BaseModel` subclass across `src/ai/backend/` to `BackendAISchema`, so any `model_validate()` failure auto-converts to a `BackendAISchemaValidationFailed` (HTTP 400) instead of leaking as raw `pydantic.ValidationError`.
13 changes: 10 additions & 3 deletions src/ai/backend/account_manager/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from pydantic import BaseModel, ConfigDict, ValidationError

from ai.backend.account_manager.exceptions import AuthorizationFailed, InvalidAPIParameters
from ai.backend.common.exception import BackendAISchemaValidationFailed
from ai.backend.common.types import BackendAISchema
from ai.backend.logging import BraceStyleAdapter


Expand Down Expand Up @@ -69,7 +71,7 @@ def mask_sensitive_keys(data: Mapping[str, Any]) -> Mapping[str, Any]:
TBaseModel = TypeVar("TBaseModel", bound=BaseModel)


class RequestData(BaseModel):
class RequestData(BackendAISchema):
model_config = ConfigDict(
extra="allow",
)
Expand Down Expand Up @@ -173,8 +175,13 @@ async def wrapped(
kwargs["query"] = query_params
except (json.decoder.JSONDecodeError, yaml.YAMLError, yaml.MarkedYAMLError) as e:
raise InvalidAPIParameters("Malformed body") from e
except ValidationError as e:
raise InvalidAPIParameters("Input validation error", extra_data=e.errors()) from e
except (BackendAISchemaValidationFailed, ValidationError) as e:
# ``ValidationError`` covers plain ``BaseModel`` subclasses that
# skip the ``BackendAISchema`` auto-conversion override.
raise InvalidAPIParameters(
"Input validation error",
extra_data={"errors": e.errors()},
Comment on lines +181 to +183
) from e
result = await handler(request, checked_params, *args, **kwargs)
return ensure_stream_response_type(result)

Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/account_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import click
from pydantic import (
BaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
Expand All @@ -23,6 +22,7 @@
from pydantic_core import PydanticUndefined, core_schema

from ai.backend.common import config
from ai.backend.common.types import BackendAISchema
from ai.backend.logging import LogLevel

from .types import EventLoopType
Expand All @@ -46,7 +46,7 @@ class TransactionIsolationLevel(enum.StrEnum):
SERIALIZABLE = "SERIALIZABLE"


class BaseSchema(BaseModel):
class BaseSchema(BackendAISchema):
model_config = ConfigDict(
validate_by_name=True,
from_attributes=True,
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/agent/kernel_registry/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Self

from pydantic import BaseModel, Field
from pydantic import Field

from ai.backend.agent.kernel import KernelOwnershipData
from ai.backend.agent.proxy import DomainSocketPathPair
from ai.backend.agent.resources import KernelResourceSpec
from ai.backend.common.docker import ImageRef
from ai.backend.common.types import AgentId, KernelId, ServicePort, SessionTypes
from ai.backend.common.types import AgentId, BackendAISchema, KernelId, ServicePort, SessionTypes

if TYPE_CHECKING:
from ai.backend.agent.docker.kernel import DockerKernel


class KernelRecoveryData(BaseModel):
class KernelRecoveryData(BackendAISchema):
"""
Data required for recovering a Kernel.
Agent should load and write Kernel data using this structure
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Any

import attrs
from pydantic import BaseModel

from ai.backend.common.asyncio import current_loop
from ai.backend.common.types import BackendAISchema


@attrs.define(auto_attribs=True, slots=True)
Expand All @@ -16,7 +16,7 @@ class DomainSocketProxy:
proxy_server: asyncio.AbstractServer


class DomainSocketPathPair(BaseModel):
class DomainSocketPathPair(BackendAISchema):
host_sock_path: Path
host_proxy_path: Path

Expand Down
6 changes: 2 additions & 4 deletions src/ai/backend/agent/scratch/types.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from collections.abc import Mapping
from typing import Self

from pydantic import BaseModel

from ai.backend.agent.kernel import KernelOwnershipData
from ai.backend.agent.kernel_registry.types import KernelRecoveryData
from ai.backend.agent.proxy import DomainSocketPathPair
from ai.backend.agent.resources import KernelResourceSpec
from ai.backend.common.docker import ImageRef
from ai.backend.common.types import AgentId, KernelId, ServicePort, SessionTypes
from ai.backend.common.types import AgentId, BackendAISchema, KernelId, ServicePort, SessionTypes


class KernelRecoveryScratchData(BaseModel):
class KernelRecoveryScratchData(BackendAISchema):
"""
Serializable subset of KernelRecoveryData for scratch storage.
Excludes `resource_spec` and `environ` which are loaded separately.
Expand Down
5 changes: 2 additions & 3 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from callosum.ordering import ExitOrderedAsyncScheduler
from callosum.rpc import Peer, RPCMessage
from etcd_client import WatchEventType
from pydantic import ValidationError
from setproctitle import setproctitle
from zmq.auth.certs import load_certificate

Expand Down Expand Up @@ -87,7 +86,7 @@
KernelTerminatedBroadcastEvent,
)
from ai.backend.common.events.event_types.kernel.types import KernelLifecycleEventReason
from ai.backend.common.exception import ConfigurationError
from ai.backend.common.exception import BackendAISchemaValidationFailed, ConfigurationError
from ai.backend.common.health_checker.checkers.etcd import EtcdHealthChecker
from ai.backend.common.health_checker.checkers.valkey import ValkeyHealthChecker
from ai.backend.common.health_checker.probe import HealthProbe, HealthProbeOptions
Expand Down Expand Up @@ -1740,7 +1739,7 @@ def main(
if server_config.debug.enabled:
print("== Agent configuration ==")
pprint(server_config.model_dump(by_alias=True))
except ValidationError as e:
except BackendAISchemaValidationFailed as e:
print(
"ConfigurationError: Agent local config failed validation checks:",
file=sys.stderr,
Expand Down
5 changes: 3 additions & 2 deletions src/ai/backend/appproxy/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pydantic_core import PydanticUndefined, core_schema

from ai.backend.common.meta import BackendAIConfigMeta, CompositeType, ConfigExample
from ai.backend.common.types import BackendAISchema

from .errors import (
GroupNotFoundError,
Expand All @@ -30,7 +31,7 @@
# FIXME: merge majority of common definitions to ai.backend.common when ready


class BaseSchema(BaseModel):
class BaseSchema(BackendAISchema):
model_config = ConfigDict(
populate_by_name=True,
from_attributes=True,
Expand Down Expand Up @@ -723,7 +724,7 @@ class UnsupportedTypeError(RuntimeError):


def generate_example_json(
schema: type[BaseModel] | types.GenericAlias | types.UnionType,
schema: type[BackendAISchema] | types.GenericAlias | types.UnionType,
parent: list[str] | None = None,
) -> dict[str, Any] | list[Any]:
if parent is None:
Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/appproxy/common/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel

from ai.backend.appproxy.common.types import PydanticResponse
from ai.backend.common.types import BackendAISchema

from . import __version__

Expand Down Expand Up @@ -131,7 +132,7 @@ def generate_openapi(
route_def["description"] = "\n".join(description)
type_hints = get_type_hints(route.handler)

def _parse_schema(model_cls: type[BaseModel]) -> dict[str, Any]:
def _parse_schema(model_cls: type[BackendAISchema]) -> dict[str, Any]:
if not issubclass(model_cls, BaseModel):
raise RuntimeError(f"{model_cls} not considered as a valid response type")

Expand Down
14 changes: 7 additions & 7 deletions src/ai/backend/appproxy/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aiohttp import web
from pydantic import AliasChoices, AnyUrl, BaseModel, ConfigDict, Field

from ai.backend.common.types import ModelServiceStatus, RuntimeVariant
from ai.backend.common.types import BackendAISchema, ModelServiceStatus, RuntimeVariant

# FIXME: merge majority of common definitions to ai.backend.common when ready

Expand Down Expand Up @@ -84,7 +84,7 @@ class DigestModType(enum.StrEnum):
]


class RouteInfo(BaseModel):
class RouteInfo(BackendAISchema):
"""Information about a route within a circuit.

Routes describe a kernel endpoint (``kernel_host`` + ``kernel_port``)
Expand Down Expand Up @@ -157,7 +157,7 @@ def current_kernel_host(self) -> str:
return self.kernel_host or "localhost"


class SerializableCircuit(BaseModel):
class SerializableCircuit(BackendAISchema):
"""
Serializable representation of `ai.backend.appproxy.coordinator.models.Circuit`
"""
Expand Down Expand Up @@ -241,7 +241,7 @@ def traefik_router_name(self) -> str:
return f"bai_router_{self.id}@etcd"


class SerializableToken(BaseModel):
class SerializableToken(BackendAISchema):
login_session_token: str | None
kernel_host: str
kernel_port: int
Expand All @@ -252,7 +252,7 @@ class SerializableToken(BaseModel):
domain_name: str


class SessionConfig(BaseModel):
class SessionConfig(BackendAISchema):
model_config = ConfigDict(populate_by_name=True)

id: Annotated[UUID | None, Field(default=None)]
Expand All @@ -265,13 +265,13 @@ class SessionConfig(BaseModel):
domain_name: str


class EndpointConfig(BaseModel):
class EndpointConfig(BackendAISchema):
id: UUID
runtime_variant: Annotated[RuntimeVariant | None, Field(default=None)]
existing_url: AnyUrl | None


class HealthCheckState(BaseModel):
class HealthCheckState(BackendAISchema):
"""
Runtime health check state
"""
Expand Down
10 changes: 8 additions & 2 deletions src/ai/backend/appproxy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ai.backend.appproxy.common.types import PydanticResponse
from ai.backend.common import redis_helper
from ai.backend.common.exception import BackendAISchemaValidationFailed
from ai.backend.common.types import RedisConnectionInfo
from ai.backend.logging import BraceStyleAdapter

Expand Down Expand Up @@ -232,8 +233,13 @@ async def wrapped(
kwargs["query"] = query_params
except (json.decoder.JSONDecodeError, yaml.YAMLError, yaml.MarkedYAMLError) as e:
raise InvalidAPIParameters("Malformed body") from e
except ValidationError as e:
raise InvalidAPIParameters("Input validation error", extra_data=e.errors()) from e
except (BackendAISchemaValidationFailed, ValidationError) as e:
# ``ValidationError`` covers plain ``BaseModel`` subclasses that
# skip the ``BackendAISchema`` auto-conversion override.
raise InvalidAPIParameters(
"Input validation error",
extra_data={"errors": e.errors()},
Comment on lines +239 to +241
) from e
result = await handler(request, checked_params, *args, **kwargs)
return ensure_stream_response_type(result)

Expand Down
5 changes: 3 additions & 2 deletions src/ai/backend/appproxy/coordinator/api/circuit_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import aiohttp_cors
import sqlalchemy as sa
from aiohttp import web
from pydantic import BaseModel, Field, validator
from pydantic import Field, validator
from sqlalchemy.orm import selectinload

from ai.backend.appproxy.common.types import (
Expand All @@ -22,6 +22,7 @@
from ai.backend.appproxy.coordinator.models import Circuit
from ai.backend.appproxy.coordinator.models.utils import execute_with_txn_retry
from ai.backend.appproxy.coordinator.types import RootContext
from ai.backend.common.types import BackendAISchema

from .types import StubResponseModel
from .utils import auth_required
Expand All @@ -30,7 +31,7 @@
from sqlalchemy.ext.asyncio import AsyncSession as SASession


class BulkRemoveCircuitsRequestModel(BaseModel):
class BulkRemoveCircuitsRequestModel(BackendAISchema):
circuit_ids: Annotated[list[UUID], Field(description="Comma separated list of Circuit UUIDs.")]

@validator("circuit_ids", pre=True)
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/appproxy/coordinator/api/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import aiohttp_cors
from aiohttp import web
from pydantic import BaseModel

from ai.backend.appproxy.common.errors import AuthorizationFailed
from ai.backend.appproxy.common.types import CORSOptions, PydanticResponse, WebMiddleware
Expand All @@ -13,12 +12,13 @@
from ai.backend.appproxy.coordinator.errors import InvalidSessionParameterError
from ai.backend.appproxy.coordinator.models import Token
from ai.backend.appproxy.coordinator.types import RootContext
from ai.backend.common.types import BackendAISchema
from ai.backend.logging import BraceStyleAdapter

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class TokenResponseModel(BaseModel):
class TokenResponseModel(BackendAISchema):
token: str


Expand Down
15 changes: 8 additions & 7 deletions src/ai/backend/appproxy/coordinator/api/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import aiohttp_cors
from aiohttp import web
from pydantic import AnyUrl, BaseModel, Field
from pydantic import AnyUrl, Field

from ai.backend.appproxy.common.errors import ObjectNotFound
from ai.backend.appproxy.common.types import (
Expand Down Expand Up @@ -48,17 +48,18 @@
TagsModel,
)
from ai.backend.common.identifier.deployment import DeploymentID
from ai.backend.common.types import BackendAISchema

from .types import StubResponseModel
from .utils import auth_required


class EndpointTagConfig(BaseModel):
class EndpointTagConfig(BackendAISchema):
session: SessionConfig
endpoint: EndpointConfig


class EndpointCreationRequestModel(BaseModel):
class EndpointCreationRequestModel(BackendAISchema):
version: Annotated[str, Field(description="Creation API version")]
service_name: Annotated[str, Field(description="Name of the model service.")]
tags: Annotated[
Expand Down Expand Up @@ -101,7 +102,7 @@ class EndpointCreationRequestModel(BaseModel):
] = None


class EndpointCreationResponseModel(BaseModel):
class EndpointCreationResponseModel(BackendAISchema):
success: bool
endpoint: AnyUrl
health_check_enabled: bool
Expand Down Expand Up @@ -211,7 +212,7 @@ async def bulk_unregister_routes(
return PydanticResponse(BulkUnregisterRoutesResponse(endpoints=items))


class UpdateModelHealthCheckRequestModel(BaseModel):
class UpdateModelHealthCheckRequestModel(BackendAISchema):
health_check: ModelHealthCheck | None


Expand All @@ -238,7 +239,7 @@ async def inject_health_check_information(
return PydanticResponse(StubResponseModel(success=True))


class EndpointAPITokenGenerationRequestModel(BaseModel):
class EndpointAPITokenGenerationRequestModel(BackendAISchema):
user_uuid: UUID
"""
Token requester's user UUID.
Expand All @@ -249,7 +250,7 @@ class EndpointAPITokenGenerationRequestModel(BaseModel):
"""


class EndpointAPITokenResponseModel(BaseModel):
class EndpointAPITokenResponseModel(BackendAISchema):
token: str


Expand Down
Loading
Loading