Skip to content

Commit 6ee2933

Browse files
committed
Address LTX review feedback here like add AutoBlocks, refactor I2V latents, lift encoders
1 parent 330c5f6 commit 6ee2933

9 files changed

Lines changed: 377 additions & 197 deletions

File tree

src/diffusers/modular_pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
"QwenImageLayeredAutoBlocks",
9090
]
9191
_import_structure["hunyuan_video1_5"] = [
92+
"HunyuanVideo15AutoBlocks",
9293
"HunyuanVideo15Blocks",
9394
"HunyuanVideo15Image2VideoBlocks",
9495
"HunyuanVideo15ModularPipeline",
@@ -125,6 +126,7 @@
125126
HeliosPyramidModularPipeline,
126127
)
127128
from .hunyuan_video1_5 import (
129+
HunyuanVideo15AutoBlocks,
128130
HunyuanVideo15Blocks,
129131
HunyuanVideo15Image2VideoBlocks,
130132
HunyuanVideo15ModularPipeline,

src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121

2222
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2323
else:
24-
_import_structure["modular_blocks_hunyuan_video1_5"] = ["HunyuanVideo15Blocks", "HunyuanVideo15Image2VideoBlocks"]
24+
_import_structure["modular_blocks_hunyuan_video1_5"] = [
25+
"HunyuanVideo15AutoBlocks",
26+
"HunyuanVideo15Blocks",
27+
"HunyuanVideo15Image2VideoBlocks",
28+
]
2529
_import_structure["modular_pipeline"] = ["HunyuanVideo15ModularPipeline"]
2630

2731
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -31,7 +35,11 @@
3135
except OptionalDependencyNotAvailable:
3236
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3337
else:
34-
from .modular_blocks_hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15Image2VideoBlocks
38+
from .modular_blocks_hunyuan_video1_5 import (
39+
HunyuanVideo15AutoBlocks,
40+
HunyuanVideo15Blocks,
41+
HunyuanVideo15Image2VideoBlocks,
42+
)
3543
from .modular_pipeline import HunyuanVideo15ModularPipeline
3644
else:
3745
import sys

src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py

Lines changed: 62 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import numpy as np
1818
import torch
1919

20-
from ...configuration_utils import FrozenDict
2120
from ...models import HunyuanVideo15Transformer3DModel
2221
from ...schedulers import FlowMatchEulerDiscreteScheduler
2322
from ...utils import logging
@@ -33,23 +32,53 @@
3332
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
3433
def retrieve_timesteps(
3534
scheduler,
36-
num_inference_steps=None,
37-
device=None,
38-
timesteps=None,
39-
sigmas=None,
35+
num_inference_steps: int | None = None,
36+
device: str | torch.device | None = None,
37+
timesteps: list[int] | None = None,
38+
sigmas: list[float] | None = None,
4039
**kwargs,
4140
):
41+
r"""
42+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
43+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
44+
45+
Args:
46+
scheduler (`SchedulerMixin`):
47+
The scheduler to get timesteps from.
48+
num_inference_steps (`int`):
49+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
50+
must be `None`.
51+
device (`str` or `torch.device`, *optional*):
52+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
53+
timesteps (`list[int]`, *optional*):
54+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
55+
`num_inference_steps` and `sigmas` must be `None`.
56+
sigmas (`list[float]`, *optional*):
57+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
58+
`num_inference_steps` and `timesteps` must be `None`.
59+
60+
Returns:
61+
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
62+
second element is the number of inference steps.
63+
"""
4264
if timesteps is not None and sigmas is not None:
43-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
65+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
4466
if timesteps is not None:
67+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
68+
if not accepts_timesteps:
69+
raise ValueError(
70+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
71+
f" timestep schedules. Please check whether you are using the correct scheduler."
72+
)
4573
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
4674
timesteps = scheduler.timesteps
4775
num_inference_steps = len(timesteps)
4876
elif sigmas is not None:
4977
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
5078
if not accept_sigmas:
5179
raise ValueError(
52-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas."
80+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
81+
f" sigmas schedules. Please check whether you are using the correct scheduler."
5382
)
5483
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
5584
timesteps = scheduler.timesteps
@@ -65,13 +94,7 @@ class HunyuanVideo15TextInputStep(ModularPipelineBlocks):
6594

6695
@property
6796
def description(self) -> str:
68-
return "Input processing step that determines batch_size and dtype"
69-
70-
@property
71-
def expected_components(self) -> list[ComponentSpec]:
72-
return [
73-
ComponentSpec("transformer", HunyuanVideo15Transformer3DModel),
74-
]
97+
return "Input processing step that determines batch_size"
7598

7699
@property
77100
def inputs(self) -> list[InputParam]:
@@ -85,14 +108,12 @@ def inputs(self) -> list[InputParam]:
85108
def intermediate_outputs(self) -> list[OutputParam]:
86109
return [
87110
OutputParam("batch_size", type_hint=int),
88-
OutputParam("dtype", type_hint=torch.dtype),
89111
]
90112

91113
@torch.no_grad()
92114
def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState:
93115
block_state = self.get_block_state(state)
94116
block_state.batch_size = getattr(block_state, "batch_size", None) or block_state.prompt_embeds.shape[0]
95-
block_state.dtype = components.transformer.dtype
96117
self.set_block_state(state, block_state)
97118
return components, state
98119

@@ -122,7 +143,6 @@ def intermediate_outputs(self) -> list[OutputParam]:
122143
OutputParam("num_inference_steps", type_hint=int),
123144
]
124145

