Skip to content

Commit daf5f25

Browse files
committed
[https://nvbugs/5940460][fix] Harden FP8 quant fusion matching after PyTorch 26.02 update
The PyTorch 26.02 stack changed the traced graph shape around static_quantize_e4m3_per_tensor: the unused scale getitem may be dead-code-eliminated, but some graphs still retain a live scale consumer. Update the FP8 AR+Residual+RMSNorm quant fusion patterns to match both live-scale and pruned-scale graphs, with an explicit guard so the 2-output rewrite only fires when the scale output is absent. Name the custom passes and track aggregate match_count_by_pass totals so tests can assert exact semantic pass totals instead of the fixed-point bookkeeping trace. Also build the custom pass pipeline per Backend instance rather than sharing a process-global cache, and add regressions for the live-scale compile path and per-instance backend pass configuration. This keeps the user-buffer tests unwaived without reducing the intended fusion coverage. Signed-off-by: Dan Hansen <1+dhansen-nvidia@users.noreply.github.com>
1 parent 7ee9e8b commit daf5f25

File tree

5 files changed

+396
-98
lines changed

5 files changed

+396
-98
lines changed

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from collections import OrderedDict
23
from typing import List, Optional
34

45
import torch
@@ -14,6 +15,7 @@
1415
from tensorrt_llm.mapping import Mapping
1516

1617
from .multi_stream.auto_multi_stream import multi_stream_schedule
18+
from .patterns import MATCHER_SUBSYSTEM
1719
from .patterns.ar_residual_norm import register_ar_fusions
1820
from .patterns.residual_add_norm import register_add_norm
1921
from .piecewise_optimizer import piecewise_optimizer
@@ -23,7 +25,6 @@
2325

2426
class Backend:
2527

26-
_custom_pass_instances: List[PatternMatcherPass] = None
2728
_graph_pool_handle: tuple[int, int] = None
2829

2930
# Following classes are used to let weakref ref the stream and eventlist objects.
@@ -48,8 +49,8 @@ def __init__(
4849
self.module_inference_time = 0
4950
self.call_count = 0
5051
self.mapping = mapping
51-
self.custom_passes = Backend.get_custom_pass(enable_userbuffers,
52-
mapping)
52+
self.custom_passes = Backend.build_custom_passes(
53+
enable_userbuffers, mapping)
5354
self.rank = tensorrt_llm.mpi_rank()
5455
self.enable_inductor = enable_inductor
5556
self.capture_num_tokens = sorted(capture_num_tokens or [])
@@ -63,27 +64,29 @@ def __init__(
6364
Backend._graph_pool_handle = torch.cuda.graph_pool_handle()
6465

6566
self.match_count = []
67+
self.match_count_by_pass = OrderedDict()
6668

6769
@classmethod
68-
def get_custom_pass(cls, enable_userbuffers, mapping: Mapping):
70+
def build_custom_passes(cls, enable_userbuffers, mapping: Mapping):
6971
world_size = tensorrt_llm.mpi_world_size()
70-
if not cls._custom_pass_instances:
71-
# Really naive pass manager here
72-
cls._custom_pass_instances = [PatternMatcherPass()]
73-
if world_size > 1:
74-
# Currently torch compile cannot work properly with lamport fusion kernel
75-
# TO-DO: Fix this issue
76-
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
77-
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
78-
)
79-
register_ar_fusions(cls._custom_pass_instances, mapping,
80-
ub_enabled)
81-
# Fallback: fuse remaining add+rmsnorm not preceded by allreduce
82-
cls._custom_pass_instances.append(PatternMatcherPass())
83-
register_add_norm(cls._custom_pass_instances[-1])
84-
else:
85-
register_add_norm(cls._custom_pass_instances[0])
86-
return cls._custom_pass_instances
72+
# Really naive pass manager here
73+
custom_passes = [PatternMatcherPass("add_norm", MATCHER_SUBSYSTEM)]
74+
if world_size > 1:
75+
# Currently torch compile cannot work properly with lamport fusion kernel
76+
# TO-DO: Fix this issue
77+
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
78+
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
79+
)
80+
custom_passes[-1] = PatternMatcherPass("ar_residual_norm",
81+
MATCHER_SUBSYSTEM)
82+
register_ar_fusions(custom_passes, mapping, ub_enabled)
83+
# Fallback: fuse remaining add+rmsnorm not preceded by allreduce
84+
custom_passes.append(
85+
PatternMatcherPass("add_norm_fallback", MATCHER_SUBSYSTEM))
86+
register_add_norm(custom_passes[-1])
87+
else:
88+
register_add_norm(custom_passes[0])
89+
return custom_passes
8790

8891
def bypass_optimization(self):
8992
self.no_optimization = True
@@ -103,10 +106,20 @@ def optimize(
103106
example_inputs: List[torch.Tensor],
104107
):
105108
graph = gm.graph
109+
self.match_count = []
110+
self.match_count_by_pass = OrderedDict()
106111
for custom_pass in self.custom_passes:
107-
self.match_count.append(custom_pass.apply(graph))
108-
while self.match_count[-1]:
109-
self.match_count.append(custom_pass.apply(graph))
112+
total_match_count = 0
113+
match_count = custom_pass.apply(graph)
114+
self.match_count.append(match_count)
115+
total_match_count += match_count
116+
while match_count:
117+
match_count = custom_pass.apply(graph)
118+
self.match_count.append(match_count)
119+
total_match_count += match_count
120+
pass_name = custom_pass.pass_name or (
121+
f"unnamed_pass_{len(self.match_count_by_pass)}")
122+
self.match_count_by_pass[pass_name] = total_match_count
110123
graph.eliminate_dead_code()
111124
# After this pass, cannot run any dce!!!
112125
remove_copy_for_mutates_args(graph)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
MATCHER_SUBSYSTEM = "torch_compile"

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 132 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,45 @@
1313
aten = torch.ops.aten
1414
from tensorrt_llm.mapping import Mapping
1515

16+
from . import MATCHER_SUBSYSTEM
17+
18+
19+
def _append_named_pass(custom_passes: List[PatternMatcherPass], pass_name: str):
20+
custom_passes.append(PatternMatcherPass(pass_name, MATCHER_SUBSYSTEM))
21+
22+
23+
def _check_getitem_only_users(match: Match, pattern_node) -> bool:
24+
node = match.ctx.pattern_to_node[pattern_node]
25+
if not isinstance(node, torch.fx.graph.Node):
26+
return False
27+
for user in node.users:
28+
if user.op != "call_function" or user.target is not getitem:
29+
return False
30+
return True
31+
32+
33+
def _has_getitem_user(match: Match, pattern_node, index: int) -> bool:
34+
node = match.ctx.pattern_to_node[pattern_node]
35+
if not isinstance(node, torch.fx.graph.Node):
36+
return False
37+
for user in node.users:
38+
if (user.op == "call_function" and user.target is getitem
39+
and user.args[1] == index):
40+
return True
41+
return False
42+
43+
44+
def _make_fp8_quant_extra_check(input_node, strategy_node, quant_node,
45+
require_scale_output: bool):
46+
47+
def extra_check(match: Match) -> bool:
48+
return (check_f16_bf16_input(match, input_node)
49+
and check_non_ub_strategy(match, strategy_node)
50+
and _check_getitem_only_users(match, quant_node) and
51+
_has_getitem_user(match, quant_node, 1) == require_scale_output)
52+
53+
return extra_check
54+
1655

1756
def register_ar_residual_norm(custom_pass: PatternMatcherPass, mapping: Mapping,
1857
allreduce_func: Callable):
@@ -134,15 +173,16 @@ def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass,
134173
torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor.default,
135174
getitem_0,
136175
KeywordArg("scale"),
137-
_users=2)
138-
getitem_2 = CallFunction(getitem,
139-
static_quantize_e4m3_per_tensor_default,
140-
0,
141-
_users=2)
176+
_users=MULTIPLE)
177+
getitem_2 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default,
178+
0)
142179
getitem_3 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default,
143180
1)
144-
pattern = MultiOutputPattern([getitem_0, getitem_1, getitem_2, getitem_3
145-
]) # norm_out, residual_out, quant_out, scale
181+
pattern_with_scale = MultiOutputPattern(
182+
[getitem_0, getitem_1, getitem_2,
183+
getitem_3]) # norm_out, residual_out, quant_out, scale
184+
pattern_without_scale = MultiOutputPattern(
185+
[getitem_0, getitem_1, getitem_2]) # norm_out, residual_out, quant_out
146186

