Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion checkpoint_engine/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
from fastapi import Request
from fastapi.responses import JSONResponse, Response
from loguru import logger
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter, ValidationError

from checkpoint_engine.data_types import MemoryBufferMetaList
from checkpoint_engine.ps import ParameterServer


_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList])


def request_inference_to_update(
url: str,
socket_paths: dict[str, str],
Expand Down Expand Up @@ -79,6 +83,28 @@ async def healthz() -> Response:
async def gather_metas(checkpoint_name: str) -> Response:
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))

@app.get("/v1/metas")
async def get_metas() -> Response:
try:
metas = ps.get_metas()
except Exception as e: # noqa: BLE001
logger.exception("get_metas failed")
return JSONResponse(content=str(e), status_code=500)
return Response(
content=_METAS_ADAPTER.dump_json(metas),
media_type="application/json",
)

@app.post("/v1/metas")
async def load_metas(raw: Request) -> Response:
body = await raw.body()
try:
metas = _METAS_ADAPTER.validate_json(body)
except ValidationError as e:
logger.exception("load_metas json validation failed")
return JSONResponse(content=str(e), status_code=400)
return wrap_exception(lambda: ps.load_metas(metas))

@app.post("/v1/checkpoints/{checkpoint_name}/update")
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
def update_func(socket_paths: list[tuple[str, str]]):
Expand Down
40 changes: 32 additions & 8 deletions examples/update.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import json
import os
import pickle
import time
from collections import defaultdict
from collections.abc import Callable
Expand All @@ -11,13 +10,18 @@
import httpx
import torch
from loguru import logger
from pydantic import TypeAdapter
from safetensors import safe_open

import checkpoint_engine.distributed as dist
from checkpoint_engine import request_inference_to_update
from checkpoint_engine.data_types import MemoryBufferMetaList
from checkpoint_engine.ps import ParameterServer


_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList])


