@@ -175,9 +175,10 @@ def collate_function(self, batch):
175175class Qwen3OmniImageProcessor (BaseImageProcessor ):
176176 """Image processor for Qwen3-Omni multimodal model."""
177177
178- def __init__ (self , tokenizer , device = "auto" , use_audio_in_video = False ):
178+ def __init__ (self , tokenizer , device = "auto" , dtype = None , use_audio_in_video = False ):
179179 """Constructor."""
180180 super ().__init__ (tokenizer , device )
181+ self .dtype = dtype
181182 self .use_audio_in_video = use_audio_in_video
182183 # Try to import qwen_omni_utils for multimodal processing
183184 try :
@@ -251,7 +252,8 @@ def collate_function(self, batch):
251252 """Collate function to process inputs during data loading."""
252253 result = {}
253254
254- # Take first item from batch (batch_size handling)
255+ # Take first item only — multimodal inputs have variable-length sequences
256+ # (images, audio) that cannot be stacked, so batch_size=1 is expected.
255257 first = batch [0 ]
256258
257259 # Convert lists to tensors and move to device
@@ -262,7 +264,10 @@ def collate_function(self, batch):
262264
263265 # Handle pixel values for images
264266 if first .get ("pixel_values" ) is not None :
265- result ["pixel_values" ] = torch .tensor (first ["pixel_values" ]).to (self .device )
267+ pv = torch .tensor (first ["pixel_values" ])
268+ if self .dtype is not None :
269+ pv = pv .to (self .dtype )
270+ result ["pixel_values" ] = pv .to (self .device )
266271
267272 # Handle image grid thw (tile height width info)
268273 if first .get ("image_grid_thw" ) is not None :
@@ -274,7 +279,10 @@ def collate_function(self, batch):
274279 self .device
275280 )
276281 if first .get ("audio_features" ) is not None :
277- result ["audio_features" ] = torch .tensor (first ["audio_features" ]).to (self .device )
282+ af = torch .tensor (first ["audio_features" ])
283+ if self .dtype is not None :
284+ af = af .to (self .dtype )
285+ result ["audio_features" ] = af .to (self .device )
278286
279287 # Handle video features if present
280288 if first .get ("video_grid_thw" ) is not None :
0 commit comments