Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
"image_emb_start": 5,
"image_emb_end": 581,
"image_emb_len": 576,
"double_return_token_id": 271,
}


Expand Down Expand Up @@ -299,7 +298,6 @@ def _get_llama_prompt_embeds(
image_emb_len = prompt_template.get("image_emb_len", 576)
image_emb_start = prompt_template.get("image_emb_start", 5)
image_emb_end = prompt_template.get("image_emb_end", 581)
double_return_token_id = prompt_template.get("double_return_token_id", 271)

if crop_start is None:
prompt_template_input = self.tokenizer(
Expand Down Expand Up @@ -351,23 +349,30 @@ def _get_llama_prompt_embeds(

if crop_start is not None and crop_start > 0:
text_crop_start = crop_start - 1 + image_emb_len
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)

if last_double_return_token_indices.shape[0] == 3:
# Find assistant section marker using <|end_header_id|> token (works across all transformers versions)
end_header_token_id = self.tokenizer.convert_tokens_to_ids("<|end_header_id|>")
batch_indices, end_header_indices = torch.where(text_input_ids == end_header_token_id)

# Expected: 3 <|end_header_id|> per prompt (system, user, assistant)
# If truncated (only 2 found for batch_size=1), add text length as fallback position
if end_header_indices.shape[0] == 2:
# in case the prompt is too long
last_double_return_token_indices = torch.cat(
(last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]]))
end_header_indices = torch.cat(
(
end_header_indices,
torch.tensor([text_input_ids.shape[-1] - 1], device=end_header_indices.device),
)
)
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
batch_indices = torch.cat((batch_indices, torch.tensor([0], device=batch_indices.device)))

last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[
:, -1
]
# Get the last <|end_header_id|> position per batch, then +1 to get the position after it
assistant_start_indices = end_header_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + 1
batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1]
assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4
assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len
attention_mask_assistant_crop_start = last_double_return_token_indices - 4
attention_mask_assistant_crop_end = last_double_return_token_indices
assistant_crop_start = assistant_start_indices - 1 + image_emb_len - 4
assistant_crop_end = assistant_start_indices - 1 + image_emb_len
attention_mask_assistant_crop_start = assistant_start_indices - 4
attention_mask_assistant_crop_end = assistant_start_indices

prompt_embed_list = []
prompt_attention_mask_list = []
Expand Down
1 change: 0 additions & 1 deletion tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def get_dummy_inputs(self, device, seed=0):
"image_emb_len": 49,
"image_emb_start": 5,
"image_emb_end": 54,
"double_return_token_id": 0,
},
"generator": generator,
"num_inference_steps": 2,
Expand Down
Loading