Skip to content

Commit 5b8b4e4

Browse files
committed
Prepare for open source release
- Clean up code for public release - Remove unused imports and deprecated code - Translate Chinese comments to English - Remove sensitive internal paths - Delete deprecated scripts and unused encoders - Add OneVision encoder and SigLip2 NaFlex support - Add Qwen3 language model support
1 parent e35a5bd commit 5b8b4e4

176 files changed

Lines changed: 580 additions & 19663 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
export OMP_NUM_THREADS=8
2+
export NCCL_IB_DISABLE=0
3+
export NCCL_IB_GID_INDEX=3
4+
export NCCL_SOCKET_IFNAME=eth0
5+
export PYTHONPATH=$(pwd)
6+
export CUDA_VISIBLE_DEVICES=6,7
7+
8+
LLM_VERSION="/vlm/pretrain_models/Qwen/Qwen2.5-1.5B-Instruct"
9+
LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
10+
VISION_MODEL_VERSION="/video_vit/pretrain_models/deepglint/onevision-encoder-large"
11+
VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
12+
export WANDB_MODE=disabled
13+
14+
15+
export PORT=29502
16+
PROMPT_VERSION="qwen_1_5"
17+
18+
BASE_RUN_NAME="./checkpoints/date1220_llavanext-llavavit_-2hid-qwen2.5-1.5b-sigvid-8nodes"
19+
echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
20+
21+
mkdir -p $BASE_RUN_NAME
22+
cp $0 $BASE_RUN_NAME/$(basename $0)
23+
24+
deepspeed --master_port 65535 \
25+
llava/train/train_mem.py \
26+
--deepspeed scripts/zero3.json \
27+
--model_name_or_path ${LLM_VERSION} \
28+
--version ${PROMPT_VERSION} \
29+
--data_path /rice_vl/llava_video_8f_imgs_1027/video_800k_llavanextsig_740k_shuffled.jsonl \
30+
--image_folder /rice_vl/llava_video_8f_imgs_1027 \
31+
--pretrain_mm_mlp_adapter="/vlm/yinxie/code/checkpoints/projectors/llavanext-llavavit_-2hid-qwen2.5-1.5b-instruct-pretrain_blip558k_plain-1220-dist/mm_projector.bin" \
32+
--mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
33+
--mm_vision_tower_lr=2e-6 \
34+
--vision_tower ${VISION_MODEL_VERSION} \
35+
--mm_projector_type mlp2x_gelu \
36+
--mm_vision_select_layer -2 \
37+
--mm_use_im_start_end False \
38+
--mm_use_im_patch_token False \
39+
--group_by_modality_length True \
40+
--image_aspect_ratio anyres \
41+
--image_grid_pinpoints "[(574, 1120), (1120, 574), (1120, 1120), (1694, 574), (574, 1694)]" \
42+
--mm_patch_merge_type flat \
43+
--bf16 True \
44+
--run_name $BASE_RUN_NAME \
45+
--output_dir $BASE_RUN_NAME \
46+
--num_train_epochs 1 \
47+
--per_device_train_batch_size 1 \
48+
--gradient_accumulation_steps 2 \
49+
--save_strategy "steps" \
50+
--save_steps 500 \
51+
--save_total_limit 20 \
52+
--learning_rate 1e-5 \
53+
--logging_steps 1 \
54+
--tf32 True \
55+
--model_max_length 321120 \
56+
--gradient_checkpointing True \
57+
--dataloader_num_workers 1 \
58+
--lazy_preprocess True \
59+
--dataloader_drop_last True \
60+
--attn_implementation flash_attention_2 | tee $BASE_RUN_NAME/train.log
61+
62+
# You can delete the sdpa attn_implementation if you want to use flash attn

llava_next/llava/mm_utils.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ def process_anyres_image(image, processor, grid_pinpoints):
274274
possible_resolutions = ast.literal_eval(grid_pinpoints)
275275
best_resolution = select_best_resolution(image.size, possible_resolutions)
276276
image_padded = resize_and_pad_image(image, best_resolution)
277+
if 'siglip' in processor.__class__.__name__.lower():
278+
image_patches = [processor.preprocess(image_padded, return_tensors="pt", do_resize=False)["pixel_values"]]
279+
grid_thw = [1, best_resolution[1] // 16, best_resolution[0] // 16]
280+
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
281+
else: # FIXME: for onevision encoder
282+
image_patches = [processor.preprocess(image_padded, return_tensors="pt", do_resize=False)["pixel_values"]]
283+
grid_thw = [1, best_resolution[1] // 14, best_resolution[0] // 14]
284+
return {'pixel_values': torch.cat(image_patches, dim=0), 'grid_thw': grid_thw}
277285

278286
patches = divide_to_patches(image_padded, processor.crop_size["height"])
279287

@@ -314,6 +322,9 @@ def expand2square(pil_img, background_color):
314322
def process_images(images, image_processor, model_cfg):
315323
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
316324
new_images = []
325+
if len(images) == 8: #FIXME hardcoded for 8 images input as video sample
326+
image_aspect_ratio = 'pad'
327+
317328
if image_aspect_ratio == "highres":
318329
for image in images:
319330
image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
@@ -322,15 +333,36 @@ def process_images(images, image_processor, model_cfg):
322333
for image in images:
323334
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
324335
new_images.append(image)
336+
return {'image_patchs': [img['pixel_values'] for img in new_images], 'grid_thw': [img['grid_thw'] for img in new_images]}
325337
elif image_aspect_ratio == "crop_split":
326338
for image in images:
327339
image = process_highres_image_crop_split(image, model_cfg, image_processor)
328340
new_images.append(image)
329341
elif image_aspect_ratio == "pad":
330-
for image in images:
331-
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
332-
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
333-
new_images.append(image)
342+
if 'siglip' in image_processor.__class__.__name__.lower():
343+
image_patchs = []
344+
grid_thw = []
345+
for image in images:
346+
image = expand2square(image, tuple(int(0 * 255) for x in [0,0,0]))
347+
image = image.resize((512, 512))
348+
image_patchs.append(image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"])
349+
grid_thw.append([1, 32, 32])
350+
return {'image_patchs': image_patchs, 'grid_thw': torch.tensor(grid_thw)}
351+
352+
else: # FIXME: for onevision encoder video
353+
image_patchs = []
354+
grid_thw = []
355+
for image in images:
356+
image = expand2square(image, tuple(int(0 * 255) for x in [0,0,0]))
357+
image = image.resize((504, 504))
358+
image_patchs.append(image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"])
359+
grid_thw.append([1, 36, 36])
360+
return {'image_patchs': image_patchs, 'grid_thw': torch.tensor(grid_thw)}
361+
362+
image = image.resize((504, 504))
363+
# image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
364+
image = image_processor.preprocess(image, return_tensors="pt", do_resize=False)["pixel_values"]
365+
new_images.append(image)
334366
else:
335367
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
336368
if all(x.shape == new_images[0].shape for x in new_images):

llava_next/llava/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717

1818

1919
from .language_model.llava_qwen import LlavaQwenForCausalLM, LlavaQwenConfig
20+
from .language_model.llava_qwen3 import LlavaQwen3ForCausalLM, LlavaQwen3Config

llava_next/llava/model/builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,16 @@ def load_from_hf(repo_id, filename, subfolder=None):
221221
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
222222
else:
223223
model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
224+
elif "qwen3" in model_name.lower():
225+
from llava.model.language_model.llava_qwen3 import LlavaQwen3Config
226+
if overwrite_config is not None:
227+
llava_cfg = LlavaQwen3Config.from_pretrained(model_path)
228+
rank0_print(f"Overwriting config with {overwrite_config}")
229+
for k, v in overwrite_config.items():
230+
setattr(llava_cfg, k, v)
231+
model = LlavaQwen3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
232+
else:
233+
model = LlavaQwen3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
224234

225235
else:
226236
from llava.model.language_model.llava_qwen import LlavaQwenConfig

llava_next/llava/model/language_model/llava_gemma.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)