Skip to content

Commit 2ffab8d

Browse files
authored
[NVBUG-6266259][fix] Fix userbuffers prologue patterns (#15220)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
1 parent 93ad566 commit 2ffab8d

2 files changed

Lines changed: 60 additions & 20 deletions

File tree

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -583,11 +583,13 @@ def extra_check_convert_supported_ar_to_ub(match: Match) -> bool:
583583
def register_ub_prologue_patterns(custom_pass: PatternMatcherPass):
584584

585585
def register_scaled_mm_prologue(custom_pass: PatternMatcherPass):
586+
output_buffer_kind_key = KeywordArg('output_buffer_kind')
586587
trtllm_cublas_scaled_mm_default = CallFunction(
587588
torch.ops.trtllm.cublas_scaled_mm.default, KeywordArg('mm0_a'),
588589
KeywordArg('mm0_b'), KeywordArg('mm0_a_scale'),
589590
KeywordArg('mm0_b_scale'), KeywordArg('mm0_bias'),
590-
KeywordArg('mm_dtype'))
591+
KeywordArg('mm_dtype'), output_buffer_kind_key,
592+
mapping.tp_group)
591593
ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers,
592594
trtllm_cublas_scaled_mm_default)
593595

@@ -598,6 +600,7 @@ def empty_scaled_mm_prologue_pattern(
598600
mm0_b_scale: torch.Tensor,
599601
mm0_bias: Optional[torch.Tensor],
600602
mm_dtype: torch.dtype,
603+
output_buffer_kind: int,
601604
):
602605
return
603606

@@ -608,10 +611,11 @@ def target_scaled_mm_prologue_pattern(
608611
mm0_b_scale: torch.Tensor,
609612
mm0_bias: Optional[torch.Tensor],
610613
mm_dtype: torch.dtype,
614+
output_buffer_kind: int,
611615
):
612616
scaled_mm_output = torch.ops.trtllm.cublas_scaled_mm(
613617
mm0_a, mm0_b, mm0_a_scale, mm0_b_scale, mm0_bias, mm_dtype,
614-
True)
618+
int(BufferKind.USERBUFFERS), mapping.tp_group)
615619
return scaled_mm_output
616620

617621
# No extra check needed as the output dtype of scaled_mm has been verified when
@@ -635,15 +639,9 @@ def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
635639
output_buffer_kind_key = KeywordArg('output_buffer_kind')
636640
allowed_backends_key = KeywordArg('allowed_backends')
637641
trtllm_nvfp4_gemm_default = CallFunction(
638-
torch.ops.trtllm.nvfp4_gemm.default,
639-
act_fp4_key,
640-
weight_key,
641-
act_sf_key,
642-
weight_scale_key,
643-
alpha_key,
644-
output_dtype_key,
645-
output_buffer_kind=output_buffer_kind_key,
646-
allowed_backends=allowed_backends_key)
642+
torch.ops.trtllm.nvfp4_gemm.default, act_fp4_key, weight_key,
643+
act_sf_key, weight_scale_key, alpha_key, output_dtype_key,
644+
output_buffer_kind_key, allowed_backends_key, mapping.tp_group)
647645
ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers,
648646
trtllm_nvfp4_gemm_default)
649647

