Skip to content

hunyuan_video1_5 model/pipeline review #13582

@hlky

Description

@hlky

hunyuan_video1_5 model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules in .ai/review-rules.md and referenced rule files.

Local checks:

  • Top-level imports for the HunyuanVideo 1.5 model, standard pipelines, and modular pipeline classes passed.
  • tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py passed with 11 passed, 3 skipped.
  • tests/models/transformers/test_models_transformer_hunyuan_1_5.py and tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py could not collect in this local .venv because the installed Torch build lacks torch._C._distributed_c10d, which is imported by shared training test utilities.

Duplicate search performed before filing:

  • Searched Issues and PRs for hunyuan_video1_5, HunyuanVideo15*, affected class/function names, num_videos_per_prompt, VAE batch attention mask, modular glyph regex, slow/image2video coverage, and doc snippet failures.
  • Known duplicates are called out in the relevant items below.

Issue 1: VAE attention mask breaks batch > 1

Affected code:

attention_mask = self.prepare_causal_attention_mask(
frames, height * width, query.dtype, query.device, batch_size=batch_size
)
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)

Problem:
HunyuanVideo15AttnBlock passes a (batch, seq, seq) mask to 4D SDPA queries shaped (batch, 1, seq, channels). For batch > 1, PyTorch broadcasts the mask incorrectly and raises.

Duplicate:
This is already covered by PR #13133.

Impact:
Batch generation, num_videos_per_prompt > 1, and the skipped modular batch tests fail at VAE encode/decode time.

Reproduction:

import torch
from diffusers import AutoencoderKLHunyuanVideo15

vae = AutoencoderKLHunyuanVideo15(
    in_channels=3,
    out_channels=3,
    latent_channels=4,
    block_out_channels=(16, 16),
    layers_per_block=1,
    spatial_compression_ratio=4,
    temporal_compression_ratio=2,
    downsample_match_channel=False,
    upsample_match_channel=False,
).eval()

with torch.no_grad():
    vae(torch.randn(2, 3, 9, 16, 16), return_dict=False)

Relevant precedent:
PR #13133 applies the same fix and cites the original implementation.

Suggested fix:

x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask.unsqueeze(1))

Issue 2: T2V num_videos_per_prompt > 1 does not repeat zero image embeddings

Affected code:

image_embeds = torch.zeros(
batch_size,
self.vision_num_semantic_tokens,
self.vision_states_dim,
dtype=self.transformer.dtype,
device=device,
)

block_state.image_embeds = torch.zeros(
block_state.batch_size,
components.vision_num_semantic_tokens,
components.vision_states_dim,
dtype=dtype,
device=device,
)

Problem:
The T2V path expands latents and text embeddings to batch_size * num_videos_per_prompt, but zero image_embeds are allocated with only batch_size. The transformer then receives mismatched batch dimensions.

Duplicate:
No duplicate issue/PR found.

Impact:
num_videos_per_prompt > 1 fails before denoising finishes. The modular test suite currently skips this exact path.

Reproduction:

import torch
from diffusers import HunyuanVideo15Transformer3DModel

model = HunyuanVideo15Transformer3DModel(
    in_channels=9, out_channels=4, num_attention_heads=2, attention_head_dim=8,
    num_layers=1, num_refiner_layers=1, mlp_ratio=2.0, patch_size=1, patch_size_t=1,
    text_embed_dim=16, text_embed_2_dim=32, image_embed_dim=12,
    rope_axes_dim=(2, 2, 4), target_size=16, task_type="t2v",
).eval()

with torch.no_grad():
    model(
        hidden_states=torch.randn(2, 9, 1, 2, 4),
        timestep=torch.ones(2),
        encoder_hidden_states=torch.randn(2, 6, 16),
        encoder_attention_mask=torch.ones(2, 6),
        encoder_hidden_states_2=torch.randn(2, 4, 32),
        encoder_attention_mask_2=torch.ones(2, 4),
        image_embeds=torch.zeros(1, 3, 12),
        return_dict=False,
    )

