Skip to content

Commit 75aea27

Browse files
committed
[NVBUG-6224637][fix] Enable CuTe DSL BF16 kernels on SM100 PP
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent cd1b886 commit 75aea27

3 files changed

Lines changed: 15 additions & 3 deletions

File tree

tensorrt_llm/_torch/modules/attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,8 @@ def __init__(
529529
force_dynamic_quantization=config.force_dynamic_quantization,
530530
disable_deep_gemm=disable_deep_gemm,
531531
use_custom_cublas_mm=use_custom_cublas_mm,
532-
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
532+
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
533+
use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm)
533534

534535
self.quant_config = config.get_quant_config()
535536
self.attn_backend = config.attn_backend
@@ -1462,7 +1463,8 @@ def __init__(
14621463
reduce_output=reduce_output,
14631464
allreduce_strategy=config.allreduce_strategy,
14641465
force_dynamic_quantization=config.force_dynamic_quantization,
1465-
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
1466+
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
1467+
use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm)
14661468

14671469
def yarn_get_mscale(scale=1, mscale=1):
14681470
if scale <= 1:

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
allreduce_strategy=config.allreduce_strategy,
8585
force_dynamic_quantization=config.force_dynamic_quantization,
8686
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
87+
use_cute_dsl_bf16_gemm=config.use_cute_dsl_bf16_gemm,
8788
disable_deep_gemm=disable_deep_gemm,
8889
fused_weight_shard_indices_mapping=gateup_shard_indices_mapping,
8990
use_custom_cublas_mm=use_custom_cublas_mm,
@@ -114,6 +115,7 @@ def __init__(
114115
allreduce_strategy=config.allreduce_strategy,
115116
force_dynamic_quantization=config.force_dynamic_quantization,
116117
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
118+
use_cute_dsl_bf16_gemm=config.use_cute_dsl_bf16_gemm,
117119
disable_deep_gemm=disable_deep_gemm,
118120
use_custom_cublas_mm=use_custom_cublas_mm,
119121
)

tensorrt_llm/llmapi/llm_args.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
from tensorrt_llm.lora_helper import (LoraConfig,
4747
get_default_trtllm_modules_to_hf_modules)
4848

49-
from .._utils import _str_to_torch_dtype_dict, mpi_rank, prefer_pinned
49+
from .._utils import (_str_to_torch_dtype_dict, is_sm_100f, mpi_rank,
50+
prefer_pinned)
5051

5152
# yapf: disable
5253
# isort: off
@@ -5077,6 +5078,13 @@ def validate_ray_placement_config(self) -> 'TorchLlmArgs':
50775078

50785079
@model_validator(mode='after')
50795080
def validate_cute_dsl_bf16(self) -> 'TorchLlmArgs':
5081+
if (not (self.use_cute_dsl_bf16_bmm and self.use_cute_dsl_bf16_gemm)
5082+
and self.pipeline_parallel_size > 1 and is_sm_100f()):
5083+
logger.info("Automatically enabling CuTe DSL BF16 BMM and GEMM for "
5084+
"SM100/SM103 PP.")
5085+
self.use_cute_dsl_bf16_bmm = True
5086+
self.use_cute_dsl_bf16_gemm = True
5087+
50805088
if self.use_cute_dsl_bf16_bmm or self.use_cute_dsl_bf16_gemm:
50815089
major, minor = torch.cuda.get_device_capability()
50825090
sm = major * 10 + minor

0 commit comments

Comments
 (0)