Skip to content

Commit 960b2ac

Browse files
committed
fix: Target server exits automatically
1 parent b0f94b4 commit 960b2ac

4 files changed

Lines changed: 245 additions & 10 deletions

File tree

docs/advanced_features/remote_training.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ Request processing flow:
5050
- Retrieves model configuration (hidden_size, vocab_size, etc.) via POST `/setup`
5151
- Each training step sends a request via POST `/generate` and receives results via NCCL recv
5252
- Supports TP>1 training: only rank 0 sends requests, results are broadcast to other ranks
53+
- After the first successful connection, a background heartbeat is started; on `close()`, a best-effort `/disconnect` is sent
54+
55+
### Client Lifecycle and Automatic Exit
56+
57+
The target server tracks client activity and automatically shuts down after the client exits, preventing leftover GPU-occupying server processes after training completes:
58+
59+
- After the client's first successful request or successful NCCL initialization, a background heartbeat thread is started, sending POST `/heartbeat` every 15 seconds by default
60+
- When the client exits normally, it sends a best-effort POST `/disconnect`; upon receiving it, the server immediately triggers shutdown
61+
- When the client exits abnormally, the server watchdog triggers shutdown after `--client-heartbeat-timeout` is exceeded (default 60 seconds)
62+
- The server only counts actual client API calls as active requests; `GET /health` and unrelated POSTs do not renew the watchdog timer
63+
- `--client-heartbeat-timeout 0` disables the server-side timeout watchdog, but `/disconnect` will still trigger automatic shutdown
64+
65+
Since NCCL transport does not support safe disconnect and reconnect within the same server process, it is recommended to treat each target server process as a resource for a single training session: it automatically exits after training completes or the client disconnects, and a new instance is started for the next training run.
5366

5467
### NCCL Transport
5568

@@ -198,6 +211,7 @@ export NCCL_IB_GID_INDEX=3 # RoCE GID index
198211
| `SPECFORGE_TOPK` | `0` | Server-side target_p top-k compression (`0` = full distribution) |
199212
| `SPECFORGE_TARGET_DTYPE` | `fp32` | target_p computation precision |
200213
| `SPECFORGE_GPU_ID` | auto | Specify GPU device ID |
214+
| `SPECFORGE_HEARTBEAT_INTERVAL` | `15` | Client heartbeat send interval (seconds; `<=0` means the heartbeat thread is not started) |
201215

202216
## 📊 Benchmark Results
203217

scripts/launch_target_server.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def parse_args():
105105
default=None,
106106
help="NCCL TCP rendezvous port for GPU-to-GPU data transfer (default: HTTP port + 100)",
107107
)
108+
parser.add_argument(
109+
"--client-heartbeat-timeout",
110+
type=float,
111+
default=60.0,
112+
help="Seconds of client inactivity before treating as disconnected (0 to disable). Default: 60",
113+
)
108114
return parser.parse_args()
109115

110116

@@ -154,6 +160,7 @@ def main():
154160
nccl_port=args.nccl_port if args.nccl_port else args.port + 100,
155161
host=args.host,
156162
attention_backend=args.attention_backend,
163+
client_heartbeat_timeout=args.client_heartbeat_timeout,
157164
)
158165
server_app.load_model()
159166
logger.info("Model loaded successfully.")
@@ -169,16 +176,26 @@ def main():
169176
args.port,
170177
args.mode,
171178
)
179+
logger.info(
180+
"Client disconnect will shut down server (heartbeat timeout: %.0fs)",
181+
args.client_heartbeat_timeout,
182+
)
172183

173-
def shutdown(signum, frame):
174-
logger.info("Received signal %d, shutting down...", signum)
184+
def request_http_shutdown():
175185
# HTTPServer.shutdown() must be called from a different thread than
176186
# serve_forever(); signal handlers run on the main thread.
177187
threading.Thread(target=httpd.shutdown, daemon=True).start()
178188

189+
def shutdown(signum, frame):
190+
logger.info("Received signal %d, shutting down...", signum)
191+
request_http_shutdown()
192+
179193
signal.signal(signal.SIGINT, shutdown)
180194
signal.signal(signal.SIGTERM, shutdown)
181195