Relevant precedent:

image_embeds = self.encode_image(
image=image,
batch_size=batch_size * num_videos_per_prompt,
device=device,
dtype=self.transformer.dtype,
)

Suggested fix:

image_embeds = torch.zeros(
    batch_size * num_videos_per_prompt,
    self.vision_num_semantic_tokens,
    self.vision_states_dim,
    dtype=self.transformer.dtype,
    device=device,
)

For the modular path, use the effective batch size when constructing zero-filled image_embeds in HunyuanVideo15PrepareLatentsStep.

Issue 3: Modular glyph extraction drops curly-quoted text

Affected code:

def extract_glyph_texts(prompt):
pattern = r"\"(.*?)\"|\"(.*?)\""
matches = re.findall(pattern, prompt)

Problem:
The modular regex checks straight quotes twice and omits the curly quote branch used by the standard pipeline.

Duplicate:
This is already covered by PR #13523.

Impact:
Prompts using “...” lose glyph text conditioning in the modular pipeline, so text rendering behavior diverges from the standard pipeline.

Reproduction:

from diffusers.modular_pipelines.hunyuan_video1_5.encoders import extract_glyph_texts as modular_extract
from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import extract_glyph_texts as standard_extract

prompt = 'A sign says “HELLO”.'
assert standard_extract(prompt) == 'Text "HELLO". '
assert modular_extract(prompt) is None

Relevant precedent:

pattern = r"\"(.*?)\"|“(.*?)”"
matches = re.findall(pattern, prompt)
result = [match[0] or match[1] for match in matches]
result = list(dict.fromkeys(result)) if len(result) > 1 else result
if result:
formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". "
else:
formatted_result = None
return formatted_result

Suggested fix:

pattern = r"\"(.*?)\"|“(.*?)”"

Issue 4: Modular blocks bypass declared IO contracts

Affected code:

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam.template("prompt_embeds"),
OutputParam.template("prompt_embeds_mask"),
OutputParam.template("negative_prompt_embeds"),
OutputParam.template("negative_prompt_embeds_mask"),
OutputParam(
"prompt_embeds_2",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="ByT5 glyph-text embeddings used as a second conditioning stream for the transformer.",
),
OutputParam(
"prompt_embeds_mask_2",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Attention mask for the ByT5 glyph-text embeddings.",
),
OutputParam(
"negative_prompt_embeds_2",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="ByT5 glyph-text negative embeddings for classifier-free guidance.",
),
OutputParam(
"negative_prompt_embeds_mask_2",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Attention mask for the ByT5 glyph-text negative embeddings.",
),
]

state.set("batch_size", batch_size)
self.set_block_state(state, block_state)
return components, state
def retrieve_latents(
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class HunyuanVideo15VaeEncoderStep(ModularPipelineBlocks):
model_name = "hunyuan-video-1.5"
@property
def description(self) -> str:
return "VAE Encoder step that encodes an input image into latent space for image-to-video generation"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLHunyuanVideo15),
ComponentSpec(
"video_processor",
HunyuanVideo15ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("image", required=True),
InputParam.template("height"),
InputParam.template("width"),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="Encoded image latents from the VAE encoder",
),
OutputParam("height", type_hint=int, description="Target height resolved from image"),
OutputParam("width", type_hint=int, description="Target width resolved from image"),
]
@torch.no_grad()
def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
image = block_state.image
height = block_state.height
width = block_state.width
if height is None or width is None:
height, width = components.video_processor.calculate_default_height_width(
height=image.size[1], width=image.size[0], target_size=components.target_size
)
image = components.video_processor.resize(image, height=height, width=width, resize_mode="crop")
vae_dtype = components.vae.dtype
image_tensor = components.video_processor.preprocess(image, height=height, width=width).to(
device=device, dtype=vae_dtype
)
image_tensor = image_tensor.unsqueeze(2)
image_latents = retrieve_latents(components.vae.encode(image_tensor), sample_mode="argmax")
image_latents = image_latents * components.vae.config.scaling_factor
block_state.image_latents = image_latents
block_state.height = height
block_state.width = width
state.set("image", image)

