Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ ray_results/
comet_ml/
neptune/
optuna/
checkpoints/

# Common data file formats (uncomment any you DO want to track)
*.csv
Expand Down
113 changes: 109 additions & 4 deletions llava_next/llava/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,75 @@ def process_highres_image(image, processor, grid_pinpoints):
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)

def smart_resize(
height: int,
width: int,
patch_size: int = 16,
min_pixels: int = 32 * 32,
):
"""
Rescales the image dimensions so that:
1. Both dimensions (height and width) are divisible by 'factor' (32 for Siglip2).
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.

This is similar to Qwen2VL's smart_resize but adapted for Siglip2's requirements.

Args:
height: Original image height
width: Original image width
factor: Factor that dimensions must be divisible by (default: 32 = 2 * 16)
min_pixels: Minimum number of pixels (default: 1024 = 32*32)
max_pixels: Maximum number of pixels (default: 262144 = 512*512)

Returns:
Tuple of (resized_height, resized_width)
"""
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"Absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)

# Round to nearest factor
h_bar = round(height / patch_size) * patch_size
w_bar = round(width / patch_size) * patch_size

# Ensure minimum factor size
h_bar = max(patch_size, h_bar)
w_bar = max(patch_size, w_bar)

if h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / patch_size) * patch_size
w_bar = math.ceil(width * beta / patch_size) * patch_size

return h_bar, w_bar


def process_native_image(image, processor):
orig_width, orig_height = image.size
if 'siglip' in processor.__class__.__name__.lower():
target_height, target_width = smart_resize(
height=orig_height,
width=orig_width,
patch_size=16,
min_pixels=16*4,
)
image = image.resize((target_width, target_height), Image.BICUBIC)
image_patches = [processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]]
grid_thw = [1, target_height // 16, target_width // 16]
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
else:
target_height, target_width = smart_resize(
height=orig_height,
width=orig_width,
patch_size=14,
min_pixels=14*4,
)
image = image.resize((target_width, target_height), Image.BICUBIC)
image_patches = [processor.preprocess(image, return_tensors="pt", do_resize=False, do_center_crop=False)["pixel_values"]]
grid_thw = [1, target_height // 14, target_width // 14]
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}

def select_best_resolution(original_size, possible_resolutions):
"""
Expand Down Expand Up @@ -274,6 +343,14 @@ def process_anyres_image(image, processor, grid_pinpoints):
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
if 'siglip' in processor.__class__.__name__.lower():
image_patches = [processor.preprocess(image_padded, return_tensors="pt", do_resize=False)["pixel_values"]]
grid_thw = [1, best_resolution[1] // 16, best_resolution[0] // 16]
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
else: # FIXME: for onevision encoder
image_patches = [processor.preprocess(image_padded, return_tensors="pt", do_resize=False, do_center_crop=False)["pixel_values"]]
grid_thw = [1, best_resolution[1] // 14, best_resolution[0] // 14]
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}

patches = divide_to_patches(image_padded, processor.crop_size["height"])

Expand Down Expand Up @@ -314,23 +391,51 @@ def expand2square(pil_img, background_color):
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if len(images) == 8: #FIXME hardcoded for 8 images input as video sample
image_aspect_ratio = 'pad'

if image_aspect_ratio == "highres":
for image in images:
image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
elif image_aspect_ratio == "native":
for image in images:
image = process_native_image(image, image_processor)
new_images.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
return {'image_patchs': [img['pixel_values'] for img in new_images], 'grid_thw': [img['grid_thw'] for img in new_images]}
elif image_aspect_ratio == "crop_split":
for image in images:
image = process_highres_image_crop_split(image, model_cfg, image_processor)
new_images.append(image)
elif image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
if 'siglip' in image_processor.__class__.__name__.lower():
image_patchs = []
grid_thw = []
for image in images:
image = expand2square(image, tuple(int(0 * 255) for x in [0,0,0]))
image = image.resize((512, 512))
image_patchs.append(image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"])
grid_thw.append([1, 32, 32])
return {'image_patchs': image_patchs, 'grid_thw': torch.tensor(grid_thw)}

else: # FIXME: for onevision encoder video
image_patchs = []
grid_thw = []
for image in images:
image = expand2square(image, tuple(int(0 * 255) for x in [0,0,0]))
image = image.resize((504, 504))
image_patchs.append(image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"])
grid_thw.append([1, 36, 36])
return {'image_patchs': image_patchs, 'grid_thw': torch.tensor(grid_thw)}

image = image.resize((504, 504))
# image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]
new_images.append(image)
else:
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
Expand Down
1 change: 1 addition & 0 deletions llava_next/llava/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@


from .language_model.llava_qwen import LlavaQwenForCausalLM, LlavaQwenConfig
from .language_model.llava_qwen3 import LlavaQwen3ForCausalLM, LlavaQwen3Config
10 changes: 10 additions & 0 deletions llava_next/llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def load_from_hf(repo_id, filename, subfolder=None):
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
else:
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
elif "qwen3" in model_name.lower():
from llava.model.language_model.llava_qwen3 import LlavaQwen3Config
if overwrite_config is not None:
llava_cfg = LlavaQwen3Config.from_pretrained(model_path)
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(llava_cfg, k, v)
model = LlavaQwen3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
else:
model = LlavaQwen3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)

else:
from llava.model.language_model.llava_qwen import LlavaQwenConfig
Expand Down
122 changes: 0 additions & 122 deletions llava_next/llava/model/language_model/llava_gemma.py

This file was deleted.

Loading