Skip to content

Commit df9b428

Browse files
CUHKSZzxyRunningLeonlvhan028
authored
Support InternS2 Preview (InternLM#4575)
* support interns2preview * support time series * fix time series * fix visual * fix: address InternS2 preview review comments * fix: align InternS1 Pro time-series handling * fix: restore InternS1 Pro processor dtype contract * fix: require dtype for Qwen3 VL input processor --------- Co-authored-by: RunningLeon <mnsheng@yeah.net> Co-authored-by: 吕晗 <lvhan@pjlab.org.cn>
1 parent 0bf8a07 commit df9b428

14 files changed

Lines changed: 227 additions & 33 deletions

File tree

lmdeploy/archs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,20 @@ def check_vl_llm(backend: str, config: dict) -> bool:
114114
'Qwen3_5MoeForConditionalGeneration', 'MllamaForConditionalGeneration', 'MolmoForCausalLM',
115115
'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', 'InternVLForConditionalGeneration',
116116
'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration',
117-
'InternS1_1_ForConditionalGeneration', 'Glm4vForConditionalGeneration'
117+
'InternS1_1_ForConditionalGeneration', 'Glm4vForConditionalGeneration',
118+
'InternS2PreviewForConditionalGeneration', 'InternS2PreviewForCausalLM',
118119
])
120+
turbomind_unsupported_archs = ['Qwen3_5ForConditionalGeneration',
121+
'Qwen3_5MoeForConditionalGeneration',
122+
'InternS2PreviewForConditionalGeneration',
123+
'InternS2PreviewForCausalLM']
119124
if arch == 'QWenLMHeadModel' and 'visual' in config:
120125
return True
121126
elif arch == 'MultiModalityCausalLM' and 'language_config' in config:
122127
return True
123128
elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] and 'vision_config' in config:
124129
return True
125-
elif arch in ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration'] and backend == 'turbomind':
130+
elif arch in turbomind_unsupported_archs and backend == 'turbomind':
126131
return False
127132
elif arch in supported_archs:
128133
return True

lmdeploy/pytorch/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def from_config(
567567
target_model: str = None,
568568
dtype: str = 'auto',
569569
trust_remote_code: bool = False,
570+
hf_overrides: dict[str, Any] = None,
570571
):
571572
model = model or target_model
572573
model_config = ModelConfig.from_pretrained(model,
@@ -575,6 +576,7 @@ def from_config(
575576
is_draft_model=True,
576577
spec_method=method,
577578
block_size=target_cache_cfg.block_size,
579+
hf_overrides=hf_overrides,
578580
)
579581
cache_config = None
580582
# include medusa

lmdeploy/pytorch/configurations/qwen3_5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Qwen3_5ModelConfigBuilder(AutoModelConfigBuilder):
1414
@classmethod
1515
def condition(cls, hf_config):
1616
"""config."""
17-
return hf_config.model_type in ['qwen3_5', 'qwen3_5_moe']
17+
return hf_config.model_type in ['qwen3_5', 'qwen3_5_moe', 'intern_s2_preview']
1818

1919
@classmethod
2020
def build(cls,

lmdeploy/pytorch/engine/config_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,6 @@ def build_specdecode_config(target_model, speculative_config: SpeculativeConfig,
115115
target_cache_cfg=cache_config,
116116
dtype=engine_config.dtype,
117117
trust_remote_code=trust_remote_code,
118+
hf_overrides=engine_config.hf_overrides,
118119
)
119120
return specdecode_config

lmdeploy/pytorch/messages.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
1313
from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs
1414
from lmdeploy.utils import get_logger
15+
from lmdeploy.vl.constants import Modality
1516

1617
from .block import LogicalTokenBlocks
1718

@@ -872,6 +873,10 @@ def _update_mrope_pos_ids(self):
872873
modal_datas = list(multimodals.values())[0]
873874
mm_offset = next_pos
874875
for modal_data in modal_datas:
876+
# InternS2Preview uses mrope for image / video, except time series
877+
if modal_data.modality == Modality.TIME_SERIES:
878+
continue
879+
875880
mm_start = modal_data.start + mm_offset
876881

877882
# tokens

lmdeploy/pytorch/models/interns1_pro.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def forward(
123123
multimodal_mask = multimodal_mask.unsqueeze(-1).expand_as(inputs_embeds)
124124
inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask, image_embeds)
125125
elif ts_values is not None:
126+
if not hasattr(self, 'time_series'):
127+
raise RuntimeError('Time-series inputs require a time_series module.')
126128
ts_embeds = self.time_series(ts_values, ts_lens, ts_sr) # [B, T, C]
127129
inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask[..., None], ts_embeds)
128130

@@ -182,8 +184,8 @@ def prepare_inputs_for_generation(
182184

183185
if modality == Modality.TIME_SERIES:
184186
ts_values = torch.cat([inp.data for inp in mm_inputs])
185-
ts_lens = mm_inputs[0].meta['ts_lens']
186-
ts_sr = mm_inputs[0].meta['ts_sr']
187+
ts_lens = torch.cat([inp.meta['ts_lens'] for inp in mm_inputs])
188+
ts_sr = torch.cat([inp.meta['ts_sr'] for inp in mm_inputs])
187189
else:
188190
pixel_values = torch.cat([inp.data for inp in mm_inputs])
189191
grid_thw = torch.stack([data.meta['grid_thw'] for data in mm_inputs]).cpu()
@@ -346,6 +348,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
346348
elif name in buffers_dict:
347349
param = buffers_dict[name]
348350
load_weight(param, loaded_weight)
351+
else:
352+
raise KeyError(f'Unexpected weight name: {name}')
349353

350354
def get_input_processor(self) -> BaseModelInputProcessor:
351355
"""Get input processor."""

lmdeploy/pytorch/models/interns1_pro_time_series.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
2121

2222
self.embed_dim = config.d_model
2323
self.num_mel_bins = config.num_mel_bins
24-
self.padding_idx = config.pad_token_id
2524
self.max_source_positions = config.max_source_positions
2625
self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0
2726

lmdeploy/pytorch/models/module_map.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,16 @@
186186
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_moe.Qwen3_5MoeForConditionalGeneration',
187187
})
188188

