Skip to content

Commit 456aca0

Browse files
authored
Compatible with transformers 5.0 at TurboMind side (#4304)
* Compatible with transformers 5.0 * no constraint on transformers * qwen2.5 vl * fix internvl * fix internlm * minor fix qwen2-vl * improve type hint
1 parent d9a5856 commit 456aca0

7 files changed

Lines changed: 38 additions & 146 deletions

File tree

lmdeploy/turbomind/deploy/parameter.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from abc import abstractmethod
3-
from typing import List
43

54
import torch
65

@@ -23,7 +22,7 @@ def to_fp8(x: torch.Tensor):
2322

2423

2524
def pack_u4_row(x: torch.Tensor) -> torch.Tensor:
26-
assert x.dtype == torch.uint8
25+
assert x.dtype == torch.uint8, f'x.dtype: {x.dtype}'
2726
xs = x.view(*x.shape[:-1], -1, 8).split(1, dim=-1)
2827
a = torch.zeros(xs[0].shape, dtype=torch.int32, device=x.device)
2928
for t in reversed(xs):
@@ -45,7 +44,7 @@ class Parameter:
4544
KEY = ()
4645

4746
@classmethod
48-
def take(cls, keys: List[str]):
47+
def take(cls, keys: list[str]):
4948
if not any(k.endswith(cls.KEYS[0]) for k in keys):
5049
return False
5150
xs = []
@@ -126,7 +125,7 @@ def __call__(self, f, g, i):
126125
f(i, g('Plora_B.weight'), 'lora_b.weight', identity)
127126

128127

129-
def get_params(keys: List[str], bias=0):
128+
def get_params(keys: list[str], bias=0):
130129
ps = []
131130
if PLora.take(keys):
132131
ps.append(PLora())

lmdeploy/turbomind/deploy/source_model/deepseek2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ def model_info(self):
146146
info['router_n_groups'] = cfg['router_n_groups']
147147
rope_param: RopeParam = info['rope_param']
148148
rope_param.dim = qk_rope_dim
149-
rope_scaling = cfg.get('rope_scaling')
149+
if 'rope_parameters' in cfg:
150+
# transformers v5.0.0 aggregates all rope-related parameters into 'rope_parameters'
151+
rope_scaling = cfg['rope_parameters']
152+
else:
153+
rope_scaling = cfg.get('rope_scaling')
150154
if rope_scaling and rope_scaling.get('type') == 'yarn':
151155
attention_factor, yarn_scale = get_yarn_params(rope_scaling)
152156
yarn_scale *= q_head_dim**(-0.5)

lmdeploy/turbomind/deploy/source_model/internvl.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ class InternVLReader(LlamaReader):
1515
norm_weight_key = 'language_model.model.norm.weight'
1616
output_weight_key = 'language_model.lm_head.weight'
1717

18-
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
19-
model_cfg = model_cfg.get('llm_config') or model_cfg.get('text_config')
20-
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
21-
2218

2319
# Note the subtle difference in keys
2420
class InternVL2Reader(InternLM2Reader):
@@ -30,10 +26,6 @@ class InternVL2Reader(InternLM2Reader):
3026
norm_weight_key = 'language_model.model.norm.weight'
3127
output_weight_key = 'language_model.output.weight'
3228

33-
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
34-
model_cfg = model_cfg.get('llm_config')
35-
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
36-
3729

3830
class InternVL3d5Reader(Qwen3Reader):
3931
attn_layer_prefix = 'language_model.model.layers'
@@ -42,10 +34,6 @@ class InternVL3d5Reader(Qwen3Reader):
4234
norm_weight_key = 'language_model.model.norm.weight'
4335
output_weight_key = 'language_model.lm_head.weight'
4436

45-
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
46-
model_cfg = model_cfg.get('llm_config') or model_cfg.get('text_config')
47-
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
48-
4937

5038
class InternVL3d5Qwen3MoEReader(Qwen3MoeReader):
5139
attn_layer_prefix = 'language_model.model.layers'
@@ -54,10 +42,6 @@ class InternVL3d5Qwen3MoEReader(Qwen3MoeReader):
5442
norm_weight_key = 'language_model.model.norm.weight'
5543
output_weight_key = 'language_model.lm_head.weight'
5644

57-
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
58-
model_cfg = model_cfg.get('llm_config') or model_cfg.get('text_config')
59-
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
60-
6145

6246
class InternVL3d5GptOSSReader(GptOssReader):
6347
attn_layer_prefix = 'language_model.model.layers'
@@ -66,10 +50,6 @@ class InternVL3d5GptOSSReader(GptOssReader):
6650
norm_weight_key = 'language_model.model.norm.weight'
6751
output_weight_key = 'language_model.lm_head.weight'
6852

69-
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
70-
model_cfg = model_cfg.get('llm_config') or model_cfg.get('text_config')
71-
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
72-
7353

7454
class InternS1Reader(Qwen3MoeReader):
7555
"""InternS1Reader for internlm/InternS1 model."""
@@ -80,12 +60,6 @@ class InternS1Reader(Qwen3MoeReader):
8060
norm_weight_key = 'model.language_model.norm.weight'
8161
output_weight_key = 'lm_head.weight'
8262

83-
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
84-
model_cfg = model_cfg.get('text_config')
85-
if model_cfg is None:
86-
raise ValueError(f'Miss "text_config" in model config: {model_cfg}')
87-
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
88-
8963

9064
class InternS1MiniReader(Qwen3Reader):
9165

@@ -95,12 +69,6 @@ class InternS1MiniReader(Qwen3Reader):
9569
norm_weight_key = 'model.language_model.norm.weight'
9670
output_weight_key = 'lm_head.weight'
9771

98-
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):
99-
model_cfg = model_cfg.get('text_config')
100-
if model_cfg is None:
101-
raise ValueError(f'Miss "text_config" in model config: {model_cfg}')
102-
super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)
103-
10472