Problem:
HunyuanVideo15TextEncoderStep writes batch_size with state.set() without declaring it as an output, and HunyuanVideo15VaeEncoderStep mutates image the same way. The modular rules require writing through declared outputs via block state.

Duplicate:
No duplicate issue/PR found.

Impact:
The generated block interface/docs are incomplete, standalone block reuse is harder, and downstream dependencies can be hidden from modular validation.

Reproduction:

import inspect
from diffusers.modular_pipelines.hunyuan_video1_5.encoders import (
    HunyuanVideo15TextEncoderStep,
    HunyuanVideo15VaeEncoderStep,
)

text_step = HunyuanVideo15TextEncoderStep()
vae_step = HunyuanVideo15VaeEncoderStep()

assert "batch_size" not in [p.name for p in text_step.intermediate_outputs]
assert 'state.set("batch_size"' in inspect.getsource(text_step.__call__)
assert "image" not in [p.name for p in vae_step.intermediate_outputs]
assert 'state.set("image"' in inspect.getsource(vae_step.__call__)

Relevant precedent:
The modular review rule says not to call state.set() inside a block and to declare every written output.

Suggested fix:
Declare OutputParam("batch_size", type_hint=int) and set block_state.batch_size = batch_size. Avoid mutating image, or declare it as an output if the resized/cropped image is part of the public block contract.

Issue 5: Missing VAE/I2V/slow coverage and skipped batch coverage

Affected code:

