1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import json
1516from typing import Any , Dict , List
1617
1718import torch
19+ from transformers .image_utils import load_image
1820
1921__all__ = [
2022 "process_token_dict_to_mappings" ,
@@ -195,6 +197,15 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
195197
196198class VLMDataCollatorWithPadding :
197199
200+ def __init__ (self , processor = None ):
201+ """
202+ Args:
203+ processor: VLM processor (e.g. AutoProcessor for qwen3_vl).
204+ When provided, image_paths in features will be decoded
205+ on-the-fly to pixel_values (used in online training).
206+ """
207+ self .processor = processor
208+
198209 def __call__ (self , features : List [Dict [str , Any ]]) -> Dict [str , Any ]:
199210 max_length = max (item ["input_ids" ].shape [1 ] for item in features )
200211 batch_input_ids = torch .cat (
@@ -217,27 +228,53 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
217228 "position_ids" : None ,
218229 }
219230
220- if "pixel_values" in features [0 ]:
221- batch ["pixel_values" ] = paddingtensor3D_BHW (
222- [item ["pixel_values" ] for item in features ]
223- )
224- if "video_pixel_values" in features [0 ]:
225- batch ["video_pixel_values" ] = paddingtensor3D_BHW (
226- [item ["video_pixel_values" ] for item in features ]
227- )
228-
229- if all (
230- "image_grid_thw" in item and item ["image_grid_thw" ] is not None for item in features
231- ):
232- batch ["image_grid_thw" ] = torch .cat (
233- [item ["image_grid_thw" ] for item in features ], dim = 0
234- )
235- if all (
236- "video_grid_thw" in item and item ["video_grid_thw" ] is not None for item in features
237- ):
238- batch ["video_grid_thw" ] = torch .cat (
239- [item ["video_grid_thw" ] for item in features ], dim = 0
240- )
231+ # Online training: decode image_paths -> pixel_values on-the-fly
232+ if self .processor is not None and "image_paths" in features [0 ]:
233+ all_pixel_values , all_image_grid_thw = [], []
234+ all_video_pixel_values , all_video_grid_thw = [], []
235+ for item in features :
236+ image_paths = json .loads (item ["image_paths" ])
237+ if image_paths :
238+ images = [load_image (p ) for p in image_paths ]
239+ vision_enc = self .processor .image_processor (images = images , return_tensors = "pt" )
240+ all_pixel_values .append (vision_enc ["pixel_values" ])
241+ if "image_grid_thw" in vision_enc :
242+ all_image_grid_thw .append (vision_enc ["image_grid_thw" ])
243+ if "video_pixel_values" in vision_enc :
244+ all_video_pixel_values .append (vision_enc ["video_pixel_values" ])
245+ if "video_grid_thw" in vision_enc :
246+ all_video_grid_thw .append (vision_enc ["video_grid_thw" ])
247+ if all_pixel_values :
248+ batch ["pixel_values" ] = paddingtensor3D_BHW (all_pixel_values )
249+ if all_image_grid_thw :
250+ batch ["image_grid_thw" ] = torch .cat (all_image_grid_thw , dim = 0 )
251+ if all_video_pixel_values :
252+ batch ["video_pixel_values" ] = paddingtensor3D_BHW (all_video_pixel_values )
253+ if all_video_grid_thw :
254+ batch ["video_grid_thw" ] = torch .cat (all_video_grid_thw , dim = 0 )
255+ else :
256+ if "pixel_values" in features [0 ]:
257+ batch ["pixel_values" ] = paddingtensor3D_BHW (
258+ [item ["pixel_values" ] for item in features ]
259+ )
260+ if "video_pixel_values" in features [0 ]:
261+ batch ["video_pixel_values" ] = paddingtensor3D_BHW (
262+ [item ["video_pixel_values" ] for item in features ]
263+ )
264+ if all (
265+ "image_grid_thw" in item and item ["image_grid_thw" ] is not None
266+ for item in features
267+ ):
268+ batch ["image_grid_thw" ] = torch .cat (
269+ [item ["image_grid_thw" ] for item in features ], dim = 0
270+ )
271+ if all (
272+ "video_grid_thw" in item and item ["video_grid_thw" ] is not None
273+ for item in features
274+ ):
275+ batch ["video_grid_thw" ] = torch .cat (
276+ [item ["video_grid_thw" ] for item in features ], dim = 0
277+ )
241278
242279 # Check if both hidden_states and target_hiddens exist in all features
243280 if all ("hidden_states" in item and "target_hiddens" in item for item in features ):
@@ -261,6 +298,15 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
261298
262299class VLMHunyuanDataCollatorWithPadding :
263300
301+ def __init__ (self , processor = None ):
302+ """
303+ Args:
304+ processor: VLM processor (e.g. AutoProcessor for hunyuan_vl).
305+ When provided, image_paths in features will be decoded
306+ on-the-fly to pixel_values (used in online training).
307+ """
308+ self .processor = processor
309+
264310 def __call__ (self , features : List [Dict [str , Any ]]) -> Dict [str , Any ]:
265311 max_length = max (item ["input_ids" ].shape [1 ] for item in features )
266312 batch_input_ids = torch .cat (
@@ -283,17 +329,33 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
283329 "input_position_ids" : None ,
284330 }
285331
286- if "pixel_values" in features [0 ]:
287- batch ["pixel_values" ] = paddingtensor3D_BHW (
288- [item ["pixel_values" ] for item in features ]
289- )
290-
291- if all (
292- "image_grid_thw" in item and item ["image_grid_thw" ] is not None for item in features
293- ):
294- batch ["image_grid_thw" ] = torch .cat (
295- [item ["image_grid_thw" ] for item in features ], dim = 0
296- )
332+ # Online training: decode image_paths -> pixel_values on-the-fly
333+ if self .processor is not None and "image_paths" in features [0 ]:
334+ all_pixel_values , all_image_grid_thw = [], []
335+ for item in features :
336+ image_paths = json .loads (item ["image_paths" ])
337+ if image_paths :
338+ images = [load_image (p ) for p in image_paths ]
339+ vision_enc = self .processor (images = images , return_tensors = "pt" )
340+ all_pixel_values .append (vision_enc ["pixel_values" ])
341+ if "image_grid_thw" in vision_enc :
342+ all_image_grid_thw .append (vision_enc ["image_grid_thw" ])
343+ if all_pixel_values :
344+ batch ["pixel_values" ] = paddingtensor3D_BHW (all_pixel_values )
345+ if all_image_grid_thw :
346+ batch ["image_grid_thw" ] = torch .cat (all_image_grid_thw , dim = 0 )
347+ else :
348+ if "pixel_values" in features [0 ]:
349+ batch ["pixel_values" ] = paddingtensor3D_BHW (
350+ [item ["pixel_values" ] for item in features ]
351+ )
352+ if all (
353+ "image_grid_thw" in item and item ["image_grid_thw" ] is not None
354+ for item in features
355+ ):
356+ batch ["image_grid_thw" ] = torch .cat (
357+ [item ["image_grid_thw" ] for item in features ], dim = 0
358+ )
297359
298360 # Check if both hidden_states and target_hiddens exist in all features
299361 if all ("hidden_states" in item and "target_hiddens" in item for item in features ):
0 commit comments