10573
@INPUT_MODELS.register_module(name='internvl')
10674
class InternVLModel(LlamaModel):

lmdeploy/turbomind/deploy/source_model/llama.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,12 @@ class LlamaModel(BaseInputModel):
132132
def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):
133133
super().__init__(model_path, tokenizer_path)
134134
self.policy = kwargs.get('input_policy')
135-
_, self.model_config = get_model_arch(model_path)
136-
self.model_config = self.model_config.to_dict()
135+
_, model_config = get_model_arch(model_path)
136+
if hasattr(model_config, 'text_config'):
137+
model_config = model_config.text_config
138+
elif hasattr(model_config, 'llm_config'):
139+
model_config = model_config.llm_config
140+
self.model_config = model_config.to_dict()
137141
self.fp8_quant = kwargs.get('fp8_quant', False)
138142

139143
def readers(self):
@@ -171,27 +175,21 @@ def model_info(self):
171175
max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))
172176
rope_param = RopeParam(type='default', base=rope_theta, dim=head_dim)
173177
if isinstance(rope_scaling, dict):
174-
llama2_scaling_type = rope_scaling.get('type', '')
175-
llama3_scaling_type = rope_scaling.get('rope_type', '')
176-
if llama2_scaling_type and llama3_scaling_type \
177-
and llama2_scaling_type != llama3_scaling_type:
178-
raise ValueError(f'Ambiguous rope_scaling in config: {model_arg}')
179-
scaling_type = llama2_scaling_type if llama2_scaling_type \
180-
else llama3_scaling_type
178+
rope_type = rope_scaling.get('rope_type', '') or rope_scaling.get('type', '')
181179
if rope_scaling.get('mrope_section') is not None:
182180
# TODO: treat mrope as an option to the common rope functions
183-
scaling_type = 'mrope'
181+
rope_type = 'mrope'
184182
scaling_factor = rope_scaling.get('factor', 0.0)
185-
if scaling_type == 'default':
183+
if rope_type == 'default':
186184
pass
187-
elif scaling_type == 'dynamic':
185+
elif rope_type == 'dynamic':
188186
rope_param.type = 'dynamic'
189187
rope_param.factor = scaling_factor
190188
rope_param.max_position_embeddings = max_position_embeddings
191-
elif scaling_type == 'linear':
189+
elif rope_type == 'linear':
192190
rope_param.type = 'linear'
193191
rope_param.factor = scaling_factor
194-
elif scaling_type == 'llama3':
192+
elif rope_type == 'llama3':
195193
low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
196194
high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
197195
original_max_position_embeddings = rope_scaling.get('original_max_position_embeddings', 0)
@@ -200,7 +198,7 @@ def model_info(self):
200198
rope_param.low_freq_factor = low_freq_factor
201199
rope_param.high_freq_factor = high_freq_factor
202200
rope_param.original_max_position_embeddings = original_max_position_embeddings
203-
elif scaling_type == 'yarn':
201+
elif rope_type == 'yarn':
204202
attention_factor = rope_scaling.get('attention_factor', None)
205203
if attention_factor is None:
206204
attention_factor = 0.1 * math.log(scaling_factor) + 1.0
@@ -217,12 +215,12 @@ def model_info(self):
217215
rope_param.attention_factor = attention_factor
218216
rope_param.beta_fast = beta_fast
219217
rope_param.beta_slow = beta_slow
220-
elif scaling_type == 'mrope':
218+
elif rope_type == 'mrope':
221219
mrope_section = rope_scaling.get('mrope_section')
222220
rope_param.type = 'mrope'
223221
rope_param.mrope_section = mrope_section
224222
else:
225-
raise RuntimeError(f'Unsupported rope type: {scaling_type}')
223+
raise RuntimeError(f'Unsupported rope type: {rope_type}')
226224

