diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index c599488c2379..c7d43424c344 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -96,7 +96,6 @@ "image_emb_start": 5, "image_emb_end": 581, "image_emb_len": 576, - "double_return_token_id": 271, } @@ -299,7 +298,6 @@ def _get_llama_prompt_embeds( image_emb_len = prompt_template.get("image_emb_len", 576) image_emb_start = prompt_template.get("image_emb_start", 5) image_emb_end = prompt_template.get("image_emb_end", 581) - double_return_token_id = prompt_template.get("double_return_token_id", 271) if crop_start is None: prompt_template_input = self.tokenizer( @@ -351,23 +349,30 @@ def _get_llama_prompt_embeds( if crop_start is not None and crop_start > 0: text_crop_start = crop_start - 1 + image_emb_len - batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id) - if last_double_return_token_indices.shape[0] == 3: + # Find assistant section marker using <|end_header_id|> token (works across all transformers versions) + end_header_token_id = self.tokenizer.convert_tokens_to_ids("<|end_header_id|>") + batch_indices, end_header_indices = torch.where(text_input_ids == end_header_token_id) + + # Expected: 3 <|end_header_id|> per prompt (system, user, assistant) + # If truncated (only 2 found for batch_size=1), add text length as fallback position + if end_header_indices.shape[0] == 2: # in case the prompt is too long - last_double_return_token_indices = torch.cat( - (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]])) + end_header_indices = torch.cat( + ( + end_header_indices, + torch.tensor([text_input_ids.shape[-1] - 1], device=end_header_indices.device), + ) ) - batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + batch_indices = torch.cat((batch_indices, torch.tensor([0], device=batch_indices.device))) - last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[ - :, -1 - ] + # Get the last <|end_header_id|> position per batch, then +1 to get the position after it + assistant_start_indices = end_header_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + 1 batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1] - assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4 - assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len - attention_mask_assistant_crop_start = last_double_return_token_indices - 4 - attention_mask_assistant_crop_end = last_double_return_token_indices + assistant_crop_start = assistant_start_indices - 1 + image_emb_len - 4 + assistant_crop_end = assistant_start_indices - 1 + image_emb_len + attention_mask_assistant_crop_start = assistant_start_indices - 4 + attention_mask_assistant_crop_end = assistant_start_indices prompt_embed_list = [] prompt_attention_mask_list = [] diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py index 1732ac06d1f1..4a0129e0826f 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py @@ -207,7 +207,6 @@ def get_dummy_inputs(self, device, seed=0): "image_emb_len": 49, "image_emb_start": 5, "image_emb_end": 54, - "double_return_token_id": 0, }, "generator": generator, "num_inference_steps": 2,