147187
def empty_pattern(
148188
input: torch.Tensor,
@@ -173,18 +213,48 @@ def target_pattern(
173213
trigger_completion_at_end)
174214
return allreduce[0], allreduce[2], allreduce[1], scale
175215

176-
def extra_check(match: Match) -> bool:
177-
return check_f16_bf16_input(
178-
match, input_node) and check_non_ub_strategy(match, strategy_node)
216+
def target_pattern_without_scale(
217+
input: torch.Tensor,
218+
residual: torch.Tensor,
219+
gamma: torch.Tensor,
220+
workspace: torch.LongTensor,
221+
strategy: int,
222+
eps: float,
223+
scale: torch.Tensor,
224+
trigger_completion_at_end: bool,
225+
):
226+
allreduce = allreduce_func(
227+
input, residual, gamma, scale, None, workspace, mapping.tp_group,
228+
int(strategy),
229+
int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8), float(eps),
230+
trigger_completion_at_end)
231+
return allreduce[0], allreduce[2], allreduce[1]
232+
233+
extra_check_with_scale = _make_fp8_quant_extra_check(
234+
input_node, strategy_node, static_quantize_e4m3_per_tensor_default,
235+
True)
236+
extra_check_without_scale = _make_fp8_quant_extra_check(
237+
input_node, strategy_node, static_quantize_e4m3_per_tensor_default,
238+
False)
179239