class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HunyuanVideo15Pipeline
params = frozenset(
[
"prompt",
"negative_prompt",
"height",
"width",
"prompt_embeds",
"prompt_embeds_mask",
"negative_prompt_embeds",
"negative_prompt_embeds_mask",
"prompt_embeds_2",
"prompt_embeds_mask_2",
"negative_prompt_embeds_2",
"negative_prompt_embeds_mask_2",
]
)
batch_params = ["prompt", "negative_prompt"]
required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
test_attention_slicing = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = False
supports_dduf = False
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = HunyuanVideo15Transformer3DModel(
in_channels=9,
out_channels=4,
num_attention_heads=2,
attention_head_dim=8,
num_layers=num_layers,
num_refiner_layers=1,
mlp_ratio=2.0,
patch_size=1,
patch_size_t=1,
text_embed_dim=16,
text_embed_2_dim=32,
image_embed_dim=12,
rope_axes_dim=(2, 2, 4),
target_size=16,
task_type="t2v",
)
torch.manual_seed(0)
vae = AutoencoderKLHunyuanVideo15(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(16, 16),
layers_per_block=1,
spatial_compression_ratio=4,
temporal_compression_ratio=2,
downsample_match_channel=False,
upsample_match_channel=False,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
torch.manual_seed(0)
qwen_config = Qwen2_5_VLTextConfig(
**{
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
}
)
text_encoder = Qwen2_5_VLTextModel(qwen_config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
torch.manual_seed(0)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder_2 = T5EncoderModel(config)
tokenizer_2 = ByT5Tokenizer()
guider = ClassifierFreeGuidance(guidance_scale=1.0)
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder.eval(),
"text_encoder_2": text_encoder_2.eval(),
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"guider": guider,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "monkey",
"generator": generator,
"num_inference_steps": 2,
"height": 16,
"width": 16,
"num_frames": 9,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
result = pipe(**inputs)
video = result.frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
# fmt: off
expected_slice = torch.tensor([0.4296, 0.5549, 0.3088, 0.9115, 0.5049, 0.7926, 0.5549, 0.8618, 0.5091, 0.5075, 0.7117, 0.5292, 0.7053, 0.4864, 0.5206, 0.3878])
# fmt: on
self.assertTrue(
torch.abs(generated_slice - expected_slice).max() < 1e-3,
f"output_slice: {generated_slice}, expected_slice: {expected_slice}",
)
@unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Needs to be revisited.")
def test_inference_batch_consistent(self):
super().test_inference_batch_consistent()
@unittest.skip("Needs to be revisited.")
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical()

pretrained_model_name_or_path = "akshan-main/tiny-hunyuanvideo1_5-modular-pipe"
params = frozenset(["prompt", "height", "width", "num_frames"])
batch_params = frozenset(["prompt"])
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
expected_workflow_blocks = HUNYUANVIDEO15_WORKFLOWS
output_name = "videos"
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"num_frames": 9,
"output_type": "pt",
}
return inputs
@pytest.mark.skip(reason="num_videos_per_prompt")
def test_num_images_per_prompt(self):
pass
@pytest.mark.skip(reason="VAE causal attention mask does not support batch>1 decode")
def test_inference_batch_consistent(self):
pass
@pytest.mark.skip(reason="VAE causal attention mask does not support batch>1 decode")
def test_inference_batch_single_identical(self):

from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo

Problem:
There is no fast test class for HunyuanVideo15ImageToVideoPipeline, no AutoencoderKLHunyuanVideo15 model test, and no slow tests for the HunyuanVideo 1.5 family. The modular test uses a contributor repo (akshan-main/...) instead of hf-internal-testing/... and skips num_videos_per_prompt plus batch consistency.

Duplicate:
No duplicate issue/PR found.

Impact:
The failures above are not caught in CI, and slow-test coverage is absent.

Reproduction:

from pathlib import Path

checks = {
    "i2v_pipeline_fast": "HunyuanVideo15ImageToVideoPipeline" in Path("tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py").read_text(),
    "vae15_model_fast": "AutoencoderKLHunyuanVideo15Tests" in Path("tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py").read_text(),
    "target_slow": "@slow" in Path("tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py").read_text(),
    "internal_tiny_model": "hf-internal-testing/" in Path("tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py").read_text(),
}
print(checks)
assert checks == {key: True for key in checks}

Relevant precedent:
Existing video families such as tests/pipelines/hunyuan_video/test_hunyuan_image2video.py keep I2V coverage separate from T2V coverage.

Suggested fix:
Add fast tests for HunyuanVideo15ImageToVideoPipeline and AutoencoderKLHunyuanVideo15, add slow tests for published T2V/I2V checkpoints, move the modular tiny fixture to hf-internal-testing/, and unskip batch/num_videos_per_prompt tests after Issues 1 and 2 are fixed.

Issue 6: HunyuanVideo15 docs contain broken loading snippets

Affected code:

```python
from diffusers import HunyuanVideo15Transformer3DModel
transformer = HunyuanVideo15Transformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v" subfolder="transformer", torch_dtype=torch.bfloat16)
```

```py
import torch
from diffusers import AutoModel, HunyuanVideo15Pipeline
from diffusers.utils import export_to_video
pipeline = HunyuanVideo15Pipeline.from_pretrained(
"HunyuanVideo-1.5-Diffusers-480p_t2v",
torch_dtype=torch.bfloat16,

Problem:
The transformer snippet is a Python syntax error because it misses a comma before subfolder. The pipeline snippet also imports unused AutoModel and uses an unqualified model id despite the surrounding text saying the examples use hunyuanvideo-community.

Duplicate:
No duplicate issue/PR found.

Impact:
Users copying the docs hit immediate syntax or loading errors.

Reproduction:

import ast

snippet = '''
from diffusers import HunyuanVideo15Transformer3DModel

transformer = HunyuanVideo15Transformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v" subfolder="transformer", torch_dtype=torch.bfloat16)
'''
ast.parse(snippet)

Relevant precedent:

```python
from diffusers import AutoencoderKLHunyuanVideo15
vae = AutoencoderKLHunyuanVideo15.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v", subfolder="vae", torch_dtype=torch.float32)
# make sure to enable tiling to avoid OOM
vae.enable_tiling()

Suggested fix:

import torch
from diffusers import HunyuanVideo15Transformer3DModel

transformer = HunyuanVideo15Transformer3DModel.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v",
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions