Skip to content

Commit 77e2d61

Browse files
committed
feat: split __main__.py from ps.py
1 parent e0a6a55 commit 77e2d61

2 files changed

Lines changed: 29 additions & 57 deletions

File tree

checkpoint_engine/__main__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import argparse
2+
import os
3+
4+
from loguru import logger
5+
6+
from checkpoint_engine.api import _init_api
7+
from checkpoint_engine.ps import ParameterServer
8+
9+
10+
@logger.catch(reraise=True)
11+
def run_from_cli():
12+
import uvicorn
13+
14+
parser = argparse.ArgumentParser(description="Parameter Server")
15+
parser.add_argument("--uds", type=str)
16+
17+
args = parser.parse_args()
18+
logger.info(
19+
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
20+
)
21+
22+
assert args.uds and len(args.uds) > 0, args.uds
23+
ps = ParameterServer(auto_pg=True)
24+
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
25+
26+
27+
if __name__ == "__main__":
28+
run_from_cli()

checkpoint_engine/ps.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
1-
import argparse
21
import ctypes
32
import os
43
import threading
54
from collections import defaultdict
65
from collections.abc import Callable
76
from datetime import timedelta
8-
from typing import TYPE_CHECKING, Any, BinaryIO
7+
from typing import TYPE_CHECKING
98

10-
import httpx
119
import torch
1210
import torch.distributed as dist
1311
import zmq
1412
from loguru import logger
15-
from pydantic import BaseModel
1613
from torch.multiprocessing.reductions import reduce_tensor
1714

18-
import httpx
1915
from checkpoint_engine.api import _init_api
2016
from checkpoint_engine.data_types import (
2117
BucketRange,
@@ -61,37 +57,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None
6157
raise ValueError(f"fail to get physical gpu id {device_index}") from e
6258

6359

64-
def request_inference_to_update(
65-
url: str,
66-
socket_paths: dict[str, str],
67-
timeout: float = 300.0,
68-
uds: str | None = None,
69-
):
70-
"""Send an inference update request to inference server via HTTP or Unix socket.
71-
72-
Args:
73-
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
74-
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
75-
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
76-
uds (str, optional): Path to a Unix domain socket. If provided, the request
77-
will be sent via the Unix socket instead of HTTP. Defaults to None.
78-
79-
Raises:
80-
httpx.HTTPStatusError: If the response contains an HTTP error status.
81-
httpx.RequestError: If there was an issue while making the request.
82-
"""
83-
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
84-
url,
85-
json={
86-
"method": "update_weights_from_ipc",
87-
"args": [socket_paths],
88-
"timeout": timeout,
89-
},
90-
timeout=timeout,
91-
)
92-
resp.raise_for_status()
93-
94-
9560
def _gen_h2d_buckets(
9661
global_metas: dict[int, MemoryBufferMetaList],
9762
bucket_size: int,
@@ -888,24 +853,3 @@ def _update_per_bucket(
888853
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
889854

890855
self.device_manager.device_module.empty_cache()
891-
892-
893-
@logger.catch(reraise=True)
894-
def run_from_cli():
895-
import uvicorn
896-
897-
parser = argparse.ArgumentParser(description="Parameter Server")
898-
parser.add_argument("--uds", type=str)
899-
900-
args = parser.parse_args()
901-
logger.info(
902-
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
903-
)
904-
905-
assert args.uds and len(args.uds) > 0, args.uds
906-
ps = ParameterServer(auto_pg=True)
907-
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
908-
909-
910-
if __name__ == "__main__":
911-
run_from_cli()

0 commit comments

Comments
 (0)