diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 186b107..42cac9e 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -12,7 +12,7 @@ from collections import defaultdict from datetime import timedelta from functools import lru_cache -from typing import TYPE_CHECKING, Any, BinaryIO, NamedTuple +from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple import numpy as np import requests @@ -20,7 +20,7 @@ import torch.distributed as dist import zmq from loguru import logger -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator, WithJsonSchema from safetensors.torch import safe_open from torch.multiprocessing.reductions import reduce_tensor @@ -38,16 +38,47 @@ class FileMeta(TypedDict): tp_concat_dim: int -class ParameterMeta(BaseModel): - # now all classes are changed to pydantic BaseModel - # it will directly report validation errors for unknown types - # like torch.dtype, torch.Size, so we need this configuration - # see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment - model_config = ConfigDict(arbitrary_types_allowed=True) +def _dt_validate(value: Any) -> torch.dtype: + if isinstance(value, str): + if not value.startswith("torch."): + raise ValueError(f"dtype {value} should start with torch.") + try: + value = getattr(torch, value.split(".")[1]) + except AttributeError as e: + raise ValueError(f"unknown dtype: {value}") from e + if not isinstance(value, torch.dtype): + raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}") + return value + + +_TorchDtype = Annotated[ + torch.dtype, + PlainValidator(_dt_validate), + PlainSerializer(lambda x: str(x), return_type=str), + WithJsonSchema({"type": "string"}, mode="serialization"), +] + +def _size_validate(value: Any) -> torch.Size: + if isinstance(value, list | tuple): + return torch.Size(value) + if not isinstance(value, torch.Size): + raise TypeError(f"size {value} should be torch.Size, got {type(value)}") + return value + + +_TorchSize = Annotated[ + torch.Size, + PlainValidator(_size_validate), + PlainSerializer(lambda x: tuple(x), return_type=tuple), + WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"), +] + + +class ParameterMeta(BaseModel): name: str - dtype: torch.dtype - shape: torch.Size + dtype: _TorchDtype + shape: _TorchSize class BucketRange(NamedTuple): diff --git a/pyproject.toml b/pyproject.toml index f4c9089..533d2ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,13 @@ requires-python = ">=3.10" dependencies = [ "torch>=2.5.0", "fastapi", - "pydantic", + "pydantic>=2.0.0", "safetensors", "pyzmq", "uvicorn", "loguru", "numpy", + "requests", ] [project.optional-dependencies]