180240
register_replacement(
181241
empty_pattern,
182242
target_pattern,
183243
[],
184244
fwd_only,
185245
custom_pass,
186-
search_fn_pattern=pattern,
187-
extra_check=extra_check,
246+
search_fn_pattern=pattern_with_scale,
247+
extra_check=extra_check_with_scale,
248+
)
249+
250+
register_replacement(
251+
empty_pattern,
252+
target_pattern_without_scale,
253+
[],
254+
fwd_only,
255+
custom_pass,
256+
search_fn_pattern=pattern_without_scale,
257+
extra_check=extra_check_without_scale,
188258
)
189259

190260

@@ -212,15 +282,15 @@ def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass,
212282
torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor.default,
213283
getitem_0,
214284
KeywordArg("scale"),
215-
_users=2)
216-
getitem_2 = CallFunction(getitem,
217-
static_quantize_e4m3_per_tensor_default,
218-
0,
219-
_users=2)
285+
_users=MULTIPLE)
286+
getitem_2 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default,
287+
0)
220288
getitem_3 = CallFunction(getitem, static_quantize_e4m3_per_tensor_default,
221289
1)
222-
pattern = MultiOutputPattern([getitem_1, getitem_2,
223-
getitem_3]) # residual_out, quant_out, scale
290+
pattern_with_scale = MultiOutputPattern(
291+
[getitem_1, getitem_2, getitem_3]) # residual_out, quant_out, scale
292+
pattern_without_scale = MultiOutputPattern([getitem_1, getitem_2
293+
]) # residual_out, quant_out
224294

225295
def empty_pattern(
226296
input: torch.Tensor,
@@ -250,18 +320,47 @@ def target_pattern(
250320
float(eps), trigger_completion_at_end)
251321
return allreduce[1], allreduce[0], scale
252322

253-
def extra_check(match: Match) -> bool:
254-
return check_f16_bf16_input(
255-
match, input_node) and check_non_ub_strategy(match, strategy_node)
323+
def target_pattern_without_scale(
324+
input: torch.Tensor,
325+
residual: torch.Tensor,
326+
gamma: torch.Tensor,
327+
workspace: torch.LongTensor,
328+
strategy: int,
329+
eps: float,
330+
scale: torch.Tensor,
331+
trigger_completion_at_end: bool,
332+
):
333+
allreduce = allreduce_func(
334+
input, residual, gamma, scale, None, workspace, mapping.tp_group,
335+
int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8),
336+
float(eps), trigger_completion_at_end)
337+
return allreduce[1], allreduce[0]
338+
339+
extra_check_with_scale = _make_fp8_quant_extra_check(
340+
input_node, strategy_node, static_quantize_e4m3_per_tensor_default,
341+
True)
342+
extra_check_without_scale = _make_fp8_quant_extra_check(
343+
input_node, strategy_node, static_quantize_e4m3_per_tensor_default,
344+
False)
256345