125-
# Copied from pipeline_hunyuan_video1_5.py line 702-704
126146
@torch.no_grad()
127147
def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState:
128148
block_state = self.get_block_state(state)
@@ -147,6 +167,10 @@ class HunyuanVideo15PrepareLatentsStep(ModularPipelineBlocks):
147167
def description(self) -> str:
148168
return "Prepare latents, conditioning latents, mask, and image_embeds for T2V"
149169

170+
@property
171+
def expected_components(self) -> list[ComponentSpec]:
172+
return [ComponentSpec("transformer", HunyuanVideo15Transformer3DModel)]
173+
150174
@property
151175
def inputs(self) -> list[InputParam]:
152176
return [
@@ -157,24 +181,22 @@ def inputs(self) -> list[InputParam]:
157181
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
158182
InputParam.template("generator"),
159183
InputParam.template("batch_size", required=True, default=None),
160-
InputParam.template("dtype", default=None),
161184
]
162185

163186
@property
164187
def intermediate_outputs(self) -> list[OutputParam]:
165188
return [
166-
OutputParam.template("latents"),
189+
OutputParam("latents", type_hint=torch.Tensor, description="Pure noise latents"),
167190
OutputParam("cond_latents_concat", type_hint=torch.Tensor),
168191
OutputParam("mask_concat", type_hint=torch.Tensor),
169192
OutputParam("image_embeds", type_hint=torch.Tensor),
170193
]
171194

172-
# Copied from pipeline_hunyuan_video1_5.py lines 652-655, 477-524, 706-725 with self->components
173195
@torch.no_grad()
174196
def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState:
175197
block_state = self.get_block_state(state)
176198
device = components._execution_device
177-
dtype = block_state.dtype
199+
dtype = components.transformer.dtype
178200

179201
height = block_state.height
180202
width = block_state.width
@@ -186,7 +208,6 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
186208
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
187209
num_frames = block_state.num_frames
188210

189-
# Copied from HunyuanVideo15Pipeline.prepare_latents with self->components
190211
latents = block_state.latents
191212
if latents is not None:
192213
latents = latents.to(device=device, dtype=dtype)
@@ -207,12 +228,10 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
207228

208229
block_state.latents = latents
209230

