Skip to content

Commit 3ed8a62

Browse files
committed
fix: address AlltoAll watchdog review comments
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent d62f307 commit 3ed8a62

7 files changed

Lines changed: 344 additions & 41 deletions

File tree

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ __device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, i
741741
{
742742
int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k];
743743
int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k];
744-
if (dst_idx < 0)
744+
if (dst_idx < 0 || !is_rank_active(ptrs.active_rank_mask, target_rank))
745745
{
746746
acc[k].fill(0.0f);
747747
continue;
@@ -766,8 +766,12 @@ __device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, i
766766
#pragma unroll
767767
for (int k = 0; k < TOP_K; ++k)
768768
{
769-
if (ptrs.topk_send_indices[local_token_idx * TOP_K + k] < 0)
769+
int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k];
770+
int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k];
771+
if (dst_idx < 0 || !is_rank_active(ptrs.active_rank_mask, target_rank))
772+
{
770773
continue; // acc[k] already holds 0.0f from fill() above
774+
}
771775
#pragma unroll
772776
for (int j = elems_per_vec - 1; j >= 0; --j)
773777
acc[k][j] = static_cast<float>(reinterpret_cast<InT const*>(&acc[k])[j]);

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ struct CombineKernelPointers
9393
int const* topk_send_indices; // dst index per k, -1 for duplicates
9494

9595
// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. Combine skips flag
96-
// writes/waits to/from masked peers; per-token accumulation uses topk_send_indices[k] < 0
97-
// (set by dispatch) to skip dead-targeted slots, so no explicit mask check is needed there.
96+
// writes/waits to/from masked peers and also skips per-token accumulation for ranks that
97+
// become inactive between dispatch and combine.
9898
uint64_t active_rank_mask[kRankMaskWords];
9999
};
100100

tensorrt_llm/_torch/alltoall_watchdog.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626
import threading
2727
import time
2828
from collections import deque
29+
from collections.abc import Callable, Mapping, Sequence
2930
from dataclasses import dataclass
30-
from typing import Callable, Deque, Mapping, Optional, Protocol, Sequence
31+
from typing import Protocol
3132

3233
import torch
3334

@@ -110,6 +111,8 @@ def __init__(
110111
}
111112
self._device_copy_timeout_s = float(device_copy_timeout_s)
112113
self._copy_stream: torch.cuda.Stream | None = None
114+
self._host_flags: torch.Tensor | None = None
115+
self._copy_event: torch.cuda.Event | None = None
113116
self._retired_copies: list[tuple[torch.Tensor, torch.cuda.Event]] = []
114117
if workspace.device.type == "cuda":
115118
self._copy_stream = torch.cuda.Stream(device=workspace.device)
@@ -123,13 +126,17 @@ def _read_cuda_flags(self, flags: torch.Tensor) -> tuple[int, ...]:
123126
assert self._copy_stream is not None
124127
self._prune_retired_copies()
125128

