@@ -115,6 +115,75 @@ def process_highres_image(image, processor, grid_pinpoints):
115115 image_patches = [processor .preprocess (image_patch , return_tensors = "pt" )["pixel_values" ][0 ] for image_patch in image_patches ]
116116 return torch .stack (image_patches , dim = 0 )
117117
118+ def smart_resize (
119+ height : int ,
120+ width : int ,
121+ patch_size : int = 16 ,
122+ min_pixels : int = 32 * 32 ,
123+ ):
124+ """
125+ Rescales the image dimensions so that:
126+ 1. Both dimensions (height and width) are divisible by 'factor' (32 for Siglip2).
127+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
128+ 3. The aspect ratio of the image is maintained as closely as possible.
129+
130+ This is similar to Qwen2VL's smart_resize but adapted for Siglip2's requirements.
131+
132+ Args:
133+ height: Original image height
134+ width: Original image width
135+ factor: Factor that dimensions must be divisible by (default: 32 = 2 * 16)
136+ min_pixels: Minimum number of pixels (default: 1024 = 32*32)
137+ max_pixels: Maximum number of pixels (default: 262144 = 512*512)
138+
139+ Returns:
140+ Tuple of (resized_height, resized_width)
141+ """
142+ if max (height , width ) / min (height , width ) > 200 :
143+ raise ValueError (
144+ f"Absolute aspect ratio must be smaller than 200, got { max (height , width ) / min (height , width )} "
145+ )
146+
147+ # Round to nearest factor
148+ h_bar = round (height / patch_size ) * patch_size
149+ w_bar = round (width / patch_size ) * patch_size
150+
151+ # Ensure minimum factor size
152+ h_bar = max (patch_size , h_bar )
153+ w_bar = max (patch_size , w_bar )
154+
155+ if h_bar * w_bar < min_pixels :
156+ beta = math .sqrt (min_pixels / (height * width ))
157+ h_bar = math .ceil (height * beta / patch_size ) * patch_size
158+ w_bar = math .ceil (width * beta / patch_size ) * patch_size
159+
160+ return h_bar , w_bar
161+
162+
163+ def process_native_image (image , processor ):
164+ orig_width , orig_height = image .size
165+ if 'siglip' in processor .__class__ .__name__ .lower ():
166+ target_height , target_width = smart_resize (
167+ height = orig_height ,
168+ width = orig_width ,
169+ patch_size = 16 ,
170+ min_pixels = 16 * 4 ,
171+ )
172+ image = image .resize ((target_width , target_height ), Image .BICUBIC )
173+ image_patches = [processor .preprocess (image , return_tensors = "pt" , do_resize = False )["pixel_values" ]]
174+ grid_thw = [1 , target_height // 16 , target_width // 16 ]
175+ return {'pixel_values' : torch .cat (image_patches , dim = 0 ), 'grid_thw' : grid_thw }
176+ else :
177+ target_height , target_width = smart_resize (
178+ height = orig_height ,
179+ width = orig_width ,
180+ patch_size = 14 ,
181+ min_pixels = 14 * 4 ,
182+ )
183+ image = image .resize ((target_width , target_height ), Image .BICUBIC )
184+ image_patches = [processor .preprocess (image , return_tensors = "pt" , do_resize = False , do_center_crop = False )["pixel_values" ]]
185+ grid_thw = [1 , target_height // 14 , target_width // 14 ]
186+ return {'pixel_values' : torch .cat (image_patches , dim = 0 ), 'grid_thw' : grid_thw }
118187
119188def select_best_resolution (original_size , possible_resolutions ):
120189 """
@@ -274,6 +343,14 @@ def process_anyres_image(image, processor, grid_pinpoints):
274343 possible_resolutions = ast .literal_eval (grid_pinpoints )
275344 best_resolution = select_best_resolution (image .size , possible_resolutions )
276345 image_padded = resize_and_pad_image (image , best_resolution )
346+ if 'siglip' in processor .__class__ .__name__ .lower ():
347+ image_patches = [processor .preprocess (image_padded , return_tensors = "pt" , do_resize = False )["pixel_values" ]]
348+ grid_thw = [1 , best_resolution [1 ] // 16 , best_resolution [0 ] // 16 ]
349+ return {'pixel_values' : torch .cat (image_patches , dim = 0 ), 'grid_thw' : grid_thw }
350+ else : # FIXME: for onevision encoder
351+ image_patches = [processor .preprocess (image_padded , return_tensors = "pt" , do_resize = False , do_center_crop = False )["pixel_values" ]]
352+ grid_thw = [1 , best_resolution [1 ] // 14 , best_resolution [0 ] // 14 ]
353+ return {'pixel_values' : torch .cat (image_patches , dim = 0 ), 'grid_thw' : grid_thw }
277354
278355 patches = divide_to_patches (image_padded , processor .crop_size ["height" ])
279356
@@ -314,23 +391,51 @@ def expand2square(pil_img, background_color):
314391def process_images (images , image_processor , model_cfg ):
315392 image_aspect_ratio = getattr (model_cfg , "image_aspect_ratio" , None )
316393 new_images = []
394+ if len (images ) == 8 : #FIXME hardcoded for 8 images input as video sample
395+ image_aspect_ratio = 'pad'
396+
317397 if image_aspect_ratio == "highres" :
318398 for image in images :
319399 image = process_highres_image (image , image_processor , model_cfg .image_grid_pinpoints )
320400 new_images .append (image )
401+ elif image_aspect_ratio == "native" :
402+ for image in images :
403+ image = process_native_image (image , image_processor )
404+ new_images .append (image )
321405 elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio :
322406 for image in images :
323407 image = process_anyres_image (image , image_processor , model_cfg .image_grid_pinpoints )
324408 new_images .append (image )
409+ return {'image_patchs' : [img ['pixel_values' ] for img in new_images ], 'grid_thw' : [img ['grid_thw' ] for img in new_images ]}
325410 elif image_aspect_ratio == "crop_split" :
326411 for image in images :
327412 image = process_highres_image_crop_split (image , model_cfg , image_processor )
328413 new_images .append (image )
329414 elif image_aspect_ratio == "pad" :
330- for image in images :
331- image = expand2square (image , tuple (int (x * 255 ) for x in image_processor .image_mean ))
332- image = image_processor .preprocess (image , return_tensors = "pt" )["pixel_values" ][0 ]
333- new_images .append (image )
415+ if 'siglip' in image_processor .__class__ .__name__ .lower ():
416+ image_patchs = []
417+ grid_thw = []
418+ for image in images :
419+ image = expand2square (image , tuple (int (0 * 255 ) for x in [0 ,0 ,0 ]))
420+ image = image .resize ((512 , 512 ))
421+ image_patchs .append (image_processor .preprocess (image , return_tensors = "pt" , do_resize = False )["pixel_values" ])
422+ grid_thw .append ([1 , 32 , 32 ])
423+ return {'image_patchs' : image_patchs , 'grid_thw' : torch .tensor (grid_thw )}
424+
425+ else : # FIXME: for onevision encoder video
426+ image_patchs = []
427+ grid_thw = []
428+ for image in images :
429+ image = expand2square (image , tuple (int (0 * 255 ) for x in [0 ,0 ,0 ]))
430+ image = image .resize ((504 , 504 ))
431+ image_patchs .append (image_processor .preprocess (image , return_tensors = "pt" , do_resize = False )["pixel_values" ])
432+ grid_thw .append ([1 , 36 , 36 ])
433+ return {'image_patchs' : image_patchs , 'grid_thw' : torch .tensor (grid_thw )}
434+
435+ image = image .resize ((504 , 504 ))
436+ # image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
437+ image = image_processor .preprocess (image , return_tensors = "pt" , do_resize = False )["pixel_values" ]
438+ new_images .append (image )
334439 else :
335440 return image_processor .preprocess (images , return_tensors = "pt" )["pixel_values" ]
336441 if all (x .shape == new_images [0 ].shape for x in new_images ):
0 commit comments