@@ -130,6 +130,20 @@ def paddingtensor3D_CBN(tensor_list):
130130 return torch .cat (out_tensor_list , dim = 1 )
131131
132132
133+ def paddingtensor3D_BCN (tensor_list ):
134+ if all (tensor is None for tensor in tensor_list ):
135+ return None
136+ N = max (tensor .shape [- 1 ] for tensor in tensor_list if tensor is not None )
137+ out_tensor_list = []
138+ for tensor in tensor_list :
139+ b , c , n = tensor .shape
140+ outtensor = torch .zeros (b , c , N , dtype = tensor_list [0 ].dtype )
141+ if tensor is not None :
142+ outtensor [:, :, :n ] = tensor
143+ out_tensor_list .append (outtensor )
144+ return torch .cat (out_tensor_list , dim = 0 )
145+
146+
133147def paddingtensor3D_BHW (tensor_list ):
134148 if all (tensor is None for tensor in tensor_list ):
135149 return None
@@ -240,11 +254,90 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
240254 batch ["target_hiddens" ] = torch .cat (
241255 [paddingtensor (item ["target_hiddens" ], max_length ) for item in features ]
242256 )
257+ if all (
258+ "inputs_embeds" in item and item ["inputs_embeds" ] is not None
259+ for item in features
260+ ):
243261 batch ["inputs_embeds" ] = torch .cat (
244262 [paddingtensor (item ["inputs_embeds" ], max_length ) for item in features ]
245263 )
264+ if all (
265+ "position_ids" in item and item ["position_ids" ] is not None
266+ for item in features
267+ ):
246268 batch ["position_ids" ] = paddingtensor3D_CBN (
247269 [item ["position_ids" ] for item in features ]
248270 )
249271
250272 return batch
273+
274+
275+ class VLMHunyuanDataCollatorWithPadding :
276+
277+ def __call__ (self , features : List [Dict [str , Any ]]) -> Dict [str , Any ]:
278+ max_length = max (item ["input_ids" ].shape [1 ] for item in features )
279+ batch_input_ids = torch .cat (
280+ [paddingtensor2D (item ["input_ids" ], max_length ) for item in features ]
281+ )
282+ batch_attention_mask = torch .cat (
283+ [paddingtensor2D (item ["attention_mask" ], max_length ) for item in features ]
284+ )
285+ batch_loss_mask = torch .cat (
286+ [paddingtensor2D (item ["loss_mask" ], max_length ) for item in features ]
287+ )
288+ batch = {
289+ "input_ids" : batch_input_ids ,
290+ "attention_mask" : batch_attention_mask ,
291+ "loss_mask" : batch_loss_mask ,
292+ "hidden_states" : None ,
293+ "target_hiddens" : None ,
294+ "inputs_embeds" : None ,
295+ "position_ids" : None ,
296+ "input_position_ids" : None ,
297+ }
298+
299+ if "pixel_values" in features [0 ]:
300+ batch ["pixel_values" ] = paddingtensor3D_BHW (
301+ [item ["pixel_values" ] for item in features ]
302+ )
303+
304+ if all (
305+ "image_grid_thw" in item and item ["image_grid_thw" ] is not None
306+ for item in features
307+ ):
308+ batch ["image_grid_thw" ] = torch .cat (
309+ [item ["image_grid_thw" ] for item in features ], dim = 0
310+ )
311+
312+ # Check if both hidden_states and target_hiddens exist in all features
313+ if all (
314+ "hidden_states" in item and "target_hiddens" in item for item in features
315+ ):
316+ batch ["hidden_states" ] = torch .cat (
317+ [paddingtensor (item ["hidden_states" ], max_length ) for item in features ]
318+ )
319+ batch ["target_hiddens" ] = torch .cat (
320+ [paddingtensor (item ["target_hiddens" ], max_length ) for item in features ]
321+ )
322+ if all (
323+ "inputs_embeds" in item and item ["inputs_embeds" ] is not None
324+ for item in features
325+ ):
326+ batch ["inputs_embeds" ] = torch .cat (
327+ [paddingtensor (item ["inputs_embeds" ], max_length ) for item in features ]
328+ )
329+ if all (
330+ "input_position_ids" in item and item ["input_position_ids" ] is not None
331+ for item in features
332+ ):
333+ batch ["input_position_ids" ] = paddingtensor3D_BCN (
334+ [item ["input_position_ids" ] for item in features ]
335+ )
336+ if all (
337+ "position_ids" in item and item ["position_ids" ] is not None
338+ for item in features
339+ ):
340+ batch ["position_ids" ] = torch .cat (
341+ [paddingtensor2D (item ["position_ids" ], max_length ) for item in features ]
342+ )
343+ return batch
0 commit comments