11"""CPU-only tests for the metas endpoints in api.py."""
22
3- import pickle
43from unittest .mock import MagicMock
54
65import pytest
6+ import torch
77from fastapi .testclient import TestClient
8+ from pydantic import TypeAdapter
89
910from checkpoint_engine .api import _init_api
11+ from checkpoint_engine .data_types import (
12+ MemoryBufferMetaList ,
13+ MemoryBufferMetas ,
14+ ParameterMeta ,
15+ )
16+
17+
18+ _METAS_ADAPTER = TypeAdapter (dict [int , MemoryBufferMetaList ])
19+
20+
21+ def _make_meta (rdma_device : str , ip : str ) -> MemoryBufferMetaList :
22+ return MemoryBufferMetaList (
23+ p2p_store_addr = f"{ ip } :12345" ,
24+ rdma_device = rdma_device ,
25+ memory_buffer_metas_list = [
26+ MemoryBufferMetas (
27+ metas = [
28+ ParameterMeta (
29+ name = "w" ,
30+ dtype = torch .float16 ,
31+ shape = torch .Size ([2 , 3 ]),
32+ aligned_size = 12 ,
33+ )
34+ ],
35+ ptr = 0x12345678 ,
36+ size = 1024 ,
37+ )
38+ ],
39+ )
1040
1141
1242@pytest .fixture
13- def fake_metas () -> dict :
14- # Mimic ParameterServer.get_metas() return shape (dict[int, ...]),
15- # exact value type doesn't matter for the round-trip test.
16- return {0 : {"foo" : [1 , 2 , 3 ]}, 1 : {"bar" : "baz" }}
43+ def fake_metas () -> dict [int , MemoryBufferMetaList ]:
44+ return {
45+ 0 : _make_meta ("mlx5_0" , "192.168.1.1" ),
46+ 1 : _make_meta ("mlx5_1" , "192.168.1.1" ),
47+ }
1748
1849
1950@pytest .fixture
20- def ps_mock (fake_metas : dict ) -> MagicMock :
51+ def ps_mock (fake_metas : dict [ int , MemoryBufferMetaList ] ) -> MagicMock :
2152 ps = MagicMock ()
2253 ps .get_metas .return_value = fake_metas
2354 return ps
2455
2556
26- def test_get_metas_returns_pickle_bytes (ps_mock : MagicMock , fake_metas : dict ) -> None :
57+ def test_get_metas_returns_json (
58+ ps_mock : MagicMock , fake_metas : dict [int , MemoryBufferMetaList ]
59+ ) -> None :
2760 client = TestClient (_init_api (ps_mock ))
2861 resp = client .get ("/v1/checkpoints/my-ckpt/metas" )
2962 assert resp .status_code == 200
30- assert resp .headers ["content-type" ] == "application/octet-stream "
31- assert pickle . loads (resp .content ) == fake_metas
63+ assert resp .headers ["content-type" ] == "application/json "
64+ assert _METAS_ADAPTER . validate_json (resp .content ) == fake_metas
3265 ps_mock .get_metas .assert_called_once_with ()
3366
3467
@@ -40,40 +73,57 @@ def test_get_metas_propagates_ps_error(ps_mock: MagicMock) -> None:
4073 assert "metas not gathered yet" in resp .text
4174
4275
43- def test_load_metas_decodes_and_calls_ps (ps_mock : MagicMock , fake_metas : dict ) -> None :
76+ def test_load_metas_decodes_and_calls_ps (
77+ ps_mock : MagicMock , fake_metas : dict [int , MemoryBufferMetaList ]
78+ ) -> None :
4479 client = TestClient (_init_api (ps_mock ))
4580 resp = client .post (
4681 "/v1/checkpoints/my-ckpt/load-metas" ,
47- content = pickle . dumps (fake_metas ),
48- headers = {"content-type" : "application/octet-stream " },
82+ content = _METAS_ADAPTER . dump_json (fake_metas ),
83+ headers = {"content-type" : "application/json " },
4984 )
5085 assert resp .status_code == 200
5186 ps_mock .load_metas .assert_called_once_with (fake_metas )
5287
5388
54- def test_load_metas_rejects_bad_pickle (ps_mock : MagicMock ) -> None :
89+ def test_load_metas_rejects_bad_json (ps_mock : MagicMock ) -> None :
90+ client = TestClient (_init_api (ps_mock ))
91+ resp = client .post (
92+ "/v1/checkpoints/my-ckpt/load-metas" ,
93+ content = b"not a valid json" ,
94+ )
95+ assert resp .status_code == 400
96+ ps_mock .load_metas .assert_not_called ()
97+
98+
99+ def test_load_metas_rejects_schema_mismatch (ps_mock : MagicMock ) -> None :
100+ """JSON that parses but doesn't match MemoryBufferMetaList shape -> 400."""
55101 client = TestClient (_init_api (ps_mock ))
56102 resp = client .post (
57103 "/v1/checkpoints/my-ckpt/load-metas" ,
58- content = b"not a valid pickle" ,
104+ content = b'{"0": {"foo": "bar"}}' ,
59105 )
60106 assert resp .status_code == 400
61107 ps_mock .load_metas .assert_not_called ()
62108
63109
64- def test_load_metas_propagates_ps_error (ps_mock : MagicMock , fake_metas : dict ) -> None :
110+ def test_load_metas_propagates_ps_error (
111+ ps_mock : MagicMock , fake_metas : dict [int , MemoryBufferMetaList ]
112+ ) -> None :
65113 ps_mock .load_metas .side_effect = RuntimeError ("rdma device mismatch" )
66114 client = TestClient (_init_api (ps_mock ))
67115 resp = client .post (
68116 "/v1/checkpoints/my-ckpt/load-metas" ,
69- content = pickle . dumps (fake_metas ),
117+ content = _METAS_ADAPTER . dump_json (fake_metas ),
70118 )
71119 assert resp .status_code == 500
72120 assert "rdma device mismatch" in resp .text
73121
74122
75- def test_round_trip_get_then_load (ps_mock : MagicMock , fake_metas : dict ) -> None :
76- """Pickle bytes returned by GET /metas must be accepted by POST /load-metas."""
123+ def test_round_trip_get_then_load (
124+ ps_mock : MagicMock , fake_metas : dict [int , MemoryBufferMetaList ]
125+ ) -> None :
126+ """JSON bytes returned by GET /metas must be accepted by POST /load-metas."""
77127 client = TestClient (_init_api (ps_mock ))
78128 get_resp = client .get ("/v1/checkpoints/source/metas" )
79129 assert get_resp .status_code == 200
0 commit comments