Skip to content

Commit 07270b7

Browse files
[https://nvbugs/6290345][fix] Fix allreduce benchmark input setup (#15427)
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
1 parent 8412a17 commit 07270b7

1 file changed

Lines changed: 52 additions & 18 deletions

File tree

tests/microbenchmarks/all_reduce.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
# limitations under the License.
1515

1616
import os
17+
18+
# For benchmarks, we disable the CUDA graph support overhead.
19+
# TRT-LLM production and profiling runs should do the same.
20+
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"
21+
1722
from argparse import ArgumentParser
1823
from itertools import product
1924

@@ -45,6 +50,37 @@
4550
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
4651

4752

53+
def _allocate_nccl_window_tensor(input_tensor: torch.Tensor,
54+
mapping: Mapping) -> torch.Tensor | None:
55+
"""Allocate a tensor from the NCCL symmetric-memory window pool."""
56+
try:
57+
from tensorrt_llm.bindings.internal.thop import BufferKind
58+
window_tensor, actual_kind = torch.ops.trtllm.allocate_output(
59+
input_tensor, int(BufferKind.NCCL_WINDOW), mapping.tp_group)
60+
except RuntimeError:
61+
return None
62+
if actual_kind != int(BufferKind.NCCL_WINDOW):
63+
return None
64+
return window_tensor
65+
66+
67+
def _register_nccl_window_input(input_tensor: torch.Tensor,
68+
mapping: Mapping,
69+
*,
70+
required: bool = False) -> torch.Tensor:
71+
"""Run NCCL_SYMMETRIC on the already-registered input path."""
72+
window_tensor = _allocate_nccl_window_tensor(input_tensor, mapping)
73+
if window_tensor is None:
74+
if required:
75+
raise RuntimeError(
76+
"NCCL_SYMMETRIC benchmark requires a registered NCCL_WINDOW "
77+
"input tensor, but allocation did not return NCCL_WINDOW.")
78+
return input_tensor
79+
window_tensor.copy_(input_tensor)
80+
torch.cuda.synchronize(input_tensor.device)
81+
return window_tensor
82+
83+
4884
def profile_allreduce(
4985
mapping: Mapping,
5086
dist: Distributed,
@@ -61,8 +97,19 @@ def profile_allreduce(
6197
allreduce_instance=None,
6298
profile_gemm_allreduce: bool = False,
6399
gemm_in_features: int | None = None,
100+
register_symmetric_input: bool = True,
101+
require_registered_symmetric_input: bool = True,
64102
):
65103

104+
allreduce = allreduce_instance or AllReduce(mapping=mapping,
105+
strategy=strategy)
106+
effective_strategy = getattr(allreduce, "strategy", strategy)
107+
strategy = effective_strategy
108+
if (register_symmetric_input and not profile_gemm_allreduce
109+
and effective_strategy == AllReduceStrategy.NCCL_SYMMETRIC):
110+
input = _register_nccl_window_input(
111+
input, mapping, required=require_registered_symmetric_input)
112+
66113
allreduce_params = AllReduceParams(
67114
fusion_op=fusion,
68115
residual=residual,
@@ -72,8 +119,6 @@ def profile_allreduce(
72119
bias=bias,
73120
)
74121

75-
allreduce = allreduce_instance or AllReduce(mapping=mapping,
76-
strategy=strategy)
77122
linear = None
78123
if profile_gemm_allreduce:
79124
if gemm_in_features is None:
@@ -114,11 +159,11 @@ def func(x, loop_num=inner_loop):
114159
func(input, loop_num=1)
115160

116161
if enable_cudagraph:
117-
# Untimed warmup run outside of graph capture
118-
func(input, loop_num=1)
119-
# CUDA graph warmup then capture
120-
for _ in range(2):
121-
func(input, loop_num=1)
162+
# Run one multi-iteration warmup, not repeated single-iteration
163+
# warmups: `output = allreduce(x)` keeps the previous output alive
164+
# while the next RHS allocates its output, seeding the same two
165+
# reusable output windows that graph capture will need.
166+
func(input, loop_num=3)
122167
with torch.cuda.graph(graph, stream=stream):
123168
output = func(input)
124169

@@ -266,17 +311,6 @@ def allreduce_benchmark(
266311
start_col = mapping.tp_rank * local_in_features
267312
input_for_profile = input[:, start_col:start_col +
268313
local_in_features].contiguous()
269-
elif strategy in (AllReduceStrategy.NCCL_SYMMETRIC,
270-
AllReduceStrategy.NCCL, AllReduceStrategy.AUTO):
271-
try:
272-
from tensorrt_llm.bindings.internal.thop import BufferKind
273-
window_out, actual_kind = torch.ops.trtllm.allocate_output(
274-
input, int(BufferKind.NCCL_WINDOW), mapping.tp_group)
275-
if actual_kind == int(BufferKind.NCCL_WINDOW):
276-
window_out.copy_(input)
277-
input_for_profile = window_out
278-
except RuntimeError:
279-
pass
280314

281315
median_ms = profile_allreduce(
282316
mapping=mapping,

0 commit comments

Comments
 (0)