227225
return dict(size_per_head=head_dim,
228226
num_layer=num_layer,

lmdeploy/vl/model/qwen2.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Dict, List, Tuple
3-
42
import torch
53

64
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
@@ -35,7 +33,7 @@ def build_preprocessor(self):
3533
self.image_token = self.processor.image_token
3634
self.image_token_id = tokenizer.encode(self.image_token)[-1]
3735

38-
def preprocess(self, messages: List[Dict]) -> List[Dict]:
36+
def preprocess(self, messages: list[dict]) -> list[dict]:
3937
"""Refer to `super().preprocess()` for spec."""
4038
from qwen_vl_utils import process_vision_info
4139

@@ -48,7 +46,7 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
4846
item = dict(type='image', image=image)
4947
item.update({key: params[key] for key in params.keys() if key in optional_keys})
5048
image_inputs, _ = process_vision_info([dict(content=[item])])
51-
result = self.processor.image_processor(images=image_inputs, videos=None, return_tensors='pt')
49+
result = self.processor.image_processor(images=image_inputs, return_tensors='pt')
5250
merge_length = self.processor.image_processor.merge_size**2
5351
image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
5452
result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))
@@ -77,10 +75,7 @@ def build_model(self):
7775
if hasattr(config, 'text_config'):
7876
config.text_config.tie_word_embeddings = False
7977
model = AutoModelCls._from_config(config)
80-
if hasattr(AutoModelCls, 'visual'):
81-
# transformers >= 4.52.0 modified model structure
82-
# https://github.com/huggingface/transformers/blob/v4.52.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1791-L1800
83-
model.visual = model.model.visual
78+
model.visual = model.model.visual
8479
del model.model
8580
del model.lm_head
8681
model.half()
@@ -96,12 +91,12 @@ def build_model(self):
9691
self.model = model.eval()
9792

