Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions tensorrt_llm/_torch/compilation/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections import OrderedDict
from typing import List, Optional

import torch
Expand All @@ -14,6 +15,7 @@
from tensorrt_llm.mapping import Mapping

from .multi_stream.auto_multi_stream import multi_stream_schedule
from .patterns import MATCHER_SUBSYSTEM
from .patterns.ar_residual_norm import register_ar_fusions
from .patterns.residual_add_norm import register_add_norm
from .piecewise_optimizer import piecewise_optimizer
Expand All @@ -23,7 +25,6 @@

class Backend:

_custom_pass_instances: List[PatternMatcherPass] = None
_graph_pool_handle: tuple[int, int] = None

# Following classes are used to let weakref ref the stream and eventlist objects.
Expand All @@ -48,8 +49,8 @@ def __init__(
self.module_inference_time = 0
self.call_count = 0
self.mapping = mapping
self.custom_passes = Backend.get_custom_pass(enable_userbuffers,
mapping)
self.custom_passes = Backend.build_custom_passes(
enable_userbuffers, mapping)
self.rank = tensorrt_llm.mpi_rank()
self.enable_inductor = enable_inductor
self.capture_num_tokens = sorted(capture_num_tokens or [])
Expand All @@ -63,27 +64,29 @@ def __init__(
Backend._graph_pool_handle = torch.cuda.graph_pool_handle()

self.match_count = []
self.match_count_by_pass = OrderedDict()

@classmethod
def get_custom_pass(cls, enable_userbuffers, mapping: Mapping):
def build_custom_passes(cls, enable_userbuffers, mapping: Mapping):
world_size = tensorrt_llm.mpi_world_size()
if not cls._custom_pass_instances:
# Really naive pass manager here
cls._custom_pass_instances = [PatternMatcherPass()]
if world_size > 1:
# Currently torch compile cannot work properly with lamport fusion kernel
# TO-DO: Fix this issue
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
)
register_ar_fusions(cls._custom_pass_instances, mapping,
ub_enabled)
# Fallback: fuse remaining add+rmsnorm not preceded by allreduce
cls._custom_pass_instances.append(PatternMatcherPass())
register_add_norm(cls._custom_pass_instances[-1])
else:
register_add_norm(cls._custom_pass_instances[0])
return cls._custom_pass_instances
# Really naive pass manager here
custom_passes = [PatternMatcherPass("add_norm", MATCHER_SUBSYSTEM)]
if world_size > 1:
# Currently torch compile cannot work properly with lamport fusion kernel
# TO-DO: Fix this issue
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
)
custom_passes[-1] = PatternMatcherPass("ar_residual_norm",
MATCHER_SUBSYSTEM)
register_ar_fusions(custom_passes, mapping, ub_enabled)
# Fallback: fuse remaining add+rmsnorm not preceded by allreduce
custom_passes.append(
PatternMatcherPass("add_norm_fallback", MATCHER_SUBSYSTEM))
register_add_norm(custom_passes[-1])
else:
register_add_norm(custom_passes[0])
return custom_passes

def bypass_optimization(self):
self.no_optimization = True
Expand All @@ -103,10 +106,20 @@ def optimize(
example_inputs: List[torch.Tensor],
):
graph = gm.graph
self.match_count = []
self.match_count_by_pass = OrderedDict()
for custom_pass in self.custom_passes:
self.match_count.append(custom_pass.apply(graph))
while self.match_count[-1]:
self.match_count.append(custom_pass.apply(graph))
total_match_count = 0
match_count = custom_pass.apply(graph)
self.match_count.append(match_count)
total_match_count += match_count
while match_count:
match_count = custom_pass.apply(graph)
self.match_count.append(match_count)
total_match_count += match_count
pass_name = custom_pass.pass_name or (
f"unnamed_pass_{len(self.match_count_by_pass)}")
self.match_count_by_pass[pass_name] = total_match_count
graph.eliminate_dead_code()
# After this pass, cannot run any dce!!!
remove_copy_for_mutates_args(graph)
Expand Down
16 changes: 16 additions & 0 deletions tensorrt_llm/_torch/compilation/patterns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