196+
# Register shutdown callback for client disconnect handling
197+
server_app._shutdown_callback = request_http_shutdown
198+
182199
try:
183200
httpd.serve_forever()
184201
except KeyboardInterrupt:

specforge/modeling/target/remote_target_client.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,14 @@ def __init__(
250250
self._nccl_transport: Optional[NCCLTransport] = None
251251
self._nccl_init_attempted = False
252252
self._nccl_init_lock = threading.Lock()
253+
# Heartbeat thread to keep server aware of client liveness
254+
self._heartbeat_interval = float(
255+
os.environ.get("SPECFORGE_HEARTBEAT_INTERVAL", "15")
256+
)
257+
self._heartbeat_stop = threading.Event()
258+
self._heartbeat_thread: Optional[threading.Thread] = None
259+
self._lifecycle_lock = threading.Lock()
260+
self._closed = False
253261
atexit.register(self.close)
254262

255263
def _get_nccl_port(self) -> int:
@@ -354,6 +362,8 @@ def _client_init():
354362
return False
355363

356364
logger.info("NCCL transport established successfully")
365+
# Start heartbeat to keep server aware of client liveness
366+
self._start_heartbeat()
357367
return True
358368

359369
def _request(self, endpoint: str, payload: bytes) -> bytes:
@@ -370,6 +380,7 @@ def _request(self, endpoint: str, payload: bytes) -> bytes:
370380
headers={"Content-Type": "application/octet-stream"},
371381
)
372382
resp.raise_for_status()
383+
self._start_heartbeat()
373384
return resp.content
374385
except (requests.ConnectionError, requests.Timeout) as exc:
375386
last_exc = exc
@@ -413,6 +424,7 @@ def _request_transport(
413424
url, data=payload, timeout=self.timeout, headers=headers
414425
)
415426
resp.raise_for_status()
427+
self._start_heartbeat()
416428

417429
nccl_used = resp.headers.get(NCCL_HEADER) == "1"
418430

@@ -442,11 +454,69 @@ def _request_transport(
442454
) from exc
443455

444456
def close(self):
457+
with self._lifecycle_lock:
458+
if self._closed:
459+
return
460+
self._closed = True
461+
# Stop heartbeat thread
462+
self._stop_heartbeat()
463+
# Notify server of disconnect (best-effort)
464+
self._notify_disconnect()
465+
# Destroy NCCL transport
445466
if self._nccl_transport is not None:
446467
self._nccl_transport.destroy()
447468
self._nccl_transport = None
448469
self._session.close()
449470

471+
def _notify_disconnect(self):
472+
"""Send disconnect notification to the server (best-effort)."""
473+
try:
474+
self._session.post(
475+
f"{self.url}/disconnect",
476+
data=b"",
477+
timeout=5,
478+
headers={"Content-Type": "application/octet-stream"},
479+
)
480+
except Exception:
481+
# Best-effort: server may already be gone, network may be down
482+
pass
483+
484+
def _start_heartbeat(self):
485+
"""Start background heartbeat thread."""
486+
with self._lifecycle_lock:
487+
if self._closed or self._heartbeat_thread is not None:
488+
return
489+
if self._heartbeat_interval <= 0:
490+
return
491+
self._heartbeat_stop.clear()
492+
self._heartbeat_thread = threading.Thread(
493+
target=self._heartbeat_loop, daemon=True
494+
)
495+
self._heartbeat_thread.start()
496+
497+
def _stop_heartbeat(self):
498+
"""Stop background heartbeat thread."""
499+
with self._lifecycle_lock:
500+
self._heartbeat_stop.set()
501+
thread = self._heartbeat_thread
502+
self._heartbeat_thread = None
503+
if thread is not None:
504+
thread.join(timeout=3)
505+
506+
def _heartbeat_loop(self):
507+
"""Periodically ping the server to indicate liveness."""
508+
while not self._heartbeat_stop.wait(self._heartbeat_interval):
509+
try:
510+
self._session.post(
511+
f"{self.url}/heartbeat",
512+
data=b"",
513+
timeout=5,
514+
headers={"Content-Type": "application/octet-stream"},
515+
)
516+
except Exception:
517+
# Server may be gone — stop sending heartbeats
518+
break
519+
450520
def __del__(self):
451521
try:
452522
self.close()

0 commit comments

Comments
 (0)