9893
@torch.no_grad()
99-
def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
94+
def forward(self, messages: list[dict], max_batch_size: int = 1) -> list[dict]:
10095
"""Extract image feature. ONLY implement it when the backend is
10196
turbomind engine.
10297
10398
Args:
104-
messages(List[Dict]): the outputs of `preprocess`
99+
messages(list[dict]): the outputs of `preprocess`
105100
max_batch_size(int): the max batch size when forwarding vision
106101
model
107102
Return:
@@ -117,6 +112,10 @@ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
117112
pixel_values = torch.cat(pixel_values, dim=0).to(device)
118113
image_grid_thw = torch.cat(image_grid_thw, dim=0).to(device)
119114
image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)
115+
if hasattr(image_embeds, 'pooler_output'):
116+
# transformers >= 5.0.0, the type if image_embeds is `BaseModelOutputWithPooling`
117+
# rather than torch.Tensor
118+
image_embeds = image_embeds.pooler_output
120119
merge_length = self.processor.image_processor.merge_size**2
121120
split_size = image_grid_thw.prod(dim=1) // merge_length
122121
image_embeds = image_embeds.split(split_size.tolist())
@@ -162,8 +161,8 @@ def proc_messages(self, messages, chat_template, sequence_start, chat_template_k
162161

163162
@staticmethod
164163
def get_mrope_info(seq_len: int,
165-
grid_thws: List[Tuple[int, int, int]] = None,
166-
ranges: List[Tuple[int, int]] = None):
164+
grid_thws: list[tuple[int, int, int]] = None,
165+
ranges: list[tuple[int, int]] = None):
167166
mrope_position_ids = [torch.arange(ranges[0][0]).expand(3, -1)]
168167
st_idx = ranges[0][0]
169168
for i, (grid_thw, embedding_range) in enumerate(zip(grid_thws, ranges)):

lmdeploy/vl/model/utils.py

Lines changed: 3 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
import inspect
4-
import os
5-
import sys
64
from contextlib import contextmanager
7-
from typing import Callable, Dict, Iterator, List, MutableSequence, Union
5+
from typing import Callable, MutableSequence
86

97
import torch
10-
import torch.nn as nn
11-
from safetensors.torch import load_file
12-
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
13-
from transformers.utils.hub import get_checkpoint_shard_files
14-
15-
16-
def load_weight_ckpt(ckpt: str) -> Dict[str, torch.Tensor]:
17-
"""Load checkpoint."""
18-
if ckpt.endswith('.safetensors'):
19-
return load_file(ckpt)
20-
else:
21-
return torch.load(ckpt, weights_only=True)
22-
23-
24-
def get_used_weight_files(folder: str, state_dict: Dict[str, torch.Tensor]) -> List[str]:
25-
"""Get used checkpoint which contains keys in state_dict."""
26-
_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
27-
_safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
28-
if os.path.exists(_index_file):
29-
index_file = _index_file
30-
elif os.path.exists(_safe_index_file):
31-
index_file = _safe_index_file
32-
elif os.path.isfile(os.path.join(folder, SAFE_WEIGHTS_NAME)): # Single safetensor file
33-
return [SAFE_WEIGHTS_NAME]
34-
elif os.path.isfile(os.path.join(folder, WEIGHTS_NAME)):
35-
return [WEIGHTS_NAME]
36-
else:
37-
raise FileNotFoundError
38-
_, sharded_metadata = get_checkpoint_shard_files(folder, index_file)
39-
potential_keys = set(state_dict.keys())
40-
supplied_keys = set(sharded_metadata['weight_map'].keys())
41-
shared_keys = potential_keys & supplied_keys
42-
valid_files = set(sharded_metadata['weight_map'][k] for k in shared_keys)
43-
return valid_files
44-
45-
46-
def load_model_from_weight_files(model: nn.Module, folder: str) -> None:
47-
"""Load nn.Module weight from folder."""
48-
valid_files = get_used_weight_files(folder, model.state_dict())
49-
for file_name in valid_files:
50-
ckpt = os.path.join(folder, file_name)
51-
state_dict = load_weight_ckpt(ckpt)
52-
model.load_state_dict(state_dict, strict=False)
53-
54-
55-
@contextmanager
56-
def add_sys_path(path: Union[str, os.PathLike]) -> Iterator[None]:
57-
"""Temporarily add the given path to `sys.path`."""
58-
path = os.fspath(path)
59-
try:
60-
sys.path.insert(0, path)
61-
yield
62-
finally:
63-
sys.path.remove(path)
648

659

6610
@contextmanager
@@ -82,27 +26,7 @@ def disable_logging():
8226
logging.disable(previous_level)
8327

8428

85-
@contextmanager
86-
def hack_import_with(src: List[str], dst: str = 'torch'):
87-
"""Replace wanted and uninstalled package with a dummy one.
88-
89-
Args:
90-
src (List): a list of package name
91-
dst (str): dummy package name. Default to 'torch'.
92-
"""
93-
import sys
94-
from importlib.util import find_spec
95-
not_installed = []
96-
for item in src:
97-
if not find_spec(item):
98-
not_installed.append(item)
99-
sys.modules[item] = __import__(dst)
100-
yield
101-
for item in not_installed:
102-
sys.modules.pop(item, None)
103-
104-
105-
def _set_func(origin_func_path: Union[str, None], rewrite_func: Callable, origin_func: Callable = None):
29+
def _set_func(origin_func_path: str | None, rewrite_func: Callable, origin_func: Callable = None):
10630
"""Replace old function with the new function.
10731
10832
Args:
@@ -148,7 +72,7 @@ def _set_func(origin_func_path: Union[str, None], rewrite_func: Callable, origin
14872

14973

15074
@contextmanager
151-
def rewrite_ctx(origin_func_path: List[Union[str, Callable]], rewrite_func: List[Callable]):
75+
def rewrite_ctx(origin_func_path: list[str | Callable], rewrite_func: list[Callable]):
15276
"""Rewrite context."""
15377
assert len(origin_func_path) == len(rewrite_func)
15478
origin_func_list = []

requirements/runtime_cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ tiktoken
2525
tilelang
2626
torch<=2.10.0,>=2.0.0
2727
torchvision<=0.25.0,>=0.15.0
28-
transformers<5.0.0
28+
transformers>=4.52.0
2929
triton<=3.6.0,>=3.0.0; sys_platform == "linux" and "aarch64" not in platform_machine and "arm" not in platform_machine
3030
uvicorn
3131
xgrammar

0 commit comments

Comments
 (0)