257346
register_replacement(
258347
empty_pattern,
259348
target_pattern,
260349
[],
261350
fwd_only,
262351
custom_pass,
263-
search_fn_pattern=pattern,
264-
extra_check=extra_check,
352+
search_fn_pattern=pattern_with_scale,
353+
extra_check=extra_check_with_scale,
354+
)
355+
356+
register_replacement(
357+
empty_pattern,
358+
target_pattern_without_scale,
359+
[],
360+
fwd_only,
361+
custom_pass,
362+
search_fn_pattern=pattern_without_scale,
363+
extra_check=extra_check_without_scale,
265364
)
266365

267366

@@ -772,16 +871,20 @@ def extra_check(match: Match) -> bool:
772871
extra_check=extra_check,
773872
)
774873

775-
custom_passes.append(PatternMatcherPass())
874+
_append_named_pass(
875+
custom_passes,
876+
f"ub_convert_supported_ar_to_ub:{allreduce_func.__name__}")
776877
register_convert_supported_ar_to_ub(custom_passes[-1])
777878

778-
custom_passes.append(PatternMatcherPass())
879+
_append_named_pass(custom_passes, f"ub_prologue:{allreduce_func.__name__}")
779880
register_ub_prologue_patterns(custom_passes[-1])
780881

781-
custom_passes.append(PatternMatcherPass())
882+
_append_named_pass(custom_passes, f"ub_finalize:{allreduce_func.__name__}")
782883
register_ub_finalize_patterns(custom_passes[-1])
783884

784-
custom_passes.append(PatternMatcherPass())
885+
_append_named_pass(
886+
custom_passes,
887+
f"insert_copy_for_graph_output:{allreduce_func.__name__}")
785888
insert_copy_for_graph_output(custom_passes[-1])
786889

787890

@@ -792,7 +895,7 @@ def register_ar_fusions(custom_passes: List[PatternMatcherPass],
792895
register_ar_residual_norm(custom_passes[-1], mapping,
793896
torch.ops.trtllm.tunable_allreduce)
794897

795-
custom_passes.append(PatternMatcherPass())
898+
_append_named_pass(custom_passes, "ar_residual_norm_quant")
796899
for allreduce_func in [
797900
torch.ops.trtllm.allreduce, torch.ops.trtllm.tunable_allreduce
798901
]:

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -269,22 +269,6 @@ unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer::test_two_sta
269269
accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[fp8-1-trtllm] SKIP (https://nvbugs/5921674)
270270
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] SKIP (https://nvbugs/5929339)
271271
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5940463)
272-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
273-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
274-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460)
275-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460)
276-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
277-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
278-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460)
279-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460)
280-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
281-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
282-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460)
283-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-fp16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460)
284-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens256-_hidden512] SKIP (https://nvbugs/5940460)
285-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens256-_hidden32] SKIP (https://nvbugs/5940460)
286-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens16-_hidden512] SKIP (https://nvbugs/5940460)
287-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_mm_add_prologue[2-bf16-_tokens16-_hidden32] SKIP (https://nvbugs/5940460)
288272
unittest/_torch/multimodal/test_mm_encoder_standalone.py::test_multi_request_batch_chat[llava_7b-False-False-False] SKIP (https://nvbugs/5983320)
289273
unittest/_torch/multimodal/test_mm_encoder_standalone.py::test_multi_request_batch_chat[llava_7b-False-True-True] SKIP (https://nvbugs/5983320)
290274
cpp/test_e2e.py::test_model[-gpt-80] SKIP (https://nvbugs/5983283)

0 commit comments

Comments
 (0)