Skip to content

Commit d083f34

Browse files
刘伟健claude
andcommitted
refactor: serialize metas as JSON instead of pickle
Replace pickle with pydantic TypeAdapter(dict[int, MemoryBufferMetaList]) for the metas wire format across the HTTP endpoints, join_cli, and examples/update.py. This reuses the existing pydantic schema (torch.dtype / torch.Size already have serializers in data_types.py), removes the arbitrary-code-execution risk of pickle.loads on request bodies, and makes the metas self-describing for cross-language consumers. - api.py: GET /metas returns application/json; POST /load-metas validates via validate_json and returns 400 on ValidationError (was broad except). - join_cli.py / examples/update.py: read/write metas as JSON; document the --metas-url HTTP path alongside --load-metas-file. - tests/test_api.py: use real MemoryBufferMetaList fixtures; add a schema-mismatch case (valid JSON, wrong shape -> 400). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 4045708 commit d083f34

4 files changed

Lines changed: 106 additions & 35 deletions

File tree

checkpoint_engine/api.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pickle
21
from collections.abc import Callable
32
from typing import Any
43

@@ -7,11 +6,15 @@
76
from fastapi import Request
87
from fastapi.responses import JSONResponse, Response
98
from loguru import logger
10-
from pydantic import BaseModel
9+
from pydantic import BaseModel, TypeAdapter, ValidationError
1110

11+
from checkpoint_engine.data_types import MemoryBufferMetaList
1212
from checkpoint_engine.ps import ParameterServer
1313

1414

