From b7d7b837c2de882847fdb09ad431b277d6b8be69 Mon Sep 17 00:00:00 2001 From: weixiao-huang Date: Mon, 15 Sep 2025 22:25:40 +0800 Subject: [PATCH 1/2] feat: make ParameterMeta JSON serializable --- checkpoint_engine/ps.py | 48 ++++++++++++++++++++++++++++++++--------- pyproject.toml | 3 ++- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 186b107..f04f276 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -12,7 +12,7 @@ from collections import defaultdict from datetime import timedelta from functools import lru_cache -from typing import TYPE_CHECKING, Any, BinaryIO, NamedTuple +from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple import numpy as np import requests @@ -20,7 +20,7 @@ import torch.distributed as dist import zmq from loguru import logger -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator, WithJsonSchema from safetensors.torch import safe_open from torch.multiprocessing.reductions import reduce_tensor @@ -38,16 +38,44 @@ class FileMeta(TypedDict): tp_concat_dim: int -class ParameterMeta(BaseModel): - # now all classes are changed to pydantic BaseModel - # it will directly report validation errors for unknown types - # like torch.dtype, torch.Size, so we need this configuration - # see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment - model_config = ConfigDict(arbitrary_types_allowed=True) +def _dt_validate(value: Any) -> torch.dtype: + if isinstance(value, str): + assert value.startswith("torch."), f"dtype {value} should start with torch." + try: + value = getattr(torch, value.split(".")[1]) + except AttributeError as e: + raise ValueError(f"unknown dtype: {value}") from e + assert isinstance(value, torch.dtype), f"dtype {value} should be torch.dtype, got {type(value)}" + return value + + +_TorchDtype = Annotated[ + torch.dtype, + PlainValidator(_dt_validate), + PlainSerializer(lambda x: str(x), return_type=str), + WithJsonSchema({"type": "string"}, mode="serialization"), +] + +def _size_validate(value: Any) -> torch.Size: + if isinstance(value, list | tuple): + return torch.Size(value) + assert isinstance(value, torch.Size), f"size {value} should be torch.Size, got {type(value)}" + return value + + +_TorchSize = Annotated[ + torch.Size, + PlainValidator(_size_validate), + PlainSerializer(lambda x: tuple(x), return_type=tuple), + WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"), +] + + +class ParameterMeta(BaseModel): name: str - dtype: torch.dtype - shape: torch.Size + dtype: _TorchDtype + shape: _TorchSize class BucketRange(NamedTuple): diff --git a/pyproject.toml b/pyproject.toml index f4c9089..533d2ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,13 @@ requires-python = ">=3.10" dependencies = [ "torch>=2.5.0", "fastapi", - "pydantic", + "pydantic>=2.0.0", "safetensors", "pyzmq", "uvicorn", "loguru", "numpy", + "requests", ] [project.optional-dependencies] From e82869cc3ff72da86c2a1a0ff00fe52bbc4576ee Mon Sep 17 00:00:00 2001 From: weixiao-huang Date: Tue, 16 Sep 2025 11:07:59 +0800 Subject: [PATCH 2/2] feat: change assert to raise --- checkpoint_engine/ps.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index f04f276..42cac9e 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -40,12 +40,14 @@ class FileMeta(TypedDict): def _dt_validate(value: Any) -> torch.dtype: if isinstance(value, str): - assert value.startswith("torch."), f"dtype {value} should start with torch." + if not value.startswith("torch."): + raise ValueError(f"dtype {value} should start with torch.") try: value = getattr(torch, value.split(".")[1]) except AttributeError as e: raise ValueError(f"unknown dtype: {value}") from e - assert isinstance(value, torch.dtype), f"dtype {value} should be torch.dtype, got {type(value)}" + if not isinstance(value, torch.dtype): + raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}") return value @@ -60,7 +62,8 @@ def _dt_validate(value: Any) -> torch.dtype: def _size_validate(value: Any) -> torch.Size: if isinstance(value, list | tuple): return torch.Size(value) - assert isinstance(value, torch.Size), f"size {value} should be torch.Size, got {type(value)}" + if not isinstance(value, torch.Size): + raise TypeError(f"size {value} should be torch.Size, got {type(value)}") return value