@@ -252,40 +252,32 @@ def to_pytorch_with_input_ids(self, messages):
252252
253253 return dict (prompt = None , input_ids = input_ids , multimodal = preps )
254254
255- def to_pytorch_aux (self , messages , prompt , IMAGE_TOKEN , tokenizer , sequence_start ):
255+ def to_pytorch_aux (self , messages , prompt , mm_placeholder , tokenizer , sequence_start ):
256256 """Auxiliary function to pack the preprocessing results in a format
257- compatible with what is required by pytorch engine.
258-
259- Args:
260- messages(List[Dict]): the output of `preprocess`
261- prompt(str): the prompt after applying chat template
262- IMAGE_TOKEN(str): a placeholder where image tokens will be
263- inserted
264- tokenzer: the tokenizer model
265- sequence_start: starting flag of a sequence
266- """
267- # collect all preprocessing result from messages
268- preps = [x ['content' ] for x in messages if x ['role' ] == 'preprocess' ]
269- assert len (preps ) == 1
270- preps = preps [0 ]
257+ compatible with what is required by pytorch engine."""
258+ # collect all multi-modal preprocessing result from messages, keyed by 'preprocess'
259+ mm_items = [x ['content' ] for x in messages if x ['role' ] == 'preprocess' ]
260+ assert len (mm_items ) == 1
261+ mm_items = mm_items [0 ]
271262
272263 # split prompt into segments and validate data
273- segs = prompt .split (IMAGE_TOKEN )
274- assert len (segs ) == len (preps ) + 1 , (f'the number of { IMAGE_TOKEN } is not equal '
275- f'to input images, { len (segs ) - 1 } vs { len (preps )} ' )
264+ prompt_segments = prompt .split (mm_placeholder )
265+ assert len (prompt_segments ) == len (mm_items ) + 1 , (
266+ f'the number of { mm_placeholder } is not equal '
267+ f'to input multi modal items, { len (mm_items ) - 1 } vs { len (prompt_segments )} ' )
276268
277- # calculate the image token offset for each image
269+ # calculate the token offset for each multi modal item
278270 input_ids = []
279- for i , seg in enumerate ( segs ):
280- if i > 0 and i <= len ( preps ):
281- preps [ i - 1 ]. update ( offset = len (input_ids ))
282- image_tokens = preps [i - 1 ][ 'image_tokens' ]
283- assert self . image_token_id == preps [i - 1 ]['image_token_id ' ]
284- input_ids .extend ([self . image_token_id ] * image_tokens )
271+ mm_placeholder_id = tokenizer . encode ( mm_placeholder , add_special_tokens = False )[ - 1 ]
272+ for i , seg in enumerate ( prompt_segments ):
273+ if i > 0 and i <= len (mm_items ):
274+ mm_items [i - 1 ]. update ( offset = len ( input_ids ))
275+ mm_token_num = mm_items [i - 1 ]['mm_token_num ' ]
276+ input_ids .extend ([mm_placeholder_id ] * mm_token_num )
285277 token_ids = tokenizer .encode (seg , add_bos = ((i == 0 ) and sequence_start ))
286278 input_ids .extend (token_ids )
287279
288- return dict (prompt = prompt , input_ids = input_ids , multimodal = preps )
280+ return dict (prompt = prompt , input_ids = input_ids , multimodal = mm_items )
289281
290282 def to_turbomind_aux (self , messages , prompt , IMAGE_TOKEN , tokenizer , sequence_start ):
291283 """Auxiliary function to pack the forwarding results in a format
0 commit comments