Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from collections import defaultdict
from datetime import timedelta
from functools import lru_cache
from typing import TYPE_CHECKING, Any, BinaryIO, NamedTuple
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple

import numpy as np
import requests
import torch
import torch.distributed as dist
import zmq
from loguru import logger
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator, WithJsonSchema
from safetensors.torch import safe_open
from torch.multiprocessing.reductions import reduce_tensor

Expand All @@ -38,16 +38,47 @@ class FileMeta(TypedDict):
tp_concat_dim: int


class ParameterMeta(BaseModel):
# now all classes are changed to pydantic BaseModel
# it will directly report validation errors for unknown types
# like torch.dtype, torch.Size, so we need this configuration
# see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment
model_config = ConfigDict(arbitrary_types_allowed=True)
def _dt_validate(value: Any) -> torch.dtype:
if isinstance(value, str):
if not value.startswith("torch."):
raise ValueError(f"dtype {value} should start with torch.")
try:
value = getattr(torch, value.split(".")[1])
except AttributeError as e:
raise ValueError(f"unknown dtype: {value}") from e
if not isinstance(value, torch.dtype):
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
return value


_TorchDtype = Annotated[
torch.dtype,
PlainValidator(_dt_validate),
PlainSerializer(lambda x: str(x), return_type=str),
WithJsonSchema({"type": "string"}, mode="serialization"),
]


def _size_validate(value: Any) -> torch.Size:
if isinstance(value, list | tuple):
Comment thread
weixiao-huang marked this conversation as resolved.
return torch.Size(value)
if not isinstance(value, torch.Size):
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
return value


_TorchSize = Annotated[
torch.Size,
PlainValidator(_size_validate),
PlainSerializer(lambda x: tuple(x), return_type=tuple),
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
]


class ParameterMeta(BaseModel):
name: str
dtype: torch.dtype
shape: torch.Size
dtype: _TorchDtype
shape: _TorchSize


class BucketRange(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ requires-python = ">=3.10"
dependencies = [
"torch>=2.5.0",
"fastapi",
"pydantic",
"pydantic>=2.0.0",
"safetensors",
"pyzmq",
"uvicorn",
"loguru",
"numpy",
"requests",
]

[project.optional-dependencies]
Expand Down