@@ -115,6 +115,75 @@ def process_highres_image(image, processor, grid_pinpoints):
115115 image_patches = [processor .preprocess (image_patch , return_tensors = "pt" )["pixel_values" ][0 ] for image_patch in image_patches ]
116116 return torch .stack (image_patches , dim = 0 )
117117
118+ def smart_resize (
119+ height : int ,
120+ width : int ,
121+ patch_size : int = 16 ,
122+ min_pixels : int = 32 * 32 ,
123+ ):
124+ """
125+ Rescales the image dimensions so that:
126+ 1. Both dimensions (height and width) are divisible by 'factor' (32 for Siglip2).
127+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
128+ 3. The aspect ratio of the image is maintained as closely as possible.
129+
130+ This is similar to Qwen2VL's smart_resize but adapted for Siglip2's requirements.
131+
132+ Args:
133+ height: Original image height
134+ width: Original image width
135+ factor: Factor that dimensions must be divisible by (default: 32 = 2 * 16)
136+ min_pixels: Minimum number of pixels (default: 1024 = 32*32)
137+ max_pixels: Maximum number of pixels (default: 262144 = 512*512)
138+
139+ Returns:
140+ Tuple of (resized_height, resized_width)
141+ """
142+ if max (height , width ) / min (height , width ) > 200 :
143+ raise ValueError (
144+ f"Absolute aspect ratio must be smaller than 200, got { max (height , width ) / min (height , width )} "
145+ )
146+
147+ # Round to nearest factor
148+ h_bar = round (height / patch_size ) * patch_size
149+ w_bar = round (width / patch_size ) * patch_size
150+
151+ # Ensure minimum factor size
152+ h_bar = max (patch_size , h_bar )
153+ w_bar = max (patch_size , w_bar )
154+
155+ if h_bar * w_bar < min_pixels :
156+ beta = math .sqrt (min_pixels / (height * width ))
157+ h_bar = math .ceil (height * beta / patch_size ) * patch_size
158+ w_bar = math .ceil (width * beta / patch_size ) * patch_size
159+
160+ return h_bar , w_bar
161+
162+
163+ def process_native_image (image , processor ):
164+ orig_width , orig_height = image .size
165+ if 'siglip' in processor .__class__ .__name__ .lower ():
166+ target_height , target_width = smart_resize (
167+ height = orig_height ,
168+ width = orig_width ,
169+ patch_size = 16 ,
170+ min_pixels = 16 * 4 ,
171+ )
172+ image = image .resize ((target_width , target_height ), Image .BICUBIC )
173+ image_patches = [processor .preprocess (image , return_tensors = "pt" , do_resize = False )["pixel_values" ]]
174+ grid_thw = [1 , target_height // 16 , target_width // 16 ]
175+ return {'pixel_values' : torch .cat (image_patches , dim = 0 ), 'grid_thw' : grid_thw }
176+ else :
177+ target_height , target_width = smart_resize (
178+ height = orig_height ,
179+ width = orig_width ,
180+ patch_size = 14 ,
181+ min_pixels = 14 * 4 ,
182+ )
183+ image = image .resize ((target_width , target_height ), Image .BICUBIC )
184+ image_patches = [processor .preprocess (image , return_tensors = "pt" , do_resize = False , do_center_crop = False )["pixel_values" ]]
185+ grid_thw = [1 , target_height // 14 , target_width // 14 ]
186+ return {'pixel_values' : torch .cat (image_patches , dim = 0 ), 'grid_thw' : grid_thw }
118187
119188def select_best_resolution (original_size , possible_resolutions ):
120189 """
@@ -279,7 +348,7 @@ def process_anyres_image(image, processor, grid_pinpoints):
279348 grid_thw = [1 , best_resolution [1 ] // 16 , best_resolution [0 ] // 16 ]
280349 return {'pixel_values' : torch .cat (image_patches , dim = 0 ), 'grid_thw' : grid_thw }
281350 else : # FIXME: for onevision encoder
282- image_patches = [processor .preprocess (image_padded , return_tensors = "pt" , do_resize = False )["pixel_values" ]]
351+ image_patches = [processor .preprocess (image_padded , return_tensors = "pt" , do_resize = False , do_center_crop = False )["pixel_values" ]]
283352 grid_thw = [1 , best_resolution [1 ] // 14 , best_resolution [0 ] // 14 ]
284353 return {'pixel_values' : torch .cat (image_patches , dim = 0 ), 'grid_thw' : grid_thw }
285354
0 commit comments