MATCHER_SUBSYSTEM = "torch_compile"
Comment thread
dhansen-nvidia marked this conversation as resolved.
163 changes: 133 additions & 30 deletions tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,49 @@
PatternMatcherPass, fwd_only,
register_replacement)

from tensorrt_llm.mapping import Mapping

from ...distributed import AllReduceFusionOp, AllReduceStrategy
from . import MATCHER_SUBSYSTEM

aten = torch.ops.aten
from tensorrt_llm.mapping import Mapping


def _append_named_pass(custom_passes: List[PatternMatcherPass], pass_name: str):
custom_passes.append(PatternMatcherPass(pass_name, MATCHER_SUBSYSTEM))


def _check_getitem_only_users(match: Match, pattern_node) -> bool:
node = match.ctx.pattern_to_node[pattern_node]
if not isinstance(node, torch.fx.graph.Node):
return False
for user in node.users:
if user.op != "call_function" or user.target is not getitem:
return False
return True


def _has_getitem_user(match: Match, pattern_node, index: int) -> bool:
node = match.ctx.pattern_to_node[pattern_node]
if not isinstance(node, torch.fx.graph.Node):
return False
for user in node.users:
if (user.op == "call_function" and user.target is getitem
and user.args[1] == index):
return True
return False


def _make_fp8_quant_extra_check(input_node, strategy_node, quant_node,
require_scale_output: bool):

def extra_check(match: Match) -> bool:
return (check_f16_bf16_input(match, input_node)
and check_non_ub_strategy(match, strategy_node)
and _check_getitem_only_users(match, quant_node) and
_has_getitem_user(match, quant_node, 1) == require_scale_output)

return extra_check


def register_ar_residual_norm(custom_pass: PatternMatcherPass, mapping: Mapping,
Expand Down Expand Up @@ -134,15 +173,16 @@ def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass,
torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor.default,
getitem_0,
KeywordArg("scale"),
_users=2)
getitem_2 = CallFunction(getitem,
static_quantize_e4m3_per_tensor_default,
0,
_users=2)
_users=MULTIPLE)
getitem_2 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default,
0)
getitem_3 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default,
1)
pattern = MultiOutputPattern([getitem_0, getitem_1, getitem_2, getitem_3
]) # norm_out, residual_out, quant_out, scale
pattern_with_scale = MultiOutputPattern(
[getitem_0, getitem_1, getitem_2,
getitem_3]) # norm_out, residual_out, quant_out, scale
pattern_without_scale = MultiOutputPattern(
[getitem_0, getitem_1, getitem_2]) # norm_out, residual_out, quant_out

def empty_pattern(
input: torch.Tensor,
Expand Down Expand Up @@ -173,18 +213,48 @@ def target_pattern(
trigger_completion_at_end)
return allreduce[0], allreduce[2], allreduce[1], scale

def extra_check(match: Match) -> bool:
return check_f16_bf16_input(
match, input_node) and check_non_ub_strategy(match, strategy_node)
def target_pattern_without_scale(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
workspace: torch.LongTensor,
strategy: int,
eps: float,
scale: torch.Tensor,
trigger_completion_at_end: bool,
):
allreduce = allreduce_func(
input, residual, gamma, scale, None, workspace, mapping.tp_group,
int(strategy),
int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8), float(eps),
trigger_completion_at_end)
return allreduce[0], allreduce[2], allreduce[1]

extra_check_with_scale = _make_fp8_quant_extra_check(
input_node, strategy_node, static_quantize_e4m3_per_tensor_default,
True)
extra_check_without_scale = _make_fp8_quant_extra_check(
input_node, strategy_node, static_quantize_e4m3_per_tensor_default,
False)

register_replacement(
empty_pattern,
target_pattern,
[],
fwd_only,
custom_pass,
search_fn_pattern=pattern,
extra_check=extra_check,
search_fn_pattern=pattern_with_scale,
extra_check=extra_check_with_scale,
)

register_replacement(
empty_pattern,
target_pattern_without_scale,
[],
fwd_only,
custom_pass,
search_fn_pattern=pattern_without_scale,
extra_check=extra_check_without_scale,
)


