diff --git a/tensorrt_llm/_torch/compilation/backend.py b/tensorrt_llm/_torch/compilation/backend.py index 0e2f4f7ecf0d..3d16ad271436 100644 --- a/tensorrt_llm/_torch/compilation/backend.py +++ b/tensorrt_llm/_torch/compilation/backend.py @@ -1,4 +1,5 @@ import os +from collections import OrderedDict from typing import List, Optional import torch @@ -14,6 +15,7 @@ from tensorrt_llm.mapping import Mapping from .multi_stream.auto_multi_stream import multi_stream_schedule +from .patterns import MATCHER_SUBSYSTEM from .patterns.ar_residual_norm import register_ar_fusions from .patterns.residual_add_norm import register_add_norm from .piecewise_optimizer import piecewise_optimizer @@ -23,7 +25,6 @@ class Backend: - _custom_pass_instances: List[PatternMatcherPass] = None _graph_pool_handle: tuple[int, int] = None # Following classes are used to let weakref ref the stream and eventlist objects. @@ -48,8 +49,8 @@ def __init__( self.module_inference_time = 0 self.call_count = 0 self.mapping = mapping - self.custom_passes = Backend.get_custom_pass(enable_userbuffers, - mapping) + self.custom_passes = Backend.build_custom_passes( + enable_userbuffers, mapping) self.rank = tensorrt_llm.mpi_rank() self.enable_inductor = enable_inductor self.capture_num_tokens = sorted(capture_num_tokens or []) @@ -63,27 +64,29 @@ def __init__( Backend._graph_pool_handle = torch.cuda.graph_pool_handle() self.match_count = [] + self.match_count_by_pass = OrderedDict() @classmethod - def get_custom_pass(cls, enable_userbuffers, mapping: Mapping): + def build_custom_passes(cls, enable_userbuffers, mapping: Mapping): world_size = tensorrt_llm.mpi_world_size() - if not cls._custom_pass_instances: - # Really naive pass manager here - cls._custom_pass_instances = [PatternMatcherPass()] - if world_size > 1: - # Currently torch compile cannot work properly with lamport fusion kernel - # TO-DO: Fix this issue - os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1" - ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported( - ) - register_ar_fusions(cls._custom_pass_instances, mapping, - ub_enabled) - # Fallback: fuse remaining add+rmsnorm not preceded by allreduce - cls._custom_pass_instances.append(PatternMatcherPass()) - register_add_norm(cls._custom_pass_instances[-1]) - else: - register_add_norm(cls._custom_pass_instances[0]) - return cls._custom_pass_instances + # Really naive pass manager here + custom_passes = [PatternMatcherPass("add_norm", MATCHER_SUBSYSTEM)] + if world_size > 1: + # Currently torch compile cannot work properly with lamport fusion kernel + # TO-DO: Fix this issue + os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1" + ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported( + ) + custom_passes[-1] = PatternMatcherPass("ar_residual_norm", + MATCHER_SUBSYSTEM) + register_ar_fusions(custom_passes, mapping, ub_enabled) + # Fallback: fuse remaining add+rmsnorm not preceded by allreduce + custom_passes.append( + PatternMatcherPass("add_norm_fallback", MATCHER_SUBSYSTEM)) + register_add_norm(custom_passes[-1]) + else: + register_add_norm(custom_passes[0]) + return custom_passes def bypass_optimization(self): self.no_optimization = True @@ -103,10 +106,20 @@ def optimize( example_inputs: List[torch.Tensor], ): graph = gm.graph + self.match_count = [] + self.match_count_by_pass = OrderedDict() for custom_pass in self.custom_passes: - self.match_count.append(custom_pass.apply(graph)) - while self.match_count[-1]: - self.match_count.append(custom_pass.apply(graph)) + total_match_count = 0 + match_count = custom_pass.apply(graph) + self.match_count.append(match_count) + total_match_count += match_count + while match_count: + match_count = custom_pass.apply(graph) + self.match_count.append(match_count) + total_match_count += match_count + pass_name = custom_pass.pass_name or ( + f"unnamed_pass_{len(self.match_count_by_pass)}") + self.match_count_by_pass[pass_name] = total_match_count graph.eliminate_dead_code() # After this pass, cannot run any dce!!! remove_copy_for_mutates_args(graph) diff --git a/tensorrt_llm/_torch/compilation/patterns/__init__.py b/tensorrt_llm/_torch/compilation/patterns/__init__.py index e69de29bb2d1..bd3c588297ec 100644 --- a/tensorrt_llm/_torch/compilation/patterns/__init__.py +++ b/tensorrt_llm/_torch/compilation/patterns/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MATCHER_SUBSYSTEM = "torch_compile" diff --git a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py index 3d2b2e96245a..7a6a4ba618a4 100644 --- a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py +++ b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py @@ -8,10 +8,49 @@ PatternMatcherPass, fwd_only, register_replacement) +from tensorrt_llm.mapping import Mapping + from ...distributed import AllReduceFusionOp, AllReduceStrategy +from . import MATCHER_SUBSYSTEM aten = torch.ops.aten -from tensorrt_llm.mapping import Mapping + + +def _append_named_pass(custom_passes: List[PatternMatcherPass], pass_name: str): + custom_passes.append(PatternMatcherPass(pass_name, MATCHER_SUBSYSTEM)) + + +def _check_getitem_only_users(match: Match, pattern_node) -> bool: + node = match.ctx.pattern_to_node[pattern_node] + if not isinstance(node, torch.fx.graph.Node): + return False + for user in node.users: + if user.op != "call_function" or user.target is not getitem: + return False + return True + + +def _has_getitem_user(match: Match, pattern_node, index: int) -> bool: + node = match.ctx.pattern_to_node[pattern_node] + if not isinstance(node, torch.fx.graph.Node): + return False + for user in node.users: + if (user.op == "call_function" and user.target is getitem + and user.args[1] == index): + return True + return False + + +def _make_fp8_quant_extra_check(input_node, strategy_node, quant_node, + require_scale_output: bool): + + def extra_check(match: Match) -> bool: + return (check_f16_bf16_input(match, input_node) + and check_non_ub_strategy(match, strategy_node) + and _check_getitem_only_users(match, quant_node) and + _has_getitem_user(match, quant_node, 1) == require_scale_output) + + return extra_check def register_ar_residual_norm(custom_pass: PatternMatcherPass, mapping: Mapping, @@ -134,15 +173,16 @@ def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass, torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor.default, getitem_0, KeywordArg("scale"), - _users=2) - getitem_2 = CallFunction(getitem, - static_quantize_e4m3_per_tensor_default, - 0, - _users=2) + _users=MULTIPLE) + getitem_2 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default, + 0) getitem_3 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default, 1) - pattern = MultiOutputPattern([getitem_0, getitem_1, getitem_2, getitem_3 - ]) # norm_out, residual_out, quant_out, scale + pattern_with_scale = MultiOutputPattern( + [getitem_0, getitem_1, getitem_2, + getitem_3]) # norm_out, residual_out, quant_out, scale + pattern_without_scale = MultiOutputPattern( + [getitem_0, getitem_1, getitem_2]) # norm_out, residual_out, quant_out def empty_pattern( input: torch.Tensor, @@ -173,9 +213,29 @@ def target_pattern( trigger_completion_at_end) return allreduce[0], allreduce[2], allreduce[1], scale - def extra_check(match: Match) -> bool: - return check_f16_bf16_input( - match, input_node) and check_non_ub_strategy(match, strategy_node) + def target_pattern_without_scale( + input: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + workspace: torch.LongTensor, + strategy: int, + eps: float, + scale: torch.Tensor, + trigger_completion_at_end: bool, + ): + allreduce = allreduce_func( + input, residual, gamma, scale, None, workspace, mapping.tp_group, + int(strategy), + int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8), float(eps), + trigger_completion_at_end) + return allreduce[0], allreduce[2], allreduce[1] + + extra_check_with_scale = _make_fp8_quant_extra_check( + input_node, strategy_node, static_quantize_e4m3_per_tensor_default, + True) + extra_check_without_scale = _make_fp8_quant_extra_check( + input_node, strategy_node, static_quantize_e4m3_per_tensor_default, + False) register_replacement( empty_pattern, @@ -183,8 +243,18 @@ def extra_check(match: Match) -> bool: [], fwd_only, custom_pass, - search_fn_pattern=pattern, - extra_check=extra_check, + search_fn_pattern=pattern_with_scale, + extra_check=extra_check_with_scale, + ) + + register_replacement( + empty_pattern, + target_pattern_without_scale, + [], + fwd_only, + custom_pass, + search_fn_pattern=pattern_without_scale, + extra_check=extra_check_without_scale, ) @@ -212,15 +282,15 @@ def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass, torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor.default, getitem_0, KeywordArg("scale"), - _users=2) - getitem_2 = CallFunction(getitem, - static_quantize_e4m3_per_tensor_default, - 0, - _users=2) + _users=MULTIPLE) + getitem_2 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default, + 0) getitem_3 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default, 1) - pattern = MultiOutputPattern([getitem_1, getitem_2, - getitem_3]) # residual_out, quant_out, scale + pattern_with_scale = MultiOutputPattern( + [getitem_1, getitem_2, getitem_3]) # residual_out, quant_out, scale + pattern_without_scale = MultiOutputPattern([getitem_1, getitem_2 + ]) # residual_out, quant_out def empty_pattern( input: torch.Tensor, @@ -250,9 +320,28 @@ def target_pattern( float(eps), trigger_completion_at_end) return allreduce[1], allreduce[0], scale - def extra_check(match: Match) -> bool: - return check_f16_bf16_input( - match, input_node) and check_non_ub_strategy(match, strategy_node) + def target_pattern_without_scale( + input: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + workspace: torch.LongTensor, + strategy: int, + eps: float, + scale: torch.Tensor, + trigger_completion_at_end: bool, + ): + allreduce = allreduce_func( + input, residual, gamma, scale, None, workspace, mapping.tp_group, + int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8), + float(eps), trigger_completion_at_end) + return allreduce[1], allreduce[0] + + extra_check_with_scale = _make_fp8_quant_extra_check( + input_node, strategy_node, static_quantize_e4m3_per_tensor_default, + True) + extra_check_without_scale = _make_fp8_quant_extra_check( + input_node, strategy_node, static_quantize_e4m3_per_tensor_default, + False) register_replacement( empty_pattern, @@ -260,8 +349,18 @@ def extra_check(match: Match) -> bool: [], fwd_only, custom_pass, - search_fn_pattern=pattern, - extra_check=extra_check, + search_fn_pattern=pattern_with_scale, + extra_check=extra_check_with_scale, + ) + + register_replacement( + empty_pattern, + target_pattern_without_scale, + [], + fwd_only, + custom_pass, + search_fn_pattern=pattern_without_scale, + extra_check=extra_check_without_scale, ) @@ -772,16 +871,20 @@ def extra_check(match: Match) -> bool: extra_check=extra_check, ) - custom_passes.append(PatternMatcherPass()) + _append_named_pass( + custom_passes, + f"ub_convert_supported_ar_to_ub:{allreduce_func.__name__}") register_convert_supported_ar_to_ub(custom_passes[-1]) - custom_passes.append(PatternMatcherPass()) + _append_named_pass(custom_passes, f"ub_prologue:{allreduce_func.__name__}") register_ub_prologue_patterns(custom_passes[-1]) - custom_passes.append(PatternMatcherPass()) + _append_named_pass(custom_passes, f"ub_finalize:{allreduce_func.__name__}") register_ub_finalize_patterns(custom_passes[-1]) - custom_passes.append(PatternMatcherPass()) + _append_named_pass( + custom_passes, + f"insert_copy_for_graph_output:{allreduce_func.__name__}") insert_copy_for_graph_output(custom_passes[-1]) @@ -792,7 +895,7 @@ def register_ar_fusions(custom_passes: List[PatternMatcherPass], register_ar_residual_norm(custom_passes[-1], mapping, torch.ops.trtllm.tunable_allreduce) - custom_passes.append(PatternMatcherPass()) + _append_named_pass(custom_passes, "ar_residual_norm_quant") for allreduce_func in [ torch.ops.trtllm.allreduce, torch.ops.trtllm.tunable_allreduce ]: diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 0cfeddceaab5..062a1f55e365 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -266,22 +266,6 @@ unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer::test_two_sta accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[fp8-1-trtllm] SKIP (https://nvbugs/5921674) full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] SKIP (https://nvbugs/5929339) accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5940463) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460) -unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460) unittest/_torch/multimodal/test_mm_encoder_standalone.py::test_multi_request_batch_chat[llava_7b-False-False-False] SKIP (https://nvbugs/5983320) unittest/_torch/multimodal/test_mm_encoder_standalone.py::test_multi_request_batch_chat[llava_7b-False-True-True] SKIP (https://nvbugs/5983320) cpp/test_e2e.py::test_model[-gpt-80] SKIP (https://nvbugs/5983283) diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index eb5051afcb52..2eea3ccb7b9b 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -33,6 +33,15 @@ pytestmark = pytest.mark.threadleak(enabled=False) +def _assert_match_counts(backend, expected_match_count_by_pass): + # Key format: + # - "" for passes that are registered once + # - ":" for UB passes that are registered + # per allreduce backend, where is "allreduce" or + # "tunable_allreduce" + assert dict(backend.match_count_by_pass) == expected_match_count_by_pass + + def create_tp_mapping(tp_size, rank): return Mapping( world_size=tp_size, @@ -428,6 +437,130 @@ def forward(self, input): return res +def run_single_rank_ar_rms_norm_fp8_live_scale_compile(tensor_parallel_size, a, + b, c, gamma, scale): + rank = tensorrt_llm.mpi_rank() + torch.cuda.set_device(rank) + try: + eps = 1e-6 + dtype = a.dtype + + a = a.cuda() + c = c.cuda() + gamma = gamma.cuda() + scale = scale.cuda() + + ub_size = c.nelement() * c.element_size() + init_userbuffers_allocator(tensor_parallel_size, rank, ub_size) + + b_partial = torch.chunk(b, tensor_parallel_size, 0) + weight = b_partial[rank].cuda() + mapping = create_tp_mapping(tensor_parallel_size, rank) + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.allreduce = AllReduce(mapping=mapping, + strategy=AllReduceStrategy.AUTO, + dtype=dtype) + + def forward(self, input, residual): + local = torch.chunk(input, tensor_parallel_size, + 1)[rank].contiguous() + hidden = torch.matmul(local, weight) + norm, fused_residual = self.allreduce( + hidden, + all_reduce_params=AllReduceParams( + strategy=AllReduceStrategy.AUTO, + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=gamma, + eps=eps, + )) + q_norm, q_scale = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + norm, scale) + dequantized = dequant(q_norm, q_scale, input.dtype) + return dequantized, fused_residual, q_scale + + model = Model() + backend = Backend(enable_inductor=False, + enable_userbuffers=True, + mapping=mapping) + model_opt = torch.compile(model, backend=backend, fullgraph=True) + + with torch.inference_mode(): + ref_dequantized, ref_residual, ref_scale = model(a, c) + fused_dequantized, fused_residual, fused_scale = model_opt(a, c) + + _assert_match_counts( + backend, { + "ar_residual_norm": 0, + "ar_residual_norm_quant": 1, + "ub_convert_supported_ar_to_ub:allreduce": 0, + "ub_prologue:allreduce": 0, + "ub_finalize:allreduce": 0, + "insert_copy_for_graph_output:allreduce": 0, + "ub_convert_supported_ar_to_ub:tunable_allreduce": 1, + "ub_prologue:tunable_allreduce": 1, + "ub_finalize:tunable_allreduce": 0, + "insert_copy_for_graph_output:tunable_allreduce": 0, + "add_norm_fallback": 0, + }) + + torch.cuda.synchronize() + if rank == 0: + torch.testing.assert_close(fused_dequantized, + ref_dequantized, + atol=5e-1, + rtol=1e-2) + torch.testing.assert_close(fused_residual, + ref_residual, + atol=5e-1, + rtol=1e-2) + torch.testing.assert_close(fused_scale, ref_scale) + except Exception: + traceback.print_exc() + raise + return True + + +def run_single_rank_backend_passes_are_per_instance(tensor_parallel_size): + rank = tensorrt_llm.mpi_rank() + torch.cuda.set_device(rank) + support = ub.ub_supported() + if not support: + return True + + mapping = create_tp_mapping(tensor_parallel_size, rank) + backend_no_ub = Backend(enable_inductor=False, + enable_userbuffers=False, + mapping=mapping) + backend_ub = Backend(enable_inductor=False, + enable_userbuffers=True, + mapping=mapping) + + assert [ + custom_pass.pass_name for custom_pass in backend_no_ub.custom_passes + ] == ["ar_residual_norm", "ar_residual_norm_quant", "add_norm_fallback"] + assert [custom_pass.pass_name + for custom_pass in backend_ub.custom_passes] == [ + "ar_residual_norm", + "ar_residual_norm_quant", + "ub_convert_supported_ar_to_ub:allreduce", + "ub_prologue:allreduce", + "ub_finalize:allreduce", + "insert_copy_for_graph_output:allreduce", + "ub_convert_supported_ar_to_ub:tunable_allreduce", + "ub_prologue:tunable_allreduce", + "ub_finalize:tunable_allreduce", + "insert_copy_for_graph_output:tunable_allreduce", + "add_norm_fallback", + ] + + return True + + def run_single_rank_ub_pass( tensor_parallel_size, input, l0_weight, l0_input_scale, l0_weight_scale, l1_weight, l1_input_scale, l1_weight_scale, l2_weight, l2_input_scale, @@ -460,15 +593,24 @@ def run_single_rank_ub_pass( model_opt = torch.compile(model, backend=backend, fullgraph=True) with torch.inference_mode(): output_fused = model_opt(input) - # 3 AR_NORM fusion happens first - # 2 AR_NORM fused with Quant - # 3 AR_NORM replacement - # 3 Scaled MM Prologue - # 2 UB Finalize Removal - # 1 Insert copy for graph output - assert backend.match_count == [ - 3, 0, 2, 0, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0, 1, 0 - ] + # Assert the exact named pass totals rather than the raw fixed-point + # trace in backend.match_count. This is still an intentional tripwire + # for optimizer changes, but on semantic pass names instead of + # pass-manager bookkeeping zeros. + _assert_match_counts( + backend, { + "ar_residual_norm": 3, + "ar_residual_norm_quant": 2, + "ub_convert_supported_ar_to_ub:allreduce": 0, + "ub_prologue:allreduce": 0, + "ub_finalize:allreduce": 0, + "insert_copy_for_graph_output:allreduce": 0, + "ub_convert_supported_ar_to_ub:tunable_allreduce": 3, + "ub_prologue:tunable_allreduce": 3, + "ub_finalize:tunable_allreduce": 2, + "insert_copy_for_graph_output:tunable_allreduce": 1, + "add_norm_fallback": 0, + }) torch.cuda.synchronize() if rank == 0: @@ -533,8 +675,7 @@ def ref_scaled_mm_col(x, w, in_s, w_s): rtol=1e-2) except Exception: traceback.print_exc() - - return False + raise return True @@ -581,6 +722,44 @@ def test_user_buffers_pass(hidden, tokens, dtype, mpi_pool_executor): assert r is True +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason='needs 2 GPUs to run this test') +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], + ids=lambda x: "fp16" if x == torch.float16 else "bf16") +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_user_buffers_ar_rms_norm_fp8_live_scale_compile( + dtype, mpi_pool_executor): + torch.manual_seed(44) + tensor_parallel_size = 2 + m = 16 + n = 32 + k = 16 + a = torch.randn((m, k), dtype=dtype) + b = torch.randn((k, n), dtype=dtype) + c = torch.randn((m, n), dtype=dtype) + gamma = torch.randn((n), dtype=dtype) + scale = torch.full((1, ), 0.1, dtype=torch.float32) + + results = mpi_pool_executor.map( + run_single_rank_ar_rms_norm_fp8_live_scale_compile, + *zip(*[(tensor_parallel_size, a, b, c, gamma, scale)] * + tensor_parallel_size)) + for r in results: + assert r is True + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason='needs 2 GPUs to run this test') +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_backend_passes_are_per_instance(mpi_pool_executor): + tensor_parallel_size = 2 + results = mpi_pool_executor.map( + run_single_rank_backend_passes_are_per_instance, + *zip(*[(tensor_parallel_size, )] * tensor_parallel_size)) + for r in results: + assert r is True + + def run_single_rank_ar_rms_norm_fp4(tensor_parallel_size, a, b, c, gamma): rank = tensorrt_llm.mpi_rank() torch.cuda.set_device(rank) @@ -762,15 +941,24 @@ def run_single_rank_ub_mm_add_pass(tensor_parallel_size, num_tokens, torch.cuda.synchronize() output_ref = model(mm0_input_0, mm0_input_1, mm1_input_0, mm1_input_1, residual_0, residual_1) - # 3 AR_NORM fusion happens first - # 0 AR_NORM fused with Quant - # 3 AR_NORM replacement - # 3 Prologue - # 1 UB Finalize Removal - # 1 Insert copy for graph output - assert backend.match_count == [ - 3, 0, 0, 0, 0, 0, 0, 3, 0, 3, 0, 1, 0, 1, 0 - ] + # Assert the exact named pass totals rather than the raw fixed-point + # trace in backend.match_count. This is still an intentional tripwire + # for optimizer changes, but on semantic pass names instead of + # pass-manager bookkeeping zeros. + _assert_match_counts( + backend, { + "ar_residual_norm": 3, + "ar_residual_norm_quant": 0, + "ub_convert_supported_ar_to_ub:allreduce": 0, + "ub_prologue:allreduce": 0, + "ub_finalize:allreduce": 0, + "insert_copy_for_graph_output:allreduce": 0, + "ub_convert_supported_ar_to_ub:tunable_allreduce": 3, + "ub_prologue:tunable_allreduce": 3, + "ub_finalize:tunable_allreduce": 1, + "insert_copy_for_graph_output:tunable_allreduce": 1, + "add_norm_fallback": 0, + }) torch.cuda.synchronize() if rank == 0: @@ -1004,15 +1192,24 @@ def block_scale_unswizzled(scale): output_ref = model(input) output_fused = model_opt(input) - # 3 AR_NORM fusion happens first - # 2 AR_NORM fused with Quant - # 3 AR_NORM replacement - # 3 Scaled MM Prologue - # 2 UB Finalize Removal - # 1 Insert copy for graph output - assert backend.match_count == [ - 3, 0, 2, 0, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0, 1, 0 - ] + # Assert the exact named pass totals rather than the raw fixed-point + # trace in backend.match_count. This is still an intentional tripwire + # for optimizer changes, but on semantic pass names instead of + # pass-manager bookkeeping zeros. + _assert_match_counts( + backend, { + "ar_residual_norm": 3, + "ar_residual_norm_quant": 2, + "ub_convert_supported_ar_to_ub:allreduce": 0, + "ub_prologue:allreduce": 0, + "ub_finalize:allreduce": 0, + "insert_copy_for_graph_output:allreduce": 0, + "ub_convert_supported_ar_to_ub:tunable_allreduce": 3, + "ub_prologue:tunable_allreduce": 3, + "ub_finalize:tunable_allreduce": 2, + "insert_copy_for_graph_output:tunable_allreduce": 1, + "add_norm_fallback": 0, + }) torch.cuda.synchronize() torch.testing.assert_close(output_fused, output_ref,