Skip to content

Commit 10ba0be

Browse files
kaixuanliuDN6
andauthored
Fix IndexError in HunyuanVideo I2V pipeline (#13244)
* add fallback logic for Hunyuan pipeline to make it compatible with latest transformers Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * use the last <|end_header_id|> token position + 1 as the assistant section marker Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix format Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update variant name Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent b8ec64c commit 10ba0be

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@
9696
"image_emb_start": 5,
9797
"image_emb_end": 581,
9898
"image_emb_len": 576,
99-
"double_return_token_id": 271,
10099
}
101100

102101

@@ -299,7 +298,6 @@ def _get_llama_prompt_embeds(
299298
image_emb_len = prompt_template.get("image_emb_len", 576)
300299
image_emb_start = prompt_template.get("image_emb_start", 5)
301300
image_emb_end = prompt_template.get("image_emb_end", 581)
302-
double_return_token_id = prompt_template.get("double_return_token_id", 271)
303301

304302
if crop_start is None:
305303
prompt_template_input = self.tokenizer(
@@ -351,23 +349,30 @@ def _get_llama_prompt_embeds(
351349

352350
if crop_start is not None and crop_start > 0:
353351
text_crop_start = crop_start - 1 + image_emb_len
354-
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
355352

356-
if last_double_return_token_indices.shape[0] == 3:
353+
# Find assistant section marker using <|end_header_id|> token (works across all transformers versions)
354+
end_header_token_id = self.tokenizer.convert_tokens_to_ids("<|end_header_id|>")
355+
batch_indices, end_header_indices = torch.where(text_input_ids == end_header_token_id)
356+
357+
# Expected: 3 <|end_header_id|> per prompt (system, user, assistant)
358+
# If truncated (only 2 found for batch_size=1), add text length as fallback position
359+
if end_header_indices.shape[0] == 2:
357360
# in case the prompt is too long
358-
last_double_return_token_indices = torch.cat(
359-
(last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]]))
361+
end_header_indices = torch.cat(
362+
(
363+
end_header_indices,
364+
torch.tensor([text_input_ids.shape[-1] - 1], device=end_header_indices.device),
365+
)
360366
)
361-
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
367+
batch_indices = torch.cat((batch_indices, torch.tensor([0], device=batch_indices.device)))
362368

363-
last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[
364-
:, -1
365-
]
369+
# Get the last <|end_header_id|> position per batch, then +1 to get the position after it
370+
assistant_start_indices = end_header_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + 1
366371
batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1]
367-
assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4
368-
assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len
369-
attention_mask_assistant_crop_start = last_double_return_token_indices - 4
370-
attention_mask_assistant_crop_end = last_double_return_token_indices
372+
assistant_crop_start = assistant_start_indices - 1 + image_emb_len - 4
373+
assistant_crop_end = assistant_start_indices - 1 + image_emb_len
374+
attention_mask_assistant_crop_start = assistant_start_indices - 4
375+
attention_mask_assistant_crop_end = assistant_start_indices
371376

372377
prompt_embed_list = []
373378
prompt_attention_mask_list = []

tests/pipelines/hunyuan_video/test_hunyuan_image2video.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def get_dummy_inputs(self, device, seed=0):
207207
"image_emb_len": 49,
208208
"image_emb_start": 5,
209209
"image_emb_end": 54,
210-
"double_return_token_id": 0,
211210
},
212211
"generator": generator,
213212
"num_inference_steps": 2,

0 commit comments

Comments
 (0)