|
96 | 96 | "image_emb_start": 5, |
97 | 97 | "image_emb_end": 581, |
98 | 98 | "image_emb_len": 576, |
99 | | - "double_return_token_id": 271, |
100 | 99 | } |
101 | 100 |
|
102 | 101 |
|
@@ -299,7 +298,6 @@ def _get_llama_prompt_embeds( |
299 | 298 | image_emb_len = prompt_template.get("image_emb_len", 576) |
300 | 299 | image_emb_start = prompt_template.get("image_emb_start", 5) |
301 | 300 | image_emb_end = prompt_template.get("image_emb_end", 581) |
302 | | - double_return_token_id = prompt_template.get("double_return_token_id", 271) |
303 | 301 |
|
304 | 302 | if crop_start is None: |
305 | 303 | prompt_template_input = self.tokenizer( |
@@ -351,23 +349,30 @@ def _get_llama_prompt_embeds( |
351 | 349 |
|
352 | 350 | if crop_start is not None and crop_start > 0: |
353 | 351 | text_crop_start = crop_start - 1 + image_emb_len |
354 | | - batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id) |
355 | 352 |
|
356 | | - if last_double_return_token_indices.shape[0] == 3: |
| 353 | + # Find assistant section marker using <|end_header_id|> token (works across all transformers versions) |
| 354 | + end_header_token_id = self.tokenizer.convert_tokens_to_ids("<|end_header_id|>") |
| 355 | + batch_indices, end_header_indices = torch.where(text_input_ids == end_header_token_id) |
| 356 | + |
| 357 | + # Expected: 3 <|end_header_id|> per prompt (system, user, assistant) |
| 358 | + # If truncated (only 2 found for batch_size=1), add text length as fallback position |
| 359 | + if end_header_indices.shape[0] == 2: |
357 | 360 | # in case the prompt is too long |
358 | | - last_double_return_token_indices = torch.cat( |
359 | | - (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]])) |
| 361 | + end_header_indices = torch.cat( |
| 362 | + ( |
| 363 | + end_header_indices, |
| 364 | + torch.tensor([text_input_ids.shape[-1] - 1], device=end_header_indices.device), |
| 365 | + ) |
360 | 366 | ) |
361 | | - batch_indices = torch.cat((batch_indices, torch.tensor([0]))) |
| 367 | + batch_indices = torch.cat((batch_indices, torch.tensor([0], device=batch_indices.device))) |
362 | 368 |
|
363 | | - last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[ |
364 | | - :, -1 |
365 | | - ] |
| 369 | + # Get the last <|end_header_id|> position per batch, then +1 to get the position after it |
| 370 | + assistant_start_indices = end_header_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + 1 |
366 | 371 | batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1] |
367 | | - assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4 |
368 | | - assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len |
369 | | - attention_mask_assistant_crop_start = last_double_return_token_indices - 4 |
370 | | - attention_mask_assistant_crop_end = last_double_return_token_indices |
| 372 | + assistant_crop_start = assistant_start_indices - 1 + image_emb_len - 4 |
| 373 | + assistant_crop_end = assistant_start_indices - 1 + image_emb_len |
| 374 | + attention_mask_assistant_crop_start = assistant_start_indices - 4 |
| 375 | + attention_mask_assistant_crop_end = assistant_start_indices |
371 | 376 |
|
372 | 377 | prompt_embed_list = [] |
373 | 378 | prompt_attention_mask_list = [] |
|
0 commit comments