diff --git a/checkpoint_engine/api.py b/checkpoint_engine/api.py index e61b41d..4f6119a 100644 --- a/checkpoint_engine/api.py +++ b/checkpoint_engine/api.py @@ -6,11 +6,15 @@ from fastapi import Request from fastapi.responses import JSONResponse, Response from loguru import logger -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter, ValidationError +from checkpoint_engine.data_types import MemoryBufferMetaList from checkpoint_engine.ps import ParameterServer +_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList]) + + def request_inference_to_update( url: str, socket_paths: dict[str, str], @@ -79,6 +83,28 @@ async def healthz() -> Response: async def gather_metas(checkpoint_name: str) -> Response: return wrap_exception(lambda: ps.gather_metas(checkpoint_name)) + @app.get("/v1/metas") + async def get_metas() -> Response: + try: + metas = ps.get_metas() + except Exception as e: # noqa: BLE001 + logger.exception("get_metas failed") + return JSONResponse(content=str(e), status_code=500) + return Response( + content=_METAS_ADAPTER.dump_json(metas), + media_type="application/json", + ) + + @app.post("/v1/metas") + async def load_metas(raw: Request) -> Response: + body = await raw.body() + try: + metas = _METAS_ADAPTER.validate_json(body) + except ValidationError as e: + logger.exception("load_metas json validation failed") + return JSONResponse(content=str(e), status_code=400) + return wrap_exception(lambda: ps.load_metas(metas)) + @app.post("/v1/checkpoints/{checkpoint_name}/update") async def update(checkpoint_name: str, req: UpdateRequest) -> Response: def update_func(socket_paths: list[tuple[str, str]]): diff --git a/examples/update.py b/examples/update.py index d331a12..a792280 100644 --- a/examples/update.py +++ b/examples/update.py @@ -1,7 +1,6 @@ import argparse import json import os -import pickle import time from collections import defaultdict from collections.abc import Callable @@ -11,13 +10,18 @@ import httpx import torch from loguru import logger +from pydantic import TypeAdapter from safetensors import safe_open import checkpoint_engine.distributed as dist from checkpoint_engine import request_inference_to_update +from checkpoint_engine.data_types import MemoryBufferMetaList from checkpoint_engine.ps import ParameterServer +_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList]) + + @contextmanager def timer(msg: str): start = time.perf_counter() @@ -110,7 +114,7 @@ def update_weights( ps.gather_metas(checkpoint_name) if save_metas_file and int(os.getenv("RANK")) == 0: with open(save_metas_file, "wb") as f: - pickle.dump(ps.get_metas(), f) + f.write(_METAS_ADAPTER.dump_json(ps.get_metas())) if update_method == "broadcast" or update_method == "all": with timer("Update weights without setting ranks"): @@ -127,15 +131,22 @@ def update_weights( def join( ps: ParameterServer, checkpoint_name: str, - load_metas_file: str, + load_metas_file: str | None, + metas_url: str | None, req_func: Callable[[list[tuple[str, str]]], None], inference_parallel_size: int, endpoint: str, uds: str | None = None, ): - assert load_metas_file, "load_metas_file is required" - with open(load_metas_file, "rb") as f: - metas = pickle.load(f) + if load_metas_file: + with open(load_metas_file, "rb") as f: + metas = _METAS_ADAPTER.validate_json(f.read()) + elif metas_url: + resp = httpx.get(metas_url, timeout=300.0) + resp.raise_for_status() + metas = _METAS_ADAPTER.validate_json(resp.content) + else: + raise ValueError("either load_metas_file or metas_url is required") ps.init_process_group() check_vllm_ready(endpoint, inference_parallel_size, uds) dist.barrier() @@ -152,7 +163,19 @@ def join( parser = argparse.ArgumentParser(description="Update weights example") parser.add_argument("--checkpoint-path", type=str, default=None) parser.add_argument("--save-metas-file", type=str, default=None) - parser.add_argument("--load-metas-file", type=str, default=None) + metas_src = parser.add_mutually_exclusive_group() + metas_src.add_argument( + "--load-metas-file", + type=str, + default=None, + help="Path to a metas JSON file (triggers join mode)", + ) + metas_src.add_argument( + "--metas-url", + type=str, + default=None, + help="HTTP URL returning a metas JSON (triggers join mode)", + ) parser.add_argument("--sleep-time", type=int, default=0) parser.add_argument("--endpoint", type=str, default="http://localhost:19730") parser.add_argument("--inference-parallel-size", type=int, default=8) @@ -167,11 +190,12 @@ def join( req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds) dist.use_backend(args.custom_dist) ps = ParameterServer(auto_pg=True) - if args.load_metas_file: + if args.load_metas_file or args.metas_url: join( ps, args.checkpoint_name, args.load_metas_file, + args.metas_url, req_func, args.inference_parallel_size, args.endpoint, diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..ce850aa --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,135 @@ +"""CPU-only tests for the metas endpoints in api.py.""" + +from unittest.mock import MagicMock + +import pytest +import torch +from fastapi.testclient import TestClient +from pydantic import TypeAdapter + +from checkpoint_engine.api import _init_api +from checkpoint_engine.data_types import ( + MemoryBufferMetaList, + MemoryBufferMetas, + ParameterMeta, +) + + +_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList]) + + +def _make_meta(rdma_device: str, ip: str) -> MemoryBufferMetaList: + return MemoryBufferMetaList( + p2p_store_addr=f"{ip}:12345", + rdma_device=rdma_device, + memory_buffer_metas_list=[ + MemoryBufferMetas( + metas=[ + ParameterMeta( + name="w", + dtype=torch.float16, + shape=torch.Size([2, 3]), + aligned_size=12, + ) + ], + ptr=0x12345678, + size=1024, + ) + ], + ) + + +@pytest.fixture +def fake_metas() -> dict[int, MemoryBufferMetaList]: + return { + 0: _make_meta("mlx5_0", "192.168.1.1"), + 1: _make_meta("mlx5_1", "192.168.1.1"), + } + + +@pytest.fixture +def ps_mock(fake_metas: dict[int, MemoryBufferMetaList]) -> MagicMock: + ps = MagicMock() + ps.get_metas.return_value = fake_metas + return ps + + +def test_get_metas_returns_json( + ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList] +) -> None: + client = TestClient(_init_api(ps_mock)) + resp = client.get("/v1/metas") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "application/json" + assert _METAS_ADAPTER.validate_json(resp.content) == fake_metas + ps_mock.get_metas.assert_called_once_with() + + +def test_get_metas_propagates_ps_error(ps_mock: MagicMock) -> None: + ps_mock.get_metas.side_effect = RuntimeError("metas not gathered yet") + client = TestClient(_init_api(ps_mock)) + resp = client.get("/v1/metas") + assert resp.status_code == 500 + assert "metas not gathered yet" in resp.text + + +def test_load_metas_decodes_and_calls_ps( + ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList] +) -> None: + client = TestClient(_init_api(ps_mock)) + resp = client.post( + "/v1/metas", + content=_METAS_ADAPTER.dump_json(fake_metas), + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 200 + ps_mock.load_metas.assert_called_once_with(fake_metas) + + +def test_load_metas_rejects_bad_json(ps_mock: MagicMock) -> None: + client = TestClient(_init_api(ps_mock)) + resp = client.post( + "/v1/metas", + content=b"not a valid json", + ) + assert resp.status_code == 400 + ps_mock.load_metas.assert_not_called() + + +def test_load_metas_rejects_schema_mismatch(ps_mock: MagicMock) -> None: + """JSON that parses but doesn't match MemoryBufferMetaList shape -> 400.""" + client = TestClient(_init_api(ps_mock)) + resp = client.post( + "/v1/metas", + content=b'{"0": {"foo": "bar"}}', + ) + assert resp.status_code == 400 + ps_mock.load_metas.assert_not_called() + + +def test_load_metas_propagates_ps_error( + ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList] +) -> None: + ps_mock.load_metas.side_effect = RuntimeError("rdma device mismatch") + client = TestClient(_init_api(ps_mock)) + resp = client.post( + "/v1/metas", + content=_METAS_ADAPTER.dump_json(fake_metas), + ) + assert resp.status_code == 500 + assert "rdma device mismatch" in resp.text + + +def test_round_trip_get_then_load( + ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList] +) -> None: + """JSON bytes returned by GET /v1/metas must be accepted by POST /v1/metas.""" + client = TestClient(_init_api(ps_mock)) + get_resp = client.get("/v1/metas") + assert get_resp.status_code == 200 + load_resp = client.post( + "/v1/metas", + content=get_resp.content, + ) + assert load_resp.status_code == 200 + ps_mock.load_metas.assert_called_once_with(fake_metas)