1+ from typing import Optional
2+
3+ import os
4+ import hashlib
5+ import random
6+ import time
7+ import numpy as np
8+ import torch
9+
10+ from vllm import LLM
11+ from vllm .sampling_params import SamplingParams
12+
13+ from transformers import Qwen3VLForConditionalGeneration , AutoProcessor
14+ from peft import PeftModel
15+
16+ from qwen_vl_utils import process_vision_info
17+
18+
19+ def set_seed (seed : int ):
20+ """
21+ Args:
22+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
23+ seed (`int`): The seed to set.
24+ """
25+ random .seed (seed )
26+ np .random .seed (seed )
27+ torch .manual_seed (seed )
28+ torch .cuda .manual_seed_all (seed )
29+
30+
31+ def apply_chat_template (prompt , num_images : int = 2 ):
32+ """
33+ This is used since the bug of transformers which do not support vision id https://github.com/QwenLM/Qwen2.5-VL/issues/716#issuecomment-2723316100
34+ """
35+ template = "<|im_start|>system\n You are a helpful assistant.<|im_end|>\n <|im_start|>user\n "
36+ template += "" .join ([f"<img{ i } >: <|vision_start|><|image_pad|><|vision_end|>" for i in range (1 , num_images + 1 )])
37+ template += f"{ prompt } <|im_end|>\n <|im_start|>assistant\n "
38+ return template
39+
40+
41+ class Qwen3VL ():
42+ def __init__ (
43+ self ,
44+ vlm_model ,
45+ max_model_len : int = 1536 ,
46+ tensor_parallel_size = 1 ,
47+ max_num_seqs = 32 ,
48+ max_num_batched_tokens = 1536 ,
49+ temperature : float = 0.7 ,
50+ seed : Optional [int ] = None ,
51+ lora_path : Optional [str ] = None ,
52+ cache_dir : Optional [str ] = None ,
53+ ) -> None :
54+ if lora_path :
55+ if cache_dir is None :
56+ root_dir = torch .hub .get_dir () # default: ~/.cache/torch/hub
57+
58+ lora_filename = os .path .splitext (os .path .basename (lora_path ))[0 ]
59+ lora_hash = hashlib .md5 (lora_path .encode ()).hexdigest ()[:8 ]
60+ lora_identifier = f"{ lora_filename } _{ lora_hash } "
61+
62+ cache_dir = os .path .join (root_dir , "EditScore" , f"{ os .path .basename (vlm_model )} _merged_lora_{ lora_identifier } " )
63+
64+ if not os .path .exists (cache_dir ):
65+ print (f"Merging LORA to { vlm_model } and saving to { cache_dir } " , flush = True )
66+ start_time = time .time ()
67+ model = Qwen3VLForConditionalGeneration .from_pretrained (
68+ vlm_model , torch_dtype = torch .bfloat16 , device_map = "cpu"
69+ )
70+ model = PeftModel .from_pretrained (model , lora_path )
71+ model = model .merge_and_unload ()
72+ model .save_pretrained (cache_dir )
73+
74+ processor = AutoProcessor .from_pretrained (vlm_model )
75+ processor .save_pretrained (cache_dir )
76+
77+ print (f"Merging LORA to { vlm_model } and saving to { cache_dir } took { time .time () - start_time } seconds" , flush = True )
78+ else :
79+ print (f"Skipping merging LORA, as merged model already exists in { cache_dir } " , flush = True )
80+
81+ vlm_model = cache_dir
82+
83+ self .model = LLM (
84+ model = vlm_model ,
85+ max_model_len = max_model_len ,
86+ tensor_parallel_size = tensor_parallel_size ,
87+ max_num_seqs = max_num_seqs ,
88+ max_num_batched_tokens = max_num_batched_tokens ,
89+ limit_mm_per_prompt = {"image" : 2 },
90+ enable_prefix_caching = True ,
91+ )
92+
93+ self .processor = AutoProcessor .from_pretrained (vlm_model )
94+ self .temperature = temperature
95+ self .seed = seed
96+
97+ def prepare_input (self , images , text_prompt : str = "" ):
98+ if not isinstance (images , list ):
99+ images = [images ]
100+
101+ messages = [
102+ {
103+ "role" : "user" ,
104+ "content" : [{"type" : "image" , "image" : image } for image in images ]
105+ + [{"type" : "text" , "text" : text_prompt }],
106+ }
107+ ]
108+ # text = apply_chat_template(text_prompt, num_images=len(images))
109+ text = self .processor .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
110+ image_inputs , _ = process_vision_info (messages )
111+
112+ messages = {
113+ "prompt" : text ,
114+ "multi_modal_data" : {"image" : image_inputs },
115+ }
116+ return messages
117+
118+ def inference (self , messages , seed : Optional [int ] = None ):
119+ seed = self .seed if seed is None else seed
120+ sampling_params = SamplingParams (max_tokens = 512 , temperature = self .temperature , top_p = 0.9 , top_k = 20 , seed = seed )
121+ outputs = self .model .generate (messages , sampling_params , use_tqdm = False )
122+
123+ responses = []
124+ for output in outputs :
125+ instruction = output .outputs [0 ].text .strip ()
126+ responses .append (instruction )
127+
128+ return responses [0 ]
129+
130+
131+ def batch_inference (self , messages , seed : Optional [int ] = None ):
132+ seed = self .seed if seed is None else seed
133+ sampling_params = SamplingParams (max_tokens = 512 , temperature = self .temperature , top_p = 0.9 , top_k = 20 , seed = seed )
134+ outputs = self .model .generate (messages , sampling_params , use_tqdm = False )
135+
136+ responses = []
137+ for output in outputs :
138+ instruction = output .outputs [0 ].text .strip ()
139+ responses .append (instruction )
140+
141+ return responses
0 commit comments