Skip to content

Commit ab13a1c

Browse files
committed
support native mode
1 parent fb9dcb7 commit ab13a1c

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

llava_next/llava/train/train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
from llava import conversation as conversation_lib
4747
from llava.model import *
48-
from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token
48+
from llava.mm_utils import process_highres_image, process_native_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token
4949
from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord
5050

5151
torch.multiprocessing.set_sharing_strategy("file_system")
@@ -1086,6 +1086,11 @@ def process_image(self, image_file: Union[bytes, str], overwrite_image_aspect_ra
10861086
image_aspect_ratio = overwrite_image_aspect_ratio
10871087
if image_aspect_ratio == "highres":
10881088
image = process_highres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
1089+
elif image_aspect_ratio == "native":
1090+
image = process_native_image(image, self.data_args.image_processor)
1091+
if type(image) is dict:
1092+
grid_thw = image['grid_thw']
1093+
image = image['pixel_values']
10891094
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
10901095
image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
10911096
if type(image) is dict:
@@ -1115,19 +1120,15 @@ def expand2square(pil_img, background_color):
11151120
image = image.resize((504, 504))
11161121
grid_thw = [1,36,36]
11171122

1118-
image = processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]
1119-
# assert 1==3, f'image.size = {image['pixel_values'].shape}, processor = {processor}'
1120-
1123+
image = processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]
11211124
else:
1122-
# assert 1==3, f'processor:{processor}, image.size={image.size}, image_aspect_ratio={image_aspect_ratio} not supported yet'
11231125
if 'siglip' in str(processor).lower():
11241126
image = image.resize((512, 512))
11251127
image = processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]
11261128
else:
11271129
image = image.resize((504, 504))
11281130
image = processor.preprocess(image, return_tensors="pt", do_resize=False, do_center_crop=False)["pixel_values"]
11291131

1130-
# assert 1==3, f'image.size = {image.shape}, processor = {processor}'
11311132
if grid_thw is None:
11321133
return image, image_size, "image"
11331134
return image, image_size, "image", grid_thw

0 commit comments

Comments
 (0)