Skip to content

Commit e5aeaf9

Browse files
committed
Address review: remove unused code, add LTXAutoBlocks, refactor I2V latents flow
1 parent a972af0 commit e5aeaf9

7 files changed

Lines changed: 83 additions & 35 deletions

File tree

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@
455455
"HeliosPyramidDistilledAutoBlocks",
456456
"HeliosPyramidDistilledModularPipeline",
457457
"HeliosPyramidModularPipeline",
458+
"LTXAutoBlocks",
458459
"LTXBlocks",
459460
"LTXImage2VideoBlocks",
460461
"LTXModularPipeline",
@@ -1237,6 +1238,7 @@
12371238
HeliosPyramidDistilledAutoBlocks,
12381239
HeliosPyramidDistilledModularPipeline,
12391240
HeliosPyramidModularPipeline,
1241+
LTXAutoBlocks,
12401242
LTXBlocks,
12411243
LTXImage2VideoBlocks,
12421244
LTXModularPipeline,

src/diffusers/modular_pipelines/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
"QwenImageLayeredAutoBlocks",
9090
]
9191
_import_structure["ltx"] = [
92+
"LTXAutoBlocks",
9293
"LTXBlocks",
9394
"LTXImage2VideoBlocks",
9495
"LTXModularPipeline",
@@ -124,7 +125,7 @@
124125
HeliosPyramidDistilledModularPipeline,
125126
HeliosPyramidModularPipeline,
126127
)
127-
from .ltx import LTXBlocks, LTXImage2VideoBlocks, LTXModularPipeline
128+
from .ltx import LTXAutoBlocks, LTXBlocks, LTXImage2VideoBlocks, LTXModularPipeline
128129
from .modular_pipeline import (
129130
AutoPipelineBlocks,
130131
BlockState,

src/diffusers/modular_pipelines/ltx/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2323
else:
24-
_import_structure["modular_blocks_ltx"] = ["LTXBlocks", "LTXImage2VideoBlocks"]
24+
_import_structure["modular_blocks_ltx"] = ["LTXAutoBlocks", "LTXBlocks", "LTXImage2VideoBlocks"]
2525
_import_structure["modular_pipeline"] = ["LTXModularPipeline"]
2626

2727
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -31,7 +31,7 @@
3131
except OptionalDependencyNotAvailable:
3232
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3333
else:
34-
from .modular_blocks_ltx import LTXBlocks, LTXImage2VideoBlocks
34+
from .modular_blocks_ltx import LTXAutoBlocks, LTXBlocks, LTXImage2VideoBlocks
3535
from .modular_pipeline import LTXModularPipeline
3636
else:
3737
import sys

src/diffusers/modular_pipelines/ltx/before_denoise.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,12 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
124124
return latents
125125

126126

127-
def _normalize_latents(
128-
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
127+
def _unpack_latents(
128+
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
129129
) -> torch.Tensor:
130-
# Normalize latents across the channel dimension [B, C, F, H, W]
131-
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
132-
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
133-
latents = (latents - latents_mean) * scaling_factor / latents_std
130+
batch_size = latents.size(0)
131+
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
132+
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
134133
return latents
135134

136135

@@ -343,19 +342,19 @@ class LTXImage2VideoPrepareLatentsStep(ModularPipelineBlocks):
343342
@property
344343
def description(self) -> str:
345344
return (
346-
"Prepare latents step for image-to-video: takes pre-encoded image latents and creates a conditioning mask"
345+
"Prepare image-to-video latents: adds noise to pre-encoded image latents and creates a conditioning mask. "
346+
"Expects pure noise `latents` from LTXPrepareLatentsStep."
347347
)
348348

349349
@property
350350
def inputs(self) -> list[InputParam]:
351351
return [
352352
InputParam("image_latents", type_hint=torch.Tensor, required=True),
353+
InputParam.template("latents", required=True),
353354
InputParam.template("height", default=512),
354355
InputParam.template("width", default=704),
355356
InputParam("num_frames", type_hint=int, default=161),
356-
InputParam.template("latents"),
357357
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
358-
InputParam.template("generator"),
359358
InputParam.template("batch_size", required=True),
360359
]
361360

@@ -377,37 +376,31 @@ def __call__(self, components: LTXModularPipeline, state: PipelineState) -> Pipe
377376
width = block_state.width // components.vae_spatial_compression_ratio
378377
num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
379378

380-
mask_shape = (batch_size, 1, num_frames, height, width)
381-
382-
if block_state.latents is not None:
383-
conditioning_mask = block_state.latents.new_zeros(mask_shape)
384-
conditioning_mask[:, :, 0] = 1.0
385-
conditioning_mask = _pack_latents(
386-
conditioning_mask,
387-
components.transformer_spatial_patch_size,
388-
components.transformer_temporal_patch_size,
389-
).squeeze(-1)
390-
block_state.latents = block_state.latents.to(device=device, dtype=torch.float32)
391-
block_state.conditioning_mask = conditioning_mask
392-
self.set_block_state(state, block_state)
393-
return components, state
394-
395379
init_latents = block_state.image_latents.to(device=device, dtype=torch.float32)
396380
if init_latents.shape[0] < batch_size:
397381
init_latents = init_latents.repeat_interleave(batch_size // init_latents.shape[0], dim=0)
398382
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
399383

400-
actual_mask_shape = (
384+
conditioning_mask = torch.zeros(
401385
init_latents.shape[0],
402386
1,
403387
init_latents.shape[2],
404388
init_latents.shape[3],
405389
init_latents.shape[4],
390+
device=device,
391+
dtype=torch.float32,
406392
)
407-
conditioning_mask = torch.zeros(actual_mask_shape, device=device, dtype=torch.float32)
408393
conditioning_mask[:, :, 0] = 1.0
409394

410-
noise = randn_tensor(init_latents.shape, generator=block_state.generator, device=device, dtype=torch.float32)
395+
# Unpack the pure noise latents from LTXPrepareLatentsStep to mix with image latents
396+
noise = _unpack_latents(
397+
block_state.latents,
398+
num_frames,
399+
height,
400+
width,
401+
components.transformer_spatial_patch_size,
402+
components.transformer_temporal_patch_size,
403+
)
411404
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
412405

413406
conditioning_mask = _pack_latents(

src/diffusers/modular_pipelines/ltx/denoise.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,6 @@ def __init__(
344344

345345
@property
346346
def expected_components(self) -> list[ComponentSpec]:
347-
from ...configuration_utils import FrozenDict
348-
from ...guiders import ClassifierFreeGuidance
349-
350347
return [
351348
ComponentSpec(
352349
"guider",

src/diffusers/modular_pipelines/ltx/modular_blocks_ltx.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,11 @@ class LTXImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
145145
block_classes = [
146146
LTXTextInputStep,
147147
LTXSetTimestepsStep,
148+
LTXPrepareLatentsStep,
148149
LTXImage2VideoPrepareLatentsStep,
149150
LTXImage2VideoDenoiseStep,
150151
]
151-
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
152+
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_i2v_latents", "denoise"]
152153

153154
@property
154155
def description(self):
@@ -268,6 +269,60 @@ def description(self):
268269
)
269270

270271

272+
# auto_docstring
273+
class LTXAutoCoreDenoiseStep(AutoPipelineBlocks):
274+
"""
275+
Auto denoise block that selects the appropriate denoise pipeline based on inputs.
276+
- `LTXImage2VideoCoreDenoiseStep` is used when `image_latents` is provided.
277+
- `LTXCoreDenoiseStep` is used otherwise (text-to-video).
278+
"""
279+
280+
model_name = "ltx"
281+
block_classes = [LTXImage2VideoCoreDenoiseStep, LTXCoreDenoiseStep]
282+
block_names = ["image2video", "text2video"]
283+
block_trigger_inputs = ["image_latents", None]
284+
285+
@property
286+
def description(self):
287+
return (
288+
"Auto denoise block that selects the appropriate denoise pipeline based on inputs.\n"
289+
" - `LTXImage2VideoCoreDenoiseStep` is used when `image_latents` is provided.\n"
290+
" - `LTXCoreDenoiseStep` is used otherwise (text-to-video)."
291+
)
292+
293+
294+
# auto_docstring
295+
class LTXAutoBlocks(SequentialPipelineBlocks):
296+
"""
297+
Auto blocks for LTX Video that support both text-to-video and image-to-video workflows.
298+
299+
Supported workflows:
300+
- `text2video`: requires `prompt`
301+
- `image2video`: requires `image`, `prompt`
302+
"""
303+
304+
model_name = "ltx"
305+
block_classes = [
306+
LTXTextEncoderStep,
307+
LTXAutoVaeEncoderStep,
308+
LTXAutoCoreDenoiseStep,
309+
LTXVaeDecoderStep,
310+
]
311+
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
312+
313+
@property
314+
def description(self):
315+
return (
316+
"Auto blocks for LTX Video that support both text-to-video and image-to-video workflows.\n"
317+
" - text2video: requires `prompt`\n"
318+
" - image2video: requires `image`, `prompt`"
319+
)
320+
321+
@property
322+
def outputs(self):
323+
return [OutputParam.template("videos")]
324+
325+
271326
# auto_docstring
272327
class LTXImage2VideoBlocks(SequentialPipelineBlocks):
273328
"""

src/diffusers/modular_pipelines/ltx/modular_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class LTXModularPipeline(
3131
> [!WARNING] > This is an experimental feature and is likely to change in the future.
3232
"""
3333

34-
default_blocks_name = "LTXBlocks"
34+
default_blocks_name = "LTXAutoBlocks"
3535

3636
@property
3737
def vae_spatial_compression_ratio(self):

0 commit comments

Comments
 (0)