|
1 | | -import argparse |
2 | 1 | import ctypes |
3 | 2 | import os |
4 | 3 | import threading |
5 | 4 | from collections import defaultdict |
6 | 5 | from collections.abc import Callable |
7 | 6 | from datetime import timedelta |
8 | | -from typing import TYPE_CHECKING, Any, BinaryIO |
| 7 | +from typing import TYPE_CHECKING |
9 | 8 |
|
10 | | -import httpx |
11 | 9 | import torch |
12 | 10 | import torch.distributed as dist |
13 | 11 | import zmq |
14 | 12 | from loguru import logger |
15 | | -from pydantic import BaseModel |
16 | 13 | from torch.multiprocessing.reductions import reduce_tensor |
17 | 14 |
|
18 | | -import httpx |
19 | 15 | from checkpoint_engine.api import _init_api |
20 | 16 | from checkpoint_engine.data_types import ( |
21 | 17 | BucketRange, |
@@ -61,37 +57,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None |
61 | 57 | raise ValueError(f"fail to get physical gpu id {device_index}") from e |
62 | 58 |
|
63 | 59 |
|
64 | | -def request_inference_to_update( |
65 | | - url: str, |
66 | | - socket_paths: dict[str, str], |
67 | | - timeout: float = 300.0, |
68 | | - uds: str | None = None, |
69 | | -): |
70 | | - """Send an inference update request to inference server via HTTP or Unix socket. |
71 | | -
|
72 | | - Args: |
73 | | - url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to. |
74 | | - socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights. |
75 | | - timeout (float, optional): Request timeout in seconds. Defaults to 300.0. |
76 | | - uds (str, optional): Path to a Unix domain socket. If provided, the request |
77 | | - will be sent via the Unix socket instead of HTTP. Defaults to None. |
78 | | -
|
79 | | - Raises: |
80 | | - httpx.HTTPStatusError: If the response contains an HTTP error status. |
81 | | - httpx.RequestError: If there was an issue while making the request. |
82 | | - """ |
83 | | - resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post( |
84 | | - url, |
85 | | - json={ |
86 | | - "method": "update_weights_from_ipc", |
87 | | - "args": [socket_paths], |
88 | | - "timeout": timeout, |
89 | | - }, |
90 | | - timeout=timeout, |
91 | | - ) |
92 | | - resp.raise_for_status() |
93 | | - |
94 | | - |
95 | 60 | def _gen_h2d_buckets( |
96 | 61 | global_metas: dict[int, MemoryBufferMetaList], |
97 | 62 | bucket_size: int, |
@@ -888,24 +853,3 @@ def _update_per_bucket( |
888 | 853 | self._p2p_store.unregister_named_tensors([h2d_buffer_name]) |
889 | 854 |
|
890 | 855 | self.device_manager.device_module.empty_cache() |
891 | | - |
892 | | - |
893 | | -@logger.catch(reraise=True) |
894 | | -def run_from_cli(): |
895 | | - import uvicorn |
896 | | - |
897 | | - parser = argparse.ArgumentParser(description="Parameter Server") |
898 | | - parser.add_argument("--uds", type=str) |
899 | | - |
900 | | - args = parser.parse_args() |
901 | | - logger.info( |
902 | | - f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}" |
903 | | - ) |
904 | | - |
905 | | - assert args.uds and len(args.uds) > 0, args.uds |
906 | | - ps = ParameterServer(auto_pg=True) |
907 | | - uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60) |
908 | | - |
909 | | - |
910 | | -if __name__ == "__main__": |
911 | | - run_from_cli() |
0 commit comments