189+
# interns2preview
190+
MODULE_MAP.update({
191+
'InternS2PreviewForConditionalGeneration':
192+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_moe.Qwen3_5MoeForConditionalGeneration',
193+
})
194+
MODULE_MAP.update({
195+
'InternS2PreviewForCausalLM':
196+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_moe.Qwen3_5MoeForConditionalGeneration',
197+
})
198+
189199
MODULE_MAP.update({
190200
'Qwen3_5MTPModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_mtp.Qwen3_5MTPModel',
191201
})

lmdeploy/pytorch/models/qwen3_5.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters
2828
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight
29+
from lmdeploy.vl.constants import Modality
2930

3031
from .patch import add_prefix, get_build_model_context
3132
from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3_5VisionRotaryEmbedding
@@ -1026,6 +1027,10 @@ def forward(
10261027
grid_thw: torch.Tensor | None = None,
10271028
all_routed_experts: torch.Tensor | None = None,
10281029
return_input_embeds: bool = False,
1030+
# for time series
1031+
ts_values: torch.Tensor = None,
1032+
ts_lens: torch.Tensor = None,
1033+
ts_sr: torch.Tensor = None,
10291034
):
10301035
"""Model forward, return logits."""
10311036

@@ -1052,6 +1057,11 @@ def forward(
10521057
# mask and scatter to create final input embeddings
10531058
multimodal_mask = multimodal_mask.unsqueeze(-1).expand_as(inputs_embeds)
10541059
inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask, image_embeds)
1060+
elif ts_values is not None:
1061+
if not hasattr(self, 'time_series'):
1062+
raise RuntimeError('Time-series inputs require a time_series module.')
1063+
ts_embeds = self.time_series(ts_values, ts_lens, ts_sr) # [B, T, C]
1064+
inputs_embeds = inputs_embeds.masked_scatter(multimodal_mask[..., None], ts_embeds)
10551065

10561066
output_inputs_embeds = inputs_embeds if return_input_embeds else None
10571067

@@ -1098,7 +1108,7 @@ def __init__(self,
10981108
self.ctx_mgr = ctx_mgr
10991109

11001110
# build preprocessor
1101-
self.input_processor = Qwen3_5InputProcessor(self.config)
1111+
self.input_processor = Qwen3_5InputProcessor(self.config, dtype)
11021112

11031113
# build model
11041114
self.model = Qwen3_5Model(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))
@@ -1129,6 +1139,10 @@ def forward(
11291139
pos_embeds: torch.Tensor | None = None,
11301140
grid_thw: torch.Tensor | None = None,
11311141
return_input_embeds: bool = False,
1142+
# for time series
1143+
ts_values: torch.Tensor = None,
1144+
ts_lens: torch.Tensor = None,
1145+
ts_sr: torch.Tensor = None,
11321146
**kwargs,
11331147
):
11341148
"""Model forward, return logits."""
@@ -1155,6 +1169,10 @@ def forward(
11551169
grid_thw=grid_thw,
11561170
all_routed_experts=all_routed_experts,
11571171
return_input_embeds=return_input_embeds,
1172+
# for time series
1173+
ts_values=ts_values,
1174+
ts_lens=ts_lens,
1175+
ts_sr=ts_sr,
11581176
)
11591177
return dict(hidden_states=hidden_states,
11601178
all_routed_experts=all_routed_experts,
@@ -1194,23 +1212,33 @@ def prepare_inputs_for_generation(
11941212
multimodal_mask = None
11951213
grid_thw = None
11961214
pos_embeds = None
1215+
# for time series
1216+
ts_values = None
1217+
ts_lens = None
1218+
ts_sr = None
11971219
if context.input_multimodals is not None:
11981220
mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals]
11991221
# flatten batch
12001222
mm_inputs = [item for sublist in mm_inputs for item in sublist]
12011223

12021224
if len(mm_inputs) > 0:
1203-
pixel_values = torch.cat([inp.data for inp in mm_inputs])
1204-
1225+
modality = mm_inputs[0].modality
12051226
multimodal_mask = self.get_multimodal_mask(input_ids, mm_inputs)
1206-
grid_thw = torch.stack([data.meta['grid_thw'] for data in mm_inputs]).cpu()
1207-
vis_pos_emb = self.model.visual.rot_pos_emb(grid_thw)
1208-
pos_embeds = self.model.visual.fast_pos_embed_interpolate(grid_thw)
1209-
vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
1210-
grid_thw[:, 0]).to(pixel_values.device)
1211-
vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)
1212-
vis_pos_emb = vis_pos_emb.repeat(1, 2)
1213-
vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())
1227+
1228+
if modality == Modality.TIME_SERIES:
1229+
ts_values = torch.cat([inp.data for inp in mm_inputs])
1230+
ts_lens = torch.cat([inp.meta['ts_lens'] for inp in mm_inputs])
1231+
ts_sr = torch.cat([inp.meta['ts_sr'] for inp in mm_inputs])
1232+
else:
1233+
pixel_values = torch.cat([inp.data for inp in mm_inputs])
1234+
grid_thw = torch.stack([data.meta['grid_thw'] for data in mm_inputs]).cpu()
1235+
vis_pos_emb = self.model.visual.rot_pos_emb(grid_thw)
1236+
pos_embeds = self.model.visual.fast_pos_embed_interpolate(grid_thw)
1237+
vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
1238+
grid_thw[:, 0]).to(pixel_values.device)
1239+
vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)
1240+
vis_pos_emb = vis_pos_emb.repeat(1, 2)
1241+
vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())
12141242