126-
host_flags = torch.empty(
127-
(self._ep_size,),
128-
dtype=torch.int32,
129-
device="cpu",
130-
pin_memory=prefer_pinned(),
131-
)
132-
event = torch.cuda.Event(blocking=False)
129+
if self._host_flags is None:
130+
self._host_flags = torch.empty(
131+
(self._ep_size,),
132+
dtype=torch.int32,
133+
device="cpu",
134+
pin_memory=prefer_pinned(),
135+
)
136+
if self._copy_event is None:
137+
self._copy_event = torch.cuda.Event(blocking=False)
138+
host_flags = self._host_flags
139+
event = self._copy_event
133140
with torch.cuda.device(flags.device), torch.cuda.stream(self._copy_stream):
134141
host_flags.copy_(flags.detach(), non_blocking=True)
135142
event.record(self._copy_stream)
@@ -139,6 +146,8 @@ def _read_cuda_flags(self, flags: torch.Tensor) -> tuple[int, ...]:
139146
remaining_s = deadline_s - time.monotonic()
140147
if remaining_s <= 0:
141148
self._retired_copies.append((host_flags, event))
149+
self._host_flags = None
150+
self._copy_event = None
142151
raise CompletionFlagReadTimeout(
143152
"timed out copying AlltoAll completion flags to host"
144153
)
@@ -175,8 +184,8 @@ def __init__(
175184
completion_reader: CompletionFlagReader,
176185
timeout_s: float = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S,
177186
poll_interval_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
178-
health: Optional[EPGroupHealthLike] = None,
179-
on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None,
187+
health: EPGroupHealthLike | None = None,
188+
on_timeout: Callable[[AlltoAllWatchdogTimeout], None] | None = None,
180189
) -> None:
181190
if ep_size <= 0:
182191
raise ValueError(f"ep_size must be > 0, got {ep_size}")
@@ -196,7 +205,7 @@ def __init__(
196205
self._on_timeout = on_timeout
197206

198207
self._cv = threading.Condition()
199-
self._queue: Deque[_CollectiveWatch] = deque()
208+
self._queue: deque[_CollectiveWatch] = deque()
200209
self._closed = False
201210
self._stopping = False
202211
self._thread: threading.Thread | None = None
@@ -213,8 +222,8 @@ def from_workspace(
213222
ep_size: int,
214223
timeout_s: float = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S,
215224
poll_interval_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
216-
health: Optional[EPGroupHealthLike] = None,
217-
on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None,
225+
health: EPGroupHealthLike | None = None,
226+
on_timeout: Callable[[AlltoAllWatchdogTimeout], None] | None = None,
218227
) -> "AlltoAllWatchdog":
219228
"""Build a watchdog from the MoE AlltoAll workspace and metainfo."""
220229
dispatch_offset = int(
@@ -421,7 +430,7 @@ def _run(self) -> None:
421430
except CompletionFlagReadTimeout:
422431
observed_flags = last_observed_flags
423432
poll_timed_out = True
424-
except BaseException as exc: # noqa: BLE001 - keep watchdog failures visible.
433+
except Exception as exc: # noqa: BLE001 - keep watchdog failures visible.
425434
with self._cv:
426435
self._last_error = exc
427436
self._queue.clear()

tensorrt_llm/_torch/distributed/moe_alltoall.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import os
1111
import sys
12+
import threading
1213
from dataclasses import dataclass
1314
from typing import Callable, Dict, Optional
1415

@@ -212,6 +213,8 @@ def __init__(
212213
"mnnvl_mem": mnnvl_mem,
213214
"workspace": workspace,
214215
"metainfo": metainfo,
216+
"watchdog_flag_generation": 0,
217+
"watchdog_flag_generation_lock": threading.Lock(),
215218
}
216219
else:
217220
assert self._WORKSPACE[
@@ -229,17 +232,20 @@ def __init__(
229232
self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"]
230233
self.workspace = self._WORKSPACE["workspace"]
231234
self.metainfo = self._WORKSPACE["metainfo"]
235+
if "watchdog_flag_generation_lock" not in self._WORKSPACE:
236+
self._WORKSPACE["watchdog_flag_generation_lock"] = threading.Lock()
237+
self._WORKSPACE[
238+
"watchdog_flag_generation"] = self._read_current_flag_val()
232239
# Internal state
233240
self._state: _A2AState = _A2AState()
234241
self.ep_group_health = ep_group_health
235242
self._destroyed = False
236-
self._watchdog_flag_generation = 0
237243
self._alltoall_watchdog: AlltoAllWatchdog | None = None
238244
if (alltoall_watchdog_timeout_s is None
239245
and self.ep_group_health is not None):
240246
alltoall_watchdog_timeout_s = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S
241247
if alltoall_watchdog_timeout_s is not None:
242-
self._watchdog_flag_generation = self._read_current_flag_val()
248+
self._sync_watchdog_flag_generation()
243249
self._alltoall_watchdog = AlltoAllWatchdog.from_workspace(
244250
workspace=self.workspace,
245251
metainfo=self.metainfo,
@@ -276,6 +282,25 @@ def _read_current_flag_val(self) -> int:
276282
flag_val = flag_val.detach().cpu()
277283
return int(flag_val.item())
278284

285+
def _sync_watchdog_flag_generation(self) -> None:
286+
workspace_state = self._WORKSPACE
287+
assert workspace_state is not None
288+
lock = workspace_state["watchdog_flag_generation_lock"]
289+
with lock:
290+
workspace_state["watchdog_flag_generation"] = max(
291+
int(workspace_state["watchdog_flag_generation"]),
292+
self._read_current_flag_val(),
293+
)
294+
295+
def _next_watchdog_flag_generation(self) -> int:
296+
workspace_state = self._WORKSPACE
297+
assert workspace_state is not None
298+
lock = workspace_state["watchdog_flag_generation_lock"]
299+
with lock:
300+
workspace_state["watchdog_flag_generation"] = (
301+
int(workspace_state["watchdog_flag_generation"]) + 1)
302+
return int(workspace_state["watchdog_flag_generation"])
303+
279304
def _get_active_rank_mask_tensor(
280305
self,
281306
active_rank_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
@@ -302,10 +327,9 @@ def _watch_collective(self, phase: str,
302327
active_rank_mask: Optional[torch.Tensor]) -> None:
303328
if self._alltoall_watchdog is None:
304329
return
305-
self._watchdog_flag_generation += 1
306330
self._alltoall_watchdog.watch(
307331
phase=phase,
308-
expected_flag=self._watchdog_flag_generation,
332+
expected_flag=self._next_watchdog_flag_generation(),
309333
active_mask=self._active_mask_int(active_rank_mask),
310334
)
311335

tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"""
2626

2727
import os
28+
import threading
2829
from typing import Callable, Dict, List, Optional, Tuple
2930

3031
import torch
@@ -287,6 +288,8 @@ def __init__(
287288
"mnnvl_mem": mnnvl_mem,
288289
"workspace": workspace,
289290
"metainfo": metainfo,
291+
"watchdog_flag_generation": 0,
292+
"watchdog_flag_generation_lock": threading.Lock(),
290293
}
291294
NVLinkOneSided._WORKSPACES[self._workspace_key] = workspace_state
292295
else:
@@ -312,17 +315,20 @@ def __init__(
312315
NVLinkOneSided._WORKSPACE_REFCOUNTS.get(self._workspace_key, 0) + 1
313316
)
314317
self._destroyed = False
318+
self._workspace_state = workspace_state
315319
self.mnnvl_mem = workspace_state["mnnvl_mem"]
316320
self.workspace = workspace_state["workspace"]
317321
self.moe_a2a_metainfo = workspace_state["metainfo"]
318322
self.max_num_tokens_per_rank = workspace_state["max_num_tokens_per_rank"]
323+
if "watchdog_flag_generation_lock" not in workspace_state:
324+
workspace_state["watchdog_flag_generation_lock"] = threading.Lock()
325+
workspace_state["watchdog_flag_generation"] = self._read_current_flag_val()
319326
self.ep_group_health = ep_group_health
320-
self._watchdog_flag_generation = 0
321327
self._alltoall_watchdog: AlltoAllWatchdog | None = None
322328
if alltoall_watchdog_timeout_s is None and self.ep_group_health is not None:
323329
alltoall_watchdog_timeout_s = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S
324330
if alltoall_watchdog_timeout_s is not None:
325-
self._watchdog_flag_generation = self._read_current_flag_val()
331+
self._sync_watchdog_flag_generation()
326332
self._alltoall_watchdog = AlltoAllWatchdog.from_workspace(
327333
workspace=self.workspace,
328334
metainfo=self.moe_a2a_metainfo,
@@ -354,6 +360,22 @@ def _read_current_flag_val(self) -> int:
354360
flag_val = flag_val.detach().cpu()
355361
return int(flag_val.item())
356362

363+
def _sync_watchdog_flag_generation(self) -> None:
364+
lock = self._workspace_state["watchdog_flag_generation_lock"]
365+
with lock:
366+
self._workspace_state["watchdog_flag_generation"] = max(
367+
int(self._workspace_state["watchdog_flag_generation"]),
368+
self._read_current_flag_val(),
369+
)
370+
371+
def _next_watchdog_flag_generation(self) -> int:
372+
lock = self._workspace_state["watchdog_flag_generation_lock"]
373+
with lock:
374+
self._workspace_state["watchdog_flag_generation"] = (
375+
int(self._workspace_state["watchdog_flag_generation"]) + 1
376+
)
377+
return int(self._workspace_state["watchdog_flag_generation"])
378+
357379
def _get_active_rank_mask_tensor(
358380
self, active_rank_mask: Optional[torch.Tensor]
359381
) -> Optional[torch.Tensor]:
@@ -374,10 +396,9 @@ def _active_mask_int(self, active_rank_mask: Optional[torch.Tensor]) -> Optional
374396
def _watch_collective(self, phase: str, active_rank_mask: Optional[torch.Tensor]) -> None:
375397
if self._alltoall_watchdog is None:
376398
return
377-
self._watchdog_flag_generation += 1
378399
self._alltoall_watchdog.watch(
379400
phase=phase,
380-
expected_flag=self._watchdog_flag_generation,
401+
expected_flag=self._next_watchdog_flag_generation(),
381402
active_mask=self._active_mask_int(active_rank_mask),
382403
)
383404

@@ -424,6 +445,7 @@ def destroy(self):
424445
self.mnnvl_mem = None
425446
self.workspace = None
426447
self.moe_a2a_metainfo = None
448+
self._workspace_state = None
427449
self._dispatch_state = {"phase": "destroyed"}
428450

429451
def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool:

0 commit comments

Comments
 (0)