@@ -671,7 +669,48 @@ def target_nvfp4_gemm_prologue_pattern(
671669
):
672670
nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm(
673671
act_fp4, weight, act_sf, weight_scale, alpha, output_dtype,
674-
int(BufferKind.USERBUFFERS), allowed_backends)
672+
int(BufferKind.USERBUFFERS), allowed_backends,
673+
mapping.tp_group)
674+
return nvfp4_gemm_output
675+
676+
bias_key = KeywordArg('bias')
677+
trtllm_nvfp4_gemm_with_bias_default = CallFunction(
678+
torch.ops.trtllm.nvfp4_gemm.default, act_fp4_key, weight_key,
679+
act_sf_key, weight_scale_key, alpha_key, output_dtype_key,
680+
output_buffer_kind_key, allowed_backends_key, mapping.tp_group,
681+
bias_key)
682+
ub_copy_with_bias = CallFunction(
683+
torch.ops.trtllm.copy_to_userbuffers,
684+
trtllm_nvfp4_gemm_with_bias_default)
685+
686+
def empty_nvfp4_gemm_bias_prologue_pattern(
687+
act_fp4: torch.Tensor,
688+
weight: torch.Tensor,
689+
act_sf: torch.Tensor,
690+
weight_scale: torch.Tensor,
691+
alpha: torch.Tensor,
692+
output_dtype: torch.dtype,
693+
output_buffer_kind: int,
694+
allowed_backends: str,
695+
bias: Optional[torch.Tensor],
696+
):
697+
return
698+
699+
def target_nvfp4_gemm_bias_prologue_pattern(
700+
act_fp4: torch.Tensor,
701+
weight: torch.Tensor,
702+
act_sf: torch.Tensor,
703+
weight_scale: torch.Tensor,
704+
alpha: torch.Tensor,
705+
output_dtype: torch.dtype,
706+
output_buffer_kind: int,
707+
allowed_backends: str,
708+
bias: Optional[torch.Tensor],
709+
):
710+
nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm(
711+
act_fp4, weight, act_sf, weight_scale, alpha, output_dtype,
712+
int(BufferKind.USERBUFFERS), allowed_backends,
713+
mapping.tp_group, bias)
675714
return nvfp4_gemm_output
676715

677716
def extra_check(match: Match) -> bool:
@@ -702,6 +741,15 @@ def extra_check(match: Match) -> bool:
702741
search_fn_pattern=ub_copy,
703742
extra_check=extra_check,
704743
)
744+
register_replacement(
745+
empty_nvfp4_gemm_bias_prologue_pattern,
746+
target_nvfp4_gemm_bias_prologue_pattern,
747+
[],
748+
fwd_only,
749+
custom_pass,
750+
search_fn_pattern=ub_copy_with_bias,
751+
extra_check=extra_check,
752+
)
705753

706754
def register_mm_prologue(custom_pass: PatternMatcherPass):
707755
aten_mm_default = CallFunction(aten.mm.default, KeywordArg('mm0_a'),

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -454,14 +454,6 @@ unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend[act=Relu2-e60_
454454
unittest/_torch/modules/moe/test_moe_module.py::test_configurable_moe_multi_gpu -k "CUTLASS and FP8 and not FP8_BLOCK_SCALES and not W4A8" SKIP (https://nvbugs/6402048)
455455
unittest/_torch/modules/tests_lora_modules/test_lora_attention_pytorch_flow_vs_trt.py::TestLoraAttentionPytorchFlowVsTRT::test_lora_attention SKIP (https://nvbugs/5701421)
456456
unittest/_torch/multi_gpu/test_mnnvl_allreduce.py::test_mnnvl_nvfp4_rejects_fp32_before_launch[2] SKIP (https://nvbugs/6396420)
457-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens16-_hidden32] SKIP (https://nvbugs/6266259)
458-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens16-_hidden512] SKIP (https://nvbugs/6266259)
459-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens256-_hidden32] SKIP (https://nvbugs/6266259)
460-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-bf16-_tokens256-_hidden512] SKIP (https://nvbugs/6266259)
461-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden32] SKIP (https://nvbugs/6266259)
462-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden512] SKIP (https://nvbugs/6266259)
463-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/6266259)
464-
unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/6266259)
465457
unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py -m "part0" SKIP (https://nvbugs/6372711)
466458
unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py::test_llm_partial_update_weights_nvfp4[auto-Qwen3/Qwen3-8B] SKIP (https://nvbugs/6372690)
467459
unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py::test_llm_partial_update_weights_nvfp4[fp8-Qwen3/Qwen3-30B-A3B] SKIP (https://nvbugs/6372690)

0 commit comments

Comments
 (0)