Expand Down Expand Up @@ -212,15 +282,15 @@ def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass,
torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor.default,
getitem_0,
KeywordArg("scale"),
_users=2)
getitem_2 = CallFunction(getitem,
static_quantize_e4m3_per_tensor_default,
0,
_users=2)
_users=MULTIPLE)
getitem_2 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default,
0)
getitem_3 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default,
1)
pattern = MultiOutputPattern([getitem_1, getitem_2,
getitem_3]) # residual_out, quant_out, scale
pattern_with_scale = MultiOutputPattern(
[getitem_1, getitem_2, getitem_3]) # residual_out, quant_out, scale
pattern_without_scale = MultiOutputPattern([getitem_1, getitem_2
]) # residual_out, quant_out

def empty_pattern(
input: torch.Tensor,
Expand Down Expand Up @@ -250,18 +320,47 @@ def target_pattern(
float(eps), trigger_completion_at_end)
return allreduce[1], allreduce[0], scale

def extra_check(match: Match) -> bool:
return check_f16_bf16_input(
match, input_node) and check_non_ub_strategy(match, strategy_node)
def target_pattern_without_scale(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
workspace: torch.LongTensor,
strategy: int,
eps: float,
scale: torch.Tensor,
trigger_completion_at_end: bool,
):
allreduce = allreduce_func(
input, residual, gamma, scale, None, workspace, mapping.tp_group,
int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8),
float(eps), trigger_completion_at_end)
return allreduce[1], allreduce[0]

extra_check_with_scale = _make_fp8_quant_extra_check(
input_node, strategy_node, static_quantize_e4m3_per_tensor_default,
True)
extra_check_without_scale = _make_fp8_quant_extra_check(
input_node, strategy_node, static_quantize_e4m3_per_tensor_default,
False)

register_replacement(
empty_pattern,
target_pattern,
[],
fwd_only,
custom_pass,
search_fn_pattern=pattern,
extra_check=extra_check,
search_fn_pattern=pattern_with_scale,
extra_check=extra_check_with_scale,
)

register_replacement(
empty_pattern,
target_pattern_without_scale,
[],
fwd_only,
custom_pass,
search_fn_pattern=pattern_without_scale,
extra_check=extra_check_without_scale,
)


Expand Down Expand Up @@ -772,16 +871,20 @@ def extra_check(match: Match) -> bool:
extra_check=extra_check,
)

custom_passes.append(PatternMatcherPass())
_append_named_pass(
custom_passes,
f"ub_convert_supported_ar_to_ub:{allreduce_func.__name__}")
register_convert_supported_ar_to_ub(custom_passes[-1])

custom_passes.append(PatternMatcherPass())
_append_named_pass(custom_passes, f"ub_prologue:{allreduce_func.__name__}")
register_ub_prologue_patterns(custom_passes[-1])

custom_passes.append(PatternMatcherPass())
_append_named_pass(custom_passes, f"ub_finalize:{allreduce_func.__name__}")
register_ub_finalize_patterns(custom_passes[-1])

custom_passes.append(PatternMatcherPass())
_append_named_pass(
custom_passes,
f"insert_copy_for_graph_output:{allreduce_func.__name__}")
insert_copy_for_graph_output(custom_passes[-1])


Expand All @@ -792,7 +895,7 @@ def register_ar_fusions(custom_passes: List[PatternMatcherPass],
register_ar_residual_norm(custom_passes[-1], mapping,
torch.ops.trtllm.tunable_allreduce)

custom_passes.append(PatternMatcherPass())
_append_named_pass(custom_passes, "ar_residual_norm_quant")
for allreduce_func in [
torch.ops.trtllm.allreduce, torch.ops.trtllm.tunable_allreduce
]:
Expand Down
16 changes: 0 additions & 16 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -266,22 +266,6 @@ unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer::test_two_sta
accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[fp8-1-trtllm] SKIP (https://nvbugs/5921674)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] SKIP (https://nvbugs/5929339)
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5940463)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460)
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460)
unittest/_torch/multimodal/test_mm_encoder_standalone.py::test_multi_request_batch_chat[llava_7b-False-False-False] SKIP (https://nvbugs/5983320)
unittest/_torch/multimodal/test_mm_encoder_standalone.py::test_multi_request_batch_chat[llava_7b-False-True-True] SKIP (https://nvbugs/5983320)
cpp/test_e2e.py::test_model[-gpt-80] SKIP (https://nvbugs/5983283)
Expand Down
Loading
Loading