Skip to content

Commit f1e1a05

Browse files
authored
support qwen3.5 on volta (#4405)
* support qwen3.5 on volta * fix copilot comment * fix float32 * update kernel
1 parent 693082c commit f1e1a05

6 files changed

Lines changed: 230 additions & 115 deletions

File tree

lmdeploy/pytorch/check_env/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def check_dtype(self, config):
5252

5353
from lmdeploy.pytorch.config import ModelConfig
5454
from lmdeploy.utils import is_bf16_supported
55-
model_config = ModelConfig.from_hf_config(config, model_path=model_path, dtype=dtype)
55+
model_config = ModelConfig.from_hf_config(config,
56+
model_path=model_path,
57+
dtype=dtype,
58+
device_type=device_type)
5659
if model_config.dtype == torch.bfloat16:
5760
if not is_bf16_supported(device_type):
5861
logger.warning('Device does not support bfloat16.')

lmdeploy/pytorch/config.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88
from lmdeploy.messages import PytorchEngineConfig
99
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
1010
from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value
11-
from lmdeploy.utils import get_logger
11+
from lmdeploy.utils import get_logger, is_bf16_supported
1212

1313
logger = get_logger('lmdeploy')
1414

1515

16-
def _update_torch_dtype(config: 'ModelConfig', dtype: str):
16+
def _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'auto'):
1717
"""Update the torch dtype from the model config.
1818
1919
Args:
2020
config (ModelConfig): The input model config.
2121
dtype (str): user specified data type. Refer to
2222
`PyTorchEngineConfig.dtype` for detailed info
23+
device_type (str): The device type. Refer to `PyTorchEngineConfig.device_type` for detailed info
2324
"""
2425
quantization_config = getattr(config.hf_config, 'quantization_config', dict())
2526
quant_method = quantization_config.get('quant_method', None)
@@ -48,6 +49,8 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
4849
# update hf_config as well
4950
setattr(config.hf_config, 'torch_dtype', torch_dtype)
5051
else:
52+
if torch_dtype == 'bfloat16' and not is_bf16_supported(device_type):
53+
torch_dtype = 'float16'
5154
# change to user specified data type if it is not 'auto'
5255
if dtype == 'auto':
5356
torch_dtype = torch_dtype if torch_dtype in ['float16', 'bfloat16'] else 'float16'
@@ -356,6 +359,7 @@ def from_pretrained(
356359
is_draft_model: bool = False,
357360
spec_method: str = None,
358361
model_format: str = None,
362+
device_type: str = 'auto',
359363
):
360364
"""Instantiate one of the configuration classes of the library from a
361365
pretrained model configuration.
@@ -386,6 +390,7 @@ def from_pretrained(
386390
dist_config=dist_config,
387391
is_draft_model=is_draft_model,
388392
spec_method=spec_method,
393+
device_type=device_type,
389394
)
390395
fp32_lm_head = False
391396
if hf_overrides is not None:
@@ -413,6 +418,7 @@ def from_hf_config(
413418
dist_config: DistConfig = None,
414419
is_draft_model: bool = False,
415420
spec_method: str = None,
421+
device_type: str = 'auto',
416422
):
417423
"""From huggingface config."""
418424
from lmdeploy.pytorch.configurations import AutoModelConfigBuilder
@@ -441,7 +447,7 @@ def from_hf_config(
441447
assert tp % model_config.num_key_value_heads == 0
442448

443449
# should after setting `hf_config` and `model_arch` attributes
444-
model_config = _update_torch_dtype(model_config, dtype)
450+
model_config = _update_torch_dtype(model_config, dtype, device_type=device_type)
445451

446452
# update eos_token_id to list
447453
if isinstance(model_config.eos_token_id, int):

lmdeploy/pytorch/configurations/qwen3_5.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import torch
33

4+
from lmdeploy.utils import is_bf16_supported
5+
46
from .builder import AutoModelConfigBuilder
57
from .default import DefaultModelConfigBuilder
68
from .qwen3_next import _check_env_qwen3_next
@@ -42,7 +44,10 @@ def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):
4244

4345
conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)
4446
recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)
45-
dtype = torch.bfloat16
47+
if is_bf16_supported():
48+
dtype = torch.bfloat16
49+
else:
50+
dtype = torch.float16
4651
cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]
4752
cfg.check_env_func = _check_env_qwen3_next
4853
return cfg

lmdeploy/pytorch/engine/executor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def build_executor(
7979
is_draft_model=False,
8080
spec_method=None if specdecode_config is None else specdecode_config.method,
8181
model_format=misc_config.model_format,
82+
device_type=device_type,
8283
)
8384

8485
if distributed_executor_backend is None:

lmdeploy/pytorch/kernels/cuda/fused_moe.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_cuda_autotune_config():
2727
},
2828
num_stages=4,
2929
num_warps=4),
30+
# SM8
3031
triton.Config({
3132
'BLOCK_SIZE_M': 128,
3233
'BLOCK_SIZE_N': 128,
@@ -51,18 +52,46 @@ def get_cuda_autotune_config():
5152
},
5253
num_stages=4,
5354
num_warps=4),
55+
# SM7-
56+
triton.Config({
57+
'BLOCK_SIZE_M': 64,
58+
'BLOCK_SIZE_N': 128,
59+
'BLOCK_SIZE_K': 32,
60+
'GROUP_SIZE_M': 1,
61+
},
62+
num_stages=4,
63+
num_warps=4),
64+
triton.Config({
65+
'BLOCK_SIZE_M': 128,
66+
'BLOCK_SIZE_N': 32,
67+
'BLOCK_SIZE_K': 32,
68+
'GROUP_SIZE_M': 1,
69+
},
70+
num_stages=4,
71+
num_warps=4),
72+
triton.Config({
73+
'BLOCK_SIZE_M': 64,
74+
'BLOCK_SIZE_N': 32,
75+
'BLOCK_SIZE_K': 32,
76+
'GROUP_SIZE_M': 1,
77+
},
78+
num_stages=5,
79+
num_warps=2),
5480
]
5581

5682

57-
def _config_prune_func(config: dict, *args, **kwargs):
83+
def _config_prune_func(config: list, *args, **kwargs):
5884
"""Fused moe config prune."""
5985
device_cap = torch.cuda.get_device_capability()
6086
num_sm9x = 2
87+
cum_num_sm8x = 5
6188

6289
if device_cap[0] >= 9:
6390
return config[:num_sm9x]
91+
elif device_cap[0] >= 8:
92+
return config[num_sm9x:cum_num_sm8x]
6493
else:
65-
return config[num_sm9x:]
94+
return config[cum_num_sm8x:]
6695

6796

6897
@triton.autotune(

0 commit comments

Comments
 (0)