Skip to content

Commit d2d75cc

Browse files
committed
heartbeat
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
1 parent 2c0561e commit d2d75cc

7 files changed

Lines changed: 304 additions & 28 deletions

File tree

tensorrt_llm/executor/base_worker.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import gc
1818
import json
1919
import os
20+
import time
2021
import weakref
2122
from pathlib import Path
2223
from queue import Queue
@@ -50,7 +51,7 @@
5051
from .result import (GenerationResult, LogProbsResult, ResponseWrapper,
5152
compute_logprobs, get_metrics_dict)
5253
from .utils import (ErrorResponse, IntraProcessQueue, RequestError,
53-
is_llm_response)
54+
WorkerHeartbeat, is_llm_response)
5455

5556
if TYPE_CHECKING:
5657
from ..disaggregated_params import DisaggregatedParams
@@ -118,6 +119,7 @@ def __init__(
118119

119120
self.engine = None
120121
self.result_queue: Optional[IpcQueue] = None
122+
self.heartbeat_queue: Optional[IpcQueue] = None
121123
self.postproc_queues: Optional[List[IpcQueue]] = None
122124
self.rank = mpi_rank()
123125
self.global_rank = global_mpi_rank()
@@ -345,6 +347,10 @@ def set_result_queue(self, queue):
345347
assert self.postproc_queues is None
346348
self.result_queue = queue
347349

350+
def set_heartbeat_queue(self, queue):
351+
"""Set the IPC queue used to send worker liveness heartbeats to the proxy."""
352+
self.heartbeat_queue = queue
353+
348354
def set_postproc_queues(self, queues: List["IpcQueue"]):
349355
""" Set the IPC queues for feeding post-processing processes. """
350356
assert self.result_queue is None
@@ -904,6 +910,11 @@ def __init__(self, worker: "BaseWorker"):
904910
self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
905911
# The error responses when submit request failed will be put here
906912
self.temp_error_responses = Queue()
913+
self._heartbeat_interval_secs = float(
914+
os.environ.get("TLLM_EXECUTOR_HEARTBEAT_INTERVAL_SECS", "1"))
915+
self._last_heartbeat_time = 0.0
916+
self._heartbeat_pid = os.getpid()
917+
self._heartbeat_rank = mpi_rank()
907918

908919
def responses_handler(self, responses: List[tllm.Response]):
909920
HandlerKind = AwaitResponseHelper.HandlerKind
@@ -971,8 +982,21 @@ def __call__(self, timeout: Optional[float] = None) -> bool:
971982
error = getattr(self.worker.engine, "_event_loop_error", None)
972983
if error is not None:
973984
return self._broadcast_event_loop_error(error)
985+
self._send_heartbeat()
974986
return True
975987

988+
def _send_heartbeat(self) -> None:
989+
heartbeat_queue = self.worker.heartbeat_queue
990+
if heartbeat_queue is None:
991+
return
992+
now = time.monotonic()
993+
if (now - self._last_heartbeat_time) < self._heartbeat_interval_secs:
994+
return
995+
self._last_heartbeat_time = now
996+
heartbeat_queue.put_noblock(WorkerHeartbeat(pid=self._heartbeat_pid,
997+
rank=self._heartbeat_rank),
998+
retry=1)
999+
9761000
def _broadcast_event_loop_error(self, error: BaseException) -> bool:
9771001
"""Wake every pending ``GenerationResult`` after an event-loop crash.
9781002

tensorrt_llm/executor/heartbeat.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
def worker_heartbeat_timed_out(
18+
*,
19+
has_inflight_requests: bool,
20+
now: float,
21+
last_heartbeat_time: float,
22+
timeout_secs: float,
23+
) -> bool:
24+
"""Return whether an in-flight worker has exceeded its heartbeat timeout."""
25+
return has_inflight_requests and (now - last_heartbeat_time) > timeout_secs

tensorrt_llm/executor/proxy.py

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,17 @@
3636
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
3737
enable_llm_debug, logger_debug, print_colored)
3838
from .executor import GenerationExecutor
39+
from .heartbeat import worker_heartbeat_timed_out
3940
from .ipc import FusedIpcQueue, IpcQueue
4041
from .postproc_worker import PostprocWorker, PostprocWorkerConfig
4142
from .request import CancellingRequest, GenerationRequest
4243
from .result import GenerationResult, IterationResult
4344
from .rpc import RPCClient
4445
from .rpc.rpc_common import RPCError, get_unique_ipc_addr
4546
from .utils import (ErrorResponse, RequestError, WorkerCommIpcAddrs,
46-
create_mpi_comm_session, get_spawn_proxy_process_env,
47-
is_llm_response, print_alive_threads)
47+
WorkerHeartbeat, create_mpi_comm_session,
48+
get_spawn_proxy_process_env, is_llm_response,
49+
print_alive_threads)
4850
from .worker import GenerationExecutorWorker, worker_main
4951

5052
__all__ = [
@@ -125,17 +127,28 @@ def __init__(
125127
self._results: Dict[int, GenerationResult] = {}
126128

127129
# --- liveness / stall detection state ---
128-
# Time of the last sign of worker progress (a request submitted or a result received). The
129-
# error monitor uses this to detect a worker that has silently stopped servicing requests.
130-
self._last_progress_time = time.monotonic()
130+
# Time of the last result and worker heartbeat. Long non-streaming requests can legitimately
131+
# go quiet on the result queue, so fatal stall detection is based on heartbeats from the
132+
# worker response-polling thread rather than result traffic.
133+
self._last_result_time = time.monotonic()
134+
self._last_worker_heartbeat_time = self._last_result_time
131135
# Max time to wait for the worker to accept a submitted request before declaring it
132136
# dead/stalled. With an unbounded send HWM the send only blocks when the worker disconnected
133137
# or stopped draining.
134138
self._submit_timeout_secs = float(
135139
os.environ.get("TLLM_EXECUTOR_SUBMIT_TIMEOUT_SECS", "300"))
136-
# Max time with requests in flight but no result before treating the worker as stalled.
137-
self._stall_timeout_secs = float(
138-
os.environ.get("TLLM_EXECUTOR_STALL_TIMEOUT_SECS", "300"))
140+
# Max time with requests in flight but no worker heartbeat before treating the worker as
141+
# stalled. The legacy env var is still accepted as an alias for compatibility.
142+
self._heartbeat_timeout_secs = float(
143+
os.environ.get(
144+
"TLLM_EXECUTOR_HEARTBEAT_TIMEOUT_SECS",
145+
os.environ.get("TLLM_EXECUTOR_STALL_TIMEOUT_SECS", "300")))
146+
# Warn about long result-quiet periods without killing the worker. This is diagnostic only:
147+
# a healthy heartbeat means the worker is still polling responses.
148+
self._result_quiet_warning_secs = float(
149+
os.environ.get("TLLM_EXECUTOR_RESULT_QUIET_WARNING_SECS",
150+
str(self._heartbeat_timeout_secs)))
151+
self._last_result_quiet_warning_time = self._last_result_time
139152
# PID of the leader worker process, learned from the init handshake; used to request a
140153
# thread-stack dump (SIGUSR1) when a stall is detected. Stays `None` for remote/out-of-host
141154
# worker sessions.
@@ -282,23 +295,34 @@ def _error_monitor_loop(self) -> None:
282295
if self._fatal_error is not None:
283296
return
284297

285-
# Progress watchdog: a worker that silently stops servicing requests (no result for
286-
# a long time while requests are in flight) is treated as a fatal stall so callers
287-
# fail fast instead of hanging indefinitely.
288-
if self._results and (
289-
time.monotonic() -
290-
self._last_progress_time) > self._stall_timeout_secs:
298+
try:
299+
self._drain_heartbeat_queue()
300+
except Exception as exc:
301+
logger.warning(
302+
"Error monitor: failed to drain worker heartbeat "
303+
f"queue; continuing timeout check: {exc!r}")
304+
305+
# Heartbeat watchdog: this checks liveness of the worker's response-polling thread,
306+
# not generation forward progress. Lack of result traffic alone is not fatal because
307+
# long non-streaming requests can be healthy but silent.
308+
if worker_heartbeat_timed_out(
309+
has_inflight_requests=bool(self._results),
310+
now=time.monotonic(),
311+
last_heartbeat_time=self._last_worker_heartbeat_time,
312+
timeout_secs=self._heartbeat_timeout_secs):
291313
logger.error(
292-
f"Error monitor: no result progress for {self._stall_timeout_secs:.2f}s "
314+
f"Error monitor: no worker heartbeat for {self._heartbeat_timeout_secs:.2f}s "
293315
f"with {len(self._results)} request(s) in flight; "
294316
"treating worker as stalled.")
295317
self._maybe_dump_worker_traceback()
296318
self._set_fatal_error(
297319
RuntimeError(
298-
f"Worker stalled: no result for {self._stall_timeout_secs:.2f}s "
320+
f"Worker stalled: no heartbeat for {self._heartbeat_timeout_secs:.2f}s "
299321
f"with {len(self._results)} request(s) in flight."))
300322
self.pre_shutdown()
301323
return
324+
325+
self._maybe_log_result_quiet_warning()
302326
except Exception as exc:
303327
logger.debug(f"Error monitor: unexpected exception (ignored): "
304328
f"{exc!r}")
@@ -325,6 +349,30 @@ def _maybe_dump_worker_traceback(self) -> None:
325349
except OSError as e:
326350
logger.debug(f"Could not signal worker pid {pid}: {e!r}")
327351

352+
def _drain_heartbeat_queue(self) -> None:
353+
"""Drain worker heartbeats and refresh the proxy-local liveness timestamp."""
354+
while self.heartbeat_queue.poll(0):
355+
heartbeat = self.heartbeat_queue.get()
356+
if isinstance(heartbeat, WorkerHeartbeat):
357+
self._last_worker_heartbeat_time = time.monotonic()
358+
if heartbeat.rank == 0:
359+
self._worker_pid = heartbeat.pid
360+
361+
def _maybe_log_result_quiet_warning(self) -> None:
362+
if not self._results:
363+
return
364+
now = time.monotonic()
365+
if (now - self._last_result_time) <= self._result_quiet_warning_secs:
366+
return
367+
time_since_last_warning = now - self._last_result_quiet_warning_time
368+
if time_since_last_warning <= self._result_quiet_warning_secs:
369+
return
370+
logger.warning(
371+
f"No result emitted for {self._result_quiet_warning_secs:.2f}s "
372+
f"with {len(self._results)} request(s) in flight, but worker "
373+
"heartbeats are still arriving.")
374+
self._last_result_quiet_warning_time = now
375+
328376
def _setup_queues(self) -> WorkerCommIpcAddrs:
329377

330378
self.request_queue = IpcQueue(is_server=True,
@@ -342,6 +390,9 @@ def _setup_queues(self) -> WorkerCommIpcAddrs:
342390
socket_type=zmq.PULL
343391
if self.enable_postprocess_parallel else zmq.PAIR,
344392
name="proxy_result_queue")
393+
self.heartbeat_queue = IpcQueue(is_server=True,
394+
socket_type=zmq.PULL,
395+
name="proxy_heartbeat_queue")
345396
self._resource_governor_queue = IpcQueue(
346397
is_server=True, name="proxy_resource_governor_queue"
347398
) if self._enable_resource_governor else None
@@ -350,6 +401,7 @@ def _setup_queues(self) -> WorkerCommIpcAddrs:
350401
request_queue_addr=self.request_queue.address,
351402
worker_init_status_queue_addr=self.worker_init_status_queue.address,
352403
result_queue_addr=self.result_queue.address,
404+
heartbeat_queue_addr=self.heartbeat_queue.address,
353405
resource_governor_queue_addr=self._resource_governor_queue.address
354406
if self._resource_governor_queue is not None else None,
355407
)
@@ -383,7 +435,7 @@ def dispatch_result_task(self) -> bool:
383435
return False # shutdown the thread
384436

385437
# A result arrived: the worker is making progress.
386-
self._last_progress_time = time.monotonic()
438+
self._last_result_time = time.monotonic()
387439

388440
async_queues = []
389441
event_loop = None
@@ -633,6 +685,7 @@ def shutdown(self):
633685
self.request_queue.close()
634686
self.worker_init_status_queue.close()
635687
self.result_queue.close()
688+
self.heartbeat_queue.close()
636689
if self._resource_governor_queue is not None:
637690
self._resource_governor_queue.close()
638691

@@ -663,6 +716,9 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
663716
executor=self,
664717
disaggregated_params=request.disaggregated_params,
665718
logprob_params=logprob_params)
719+
now = time.monotonic()
720+
self._last_result_time = now
721+
self._last_worker_heartbeat_time = now
666722
self._results[request.id] = result
667723

668724
with nvtx_range_debug("request_queue.put"):
@@ -698,7 +754,9 @@ def _submit_request(self, request: GenerationRequest) -> None:
698754
self._set_fatal_error(err)
699755
self.pre_shutdown()
700756
raise err
701-
self._last_progress_time = time.monotonic()
757+
now = time.monotonic()
758+
self._last_result_time = now
759+
self._last_worker_heartbeat_time = now
702760
self.request_queue.put(request)
703761

704762
def collective_rpc(

tensorrt_llm/executor/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ class LlmLauncherEnvs(StrEnum):
3030

3131

3232
def get_spawn_proxy_process_ipc_addr_env() -> str | None:
33-
''' Get the IPC address for the spawn proxy process dynamically. '''
33+
"""Get the IPC address for the spawn proxy process dynamically."""
3434
return os.getenv(LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR)
3535

3636

3737
def get_spawn_proxy_process_ipc_hmac_key_env() -> bytes:
38-
''' Get the HMAC key for the spawn proxy process dynamically. '''
38+
"""Get the HMAC key for the spawn proxy process dynamically."""
3939
key = os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY")
4040
assert key is not None, (
4141
f"{LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY} is not set. "
@@ -44,7 +44,7 @@ def get_spawn_proxy_process_ipc_hmac_key_env() -> bytes:
4444

4545

4646
def get_spawn_proxy_process_env() -> bool:
47-
''' Get the environment variable for the spawn proxy process dynamically. '''
47+
"""Get the environment variable for the spawn proxy process dynamically."""
4848
return os.getenv(LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS) == "1"
4949

5050

@@ -77,7 +77,7 @@ def has_event_loop() -> bool:
7777

7878

7979
class RequestError(RuntimeError):
80-
''' The error raised when the request is failed. '''
80+
"""The error raised when the request is failed."""
8181

8282

8383
class ProcessPoolExecutorSession(MpiSession):
@@ -113,8 +113,15 @@ class ErrorResponse(NamedTuple):
113113
request_id: int
114114

115115

116+
class WorkerHeartbeat(NamedTuple):
117+
"""A liveness pulse from the worker process to the proxy."""
118+
119+
pid: int
120+
rank: int
121+
122+
116123
class IntraProcessQueue:
117-
''' A Queue-like container for IPC within the same process. '''
124+
"""A Queue-like container for IPC within the same process."""
118125

119126
def __init__(self):
120127
self.queue = Queue()
@@ -149,11 +156,12 @@ def poll(self, timeout=None) -> bool:
149156

150157

151158
class WorkerCommIpcAddrs(NamedTuple):
152-
''' IPC addresses (str) and HMAC keys (bytes) for communication with the worker processes. '''
159+
"""IPC addresses (str) and HMAC keys (bytes) for communication with the worker processes."""
153160
request_queue_addr: tuple[str, Optional[bytes]]
154161
worker_init_status_queue_addr: tuple[str, Optional[bytes]]
155162
result_queue_addr: tuple[str, Optional[bytes]]
156163
resource_governor_queue_addr: Optional[tuple[str, Optional[bytes]]] = None
164+
heartbeat_queue_addr: Optional[tuple[str, Optional[bytes]]] = None
157165

158166

159167
def is_llm_response(instance):

tensorrt_llm/executor/worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def _print_stacks():
229229

230230
result_queue: Optional[IpcQueue] = None
231231
result_queues: Optional[List[IpcQueue]] = None
232+
heartbeat_queue: Optional[IpcQueue] = None
232233
resource_governor_queue: Optional[IpcQueue] = None
233234

234235
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig()
@@ -263,6 +264,12 @@ def _print_stacks():
263264
is_server=False,
264265
name="worker_resource_governor_queue"
265266
) if worker_queues.resource_governor_queue_addr else None
267+
heartbeat_queue = IpcQueue(
268+
worker_queues.heartbeat_queue_addr,
269+
is_server=False,
270+
socket_type=zmq.PUSH,
271+
name="worker_heartbeat_queue"
272+
) if worker_queues.heartbeat_queue_addr else None
266273

267274
if postproc_worker_config.enabled:
268275
# IPC queues for sending inputs to the postprocess parallel
@@ -361,6 +368,8 @@ def notify_proxy_threads_to_quit():
361368
worker.set_postproc_queues(result_queues)
362369
else:
363370
worker.set_result_queue(result_queue)
371+
if heartbeat_queue is not None:
372+
worker.set_heartbeat_queue(heartbeat_queue)
364373

365374
# Send ready signal with confirmation. The payload carries the worker PID so the
366375
# proxy can signal it (SIGUSR1 -> thread-stack dump) if it later stalls.

tensorrt_llm/serialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@
8282
"GenerationResult", "GenerationResultBase", "IterationResult",
8383
"Logprob", "LogProbsResult", "ResponseWrapper"
8484
],
85-
"tensorrt_llm.executor.utils": ["ErrorResponse", "WorkerCommIpcAddrs"],
85+
"tensorrt_llm.executor.utils":
86+
["ErrorResponse", "WorkerCommIpcAddrs", "WorkerHeartbeat"],
8687
"tensorrt_llm.executor.worker": ["GenerationExecutorWorker", "worker_main"],
8788
"tensorrt_llm.llmapi.llm_args": [
8889
"_ModelFormatKind", "_ParallelConfig", "CalibConfig",

0 commit comments

Comments
 (0)