210-
# Copied from HunyuanVideo15Pipeline.prepare_cond_latents_and_mask with self->components
211231
b, c, f, h, w = latents.shape
212232
block_state.cond_latents_concat = torch.zeros(b, c, f, h, w, dtype=dtype, device=device)
213233
block_state.mask_concat = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device)
214234

215-
# T2V: zero image_embeds
216235
block_state.image_embeds = torch.zeros(
217236
block_state.batch_size,
218237
components.vision_num_semantic_tokens,
@@ -225,125 +244,62 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
225244
return components, state
226245

227246

228-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
229-
def retrieve_latents(encoder_output, generator=None, sample_mode="sample"):
230-
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
231-
return encoder_output.latent_dist.sample(generator)
232-
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
233-
return encoder_output.latent_dist.mode()
234-
elif hasattr(encoder_output, "latents"):
235-
return encoder_output.latents
236-
raise AttributeError("Could not access latents of provided encoder_output")
237-
238-
239247
class HunyuanVideo15Image2VideoPrepareLatentsStep(ModularPipelineBlocks):
240248
model_name = "hunyuan-video-1.5"
241249

242250
@property
243251
def description(self) -> str:
244-
return "Prepare latents, conditioning latents, mask, and image_embeds for I2V"
252+
return (
253+
"Prepare I2V conditioning from image_latents and image_embeds. "
254+
"Expects pure noise `latents` from HunyuanVideo15PrepareLatentsStep. "
255+
"Builds cond_latents_concat and mask_concat for the denoiser."
256+
)
245257

246258
@property
247259
def expected_components(self) -> list[ComponentSpec]:
248-
from transformers import SiglipImageProcessor, SiglipVisionModel
249-
250-
from ...models import AutoencoderKLHunyuanVideo15
251-
from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor
252-
253-
return [
254-
ComponentSpec("vae", AutoencoderKLHunyuanVideo15),
255-
ComponentSpec(
256-
"video_processor",
257-
HunyuanVideo15ImageProcessor,
258-
config=FrozenDict({"vae_scale_factor": 16}),
259-
default_creation_method="from_config",
260-
),
261-
ComponentSpec("image_encoder", SiglipVisionModel),
262-
ComponentSpec("feature_extractor", SiglipImageProcessor),
263-
]
260+
return [ComponentSpec("transformer", HunyuanVideo15Transformer3DModel)]
264261

265262
@property
266263
def inputs(self) -> list[InputParam]:
267264
return [
268-
InputParam.template("image"),
269-
InputParam("num_frames", type_hint=int, default=121),
270-
InputParam.template("latents"),
265+
InputParam("image_latents", type_hint=torch.Tensor, required=True),
266+
InputParam("image_embeds", type_hint=torch.Tensor, required=True),
267+
InputParam.template("latents", required=True),
271268
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
272-
InputParam.template("generator"),
273269
InputParam.template("batch_size", required=True, default=None),
274-
InputParam.template("dtype", default=None),
275270
]
276271

277272
@property
278273
def intermediate_outputs(self) -> list[OutputParam]:
279274
return [
280-
OutputParam.template("latents"),
281275
OutputParam("cond_latents_concat", type_hint=torch.Tensor),
282276
OutputParam("mask_concat", type_hint=torch.Tensor),
283277
OutputParam("image_embeds", type_hint=torch.Tensor),
284278
]
285279

286-
# Copied from pipeline_hunyuan_video1_5_image2video.py lines 756-839 with self->components
287280
@torch.no_grad()
288281
def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState:
289282
block_state = self.get_block_state(state)
290283
device = components._execution_device
291-
dtype = block_state.dtype
284+
dtype = components.transformer.dtype
292285

293-
image = block_state.image
294286
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
295-
num_frames = block_state.num_frames
296287

297-
# Resize/crop image to target resolution (line 756-759)
298-
height, width = components.video_processor.calculate_default_height_width(
299-
height=image.size[1], width=image.size[0], target_size=components.target_size
300-
)
301-
image = components.video_processor.resize(image, height=height, width=width, resize_mode="crop")
302-
303-
# Encode image with Siglip (lines 776-781)
304-
image_encoder_dtype = next(components.image_encoder.parameters()).dtype
305-
image_inputs = components.feature_extractor.preprocess(
306-
images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
307-
)
308-
image_inputs = image_inputs.to(device=device, dtype=image_encoder_dtype)
309-
image_embeds = components.image_encoder(**image_inputs).last_hidden_state
310-
image_embeds = image_embeds.repeat(batch_size, 1, 1)
311-
block_state.image_embeds = image_embeds.to(device=device, dtype=dtype)
288+
b, c, f, h, w = block_state.latents.shape
312289

313-
# Prepare latents (lines 818-829)
314-
latents = block_state.latents
315-
if latents is not None:
316-
latents = latents.to(device=device, dtype=dtype)
317-
else:
318-
shape = (
319-
batch_size,
320-
components.num_channels_latents,
321-
(num_frames - 1) // components.vae_scale_factor_temporal + 1,
322-
int(height) // components.vae_scale_factor_spatial,
323-
int(width) // components.vae_scale_factor_spatial,
324-
)
325-
latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype)
326-
block_state.latents = latents
327-
328-
# Prepare cond latents and mask (lines 594-632, 831-839)
329-
b, c, f, h, w = latents.shape
330-
331-
# Copied from _get_image_latents (lines 375-388) with self->components
332-
vae_dtype = components.vae.dtype
333-
image_tensor = components.video_processor.preprocess(
334-
image, height=h * components.vae_scale_factor_spatial, width=w * components.vae_scale_factor_spatial
335-
).to(device, dtype=vae_dtype)
336-
image_tensor = image_tensor.unsqueeze(2)
337-
image_latents = retrieve_latents(components.vae.encode(image_tensor), sample_mode="argmax")
338-
image_latents = image_latents * components.vae.config.scaling_factor
339-
340-
latent_condition = image_latents.repeat(batch_size, 1, f, 1, 1)
290+
latent_condition = block_state.image_latents.to(device=device, dtype=dtype)
291+
latent_condition = latent_condition.repeat(batch_size, 1, f, 1, 1)
341292
latent_condition[:, :, 1:, :, :] = 0
342-
block_state.cond_latents_concat = latent_condition.to(device=device, dtype=dtype)
293+
block_state.cond_latents_concat = latent_condition
343294

