11import base64
2+ import re
23from io import BytesIO
34from typing import List , Optional , Tuple , Union
45
@@ -48,6 +49,9 @@ def __init__(
4849 use_custom_video_loader : Optional [bool ] = False ,
4950 fps : Optional [float ] = None , # Only applicable if use_custom_video_loader is True
5051 max_image_size : Optional [int ] = None , # Only applicable if use_custom_video_loader is True
52+ system_prompt : Optional [str ] = "You are a helpful assistant." ,
53+ interleave_visuals : Optional [bool ] = False ,
54+ reasoning_prompt : Optional [str ] = None ,
5155 ** kwargs ,
5256 ) -> None :
5357 super ().__init__ ()
@@ -66,12 +70,9 @@ def __init__(
6670 if accelerator .num_processes > 1 :
6771 self ._device = torch .device (f"cuda:{ accelerator .local_process_index } " )
6872 self .device_map = f"cuda:{ accelerator .local_process_index } "
69- elif accelerator .num_processes == 1 and device_map == "auto" :
70- self ._device = torch .device (device )
71- self .device_map = device_map
7273 else :
73- self ._device = torch .device (f"cuda: { accelerator . local_process_index } " )
74- self .device_map = f"cuda: { accelerator . local_process_index } "
74+ self ._device = torch .device (device )
75+ self .device_map = device_map if device_map else device
7576
7677 if use_flash_attention_2 :
7778 self ._model = Qwen2_5_VLForConditionalGeneration .from_pretrained (
@@ -85,10 +86,18 @@ def __init__(
8586 self .max_pixels = max_pixels
8687 self .min_pixels = min_pixels
8788 self .max_num_frames = max_num_frames
88- self .processor = AutoProcessor .from_pretrained (pretrained , max_pixels = max_pixels , min_pixels = min_pixels , padding_side = "left" )
89- self ._tokenizer = AutoTokenizer .from_pretrained (pretrained , padding_side = "left" )
89+
90+ if reasoning_prompt :
91+ self .reasoning_prompt = reasoning_prompt .replace ("\\ n" , "\n " )
92+ else :
93+ self .reasoning_prompt = None
94+ self .processor = AutoProcessor .from_pretrained (pretrained , max_pixels = max_pixels , min_pixels = min_pixels )
95+ self ._tokenizer = AutoTokenizer .from_pretrained (pretrained )
96+ self .system_prompt = system_prompt
97+ self .interleave_visuals = interleave_visuals
9098
9199 self ._config = self .model .config
100+ self ._max_length = kwargs .get ("max_length" , 2048 )
92101 self .batch_size_per_gpu = int (batch_size )
93102 self .use_cache = use_cache
94103
@@ -184,8 +193,11 @@ def _collate(x):
184193 contexts , all_gen_kwargs , doc_to_visual , doc_id , task , split = zip (* chunk )
185194 task = task [0 ]
186195 split = split [0 ]
187- visuals = [doc_to_visual [0 ](self .task_dict [task ][split ][ids ]) for ids in doc_id ]
188- visuals = self .flatten (visuals )
196+ visual_list = [doc_to_visual [0 ](self .task_dict [task ][split ][ids ]) for ids in doc_id ]
197+ if None in visual_list :
198+ visual_list = []
199+ else :
200+ visual_list = self .flatten (visual_list )
189201
190202 gen_kwargs = all_gen_kwargs [0 ]
191203
@@ -200,112 +212,116 @@ def _collate(x):
200212 elif not isinstance (until , list ):
201213 raise ValueError (f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got { type (until )} " )
202214
203- # if isinstance(contexts, tuple):
204- # contexts = list(contexts)
205-
206- # for i in range(len(contexts)):
207- # for j in range(32):
208- # if f"<image {j}>" in contexts[i]:
209- # contexts[i] = contexts[i].replace(f"<image {j}>", "<image>")
210- # if f"\\<image {j}\\>" in contexts[i]:
211- # contexts[i] = contexts[i].replace(f"\\<image {j}\\>", "<image>")
212- # if "<image>" in contexts[i]:
213- # contexts[i] = contexts[i].replace("<image>", "")
214- # print(contexts[i])
215-
216- # for i in range(len(contexts)):
217- # if "<image>" in contexts[i]:
218- # contexts[i] = contexts[i].replace("<image>", "")
219-
220- messages = []
221- processed_visuals = []
215+ if isinstance (contexts , tuple ):
216+ contexts = list (contexts )
217+
218+ for i in range (len (contexts )):
219+ if "<image>" in contexts [i ]:
220+ contexts [i ] = contexts [i ].replace ("<image>" , "" )
221+
222+ batched_messages = []
222223 for i , context in enumerate (contexts ):
223- # context += "\nPlease think step by step."
224- # if "<image>" in context:
225- # context = context.replace("<image>", "")
224+ if "<image>" in context :
225+ context = context .replace ("<image>" , "" )
226226
227- message = [{"role" : "system" , "content" : "You are a helpful assistant." }]
227+ message = [{"role" : "system" , "content" : self .system_prompt }]
228+ if self .reasoning_prompt :
229+ context = context .strip () + self .reasoning_prompt
230+ contexts [i ] = context
228231
229- if len ( visuals ) > 0 :
230- visual = visuals [ i ] if i < len ( visuals ) else None
232+ processed_visuals = []
233+ for visual in visual_list :
231234 if isinstance (visual , str ) and visual .endswith ((".mp4" , ".avi" , ".mov" )): # Video file
232- if self .use_custom_video_loader :
233- visual = read_video_pyav_base64 (visual , num_frm = self .max_num_frames , fps = self .fps , img_format = "JPEG" , max_image_size = self .max_image_size )
234- image_contents = list (map (lambda x : f"data:image/jpeg;base64,{ x } " , visual ))
235- message .append ({"role" : "user" , "content" : [{"type" : "video" , "video" : image_contents }, {"type" : "text" , "text" : context }]})
236- else :
237- vr = decord .VideoReader (visual )
238- first_frame = vr [0 ].asnumpy ()
239- height , width = first_frame .shape [:2 ]
240- # max_pixels = height * width
241- message .append ({"role" : "user" , "content" : [{"type" : "video" , "video" : visual , "max_pixels" : 360 * 420 }, {"type" : "text" , "text" : context }]})
242- elif isinstance (visual , Image .Image ): # Single image
235+ vr = decord .VideoReader (visual )
236+ first_frame = vr [0 ].asnumpy ()
237+ height , width = first_frame .shape [:2 ]
238+ # max_pixels = height * width
239+ processed_visuals .append ({"type" : "video" , "video" : visual , "max_pixels" : self .max_pixels , "min_pixels" : self .min_pixels })
240+ elif isinstance (visual , Image .Image ): # Handle both single and multiple images
243241 base64_image = visual .convert ("RGB" )
244242 buffer = BytesIO ()
245243 base64_image .save (buffer , format = "JPEG" )
246244 base64_bytes = base64 .b64encode (buffer .getvalue ())
247245 base64_string = base64_bytes .decode ("utf-8" )
248- message .append ({"role" : "user" , "content" : [{"type" : "image" , "image" : f"data:image/jpeg;base64,{ base64_string } " }, {"type" : "text" , "text" : context }]})
249- elif isinstance (visual , (list , tuple )) and all (isinstance (v , Image .Image ) for v in visual ): # Multiple images
250- image_content = []
251- for v in visual :
252- base64_image = v .convert ("RGB" )
253- buffer = BytesIO ()
254- base64_image .save (buffer , format = "JPEG" )
255- base64_bytes = base64 .b64encode (buffer .getvalue ())
256- base64_string = base64_bytes .decode ("utf-8" )
257- image_content .append ({"type" : "image" , "image" : f"data:image/jpeg;base64,{ base64_string } " })
258- message .append ({"role" : "user" , "content" : image_content + [{"type" : "text" , "text" : context }]})
259- else :
260- message .append ({"role" : "user" , "content" : [{"type" : "text" , "text" : context }]})
261- else :
262- message .append ({"role" : "user" , "content" : [{"type" : "text" , "text" : context }]})
263-
264- messages .append (message )
265- # print("message")
266-
267- text = self .processor .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
268- image_inputs , video_inputs = process_vision_info (messages )
269- inputs = self .processor (
270- text = text ,
271- images = image_inputs ,
272- videos = video_inputs ,
273- # fps=self.fps,
274- padding = True ,
275- return_tensors = "pt" ,
276- )
246+ processed_visuals .append ({"type" : "image" , "image" : f"data:image/jpeg;base64,{ base64_string } " , "max_pixels" : self .max_pixels , "min_pixels" : self .min_pixels })
247+
248+ if self .interleave_visuals is False :
249+ message .append (
250+ {
251+ "role" : "user" ,
252+ "content" : processed_visuals + [{"type" : "text" , "text" : context }],
253+ }
254+ )
255+ else : # currently support find <image x> in the context
256+ image_placeholders = re .findall (r"<image \d+>" , context )
257+ content_parts = []
258+ text_parts = re .split (r"<image \d+>" , context )
259+ if text_parts [0 ]:
260+ content_parts .append ({"type" : "text" , "text" : text_parts [0 ]})
261+
262+ for i , placeholder in enumerate (image_placeholders ):
263+ img_idx = int (re .search (r"<image (\d+)>" , placeholder ).group (1 )) - 1
264+ image_idx = min (img_idx , len (processed_visuals ) - 1 ) if processed_visuals else 0
265+ if processed_visuals and image_idx < len (processed_visuals ):
266+ content_parts .append (processed_visuals [image_idx ])
267+ if i + 1 < len (text_parts ) and text_parts [i + 1 ]:
268+ content_parts .append ({"type" : "text" , "text" : text_parts [i + 1 ]})
269+
270+ message .append (
271+ {
272+ "role" : "user" ,
273+ "content" : content_parts ,
274+ }
275+ )
276+
277+ batched_messages .append (message )
278+
279+ texts = [self .processor .apply_chat_template (msg , tokenize = False , add_generation_prompt = True ) for msg in batched_messages ]
280+ image_inputs , video_inputs = process_vision_info (batched_messages )
281+ if video_inputs is not None :
282+ total_frames = video_inputs [0 ].shape [0 ]
283+ indices = np .linspace (0 , total_frames - 1 , self .max_num_frames , dtype = int )
284+ # Append the last frame index if not already included
285+ if total_frames - 1 not in indices :
286+ indices = np .append (indices , total_frames - 1 )
287+ video_inputs [0 ] = video_inputs [0 ][indices ]
288+ inputs = self .processor (text = texts , images = image_inputs , videos = video_inputs , padding = True , return_tensors = "pt" )
277289
278290 if self .device_map == "auto" :
279291 inputs = inputs .to ("cuda" )
280292 else :
281293 inputs = inputs .to (self .device )
282294
283- if "max_new_tokens" not in gen_kwargs :
284- gen_kwargs ["max_new_tokens" ] = 4096
285- if "temperature" not in gen_kwargs :
286- gen_kwargs ["temperature" ] = 0
287- if "top_p" not in gen_kwargs :
288- gen_kwargs ["top_p" ] = None
289- if "num_beams" not in gen_kwargs :
290- gen_kwargs ["num_beams" ] = 1
295+ # Set default generation kwargs
296+ default_gen_kwargs = {
297+ "max_new_tokens" : 128 ,
298+ "temperature" : 0.0 , # Set to 0 for greedy default
299+ "top_p" : None ,
300+ "num_beams" : 1 ,
301+ }
302+ # Update with provided kwargs
303+ current_gen_kwargs = {** default_gen_kwargs , ** gen_kwargs }
291304
292305 pad_token_id = self .tokenizer .pad_token_id
293306
294307 cont = self .model .generate (
295308 ** inputs ,
296309 eos_token_id = self .tokenizer .eos_token_id ,
297310 pad_token_id = pad_token_id ,
298- do_sample = True if gen_kwargs ["temperature" ] > 0 else False ,
299- temperature = gen_kwargs ["temperature" ],
300- top_p = gen_kwargs ["top_p" ],
301- num_beams = gen_kwargs ["num_beams" ],
302- max_new_tokens = gen_kwargs ["max_new_tokens" ],
311+ do_sample = True if current_gen_kwargs ["temperature" ] > 0 else False ,
312+ temperature = current_gen_kwargs ["temperature" ],
313+ top_p = current_gen_kwargs ["top_p" ],
314+ num_beams = current_gen_kwargs ["num_beams" ],
315+ max_new_tokens = current_gen_kwargs ["max_new_tokens" ],
303316 use_cache = self .use_cache ,
304317 )
305318
306319 generated_ids_trimmed = [out_ids [len (in_ids ) :] for in_ids , out_ids in zip (inputs .input_ids , cont )]
307320 answers = self .processor .batch_decode (generated_ids_trimmed , skip_special_tokens = True , clean_up_tokenization_spaces = False )
308321 for i , ans in enumerate (answers ):
322+ for term in until :
323+ if len (term ) > 0 :
324+ ans = ans .split (term )[0 ]
309325 answers [i ] = ans
310326
311327 for ans , context in zip (answers , contexts ):
0 commit comments