Skip to content

Commit d101215

Browse files
authored
Merge pull request #2 from EvolvingLMMs-Lab/yiye
Prepare for open source release
2 parents 7201427 + d29d1f1 commit d101215

File tree

176 files changed

+594
-19664
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

176 files changed

+594
-19664
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ ray_results/
188188
comet_ml/
189189
neptune/
190190
optuna/
191+
checkpoints/
191192

192193
# Common data file formats (uncomment any you DO want to track)
193194
*.csv

llava_next/llava/mm_utils.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,75 @@ def process_highres_image(image, processor, grid_pinpoints):
115115
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
116116
return torch.stack(image_patches, dim=0)
117117

118+
def smart_resize(
119+
height: int,
120+
width: int,
121+
patch_size: int = 16,
122+
min_pixels: int = 32 * 32,
123+
):
124+
"""
125+
Rescales the image dimensions so that:
126+
1. Both dimensions (height and width) are divisible by 'factor' (32 for Siglip2).
127+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
128+
3. The aspect ratio of the image is maintained as closely as possible.
129+
130+
This is similar to Qwen2VL's smart_resize but adapted for Siglip2's requirements.
131+
132+
Args:
133+
height: Original image height
134+
width: Original image width
135+
factor: Factor that dimensions must be divisible by (default: 32 = 2 * 16)
136+
min_pixels: Minimum number of pixels (default: 1024 = 32*32)
137+
max_pixels: Maximum number of pixels (default: 262144 = 512*512)
138+
139+
Returns:
140+
Tuple of (resized_height, resized_width)
141+
"""
142+
if max(height, width) / min(height, width) > 200:
143+
raise ValueError(
144+
f"Absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
145+
)
146+
147+
# Round to nearest factor
148+
h_bar = round(height / patch_size) * patch_size
149+
w_bar = round(width / patch_size) * patch_size
150+
151+
# Ensure minimum factor size
152+
h_bar = max(patch_size, h_bar)
153+
w_bar = max(patch_size, w_bar)
154+
155+
if h_bar * w_bar < min_pixels:
156+
beta = math.sqrt(min_pixels / (height * width))
157+
h_bar = math.ceil(height * beta / patch_size) * patch_size
158+
w_bar = math.ceil(width * beta / patch_size) * patch_size
159+
160+
return h_bar, w_bar
161+
162+
163+
def process_native_image(image, processor):
164+
orig_width, orig_height = image.size
165+
if 'siglip' in processor.__class__.__name__.lower():
166+
target_height, target_width = smart_resize(
167+
height=orig_height,
168+
width=orig_width,
169+
patch_size=16,
170+
min_pixels=16*4,
171+
)
172+
image = image.resize((target_width, target_height), Image.BICUBIC)
173+
image_patches = [processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]]
174+
grid_thw = [1, target_height // 16, target_width // 16]
175+
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
176+
else:
177+
target_height, target_width = smart_resize(
178+
height=orig_height,
179+
width=orig_width,
180+
patch_size=14,
181+
min_pixels=14*4,
182+
)
183+
image = image.resize((target_width, target_height), Image.BICUBIC)
184+
image_patches = [processor.preprocess(image, return_tensors="pt", do_resize=False, do_center_crop=False)["pixel_values"]]
185+
grid_thw = [1, target_height // 14, target_width // 14]
186+
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
118187

119188
def select_best_resolution(original_size, possible_resolutions):
120189
"""
@@ -274,6 +343,14 @@ def process_anyres_image(image, processor, grid_pinpoints):
274343
possible_resolutions = ast.literal_eval(grid_pinpoints)
275344
best_resolution = select_best_resolution(image.size, possible_resolutions)
276345
image_padded = resize_and_pad_image(image, best_resolution)
346+
if 'siglip' in processor.__class__.__name__.lower():
347+
image_patches = [processor.preprocess(image_padded, return_tensors="pt", do_resize=False)["pixel_values"]]
348+
grid_thw = [1, best_resolution[1] // 16, best_resolution[0] // 16]
349+
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
350+
else: # FIXME: for onevision encoder
351+
image_patches = [processor.preprocess(image_padded, return_tensors="pt", do_resize=False, do_center_crop=False)["pixel_values"]]
352+
grid_thw = [1, best_resolution[1] // 14, best_resolution[0] // 14]
353+
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
277354

278355
patches = divide_to_patches(image_padded, processor.crop_size["height"])
279356

@@ -314,23 +391,51 @@ def expand2square(pil_img, background_color):
314391
def process_images(images, image_processor, model_cfg):
315392
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
316393
new_images = []
394+
if len(images) == 8: #FIXME hardcoded for 8 images input as video sample
395+
image_aspect_ratio = 'pad'
396+
317397
if image_aspect_ratio == "highres":
318398
for image in images:
319399
image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
320400
new_images.append(image)
401+
elif image_aspect_ratio == "native":
402+
for image in images:
403+
image = process_native_image(image, image_processor)
404+
new_images.append(image)
321405
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
322406
for image in images:
323407
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
324408
new_images.append(image)
409+
return {'image_patchs': [img['pixel_values'] for img in new_images], 'grid_thw': [img['grid_thw'] for img in new_images]}
325410
elif image_aspect_ratio == "crop_split":
326411
for image in images:
327412
image = process_highres_image_crop_split(image, model_cfg, image_processor)
328413
new_images.append(image)
329414
elif image_aspect_ratio == "pad":
330-
for image in images:
331-
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
332-
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
333-
new_images.append(image)
415+
if 'siglip' in image_processor.__class__.__name__.lower():
416+
image_patchs = []
417+
grid_thw = []
418+
for image in images:
419+
image = expand2square(image, tuple(int(0 * 255) for x in [0,0,0]))
420+
image = image.resize((512, 512))
421+
image_patchs.append(image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"])
422+
grid_thw.append([1, 32, 32])
423+
return {'image_patchs': image_patchs, 'grid_thw': torch.tensor(grid_thw)}
424+
425+
else: # FIXME: for onevision encoder video
426+
image_patchs = []
427+
grid_thw = []
428+
for image in images:
429+
image = expand2square(image, tuple(int(0 * 255) for x in [0,0,0]))
430+
image = image.resize((504, 504))
431+
image_patchs.append(image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"])
432+
grid_thw.append([1, 36, 36])
433+
return {'image_patchs': image_patchs, 'grid_thw': torch.tensor(grid_thw)}
434+
435+
image = image.resize((504, 504))
436+
# image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
437+
image = image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]
438+
new_images.append(image)
334439
else:
335440
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
336441
if all(x.shape == new_images[0].shape for x in new_images):

llava_next/llava/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717

1818

1919
from .language_model.llava_qwen import LlavaQwenForCausalLM, LlavaQwenConfig
20+
from .language_model.llava_qwen3 import LlavaQwen3ForCausalLM, LlavaQwen3Config

llava_next/llava/model/builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,16 @@ def load_from_hf(repo_id, filename, subfolder=None):
221221
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
222222
else:
223223
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
224+
elif "qwen3" in model_name.lower():
225+
from llava.model.language_model.llava_qwen3 import LlavaQwen3Config
226+
if overwrite_config is not None:
227+
llava_cfg = LlavaQwen3Config.from_pretrained(model_path)
228+
rank0_print(f"Overwriting config with {overwrite_config}")
229+
for k, v in overwrite_config.items():
230+
setattr(llava_cfg, k, v)
231+
model = LlavaQwen3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
232+
else:
233+
model = LlavaQwen3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
224234

225235
else:
226236
from llava.model.language_model.llava_qwen import LlavaQwenConfig

llava_next/llava/model/language_model/llava_gemma.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)