Skip to content

Commit b403293

Browse files
committed
fix a bug
1 parent a2a8d48 commit b403293

1 file changed

Lines changed: 70 additions & 1 deletion

File tree

llava_next/llava/mm_utils.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,75 @@ def process_highres_image(image, processor, grid_pinpoints):
115115
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
116116
return torch.stack(image_patches, dim=0)
117117

118+
def smart_resize(
119+
height: int,
120+
width: int,
121+
patch_size: int = 16,
122+
min_pixels: int = 32 * 32,
123+
):
124+
"""
125+
Rescales the image dimensions so that:
126+
1. Both dimensions (height and width) are divisible by 'factor' (32 for Siglip2).
127+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
128+
3. The aspect ratio of the image is maintained as closely as possible.
129+
130+
This is similar to Qwen2VL's smart_resize but adapted for Siglip2's requirements.
131+
132+
Args:
133+
height: Original image height
134+
width: Original image width
135+
factor: Factor that dimensions must be divisible by (default: 32 = 2 * 16)
136+
min_pixels: Minimum number of pixels (default: 1024 = 32*32)
137+
max_pixels: Maximum number of pixels (default: 262144 = 512*512)
138+
139+
Returns:
140+
Tuple of (resized_height, resized_width)
141+
"""
142+
if max(height, width) / min(height, width) > 200:
143+
raise ValueError(
144+
f"Absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
145+
)
146+
147+
# Round to nearest factor
148+
h_bar = round(height / patch_size) * patch_size
149+
w_bar = round(width / patch_size) * patch_size
150+
151+
# Ensure minimum factor size
152+
h_bar = max(patch_size, h_bar)
153+
w_bar = max(patch_size, w_bar)
154+
155+
if h_bar * w_bar < min_pixels:
156+
beta = math.sqrt(min_pixels / (height * width))
157+
h_bar = math.ceil(height * beta / patch_size) * patch_size
158+
w_bar = math.ceil(width * beta / patch_size) * patch_size
159+
160+
return h_bar, w_bar
161+
162+
163+
def process_native_image(image, processor):
164+
orig_width, orig_height = image.size
165+
if 'siglip' in processor.__class__.__name__.lower():
166+
target_height, target_width = smart_resize(
167+
height=orig_height,
168+
width=orig_width,
169+
patch_size=16,
170+
min_pixels=16*4,
171+
)
172+
image = image.resize((target_width, target_height), Image.BICUBIC)
173+
image_patches = [processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]]
174+
grid_thw = [1, target_height // 16, target_width // 16]
175+
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
176+
else:
177+
target_height, target_width = smart_resize(
178+
height=orig_height,
179+
width=orig_width,
180+
patch_size=14,
181+
min_pixels=14*4,
182+
)
183+
image = image.resize((target_width, target_height), Image.BICUBIC)
184+
image_patches = [processor.preprocess(image, return_tensors="pt", do_resize=False, do_center_crop=False)["pixel_values"]]
185+
grid_thw = [1, target_height // 14, target_width // 14]
186+
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
118187

119188
def select_best_resolution(original_size, possible_resolutions):
120189
"""
@@ -279,7 +348,7 @@ def process_anyres_image(image, processor, grid_pinpoints):
279348
grid_thw = [1, best_resolution[1] // 16, best_resolution[0] // 16]
280349
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
281350
else: # FIXME: for onevision encoder
282-
image_patches = [processor.preprocess(image_padded, return_tensors="pt", do_resize=False)["pixel_values"]]
351+
image_patches = [processor.preprocess(image_padded, return_tensors="pt", do_resize=False, do_center_crop=False)["pixel_values"]]
283352
grid_thw = [1, best_resolution[1] // 14, best_resolution[0] // 14]
284353
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
285354

0 commit comments

Comments
 (0)