@contextmanager
def timer(msg: str):
start = time.perf_counter()
Expand Down Expand Up @@ -110,7 +114,7 @@ def update_weights(
ps.gather_metas(checkpoint_name)
if save_metas_file and int(os.getenv("RANK")) == 0:
with open(save_metas_file, "wb") as f:
pickle.dump(ps.get_metas(), f)
f.write(_METAS_ADAPTER.dump_json(ps.get_metas()))

if update_method == "broadcast" or update_method == "all":
with timer("Update weights without setting ranks"):
Expand All @@ -127,15 +131,22 @@ def update_weights(
def join(
ps: ParameterServer,
checkpoint_name: str,
load_metas_file: str,
load_metas_file: str | None,
metas_url: str | None,
req_func: Callable[[list[tuple[str, str]]], None],
inference_parallel_size: int,
endpoint: str,
uds: str | None = None,
):
assert load_metas_file, "load_metas_file is required"
with open(load_metas_file, "rb") as f:
metas = pickle.load(f)
if load_metas_file:
with open(load_metas_file, "rb") as f:
metas = _METAS_ADAPTER.validate_json(f.read())
elif metas_url:
resp = httpx.get(metas_url, timeout=300.0)
resp.raise_for_status()
metas = _METAS_ADAPTER.validate_json(resp.content)
else:
raise ValueError("either load_metas_file or metas_url is required")
ps.init_process_group()
check_vllm_ready(endpoint, inference_parallel_size, uds)
dist.barrier()
Expand All @@ -152,7 +163,19 @@ def join(
parser = argparse.ArgumentParser(description="Update weights example")
parser.add_argument("--checkpoint-path", type=str, default=None)
parser.add_argument("--save-metas-file", type=str, default=None)
parser.add_argument("--load-metas-file", type=str, default=None)
metas_src = parser.add_mutually_exclusive_group()
metas_src.add_argument(
"--load-metas-file",
type=str,
default=None,
help="Path to a metas JSON file (triggers join mode)",
)
metas_src.add_argument(
"--metas-url",
type=str,
default=None,
help="HTTP URL returning a metas JSON (triggers join mode)",
)
parser.add_argument("--sleep-time", type=int, default=0)
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
parser.add_argument("--inference-parallel-size", type=int, default=8)
Expand All @@ -167,11 +190,12 @@ def join(
req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds)
dist.use_backend(args.custom_dist)
ps = ParameterServer(auto_pg=True)
if args.load_metas_file:
if args.load_metas_file or args.metas_url:
join(
ps,
args.checkpoint_name,
args.load_metas_file,
args.metas_url,
req_func,
args.inference_parallel_size,
args.endpoint,
Expand Down
135 changes: 135 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""CPU-only tests for the metas endpoints in api.py."""

from unittest.mock import MagicMock

import pytest
import torch
from fastapi.testclient import TestClient
from pydantic import TypeAdapter

from checkpoint_engine.api import _init_api
from checkpoint_engine.data_types import (
MemoryBufferMetaList,
MemoryBufferMetas,
ParameterMeta,
)


_METAS_ADAPTER = TypeAdapter(dict[int, MemoryBufferMetaList])


def _make_meta(rdma_device: str, ip: str) -> MemoryBufferMetaList:
return MemoryBufferMetaList(
p2p_store_addr=f"{ip}:12345",
rdma_device=rdma_device,
memory_buffer_metas_list=[
MemoryBufferMetas(
metas=[
ParameterMeta(
name="w",
dtype=torch.float16,
shape=torch.Size([2, 3]),
aligned_size=12,
)
],
ptr=0x12345678,
size=1024,
)
],
)


@pytest.fixture
def fake_metas() -> dict[int, MemoryBufferMetaList]:
return {
0: _make_meta("mlx5_0", "192.168.1.1"),
1: _make_meta("mlx5_1", "192.168.1.1"),
}


@pytest.fixture
def ps_mock(fake_metas: dict[int, MemoryBufferMetaList]) -> MagicMock:
ps = MagicMock()
ps.get_metas.return_value = fake_metas
return ps


def test_get_metas_returns_json(
ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList]
) -> None:
client = TestClient(_init_api(ps_mock))
resp = client.get("/v1/metas")
assert resp.status_code == 200
assert resp.headers["content-type"] == "application/json"
assert _METAS_ADAPTER.validate_json(resp.content) == fake_metas
ps_mock.get_metas.assert_called_once_with()


def test_get_metas_propagates_ps_error(ps_mock: MagicMock) -> None:
ps_mock.get_metas.side_effect = RuntimeError("metas not gathered yet")
client = TestClient(_init_api(ps_mock))
resp = client.get("/v1/metas")
assert resp.status_code == 500
assert "metas not gathered yet" in resp.text


def test_load_metas_decodes_and_calls_ps(
ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList]
) -> None:
client = TestClient(_init_api(ps_mock))
resp = client.post(
"/v1/metas",
content=_METAS_ADAPTER.dump_json(fake_metas),
headers={"content-type": "application/json"},
)
assert resp.status_code == 200
ps_mock.load_metas.assert_called_once_with(fake_metas)


def test_load_metas_rejects_bad_json(ps_mock: MagicMock) -> None:
client = TestClient(_init_api(ps_mock))
resp = client.post(
"/v1/metas",
content=b"not a valid json",
)
assert resp.status_code == 400
ps_mock.load_metas.assert_not_called()


def test_load_metas_rejects_schema_mismatch(ps_mock: MagicMock) -> None:
"""JSON that parses but doesn't match MemoryBufferMetaList shape -> 400."""
client = TestClient(_init_api(ps_mock))
resp = client.post(
"/v1/metas",
content=b'{"0": {"foo": "bar"}}',
)
assert resp.status_code == 400
ps_mock.load_metas.assert_not_called()


def test_load_metas_propagates_ps_error(
ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList]
) -> None:
ps_mock.load_metas.side_effect = RuntimeError("rdma device mismatch")
client = TestClient(_init_api(ps_mock))
resp = client.post(
"/v1/metas",
content=_METAS_ADAPTER.dump_json(fake_metas),
)
assert resp.status_code == 500
assert "rdma device mismatch" in resp.text


def test_round_trip_get_then_load(
ps_mock: MagicMock, fake_metas: dict[int, MemoryBufferMetaList]
) -> None:
"""JSON bytes returned by GET /v1/metas must be accepted by POST /v1/metas."""
client = TestClient(_init_api(ps_mock))
get_resp = client.get("/v1/metas")
assert get_resp.status_code == 200
load_resp = client.post(
"/v1/metas",
content=get_resp.content,
)
assert load_resp.status_code == 200
ps_mock.load_metas.assert_called_once_with(fake_metas)
Loading