Skip to content

Commit 410988d

Browse files
authored
[OP] support deepgeem for sm103 (#7073)
* support deepgeem for sm103 * add assert * modify code style * add assert * modify sm version condition * remove assert
1 parent ba1aa1e commit 410988d

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

fastdeploy/model_executor/layers/quantization/block_wise_fp8.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool
6767
self.quant_round_type = 1
6868
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
6969
self.is_checkpoint_bf16 = is_checkpoint_bf16
70-
self.deepgemm_scale_ue8m0 = True if get_sm_version() == 100 else False
70+
self.deepgemm_scale_ue8m0 = True if get_sm_version() >= 100 else False
7171

7272
def name(self) -> str:
7373
return "block_wise_fp8"
@@ -125,7 +125,8 @@ def deep_gemm_fp8_gemm_nt(
125125
layer_output_size: int,
126126
bias: paddle.Tensor = None,
127127
):
128-
if get_sm_version() == 100 and current_platform.is_cuda():
128+
sm_version = get_sm_version()
129+
if sm_version >= 100 and current_platform.is_cuda():
129130
# disable_ue8m0_cast is default False for SM100
130131
fp8_gemm_nt(
131132
(x, x_scale_tensor),

fastdeploy/model_executor/layers/quantization/fp8_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def load_deep_gemm():
6565
"""
6666

6767
if current_platform.is_cuda():
68-
if get_sm_version() == 100:
68+
if get_sm_version() >= 100:
6969
# SM100 should use PFCC DeepGemm
7070
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
7171
try:
@@ -245,7 +245,7 @@ def fused_stack_transpose_quant(expert_weight_list, use_ue8m0=False):
245245
# Blackwell (SM100) GPUs require pow2_scale quantization.
246246
# Guard with is_cuda() so non-CUDA environments do not call into
247247
# paddle.device.cuda.* and cause a crash.
248-
use_pow2_scale = current_platform.is_cuda() and get_sm_version() == 100
248+
use_pow2_scale = current_platform.is_cuda() and get_sm_version() >= 100
249249

250250
w, scale = paddlefleet_ops.fuse_stack_transpose_fp8_quant(
251251
expert_weight_list,

0 commit comments

Comments
 (0)