1414# limitations under the License.
1515
1616import os
17+
18+ # For benchmarks, we disable the CUDA graph support overhead.
19+ # TRT-LLM production and profiling runs should do the same.
20+ os .environ ["NCCL_GRAPH_MIXING_SUPPORT" ] = "0"
21+
1722from argparse import ArgumentParser
1823from itertools import product
1924
4550from tensorrt_llm .plugin .plugin import CustomAllReduceHelper
4651
4752
53+ def _allocate_nccl_window_tensor (input_tensor : torch .Tensor ,
54+ mapping : Mapping ) -> torch .Tensor | None :
55+ """Allocate a tensor from the NCCL symmetric-memory window pool."""
56+ try :
57+ from tensorrt_llm .bindings .internal .thop import BufferKind
58+ window_tensor , actual_kind = torch .ops .trtllm .allocate_output (
59+ input_tensor , int (BufferKind .NCCL_WINDOW ), mapping .tp_group )
60+ except RuntimeError :
61+ return None
62+ if actual_kind != int (BufferKind .NCCL_WINDOW ):
63+ return None
64+ return window_tensor
65+
66+
67+ def _register_nccl_window_input (input_tensor : torch .Tensor ,
68+ mapping : Mapping ,
69+ * ,
70+ required : bool = False ) -> torch .Tensor :
71+ """Run NCCL_SYMMETRIC on the already-registered input path."""
72+ window_tensor = _allocate_nccl_window_tensor (input_tensor , mapping )
73+ if window_tensor is None :
74+ if required :
75+ raise RuntimeError (
76+ "NCCL_SYMMETRIC benchmark requires a registered NCCL_WINDOW "
77+ "input tensor, but allocation did not return NCCL_WINDOW." )
78+ return input_tensor
79+ window_tensor .copy_ (input_tensor )
80+ torch .cuda .synchronize (input_tensor .device )
81+ return window_tensor
82+
83+
4884def profile_allreduce (
4985 mapping : Mapping ,
5086 dist : Distributed ,
@@ -61,8 +97,19 @@ def profile_allreduce(
6197 allreduce_instance = None ,
6298 profile_gemm_allreduce : bool = False ,
6399 gemm_in_features : int | None = None ,
100+ register_symmetric_input : bool = True ,
101+ require_registered_symmetric_input : bool = True ,
64102):
65103
104+ allreduce = allreduce_instance or AllReduce (mapping = mapping ,
105+ strategy = strategy )
106+ effective_strategy = getattr (allreduce , "strategy" , strategy )
107+ strategy = effective_strategy
108+ if (register_symmetric_input and not profile_gemm_allreduce
109+ and effective_strategy == AllReduceStrategy .NCCL_SYMMETRIC ):
110+ input = _register_nccl_window_input (
111+ input , mapping , required = require_registered_symmetric_input )
112+
66113 allreduce_params = AllReduceParams (
67114 fusion_op = fusion ,
68115 residual = residual ,
@@ -72,8 +119,6 @@ def profile_allreduce(
72119 bias = bias ,
73120 )
74121
75- allreduce = allreduce_instance or AllReduce (mapping = mapping ,
76- strategy = strategy )
77122 linear = None
78123 if profile_gemm_allreduce :
79124 if gemm_in_features is None :
@@ -114,11 +159,11 @@ def func(x, loop_num=inner_loop):
114159 func (input , loop_num = 1 )
115160
116161 if enable_cudagraph :
117- # Untimed warmup run outside of graph capture
118- func ( input , loop_num = 1 )
119- # CUDA graph warmup then capture
120- for _ in range ( 2 ):
121- func (input , loop_num = 1 )
162+ # Run one multi-iteration warmup, not repeated single-iteration
163+ # warmups: `output = allreduce(x)` keeps the previous output alive
164+ # while the next RHS allocates its output, seeding the same two
165+ # reusable output windows that graph capture will need.
166+ func (input , loop_num = 3 )
122167 with torch .cuda .graph (graph , stream = stream ):
123168 output = func (input )
124169
@@ -266,17 +311,6 @@ def allreduce_benchmark(
266311 start_col = mapping .tp_rank * local_in_features
267312 input_for_profile = input [:, start_col :start_col +
268313 local_in_features ].contiguous ()
269- elif strategy in (AllReduceStrategy .NCCL_SYMMETRIC ,
270- AllReduceStrategy .NCCL , AllReduceStrategy .AUTO ):
271- try :
272- from tensorrt_llm .bindings .internal .thop import BufferKind
273- window_out , actual_kind = torch .ops .trtllm .allocate_output (
274- input , int (BufferKind .NCCL_WINDOW ), mapping .tp_group )
275- if actual_kind == int (BufferKind .NCCL_WINDOW ):
276- window_out .copy_ (input )
277- input_for_profile = window_out
278- except RuntimeError :
279- pass
280314
281315 median_ms = profile_allreduce (
282316 mapping = mapping ,
0 commit comments