15+
_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList])
16+
17+
1518
def request_inference_to_update(
1619
url: str,
1720
socket_paths: dict[str, str],
@@ -87,15 +90,18 @@ async def get_metas(checkpoint_name: str) -> Response:
8790
except Exception as e: # noqa: BLE001
8891
logger.exception(f"get_metas for {checkpoint_name} failed")
8992
return JSONResponse(content=str(e), status_code=500)
90-
return Response(content=pickle.dumps(metas), media_type="application/octet-stream")
93+
return Response(
94+
content=_METAS_ADAPTER.dump_json(metas),
95+
media_type="application/json",
96+
)
9197

9298
@app.post("/v1/checkpoints/{checkpoint_name}/load-metas")
9399
async def load_metas(checkpoint_name: str, raw: Request) -> Response:
94100
body = await raw.body()
95101
try:
96-
metas = pickle.loads(body)
97-
except Exception as e: # noqa: BLE001
98-
logger.exception(f"load_metas pickle decode for {checkpoint_name} failed")
102+
metas = _METAS_ADAPTER.validate_json(body)
103+
except ValidationError as e:
104+
logger.exception(f"load_metas json validation for {checkpoint_name} failed")
99105
return JSONResponse(content=str(e), status_code=400)
100106
return wrap_exception(lambda: ps.load_metas(metas))
101107

checkpoint_engine/join_cli.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,47 @@
66
77
The remote side must already have done ``gather_metas`` so that its
88
``ps.get_metas()`` returns a usable ``dict[int, MemoryBufferMetaList]``,
9-
and the metas pickle bytes must be reachable via either a file path or
9+
and the metas JSON bytes must be reachable via either a file path or
1010
an HTTP URL.
1111
1212
Usage (one process per local GPU, e.g. via torchrun):
1313
14+
# From a local file (e.g. shared moonfs):
1415
torchrun --nproc-per-node N -m checkpoint_engine.join_cli \\
15-
--load-metas-file /path/to/metas.pkl \\
16+
--load-metas-file /path/to/metas.json \\
1617
--endpoint http://localhost:19730 \\
1718
--inference-parallel-size N \\
1819
[--checkpoint-name <name>]
1920
21+
# Or directly from the source ParameterServer's HTTP endpoint:
22+
torchrun --nproc-per-node N -m checkpoint_engine.join_cli \\
23+
--metas-url http://main-ps-host:19710/v1/checkpoints/<name>/metas \\
24+
--endpoint http://localhost:19730 \\
25+
--inference-parallel-size N
26+
2027
Environment variables (same as ``torchrun`` sets): ``RANK``, ``WORLD_SIZE``,
2128
``LOCAL_RANK``, ``MASTER_ADDR``, ``MASTER_PORT``.
2229
"""
2330

2431
import argparse
2532
import os
26-
import pickle
2733
import time
2834
from collections.abc import Callable
2935
from contextlib import contextmanager
3036

3137
import httpx
3238
from loguru import logger
39+
from pydantic import TypeAdapter
3340

3441
import checkpoint_engine.distributed as dist
3542
from checkpoint_engine import request_inference_to_update
43+
from checkpoint_engine.data_types import MemoryBufferMetaList
3644
from checkpoint_engine.ps import ParameterServer
3745

3846

47+
_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList])
48+
49+
3950
@contextmanager
4051
def _timer(msg: str):
4152
start = time.perf_counter()
@@ -78,14 +89,14 @@ def req(socket_paths: list[tuple[str, str]]) -> None:
7889
return req
7990

8091

81-
def _load_metas(args: argparse.Namespace) -> dict:
92+
def _load_metas(args: argparse.Namespace) -> dict[int, MemoryBufferMetaList]:
8293
if args.load_metas_file:
8394
with open(args.load_metas_file, "rb") as f:
84-
return pickle.load(f)
95+
return _METAS_ADAPTER.validate_json(f.read())
8596
if args.metas_url:
8697
resp = httpx.get(args.metas_url, timeout=300.0)
8798
resp.raise_for_status()
88-
return pickle.loads(resp.content)
99+
return _METAS_ADAPTER.validate_json(resp.content)
89100
raise ValueError("either --load-metas-file or --metas-url is required")
90101

91102

@@ -116,11 +127,11 @@ def main() -> None:
116127
description="Join an existing P2P weight world via mooncake RDMA"
117128
)
118129
src = parser.add_mutually_exclusive_group(required=True)
119-
src.add_argument("--load-metas-file", type=str, help="Path to a metas pickle file")
130+
src.add_argument("--load-metas-file", type=str, help="Path to a metas JSON file")
120131
src.add_argument(
121132
"--metas-url",
122133
type=str,
123-
help="HTTP URL returning a metas pickle (application/octet-stream)",
134+
help="HTTP URL returning a metas JSON (application/json)",
124135
)
125136
parser.add_argument(
126137
"--endpoint",

examples/update.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import json
33
import os
4-
import pickle
54
import time
65
from collections import defaultdict
76
from collections.abc import Callable
@@ -11,13 +10,18 @@
1110
import httpx
1211
import torch
1312
from loguru import logger
13+
from pydantic import TypeAdapter
1414
from safetensors import safe_open
1515

1616
import checkpoint_engine.distributed as dist
1717
from checkpoint_engine import request_inference_to_update
18+
from checkpoint_engine.data_types import MemoryBufferMetaList
1819
from checkpoint_engine.ps import ParameterServer
1920

2021

22+
_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList])
23+
24+
2125
@contextmanager
2226
def timer(msg: str):
2327
start = time.perf_counter()
@@ -110,7 +114,7 @@ def update_weights(
110114
ps.gather_metas(checkpoint_name)
111115
if save_metas_file and int(os.getenv("RANK")) == 0:
112116
with open(save_metas_file, "wb") as f:
113-
pickle.dump(ps.get_metas(), f)
117+
f.write(_METAS_ADAPTER.dump_json(ps.get_metas()))
114118

115119
if update_method == "broadcast" or update_method == "all":
116120
with timer("Update weights without setting ranks"):
@@ -135,7 +139,7 @@ def join(
135139
):
136140
assert load_metas_file, "load_metas_file is required"
137141
with open(load_metas_file, "rb") as f:
138-
metas = pickle.load(f)
142+
metas = _METAS_ADAPTER.validate_json(f.read())
139143
ps.init_process_group()
140144
check_vllm_ready(endpoint, inference_parallel_size, uds)
141145
dist.barrier()

tests/test_api.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,67 @@
11
"""CPU-only tests for the metas endpoints in api.py."""
22

3-
import pickle
43
from unittest.mock import MagicMock
54

65
import pytest
6+
import torch
77
from fastapi.testclient import TestClient
8+
from pydantic import TypeAdapter
89

910
from checkpoint_engine.api import _init_api
11+
from checkpoint_engine.data_types import (
12+
MemoryBufferMetaList,
13+
MemoryBufferMetas,
14+
ParameterMeta,
15+
)
16+
17+
18+
_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList])
19+
20+
21+
def _make_meta(rdma_device: str, ip: str) -> MemoryBufferMetaList:
22+
return MemoryBufferMetaList(
23+
p2p_store_addr=f"{ip}:12345",
24+
rdma_device=rdma_device,
25+
memory_buffer_metas_list=[
26+
MemoryBufferMetas(
27+
metas=[
28+
ParameterMeta(
29+
name="w",
30+
dtype=torch.float16,
31+
shape=torch.Size([2, 3]),
32+
aligned_size=12,
33+
)
34+
],
35+
ptr=0x12345678,
36+
size=1024,
37+
)
38+
],
39+
)
1040

1141

1242
@pytest.fixture
13-
def fake_metas() -> dict:
14-
# Mimic ParameterServer.get_metas() return shape (dict[int, ...]),
15-
# exact value type doesn't matter for the round-trip test.
16-
return {0: {"foo": [1, 2, 3]}, 1: {"bar": "baz"}}
43+
def fake_metas() -> dict[int, MemoryBufferMetaList]:
44+
return {
45+
0: _make_meta("mlx5_0", "192.168.1.1"),
46+
1: _make_meta("mlx5_1", "192.168.1.1"),
47+
}
1748

1849

1950
@pytest.fixture
20-
def ps_mock(fake_metas: dict) -> MagicMock:
51+
def ps_mock(fake_metas: dict[int, MemoryBufferMetaList]) -> MagicMock:
2152
ps = MagicMock()
2253
ps.get_metas.return_value = fake_metas
2354
return ps
2455

2556

26-
def test_get_metas_returns_pickle_bytes(ps_mock: MagicMock, fake_metas: dict) -> None:
57+
def test_get_metas_returns_json(
58+
ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList]
59+
) -> None:
2760
client = TestClient(_init_api(ps_mock))
2861
resp = client.get("/v1/checkpoints/my-ckpt/metas")
2962
assert resp.status_code == 200
30-
assert resp.headers["content-type"] == "application/octet-stream"
31-
assert pickle.loads(resp.content) == fake_metas
63+
assert resp.headers["content-type"] == "application/json"
64+
assert _METAS_ADAPTER.validate_json(resp.content) == fake_metas
3265
ps_mock.get_metas.assert_called_once_with()
3366

3467

@@ -40,40 +73,57 @@ def test_get_metas_propagates_ps_error(ps_mock: MagicMock) -> None:
4073
assert "metas not gathered yet" in resp.text
4174

4275

43-
def test_load_metas_decodes_and_calls_ps(ps_mock: MagicMock, fake_metas: dict) -> None:
76+
def test_load_metas_decodes_and_calls_ps(
77+
ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList]
78+
) -> None:
4479
client = TestClient(_init_api(ps_mock))
4580
resp = client.post(
4681
"/v1/checkpoints/my-ckpt/load-metas",
47-
content=pickle.dumps(fake_metas),
48-
headers={"content-type": "application/octet-stream"},
82+
content=_METAS_ADAPTER.dump_json(fake_metas),
83+
headers={"content-type": "application/json"},
4984
)
5085
assert resp.status_code == 200
5186
ps_mock.load_metas.assert_called_once_with(fake_metas)
5287

5388

54-
def test_load_metas_rejects_bad_pickle(ps_mock: MagicMock) -> None:
89+
def test_load_metas_rejects_bad_json(ps_mock: MagicMock) -> None:
90+
client = TestClient(_init_api(ps_mock))
91+
resp = client.post(
92+
"/v1/checkpoints/my-ckpt/load-metas",
93+
content=b"not a valid json",
94+
)
95+
assert resp.status_code == 400
96+
ps_mock.load_metas.assert_not_called()
97+
98+
99+
def test_load_metas_rejects_schema_mismatch(ps_mock: MagicMock) -> None:
100+
"""JSON that parses but doesn't match MemoryBufferMetaList shape -> 400."""
55101
client = TestClient(_init_api(ps_mock))
56102
resp = client.post(
57103
"/v1/checkpoints/my-ckpt/load-metas",
58-
content=b"not a valid pickle",
104+
content=b'{"0": {"foo": "bar"}}',
59105
)
60106
assert resp.status_code == 400
61107
ps_mock.load_metas.assert_not_called()
62108

63109

64-
def test_load_metas_propagates_ps_error(ps_mock: MagicMock, fake_metas: dict) -> None:
110+
def test_load_metas_propagates_ps_error(
111+
ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList]
112+
) -> None:
65113
ps_mock.load_metas.side_effect = RuntimeError("rdma device mismatch")
66114
client = TestClient(_init_api(ps_mock))
67115
resp = client.post(
68116
"/v1/checkpoints/my-ckpt/load-metas",
69-
content=pickle.dumps(fake_metas),
117+
content=_METAS_ADAPTER.dump_json(fake_metas),
70118
)
71119
assert resp.status_code == 500
72120
assert "rdma device mismatch" in resp.text
73121

74122

75-
def test_round_trip_get_then_load(ps_mock: MagicMock, fake_metas: dict) -> None:
76-
"""Pickle bytes returned by GET /metas must be accepted by POST /load-metas."""
123+
def test_round_trip_get_then_load(
124+
ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList]
125+
) -> None:
126+
"""JSON bytes returned by GET /metas must be accepted by POST /load-metas."""
77127
client = TestClient(_init_api(ps_mock))
78128
get_resp = client.get("/v1/checkpoints/source/metas")
79129
assert get_resp.status_code == 200

0 commit comments

Comments
 (0)