12151243
mrope_position_ids = getattr(context, 'mrope_position_ids', None)
12161244

@@ -1242,6 +1270,10 @@ def prepare_inputs_for_generation(
12421270
grid_thw=grid_thw,
12431271
pos_embeds=pos_embeds,
12441272
return_input_embeds=return_input_embeds,
1273+
# for time series
1274+
ts_values=ts_values,
1275+
ts_lens=ts_lens,
1276+
ts_sr=ts_sr,
12451277
)
12461278

12471279
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

lmdeploy/pytorch/models/qwen3_5_moe.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from lmdeploy.pytorch.nn.moe import build_fused_moe
1414
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
1515

16+
from .interns1_pro_time_series import InternS1ProTimeSeriesModel
1617
from .patch import add_prefix, get_build_model_context
1718
from .qwen3_5 import (
1819
Qwen3_5Attention,
@@ -232,6 +233,9 @@ def __init__(self,
232233
device=device,
233234
prefix=add_prefix('language_model', prefix))
234235

236+
# build time series model
237+
if hasattr(config, 'ts_config'):
238+
self.time_series = InternS1ProTimeSeriesModel(config.ts_config, dtype=dtype, device=device)
235239

236240
class Qwen3_5MoeForConditionalGeneration(Qwen3_5ForConditionalGeneration):
237241
"""ModelForCausalLM."""
@@ -259,7 +263,7 @@ def __init__(self,
259263
self.ctx_mgr = ctx_mgr
260264

261265
# build preprocessor
262-
self.input_processor = Qwen3_5MoeInputProcessor(self.config)
266+
self.input_processor = Qwen3_5MoeInputProcessor(self.config, dtype)
263267

264268
# build model
265269
self.model = Qwen3_5MoeModel(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))
@@ -351,6 +355,7 @@ def __skip_layers(name):
351355
rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm']
352356

353357
params_dict = dict(self.named_parameters())
358+
buffers_dict = dict(self.named_buffers())
354359
for name, loaded_weight in weights:
355360

356361
if __skip_layers(name):
@@ -369,7 +374,9 @@ def __skip_layers(name):
369374
self._load_weight_experts(name, loaded_weight, params_dict)
370375
else:
371376
for (param_name, weight_name, shard_id) in stacked_params_mapping:
372-
if weight_name not in name:
377+
# include dot to avoid partial match
378+
# e.g. in_proj_ba (in linear attn) vs in_proj_bias (in time series)
379+
if f'{weight_name}.' not in name:
373380
continue
374381
name = name.replace(weight_name, param_name)
375382
param = params_dict[name]
@@ -384,9 +391,15 @@ def __skip_layers(name):
384391
load_weight(param, k, shard_id='k')
385392
load_weight(param, v, shard_id='v')
386393
else:
387-
for rms_norm_key in rms_norm_keys:
388-
if rms_norm_key in name and 'weight' in name:
389-
loaded_weight = loaded_weight + 1
390-
break
391-
param = params_dict[name]
392-
load_weight(param, loaded_weight)
394+
if name in params_dict:
395+
for rms_norm_key in rms_norm_keys:
396+
if rms_norm_key in name and 'weight' in name:
397+
loaded_weight = loaded_weight + 1
398+
break
399+
param = params_dict[name]
400+
load_weight(param, loaded_weight)
401+
elif name in buffers_dict:
402+
param = buffers_dict[name]
403+
load_weight(param, loaded_weight)
404+
else:
405+
raise KeyError(f'Unexpected weight name: {name}')

0 commit comments

Comments
 (0)