2626)
2727from lmdeploy .pytorch .nn .rotary_embedding import get_rope_parameters
2828from lmdeploy .pytorch .weight_loader .model_weight_loader import default_weight_loader , load_weight
29+ from lmdeploy .vl .constants import Modality
2930
3031from .patch import add_prefix , get_build_model_context
3132from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3_5VisionRotaryEmbedding
@@ -1026,6 +1027,10 @@ def forward(
10261027 grid_thw : torch .Tensor | None = None ,
10271028 all_routed_experts : torch .Tensor | None = None ,
10281029 return_input_embeds : bool = False ,
1030+ # for time series
1031+ ts_values : torch .Tensor = None ,
1032+ ts_lens : torch .Tensor = None ,
1033+ ts_sr : torch .Tensor = None ,
10291034 ):
10301035 """Model forward, return logits."""
10311036
@@ -1052,6 +1057,11 @@ def forward(
10521057 # mask and scatter to create final input embeddings
10531058 multimodal_mask = multimodal_mask .unsqueeze (- 1 ).expand_as (inputs_embeds )
10541059 inputs_embeds = inputs_embeds .masked_scatter (multimodal_mask , image_embeds )
1060+ elif ts_values is not None :
1061+ if not hasattr (self , 'time_series' ):
1062+ raise RuntimeError ('Time-series inputs require a time_series module.' )
1063+ ts_embeds = self .time_series (ts_values , ts_lens , ts_sr ) # [B, T, C]
1064+ inputs_embeds = inputs_embeds .masked_scatter (multimodal_mask [..., None ], ts_embeds )
10551065
10561066 output_inputs_embeds = inputs_embeds if return_input_embeds else None
10571067
@@ -1098,7 +1108,7 @@ def __init__(self,
10981108 self .ctx_mgr = ctx_mgr
10991109
11001110 # build preprocessor
1101- self .input_processor = Qwen3_5InputProcessor (self .config )
1111+ self .input_processor = Qwen3_5InputProcessor (self .config , dtype )
11021112
11031113 # build model
11041114 self .model = Qwen3_5Model (config , dtype = dtype , device = device , prefix = add_prefix ('model' , prefix ))
@@ -1129,6 +1139,10 @@ def forward(
11291139 pos_embeds : torch .Tensor | None = None ,
11301140 grid_thw : torch .Tensor | None = None ,
11311141 return_input_embeds : bool = False ,
1142+ # for time series
1143+ ts_values : torch .Tensor = None ,
1144+ ts_lens : torch .Tensor = None ,
1145+ ts_sr : torch .Tensor = None ,
11321146 ** kwargs ,
11331147 ):
11341148 """Model forward, return logits."""
@@ -1155,6 +1169,10 @@ def forward(
11551169 grid_thw = grid_thw ,
11561170 all_routed_experts = all_routed_experts ,
11571171 return_input_embeds = return_input_embeds ,
1172+ # for time series
1173+ ts_values = ts_values ,
1174+ ts_lens = ts_lens ,
1175+ ts_sr = ts_sr ,
11581176 )
11591177 return dict (hidden_states = hidden_states ,
11601178 all_routed_experts = all_routed_experts ,
@@ -1194,23 +1212,33 @@ def prepare_inputs_for_generation(
11941212 multimodal_mask = None
11951213 grid_thw = None
11961214 pos_embeds = None
1215+ # for time series
1216+ ts_values = None
1217+ ts_lens = None
1218+ ts_sr = None
11971219 if context .input_multimodals is not None :
11981220 mm_inputs = [input_mm .get ('mm_data' , []) for input_mm in context .input_multimodals ]
11991221 # flatten batch
12001222 mm_inputs = [item for sublist in mm_inputs for item in sublist ]
12011223
12021224 if len (mm_inputs ) > 0 :
1203- pixel_values = torch .cat ([inp .data for inp in mm_inputs ])
1204-
1225+ modality = mm_inputs [0 ].modality
12051226 multimodal_mask = self .get_multimodal_mask (input_ids , mm_inputs )
1206- grid_thw = torch .stack ([data .meta ['grid_thw' ] for data in mm_inputs ]).cpu ()
1207- vis_pos_emb = self .model .visual .rot_pos_emb (grid_thw )
1208- pos_embeds = self .model .visual .fast_pos_embed_interpolate (grid_thw )
1209- vis_cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
1210- grid_thw [:, 0 ]).to (pixel_values .device )
1211- vis_cu_seqlens = vis_cu_seqlens .cumsum (dim = 0 , dtype = torch .int32 )
1212- vis_pos_emb = vis_pos_emb .repeat (1 , 2 )
1213- vis_pos_emb = (vis_pos_emb .cos (), vis_pos_emb .sin ())
1227+
1228+ if modality == Modality .TIME_SERIES :
1229+ ts_values = torch .cat ([inp .data for inp in mm_inputs ])
1230+ ts_lens = torch .cat ([inp .meta ['ts_lens' ] for inp in mm_inputs ])
1231+ ts_sr = torch .cat ([inp .meta ['ts_sr' ] for inp in mm_inputs ])
1232+ else :
1233+ pixel_values = torch .cat ([inp .data for inp in mm_inputs ])
1234+ grid_thw = torch .stack ([data .meta ['grid_thw' ] for data in mm_inputs ]).cpu ()
1235+ vis_pos_emb = self .model .visual .rot_pos_emb (grid_thw )
1236+ pos_embeds = self .model .visual .fast_pos_embed_interpolate (grid_thw )
1237+ vis_cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
1238+ grid_thw [:, 0 ]).to (pixel_values .device )
1239+ vis_cu_seqlens = vis_cu_seqlens .cumsum (dim = 0 , dtype = torch .int32 )
1240+ vis_pos_emb = vis_pos_emb .repeat (1 , 2 )
1241+ vis_pos_emb = (vis_pos_emb .cos (), vis_pos_emb .sin ())
12141242
12151243 mrope_position_ids = getattr (context , 'mrope_position_ids' , None )
12161244
@@ -1242,6 +1270,10 @@ def prepare_inputs_for_generation(
12421270 grid_thw = grid_thw ,
12431271 pos_embeds = pos_embeds ,
12441272 return_input_embeds = return_input_embeds ,
1273+ # for time series
1274+ ts_values = ts_values ,
1275+ ts_lens = ts_lens ,
1276+ ts_sr = ts_sr ,
12451277 )
12461278
12471279 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
0 commit comments