344295
latent_mask = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device)
345296
latent_mask[:, :, 0, :, :] = 1.0
346297
block_state.mask_concat = latent_mask
347298

299+
image_embeds = block_state.image_embeds.to(device=device, dtype=dtype)
300+
if image_embeds.shape[0] == 1 and batch_size > 1:
301+
image_embeds = image_embeds.repeat(batch_size, 1, 1)
302+
block_state.image_embeds = image_embeds
303+
348304
self.set_block_state(state, block_state)
349305
return components, state

src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,13 @@ def intermediate_outputs(self) -> list[OutputParam]:
5858
OutputParam.template("videos"),
5959
]
6060

61-
# Copied from pipeline_hunyuan_video1_5.py lines 823-829
6261
@torch.no_grad()
6362
def __call__(self, components, state: PipelineState) -> PipelineState:
6463
block_state = self.get_block_state(state)
6564

66-
if block_state.output_type == "latent":
67-
block_state.videos = block_state.latents
68-
else:
69-
latents = block_state.latents.to(components.vae.dtype) / components.vae.config.scaling_factor
70-
video = components.vae.decode(latents, return_dict=False)[0]
71-
block_state.videos = components.video_processor.postprocess_video(
72-
video, output_type=block_state.output_type
73-
)
65+
latents = block_state.latents.to(components.vae.dtype) / components.vae.config.scaling_factor
66+
video = components.vae.decode(latents, return_dict=False)[0]
67+
block_state.videos = components.video_processor.postprocess_video(video, output_type=block_state.output_type)
7468

7569
self.set_block_state(state, block_state)
7670
return components, state

0 commit comments

Comments
 (0)