From 9ceb4bdd79b66ef2ff1c7ea196a32f518a82b75f Mon Sep 17 00:00:00 2001 From: Jerry Song <46962917+Songrui625@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:55:01 +0800 Subject: [PATCH 001/215] Fix LTX-2 image-to-video generation failure in two stages generation (#13187) * Fix LTX-2 image-to-video generation failure in two stages generation In LTX-2's two-stage image-to-video generation task, specifically after the upsampling step, a shape mismatch occurs between the `latents` and the `conditioning_mask`, which causes an error in function `_create_noised_state`. Fix it by creating the `conditioning_mask` based on the shape of the `latents`. * Add unit test for LTX-2 i2v two stages inference with upsampler * Downscaling the upsampler in LTX-2 image-to-video unit test * Apply style fixes --------- Co-authored-by: github-actions[bot] --- .../ltx2/pipeline_ltx2_image2video.py | 11 ++- tests/pipelines/ltx2/test_ltx2_image2video.py | 67 ++++++++++++++++++- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 98d323efd477..83ba2cd7c685 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -699,9 +699,13 @@ def prepare_latents( mask_shape = (batch_size, 1, num_frames, height, width) if latents is not None: - conditioning_mask = latents.new_zeros(mask_shape) - conditioning_mask[:, :, 0] = 1.0 if latents.ndim == 5: + # conditioning_mask needs to the same shape as latents in two stages generation. + batch_size, _, num_frames, height, width = latents.shape + mask_shape = (batch_size, 1, num_frames, height, width) + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + latents = self._normalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) @@ -710,6 +714,9 @@ def prepare_latents( latents = self._pack_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) + else: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ).squeeze(-1) diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py index 3653e1cfc5e4..92c000c7bf7c 100644 --- a/tests/pipelines/ltx2/test_ltx2_image2video.py +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -24,7 +24,8 @@ LTX2ImageToVideoPipeline, LTX2VideoTransformer3DModel, ) -from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder from ...testing_utils import enable_full_determinism @@ -174,6 +175,15 @@ def get_dummy_components(self): return components + def get_dummy_upsample_component(self, in_channels=4, mid_channels=32, num_blocks_per_stage=1): + upsampler = LTX2LatentUpsamplerModel( + in_channels=in_channels, + mid_channels=mid_channels, + num_blocks_per_stage=num_blocks_per_stage, + ) + + return upsampler + def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) @@ -287,5 +297,60 @@ def test_two_stages_inference(self): assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + def test_two_stages_inference_with_upsampler(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + upsampler = self.get_dummy_upsample_component(in_channels=video_latent.shape[1]) + upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler) + upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0] + self.assertEqual(upscaled_video_latent.shape, (1, 4, 3, 32, 32)) + + inputs["latents"] = upscaled_video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 64, 64)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.4497, 0.6757, 0.4219, 0.7686, 0.4525, 0.6483, 0.3969, 0.7404, 0.3541, 0.3039, 0.4592, 0.3521, 0.3665, 0.2785, 0.3336, 0.3079 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0271, 0.0492, 0.1249, 0.1126, 0.1661, 0.1060, 0.1717, 0.0944, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) From 9dce1bd494134685d51bc0634d7c0b33bc32ca79 Mon Sep 17 00:00:00 2001 From: Christopher Date: Fri, 27 Feb 2026 11:13:41 +0100 Subject: [PATCH 002/215] Fixing Kohya loras loading: Flux.1-dev loras with TE ("lora_te1_" prefix) (#13188) * fixing text encoder lora loading * following Cursor's review --- src/diffusers/loaders/lora_conversion_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index f56632ced819..8b0f95b905e4 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -856,7 +856,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): ) state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")} - has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict) + has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict) if has_diffb: zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b") if zero_status_diff_b: @@ -895,7 +895,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): state_dict = { _custom_replace(k, limit_substrings): v for k, v in state_dict.items() - if k.startswith(("lora_unet_", "lora_te_")) + if k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) } if any("text_projection" in k for k in state_dict): From 1c2e7c01722c0eb3b0dc733f4d84ea66628ba2ce Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 27 Feb 2026 10:50:35 -1000 Subject: [PATCH 003/215] [Modular] update the auto pipeline blocks doc (#13148) * update * Apply suggestion from @yiyixuxu * Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/modular_diffusers/auto_pipeline_blocks.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add to api --------- Co-authored-by: yiyi@huggingface.co Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: yiyi@huggingface.co --- .../api/modular_diffusers/pipeline_blocks.md | 6 +- .../modular_diffusers/auto_pipeline_blocks.md | 77 ++++++++++++++++++- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/modular_diffusers/pipeline_blocks.md b/docs/source/en/api/modular_diffusers/pipeline_blocks.md index 8ad581e679ac..4808f2cf3bbe 100644 --- a/docs/source/en/api/modular_diffusers/pipeline_blocks.md +++ b/docs/source/en/api/modular_diffusers/pipeline_blocks.md @@ -14,4 +14,8 @@ ## AutoPipelineBlocks -[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks \ No newline at end of file +[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks + +## ConditionalPipelineBlocks + +[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConditionalPipelineBlocks \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/auto_pipeline_blocks.md b/docs/source/en/modular_diffusers/auto_pipeline_blocks.md index 2d4d82c735bd..1bcf1d691036 100644 --- a/docs/source/en/modular_diffusers/auto_pipeline_blocks.md +++ b/docs/source/en/modular_diffusers/auto_pipeline_blocks.md @@ -121,7 +121,7 @@ from diffusers.modular_pipelines import AutoPipelineBlocks class AutoImageBlocks(AutoPipelineBlocks): # List of sub-block classes to choose from - block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls] + block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock] # Names for each block in the same order block_names = ["inpaint", "img2img", "text2img"] # Trigger inputs that determine which block to run @@ -129,8 +129,8 @@ class AutoImageBlocks(AutoPipelineBlocks): # - "image" triggers img2img workflow (but only if mask is not provided) # - if none of above, runs the text2img workflow (default) block_trigger_inputs = ["mask", "image", None] - # Description is extremely important for AutoPipelineBlocks + @property def description(self): return ( "Pipeline generates images given different types of conditions!\n" @@ -141,7 +141,7 @@ class AutoImageBlocks(AutoPipelineBlocks): ) ``` -It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, it's conditional logic may be difficult to figure out if it isn't properly explained. +It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, its conditional logic may be difficult to figure out if it isn't properly explained. Create an instance of `AutoImageBlocks`. @@ -152,5 +152,74 @@ auto_blocks = AutoImageBlocks() For more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input. ```py -auto_blocks.get_execution_blocks("mask") +auto_blocks.get_execution_blocks(mask=True) +``` + +## ConditionalPipelineBlocks + +[`~modular_pipelines.AutoPipelineBlocks`] is a special case of [`~modular_pipelines.ConditionalPipelineBlocks`]. While [`~modular_pipelines.AutoPipelineBlocks`] selects blocks based on whether a trigger input is provided or not, [`~modular_pipelines.ConditionalPipelineBlocks`] is able to select a block based on custom selection logic provided in the `select_block` method. + +Here is the same example written using [`~modular_pipelines.ConditionalPipelineBlocks`] directly: + +```py +from diffusers.modular_pipelines import ConditionalPipelineBlocks + +class AutoImageBlocks(ConditionalPipelineBlocks): + block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image"] + default_block_name = "text2img" + + @property + def description(self): + return ( + "Pipeline generates images given different types of conditions!\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + + " - inpaint workflow is run when `mask` is provided.\n" + + " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n" + + " - text2img workflow is run when neither `image` nor `mask` is provided.\n" + ) + + def select_block(self, mask=None, image=None) -> str | None: + if mask is not None: + return "inpaint" + if image is not None: + return "img2img" + return None # falls back to default_block_name ("text2img") +``` + +The inputs listed in `block_trigger_inputs` are passed as keyword arguments to `select_block()`. When `select_block` returns `None`, it falls back to `default_block_name`. If `default_block_name` is also `None`, the entire conditional block is skipped — this is useful for optional processing steps that should only run when specific inputs are provided. + +## Workflows + +Pipelines that contain conditional blocks ([`~modular_pipelines.AutoPipelineBlocks`] or [`~modular_pipelines.ConditionalPipelineBlocks]`) can support multiple workflows — for example, our SDXL modular pipeline supports a dozen workflows all in one pipeline. But this also means it can be confusing for users to know what workflows are supported and how to run them. For pipeline builders, it's useful to be able to extract only the blocks relevant to a specific workflow. + +We recommend defining a `_workflow_map` to give each workflow a name and explicitly list the inputs it requires. + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks + +class MyPipelineBlocks(SequentialPipelineBlocks): + block_classes = [TextEncoderBlock, AutoImageBlocks, DecodeBlock] + block_names = ["text_encoder", "auto_image", "decode"] + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + "inpaint": {"mask": True, "image": True, "prompt": True}, + } +``` + +All of our built-in modular pipelines come with pre-defined workflows. The `available_workflows` property lists all supported workflows: + +```py +pipeline_blocks = MyPipelineBlocks() +pipeline_blocks.available_workflows +# ['text2image', 'image2image', 'inpaint'] +``` + +Retrieve a specific workflow with `get_workflow` to inspect and debug a specific block that executes the workflow. + +```py +pipeline_blocks.get_workflow("inpaint") ``` \ No newline at end of file From 72929cad75854f86d54dfa170850be9be9d11d9d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 28 Feb 2026 08:47:21 +0530 Subject: [PATCH 004/215] [tests] consistency tests for modular index (#13192) * add a test to check modular index consistency * check for compulsory keys. --- .../test_modular_pipelines_common.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index e97b543ff85d..c94f41935938 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -1,4 +1,6 @@ import gc +import json +import os import tempfile from typing import Callable @@ -349,6 +351,33 @@ def test_save_from_pretrained(self): assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def test_modular_index_consistency(self): + pipe = self.get_pipeline() + components_spec = pipe._component_specs + components = sorted(components_spec.keys()) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + index_file = os.path.join(tmpdir, "modular_model_index.json") + assert os.path.exists(index_file) + + with open(index_file) as f: + index_contents = json.load(f) + + compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"} + for k in compulsory_keys: + assert k in index_contents + + to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"} + for component in components: + spec = components_spec[component] + for attr in to_check_attrs: + if getattr(spec, "pretrained_model_name_or_path", None) is not None: + for attr in to_check_attrs: + assert component in index_contents, f"{component} should be present in index but isn't." + attr_value_from_index = index_contents[component][2][attr] + assert getattr(spec, attr) == attr_value_from_index + def test_workflow_map(self): blocks = self.pipeline_blocks_class() if blocks._workflow_map is None: From 52e8cdbfaeefa22102880b52a1133da2c49c90e9 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 27 Feb 2026 18:58:01 -1000 Subject: [PATCH 005/215] [modular] fallback to default_blocks_name when loading base block classes in ModularPipeline (#13193) up Co-authored-by: yiyi@huggingface.co --- .../modular_pipelines/modular_pipeline.py | 9 ++++++- .../test_modular_pipelines_common.py | 24 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 76a850b63c4e..5d6c4064ef96 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1633,7 +1633,14 @@ def __init__( blocks_class_name = self.default_blocks_name if blocks_class_name is not None: diffusers_module = importlib.import_module("diffusers") - blocks_class = getattr(diffusers_module, blocks_class_name) + blocks_class = getattr(diffusers_module, blocks_class_name, None) + # If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes + # fall back to default_blocks_name + if blocks_class is None or not blocks_class.block_classes: + blocks_class_name = self.default_blocks_name + blocks_class = getattr(diffusers_module, blocks_class_name) + + if blocks_class is not None: blocks = blocks_class() else: logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index c94f41935938..5aceae77da27 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -728,3 +728,27 @@ def test_load_components_skips_invalid_pretrained_path(self): # Verify test_component was not loaded assert not hasattr(pipe, "test_component") or pipe.test_component is None + + +class TestModularPipelineInitFallback: + """Test that ModularPipeline.__init__ falls back to default_blocks_name when + _blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict).""" + + def test_init_fallback_when_blocks_class_name_is_base_class(self, tmp_path): + # 1. Load pipeline and get a workflow (returns a base SequentialPipelineBlocks) + pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + t2i_blocks = pipe.blocks.get_workflow("text2image") + assert t2i_blocks.__class__.__name__ == "SequentialPipelineBlocks" + + # 2. Use init_pipeline to create a new pipeline from the workflow blocks + t2i_pipe = t2i_blocks.init_pipeline("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + + # 3. Save and reload — the saved config will have _blocks_class_name="SequentialPipelineBlocks" + save_dir = str(tmp_path / "pipeline") + t2i_pipe.save_pretrained(save_dir) + loaded_pipe = ModularPipeline.from_pretrained(save_dir) + + # 4. Verify it fell back to default_blocks_name and has correct blocks + assert loaded_pipe.__class__.__name__ == pipe.__class__.__name__ + assert loaded_pipe._blocks.__class__.__name__ == pipe._blocks.__class__.__name__ + assert len(loaded_pipe._blocks.sub_blocks) == len(pipe._blocks.sub_blocks) From f71980d09136c9f4d466983e8d865e50027fa393 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 2 Mar 2026 14:34:49 +0530 Subject: [PATCH 006/215] [chore] updates in the pypi publication workflow. (#12805) * updates in the pypi publication workflow. * change to 3.10 --- .github/workflows/pypi_publish.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pypi_publish.yaml b/.github/workflows/pypi_publish.yaml index 1e8ce4599544..99cbcc1fade2 100644 --- a/.github/workflows/pypi_publish.yaml +++ b/.github/workflows/pypi_publish.yaml @@ -54,7 +54,6 @@ jobs: python -m pip install --upgrade pip pip install -U setuptools wheel twine pip install -U torch --index-url https://download.pytorch.org/whl/cpu - pip install -U transformers - name: Build the dist files run: python setup.py bdist_wheel && python setup.py sdist @@ -69,6 +68,8 @@ jobs: run: | pip install diffusers && pip uninstall diffusers -y pip install -i https://test.pypi.org/simple/ diffusers + pip install -U transformers + python utils/print_env.py python -c "from diffusers import __version__; print(__version__)" python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()" python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')" From 40358e344304ff6eb16e5f6e8e4e9616b041f4eb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 2 Mar 2026 15:03:58 +0530 Subject: [PATCH 007/215] [tests] enable cpu offload test in torchao without compilation. (#12704) enable cpu offload test in torchao without compilation. --- tests/quantization/torchao/test_torchao.py | 27 +++++++++++----------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 7a8e3cc67877..a722eaece4d1 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -74,7 +74,7 @@ @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.7.0") +@require_torchao_version_greater_or_equal("0.14.0") class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -132,7 +132,7 @@ def test_repr(self): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.7.0") +@require_torchao_version_greater_or_equal("0.14.0") class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() @@ -587,7 +587,7 @@ def test_aobase_config(self): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.7.0") +@require_torchao_version_greater_or_equal("0.14.0") class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" @@ -698,23 +698,22 @@ def test_aobase_config(self): self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) -@require_torchao_version_greater_or_equal("0.7.0") +@require_torchao_version_greater_or_equal("0.14.0") class TorchAoCompileTest(QuantCompileTests, unittest.TestCase): @property def quantization_config(self): return PipelineQuantizationConfig( - quant_mapping={ - "transformer": TorchAoConfig(quant_type="int8_weight_only"), - }, + quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())}, ) - @unittest.skip( - "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work " - "when compiling." - ) def test_torch_compile_with_cpu_offload(self): + pipe = self._init_pipeline(self.quantization_config, torch.bfloat16) + pipe.enable_model_cpu_offload() + # No compilation because it fails with: # RuntimeError: _apply(): Couldn't swap Linear.weight - super().test_torch_compile_with_cpu_offload() + + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) @parameterized.expand([False, True]) @unittest.skip( @@ -745,7 +744,7 @@ def test_torch_compile_with_group_offload_leaf(self, use_stream): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.7.0") +@require_torchao_version_greater_or_equal("0.14.0") @slow @nightly class SlowTorchAoTests(unittest.TestCase): @@ -907,7 +906,7 @@ def test_memory_footprint_int8wo(self): @require_torch @require_torch_accelerator -@require_torchao_version_greater_or_equal("0.7.0") +@require_torchao_version_greater_or_equal("0.14.0") @slow @nightly class SlowTorchAoPreserializedModelTests(unittest.TestCase): From bd591b291ede8b665f55feadf39e909b81753e50 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 2 Mar 2026 16:39:56 +0530 Subject: [PATCH 008/215] remove db utils from benchmarking (#13199) --- .github/workflows/benchmark.yml | 14 --- benchmarks/populate_into_db.py | 166 -------------------------------- 2 files changed, 180 deletions(-) delete mode 100644 benchmarks/populate_into_db.py diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 3017fc96a5e3..3ca9435d97e0 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -62,20 +62,6 @@ jobs: with: name: benchmark_test_reports path: benchmarks/${{ env.BASE_PATH }} - - # TODO: enable this once the connection problem has been resolved. - - name: Update benchmarking results to DB - env: - PGDATABASE: metrics - PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }} - PGUSER: transformers_benchmarks - PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }} - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - run: | - git config --global --add safe.directory /__w/diffusers/diffusers - commit_id=$GITHUB_SHA - commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70) - cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg" - name: Report success status if: ${{ success() }} diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py deleted file mode 100644 index 55e46b058683..000000000000 --- a/benchmarks/populate_into_db.py +++ /dev/null @@ -1,166 +0,0 @@ -import argparse -import os -import sys - -import gpustat -import pandas as pd -import psycopg2 -import psycopg2.extras -from psycopg2.extensions import register_adapter -from psycopg2.extras import Json - - -register_adapter(dict, Json) - -FINAL_CSV_FILENAME = "collated_results.csv" -# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27 -BENCHMARKS_TABLE_NAME = "benchmarks" -MEASUREMENTS_TABLE_NAME = "model_measurements" - - -def _init_benchmark(conn, branch, commit_id, commit_msg): - gpu_stats = gpustat.GPUStatCollection.new_query() - metadata = {"gpu_name": gpu_stats[0]["name"]} - repository = "huggingface/diffusers" - with conn.cursor() as cur: - cur.execute( - f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id", - (repository, branch, commit_id, commit_msg, metadata), - ) - benchmark_id = cur.fetchone()[0] - print(f"Initialised benchmark #{benchmark_id}") - return benchmark_id - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "branch", - type=str, - help="The branch name on which the benchmarking is performed.", - ) - - parser.add_argument( - "commit_id", - type=str, - help="The commit hash on which the benchmarking is performed.", - ) - - parser.add_argument( - "commit_msg", - type=str, - help="The commit message associated with the commit, truncated to 70 characters.", - ) - args = parser.parse_args() - return args - - -if __name__ == "__main__": - args = parse_args() - try: - conn = psycopg2.connect( - host=os.getenv("PGHOST"), - database=os.getenv("PGDATABASE"), - user=os.getenv("PGUSER"), - password=os.getenv("PGPASSWORD"), - ) - print("DB connection established successfully.") - except Exception as e: - print(f"Problem during DB init: {e}") - sys.exit(1) - - try: - benchmark_id = _init_benchmark( - conn=conn, - branch=args.branch, - commit_id=args.commit_id, - commit_msg=args.commit_msg, - ) - except Exception as e: - print(f"Problem during initializing benchmark: {e}") - sys.exit(1) - - cur = conn.cursor() - - df = pd.read_csv(FINAL_CSV_FILENAME) - - # Helper to cast values (or None) given a dtype - def _cast_value(val, dtype: str): - if pd.isna(val): - return None - - if dtype == "text": - return str(val).strip() - - if dtype == "float": - try: - return float(val) - except ValueError: - return None - - if dtype == "bool": - s = str(val).strip().lower() - if s in ("true", "t", "yes", "1"): - return True - if s in ("false", "f", "no", "0"): - return False - if val in (1, 1.0): - return True - if val in (0, 0.0): - return False - return None - - return val - - try: - rows_to_insert = [] - for _, row in df.iterrows(): - scenario = _cast_value(row.get("scenario"), "text") - model_cls = _cast_value(row.get("model_cls"), "text") - num_params_B = _cast_value(row.get("num_params_B"), "float") - flops_G = _cast_value(row.get("flops_G"), "float") - time_plain_s = _cast_value(row.get("time_plain_s"), "float") - mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float") - time_compile_s = _cast_value(row.get("time_compile_s"), "float") - mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float") - fullgraph = _cast_value(row.get("fullgraph"), "bool") - mode = _cast_value(row.get("mode"), "text") - - # If "github_sha" column exists in the CSV, cast it; else default to None - if "github_sha" in df.columns: - github_sha = _cast_value(row.get("github_sha"), "text") - else: - github_sha = None - - measurements = { - "scenario": scenario, - "model_cls": model_cls, - "num_params_B": num_params_B, - "flops_G": flops_G, - "time_plain_s": time_plain_s, - "mem_plain_GB": mem_plain_GB, - "time_compile_s": time_compile_s, - "mem_compile_GB": mem_compile_GB, - "fullgraph": fullgraph, - "mode": mode, - "github_sha": github_sha, - } - rows_to_insert.append((benchmark_id, measurements)) - - # Batch-insert all rows - insert_sql = f""" - INSERT INTO {MEASUREMENTS_TABLE_NAME} ( - benchmark_id, - measurements - ) - VALUES (%s, %s); - """ - - psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert) - conn.commit() - - cur.close() - conn.close() - except Exception as e: - print(f"Exception: {e}") - sys.exit(1) From 05a4d5b20f7d03cd787e90f1ccd707c89228a4f3 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 2 Mar 2026 17:44:25 +0530 Subject: [PATCH 009/215] [AutoModel] Fix bug with subfolders and local model paths when loading custom code (#13197) * update * update --- src/diffusers/utils/dynamic_modules_utils.py | 11 ++++- tests/models/test_models_auto.py | 43 ++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 2a1cea10e14d..856966dd29b5 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -299,7 +299,10 @@ def get_cached_module_file( # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) - module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + if subfolder is not None: + module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file) + else: + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) if os.path.isfile(module_file_or_url): resolved_module_file = module_file_or_url @@ -384,7 +387,11 @@ def get_cached_module_file( if not os.path.exists(submodule_path / module_folder): os.makedirs(submodule_path / module_folder) module_needed = f"{module_needed}.py" - shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + if subfolder is not None: + source_path = os.path.join(pretrained_model_name_or_path, subfolder, module_needed) + else: + source_path = os.path.join(pretrained_model_name_or_path, module_needed) + shutil.copyfile(source_path, submodule_path / module_needed) else: # Get the commit hash # TODO: we will get this info in the etag soon, so retrieve it from there and not here. diff --git a/tests/models/test_models_auto.py b/tests/models/test_models_auto.py index 3506f8fa0d17..57a7760ea841 100644 --- a/tests/models/test_models_auto.py +++ b/tests/models/test_models_auto.py @@ -1,6 +1,10 @@ +import json +import os +import tempfile import unittest from unittest.mock import MagicMock, patch +import torch from transformers import CLIPTextModel, LongformerModel from diffusers.models import AutoModel, UNet2DConditionModel @@ -35,6 +39,45 @@ def test_load_from_model_index(self): ) assert isinstance(model, CLIPTextModel) + def test_load_dynamic_module_from_local_path_with_subfolder(self): + CUSTOM_MODEL_CODE = ( + "import torch\n" + "from diffusers import ModelMixin, ConfigMixin\n" + "from diffusers.configuration_utils import register_to_config\n" + "\n" + "class CustomModel(ModelMixin, ConfigMixin):\n" + " @register_to_config\n" + " def __init__(self, hidden_size=8):\n" + " super().__init__()\n" + " self.linear = torch.nn.Linear(hidden_size, hidden_size)\n" + "\n" + " def forward(self, x):\n" + " return self.linear(x)\n" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + subfolder = "custom_model" + model_dir = os.path.join(tmpdir, subfolder) + os.makedirs(model_dir) + + with open(os.path.join(model_dir, "modeling.py"), "w") as f: + f.write(CUSTOM_MODEL_CODE) + + config = { + "_class_name": "CustomModel", + "_diffusers_version": "0.0.0", + "auto_map": {"AutoModel": "modeling.CustomModel"}, + "hidden_size": 8, + } + with open(os.path.join(model_dir, "config.json"), "w") as f: + json.dump(config, f) + + torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin")) + + model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True) + assert model.__class__.__name__ == "CustomModel" + assert model.config["hidden_size"] == 8 + class TestAutoModelFromConfig(unittest.TestCase): @patch( From ae20345b295098fd79cafb98207cff452f2c9323 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 2 Mar 2026 22:13:25 +0530 Subject: [PATCH 010/215] [AutoModel] Allow registering `auto_map` to model config (#13186) * update * update --- docs/source/en/using-diffusers/automodel.md | 27 +++++++++++ src/diffusers/configuration_utils.py | 38 ++++++++++++++++ tests/models/test_models_auto.py | 50 +++++++++++++++++++++ 3 files changed, 115 insertions(+) diff --git a/docs/source/en/using-diffusers/automodel.md b/docs/source/en/using-diffusers/automodel.md index d8cea79a10c9..82d4d14a10a9 100644 --- a/docs/source/en/using-diffusers/automodel.md +++ b/docs/source/en/using-diffusers/automodel.md @@ -97,5 +97,32 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th > ) > ``` +### Saving custom models + +Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file. + +```py +# my_model.py +from diffusers import ModelMixin, ConfigMixin + +class MyCustomModel(ModelMixin, ConfigMixin): + ... + +MyCustomModel.register_for_auto_class("AutoModel") + +model = MyCustomModel(...) +model.save_pretrained("./my_model") +``` + +The saved `config.json` will include the `auto_map` field. + +```json +{ + "auto_map": { + "AutoModel": "my_model.MyCustomModel" + } +} +``` + > [!NOTE] > Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide. \ No newline at end of file diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 2545c9db35b2..7a95ce20aaff 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -107,6 +107,38 @@ class ConfigMixin: has_compatibles = False _deprecated_kwargs = [] + _auto_class = None + + @classmethod + def register_for_auto_class(cls, auto_class="AutoModel"): + """ + Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(..., + trust_remote_code=True)`. + + When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class + to this class's module and class name. + + Args: + auto_class (`str` or type, *optional*, defaults to `"AutoModel"`): + The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself. + Currently only `"AutoModel"` is supported. + + Example: + + ```python + from diffusers import ModelMixin, ConfigMixin + + + class MyCustomModel(ModelMixin, ConfigMixin): ... + + + MyCustomModel.register_for_auto_class("AutoModel") + ``` + """ + if auto_class != "AutoModel": + raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.") + + cls._auto_class = auto_class def register_to_config(self, **kwargs): if self.config_name is None: @@ -621,6 +653,12 @@ def to_json_saveable(value): # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. _ = config_dict.pop("_pre_quantization_dtype", None) + if getattr(self, "_auto_class", None) is not None: + module = self.__class__.__module__.split(".")[-1] + auto_map = config_dict.get("auto_map", {}) + auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}" + config_dict["auto_map"] = auto_map + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: str | os.PathLike): diff --git a/tests/models/test_models_auto.py b/tests/models/test_models_auto.py index 57a7760ea841..e35fb26518ef 100644 --- a/tests/models/test_models_auto.py +++ b/tests/models/test_models_auto.py @@ -7,7 +7,9 @@ import torch from transformers import CLIPTextModel, LongformerModel +from diffusers import ConfigMixin from diffusers.models import AutoModel, UNet2DConditionModel +from diffusers.models.modeling_utils import ModelMixin class TestAutoModel(unittest.TestCase): @@ -143,3 +145,51 @@ def test_from_config_with_model_type_routes_to_transformers(self, mock_get_class def test_from_config_raises_on_none(self): with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"): AutoModel.from_config(None) + + +class TestRegisterForAutoClass(unittest.TestCase): + def test_register_for_auto_class_sets_attribute(self): + class DummyModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + DummyModel.register_for_auto_class("AutoModel") + self.assertEqual(DummyModel._auto_class, "AutoModel") + + def test_register_for_auto_class_rejects_unsupported(self): + class DummyModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"): + DummyModel.register_for_auto_class("AutoPipeline") + + def test_auto_map_in_saved_config(self): + class DummyModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + DummyModel.register_for_auto_class("AutoModel") + model = DummyModel() + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_config(tmpdir) + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "r") as f: + config = json.load(f) + + self.assertIn("auto_map", config) + self.assertIn("AutoModel", config["auto_map"]) + module_name = DummyModel.__module__.split(".")[-1] + self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel") + + def test_no_auto_map_without_register(self): + class DummyModel(ModelMixin, ConfigMixin): + config_name = "config.json" + + model = DummyModel() + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_config(tmpdir) + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "r") as f: + config = json.load(f) + + self.assertNotIn("auto_map", config) From eb231a6548039abef6acb0c428f889aaa9a95328 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 2 Mar 2026 22:20:42 +0530 Subject: [PATCH 011/215] [Modular] Save Modular Pipeline weights to Hub (#13168) * update * update * update * update * update * update --- .claude/CLAUDE.md | 100 ++++++ _modular_model_index.json | 75 +++++ custom_model_automodel_guide.md | 239 +++++++++++++++ example.py | 120 ++++++++ modular_model_index.json | 73 +++++ pr_review/12498.md | 56 ++++ pr_review/12744.md | 186 ++++++++++++ pr_review/13028.md | 99 ++++++ pr_review/13075.md | 97 ++++++ pr_review/13116.md | 66 ++++ pr_review/pr_12700_flashpack.md | 144 +++++++++ pr_review/teacache_pr_12652_review.md | 286 ++++++++++++++++++ release_notes/v0.37.0.md | 129 ++++++++ scripts/compare_test_coverage.py | 183 +++++++++++ .../modular_pipelines/modular_pipeline.py | 144 +++++++-- .../modular_pipeline_utils.py | 12 + test_automodel_meta.py | 14 + test_dataclass_config.py | 11 + test_pretrained_config.py | 29 ++ .../test_modular_pipelines_common.py | 76 +++++ 20 files changed, 2112 insertions(+), 27 deletions(-) create mode 100644 .claude/CLAUDE.md create mode 100644 _modular_model_index.json create mode 100644 custom_model_automodel_guide.md create mode 100644 example.py create mode 100644 modular_model_index.json create mode 100644 pr_review/12498.md create mode 100644 pr_review/12744.md create mode 100644 pr_review/13028.md create mode 100644 pr_review/13075.md create mode 100644 pr_review/13116.md create mode 100644 pr_review/pr_12700_flashpack.md create mode 100644 pr_review/teacache_pr_12652_review.md create mode 100644 release_notes/v0.37.0.md create mode 100644 scripts/compare_test_coverage.py create mode 100644 test_automodel_meta.py create mode 100644 test_dataclass_config.py create mode 100644 test_pretrained_config.py diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md new file mode 100644 index 000000000000..ae8010084af7 --- /dev/null +++ b/.claude/CLAUDE.md @@ -0,0 +1,100 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build, Lint, and Test Commands + +```bash +# Install in development mode +pip install -e ".[dev]" + +# Run full test suite (requires beefy machine) +make test +# Or directly: +python -m pytest -n auto --dist=loadfile -s -v ./tests/ + +# Run a single test file +python -m pytest tests/.py + +# Run slow tests (downloads many GBs of models) +RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/ + +# Format code (ruff + doc-builder) +make style + +# Check code quality without modifying +make quality + +# Fast fixup for modified files only (recommended before commits) +make fixup + +# Fix copied code snippets and dummy objects +make fix-copies + +# Check repository consistency (dummies, inits, repo structure) +make repo-consistency +``` + +## Code Architecture + +Diffusers is built on three core component types that work together: + +### Pipelines (`src/diffusers/pipelines/`) +- End-to-end inference workflows combining models and schedulers +- Base class: `DiffusionPipeline` (in `pipeline_utils.py`) +- Follow **single-file policy**: each pipeline in its own directory +- Loaded via `DiffusionPipeline.from_pretrained()` which reads `model_index.json` +- Components registered via `register_modules()` become pipeline attributes +- ~99 pipeline implementations (Stable Diffusion, SDXL, Flux, etc.) + +### Models (`src/diffusers/models/`) +- Configurable neural network architectures extending PyTorch's Module +- Base classes: `ModelMixin` + `ConfigMixin` (in `modeling_utils.py`) +- **Do NOT follow single-file policy**: use shared building blocks (`attention.py`, `embeddings.py`, `resnet.py`) +- Key subdirectories: + - `autoencoders/`: VAEs for latent space compression + - `unets/`: Diffusion model architectures (UNet2DConditionModel, etc.) + - `transformers/`: Transformer-based models (Flux, SD3, etc.) + - `controlnets/`: ControlNet variants + +### Schedulers (`src/diffusers/schedulers/`) +- Guide denoising process during inference +- Base class: `SchedulerMixin` + `ConfigMixin` (in `scheduling_utils.py`) +- Follow **single-file policy**: one scheduler per file +- Key methods: `set_num_inference_steps()`, `step()`, `timesteps` property +- Easily swappable via `ConfigMixin.from_config()` +- ~55 scheduler algorithms (DDPM, DDIM, Euler, DPM-Solver, etc.) + +### Supporting Systems + +- **Loaders** (`src/diffusers/loaders/`): Mixins for LoRA, IP-Adapter, textual inversion, single-file loading +- **Quantizers** (`src/diffusers/quantizers/`): BitsAndBytes, GGUF, TorchAO, Quanto support +- **Hooks** (`src/diffusers/hooks/`): Runtime optimizations (offloading, layer skipping, caching) +- **Guiders** (`src/diffusers/guiders/`): Guidance algorithms (CFG, PAG, etc.) + +## Configuration System + +All components use `ConfigMixin` for serialization: +- Constructor arguments stored via `register_to_config(**kwargs)` +- Instantiate from config: `Component.from_config(config_dict)` +- Save/load as JSON files + +## Key Design Principles + +1. **Usability over Performance**: Models load at float32/CPU by default +2. **Simple over Easy**: Explicit > implicit; expose complexity rather than hide it +3. **Single-file policy**: Pipelines and schedulers are self-contained; models share building blocks +4. **Copy-paste over abstraction**: Prefer duplicated code over hasty abstractions for contributor-friendliness + +## Code Style + +- Uses `ruff` for linting and formatting (line length: 119) +- Documentation follows [Google style](https://google.github.io/styleguide/pyguide.html) +- Use `# Copied from` mechanism for sharing code between similar files +- Avoid lambda functions and advanced PyTorch operators for readability + +## Testing + +- Tests use `pytest` with `pytest-xdist` for parallelization +- Slow tests gated by `RUN_SLOW=yes` environment variable +- Test dependencies: `pip install -e ".[test]"` diff --git a/_modular_model_index.json b/_modular_model_index.json new file mode 100644 index 000000000000..b0eba6916d3d --- /dev/null +++ b/_modular_model_index.json @@ -0,0 +1,75 @@ +{ + "_blocks_class_name": "SequentialPipelineBlocks", + "_class_name": "Flux2ModularPipeline", + "_diffusers_version": "0.36.0.dev0", + "scheduler": [ + "diffusers", + "FlowMatchEulerDiscreteScheduler", + { + "repo": "hf-internal-testing/tiny-flux2", + "revision": null, + "subfolder": "scheduler", + "type_hint": [ + "diffusers", + "FlowMatchEulerDiscreteScheduler" + ], + "variant": null + } + ], + "text_encoder": [ + "transformers", + "Mistral3ForConditionalGeneration", + { + "repo": "hf-internal-testing/tiny-flux2", + "revision": null, + "subfolder": "text_encoder", + "type_hint": [ + "transformers", + "Mistral3ForConditionalGeneration" + ], + "variant": null + } + ], + "tokenizer": [ + "transformers", + "AutoProcessor", + { + "repo": "hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor", + "revision": null, + "subfolder": "", + "type_hint": [ + "transformers", + "AutoProcessor" + ], + "variant": null + } + ], + "transformer": [ + "diffusers", + "Flux2Transformer2DModel", + { + "repo": "hf-internal-testing/tiny-flux2", + "revision": null, + "subfolder": "transformer", + "type_hint": [ + "diffusers", + "Flux2Transformer2DModel" + ], + "variant": null + } + ], + "vae": [ + "diffusers", + "AutoencoderKLFlux2", + { + "repo": "hf-internal-testing/tiny-flux2", + "revision": null, + "subfolder": "vae", + "type_hint": [ + "diffusers", + "AutoencoderKLFlux2" + ], + "variant": null + } + ] +} diff --git a/custom_model_automodel_guide.md b/custom_model_automodel_guide.md new file mode 100644 index 000000000000..66343023e644 --- /dev/null +++ b/custom_model_automodel_guide.md @@ -0,0 +1,239 @@ +# Loading Custom Models with `AutoModel` and `trust_remote_code` + +This guide shows how to create a custom model class that lives outside the `diffusers` library and load it via `AutoModel` with `trust_remote_code=True`. + +## How It Works + +When `AutoModel.from_pretrained()` (or `from_config()`) is called with `trust_remote_code=True`, it: + +1. Loads the `config.json` from the model repository. +2. Checks for an `"auto_map"` key in the config that maps `"AutoModel"` to a `"."` reference. +3. Downloads the referenced Python module from the repository. +4. Dynamically imports and instantiates the class from that module. + +This allows anyone to define and share completely custom model architectures without requiring changes to the `diffusers` library itself. + +## Step 1: Define Your Custom Model + +Create a Python file (e.g., `modeling_my_model.py`) that defines your model class. The class must inherit from `ModelMixin` and `ConfigMixin`, and use the `@register_to_config` decorator on `__init__`. + +```python +# modeling_my_model.py + +import torch +from torch import nn +from diffusers import ModelMixin, ConfigMixin +from diffusers.configuration_utils import register_to_config + + +class MyCustomModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, in_channels: int = 3, hidden_dim: int = 64, out_channels: int = 3): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1), + nn.SiLU(), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.SiLU(), + nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) +``` + +Key requirements: + +- **`ModelMixin`** provides `save_pretrained()` / `from_pretrained()` for weight serialization. +- **`ConfigMixin`** provides `save_config()` / `from_config()` and the `config.json` machinery. +- **`@register_to_config`** automatically captures all `__init__` parameters into `config.json` so the model can be reconstructed from config alone. + +## Step 2: Save the Model Locally + +```python +from modeling_my_model import MyCustomModel + +model = MyCustomModel(in_channels=3, hidden_dim=128, out_channels=3) +model.save_pretrained("./my-custom-model") +``` + +This creates a directory with: + +``` +my-custom-model/ +├── config.json +└── diffusion_pytorch_model.safetensors +``` + +The generated `config.json` will look like: + +```json +{ + "_class_name": "MyCustomModel", + "_diffusers_version": "0.32.0", + "in_channels": 3, + "hidden_dim": 128, + "out_channels": 3 +} +``` + +## Step 3: Add the `auto_map` and Model File to the Repository + +To make `AutoModel` aware of your custom class, you need to: + +1. **Copy `modeling_my_model.py` into the saved model directory.** +2. **Add an `"auto_map"` entry to `config.json`** that points `AutoModel` to your class. + +The `auto_map` value format is `"."`: + +```json +{ + "_class_name": "MyCustomModel", + "_diffusers_version": "0.32.0", + "in_channels": 3, + "hidden_dim": 128, + "out_channels": 3, + "auto_map": { + "AutoModel": "modeling_my_model.MyCustomModel" + } +} +``` + +Your final directory structure should be: + +``` +my-custom-model/ +├── config.json # with auto_map added +├── diffusion_pytorch_model.safetensors +└── modeling_my_model.py # your custom model code +``` + +## Step 4: Load with `AutoModel` + +### From a Local Directory + +```python +from diffusers import AutoModel + +model = AutoModel.from_pretrained("./my-custom-model", trust_remote_code=True) +print(model) +``` + +### From the Hugging Face Hub + +First, push the model directory to a Hub repository: + +```python +from huggingface_hub import HfApi + +api = HfApi() +api.create_repo("your-username/my-custom-model", exist_ok=True) +api.upload_folder( + folder_path="./my-custom-model", + repo_id="your-username/my-custom-model", +) +``` + +Then load it: + +```python +from diffusers import AutoModel + +model = AutoModel.from_pretrained( + "your-username/my-custom-model", + trust_remote_code=True, +) +``` + +### Initializing from Config (Random Weights) + +```python +from diffusers import AutoModel + +model = AutoModel.from_config("./my-custom-model", trust_remote_code=True) +``` + +## Complete Example + +```python +import torch +from torch import nn +from diffusers import ModelMixin, ConfigMixin, AutoModel +from diffusers.configuration_utils import register_to_config + + +# 1. Define +class MyCustomModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, in_channels: int = 3, hidden_dim: int = 64, out_channels: int = 3): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1), + nn.SiLU(), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), + nn.SiLU(), + nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +# 2. Save +model = MyCustomModel(in_channels=3, hidden_dim=128, out_channels=3) +model.save_pretrained("./my-custom-model") + +# 3. Manually add auto_map to config.json and copy modeling file +import json, shutil + +config_path = "./my-custom-model/config.json" +with open(config_path) as f: + config = json.load(f) + +config["auto_map"] = {"AutoModel": "modeling_my_model.MyCustomModel"} + +with open(config_path, "w") as f: + json.dump(config, f, indent=2) + +shutil.copy("modeling_my_model.py", "./my-custom-model/modeling_my_model.py") + +# 4. Load via AutoModel +loaded_model = AutoModel.from_pretrained("./my-custom-model", trust_remote_code=True) + +# 5. Verify +x = torch.randn(1, 3, 32, 32) +with torch.no_grad(): + out_original = model(x) + out_loaded = loaded_model(x) + +assert torch.allclose(out_original, out_loaded) +print("Models produce identical outputs!") +``` + +## Using Relative Imports in Custom Code + +If your custom model depends on additional modules, you can use relative imports. For example, if your model uses a custom attention layer defined in a separate file: + +``` +my-custom-model/ +├── config.json +├── diffusion_pytorch_model.safetensors +├── modeling_my_model.py # imports from .my_attention +└── my_attention.py # custom attention implementation +``` + +In `modeling_my_model.py`: + +```python +from .my_attention import MyAttention +``` + +The dynamic module loader will automatically resolve and download all relatively imported files. + +## Security Note + +`trust_remote_code=True` executes arbitrary Python code from the model repository. Only use it with repositories you trust. You can globally disable remote code execution by setting the environment variable: + +```bash +export DIFFUSERS_DISABLE_REMOTE_CODE=1 +``` diff --git a/example.py b/example.py new file mode 100644 index 000000000000..bb0a5b430e3a --- /dev/null +++ b/example.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from diffusers import QwenImageTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + ContextParallelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class QwenImageTransformerTesterConfig: + model_class = QwenImageTransformer2DModel + pretrained_model_name_or_path = "" + pretrained_model_kwargs = {"subfolder": "transformer"} + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + # __init__ parameters: + # patch_size: int = 2 + # in_channels: int = 64 + # out_channels: Optional[int] = 16 + # num_layers: int = 60 + # attention_head_dim: int = 128 + # num_attention_heads: int = 24 + # joint_attention_dim: int = 3584 + # guidance_embeds: bool = False + # axes_dims_rope: Tuple[int, int, int] = + return {} + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + # forward() parameters: + # hidden_states: torch.Tensor + # encoder_hidden_states: torch.Tensor + # encoder_hidden_states_mask: torch.Tensor + # timestep: torch.LongTensor + # img_shapes: Optional[List[Tuple[int, int, int]]] + # txt_seq_lens: Optional[List[int]] + # guidance: torch.Tensor + # attention_kwargs: Optional[Dict[str, Any]] + # controlnet_block_samples + # return_dict: bool = True + # TODO: Fill in dummy inputs + return {} + + @property + def input_shape(self) -> tuple[int, ...]: + return (1, 1) + + @property + def output_shape(self) -> tuple[int, ...]: + return (1, 1) + + +class TestQwenImageTransformerModel(QwenImageTransformerTesterConfig, ModelTesterMixin): + pass + + +class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin): + pass + + +class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin): + pass + + +class TestQwenImageTransformerTorchCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin): + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + # TODO: Implement dynamic input generation + return {} + + +class TestQwenImageTransformerLora(QwenImageTransformerTesterConfig, LoraTesterMixin): + pass + + +class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): + pass + + +class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin): + pass + + +class TestQwenImageTransformerLoraHotSwappingForModel(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + # TODO: Implement dynamic input generation + return {} diff --git a/modular_model_index.json b/modular_model_index.json new file mode 100644 index 000000000000..48872b1c98d7 --- /dev/null +++ b/modular_model_index.json @@ -0,0 +1,73 @@ +{ + "_blocks_class_name": "SequentialPipelineBlocks", + "_class_name": "Flux2ModularPipeline", + "_diffusers_version": "0.36.0.dev0", + "scheduler": [ + null, + null, + { + "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev", + "revision": null, + "subfolder": "scheduler", + "type_hint": [ + "diffusers", + "FlowMatchEulerDiscreteScheduler" + ], + "variant": null + } + ], + "text_encoder": [ + null, + null, + { + "revision": null, + "subfolder": "text_encoder", + "type_hint": [ + "transformers", + "Mistral3ForConditionalGeneration" + ], + "variant": null + } + ], + "tokenizer": [ + null, + null, + { + "revision": null, + "subfolder": "tokenizer", + "type_hint": [ + "transformers", + "AutoProcessor" + ], + "variant": null + } + ], + "transformer": [ + null, + null, + { + "pretrained_model_name_or_path": "diffusers/FLUX.2-dev-bnb-4bit", + "revision": null, + "subfolder": "transformer", + "type_hint": [ + "diffusers", + "Flux2Transformer2DModel" + ], + "variant": null + } + ], + "vae": [ + null, + null, + { + "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev", + "revision": null, + "subfolder": "vae", + "type_hint": [ + "diffusers", + "AutoencoderKLFlux2" + ], + "variant": null + } + ] +} diff --git a/pr_review/12498.md b/pr_review/12498.md new file mode 100644 index 000000000000..1f52d1f7acc9 --- /dev/null +++ b/pr_review/12498.md @@ -0,0 +1,56 @@ +Now let me look at what `dequantize_gguf_tensor` does to understand the fix better: + +## Code Review: GGUF fix for unquantized types + +### 1. Summary of Changes + +The PR fixes a bug in the `_fused_mul_mat_gguf` function (line 79-105) where unquantized GGUF tensor types (F32, F16, BF16) were incorrectly handled. + +**Before:** When `qweight_type` was an unquantized type, the code directly performed matrix multiplication: `x @ qweight.T` + +**After:** It now calls `dequantize_gguf_tensor(qweight)` first, then performs the matrix multiplication: `x @ weight.T` + +The issue was that even "unquantized" GGUF tensors are stored in an 8-bit tensor format and need to be converted to their proper data type representation before use. + +### 2. Potential Issues or Bugs + +**None identified.** The fix is correct and addresses a real bug: + +- The `dequantize_gguf_tensor` function (lines 509-527) checks if the tensor has a `quant_type` attribute and handles the appropriate conversion +- For BF16 specifically, there's a dedicated `dequantize_blocks_BF16` function (lines 428-429) that properly converts the 8-bit storage format +- The fix aligns with how the native path already works in `forward_native` (lines 593-599), which always calls `dequantize_gguf_tensor` + +### 3. Code Quality Observations + +**Strengths:** +- The fix is minimal and surgical - only changes what's necessary +- Maintains consistency with the `forward_native` path which already uses `dequantize_gguf_tensor` +- The variable naming (`weight` instead of reusing `qweight`) makes it clear a transformation occurred + +**Minor observation:** +- The comment on line 80 "there is no need to call any kernel for fp16/bf16" is now slightly misleading since we DO need to call dequantization logic. Consider updating it to something like: "no need to call specialized GGUF kernel for fp16/bf16, but still need to dequantize from 8-bit storage" + +### 4. Security Considerations + +**No security concerns.** The change: +- Doesn't introduce any external input handling +- Doesn't modify control flow in a way that could bypass security checks +- Only fixes a data type conversion issue + +### 5. Suggestions for Improvement + +1. **Update the comment** on line 80 in `src/diffusers/quantizers/gguf/utils.py:80`: + ```python + # unquantized types still need dequantization from 8-bit storage, but don't need specialized kernels + if qweight_type in UNQUANTIZED_TYPES: + weight = dequantize_gguf_tensor(qweight) + return x @ weight.T + ``` + +2. **Consider adding a test** to prevent regression of this issue. A test should verify that unquantized GGUF tensors produce correct output shapes and values. + +3. **Documentation:** The PR description mentions torch 2.8/2.9 build availability. This might be worth tracking in a GitHub issue if not already done. + +### Verdict + +**Approve with minor comment update suggestion.** The fix correctly addresses a real shape mismatch bug where GGUF's 8-bit storage format wasn't being properly converted for unquantized types. The logic is sound and aligns with the existing native implementation path. diff --git a/pr_review/12744.md b/pr_review/12744.md new file mode 100644 index 000000000000..30e2ac1a71b6 --- /dev/null +++ b/pr_review/12744.md @@ -0,0 +1,186 @@ +I'll provide a comprehensive code review of this MagCache PR. + +## Summary of Changes + +This PR implements MagCache (Magnitude-aware Cache), a training-free inference acceleration technique for diffusion transformers. The implementation: + +- Adds a `MagCacheConfig` class for configuration +- Implements `MagCacheHeadHook` and `MagCacheBlockHook` following the existing ModelHook pattern +- Includes calibration mode to compute magnitude ratios for any transformer model +- Provides pre-computed `FLUX_MAG_RATIOS` for Flux models +- Adds comprehensive documentation and tests + +## Potential Issues and Bugs + +### 1. **Critical: Missing Hook Removal in `disable_cache()`** +```python +# In cache_utils.py, line ~127 +elif isinstance(self._cache_config, MagCacheConfig): + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True) +``` + +**Issue**: The code only removes the leader/head hook but not the block hooks (`_MAG_CACHE_BLOCK_HOOK`). This will leave hooks attached when disabling the cache. + +**Fix**: Add removal of block hooks: +```python +elif isinstance(self._cache_config, MagCacheConfig): + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True) + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) +``` + +### 2. **Shape Mismatch Handling Logic Issue** +In `mag_cache.py` lines 224-248, the shape mismatch handling has a potential issue: + +```python +elif ( + output.ndim == 3 + and res.ndim == 3 + and output.shape[0] == res.shape[0] + and output.shape[2] == res.shape[2] +): + diff = output.shape[1] - res.shape[1] + if diff > 0: + output = output.clone() + output[:, diff:, :] = output[:, diff:, :] + res +``` + +**Issue**: This assumes text tokens come first and image tokens come last. This may not be universal across all models (e.g., some models interleave tokens differently). + +**Suggestion**: Add a comment explaining this assumption or add configuration to specify the concatenation strategy. + +### 3. **Residual Calculation Fallback is Unsafe** +In `mag_cache.py` line 343: + +```python +else: + # Fallback for completely mismatched shapes + residual = out_hidden +``` + +**Issue**: This fallback doesn't compute a residual at all—it just uses the output. This will cause incorrect behavior in subsequent steps. + +**Suggestion**: Either raise an error or add a warning that calibration is required for this model architecture. + +### 4. **Device Mismatch Handling is Incomplete** +```python +if res.device != output.device: + res = res.to(output.device) +``` + +**Issue**: This only handles device mismatch for the residual, but doesn't handle dtype mismatches which could occur with mixed precision training. + +**Suggestion**: Add dtype handling: +```python +if res.device != output.device or res.dtype != output.dtype: + res = res.to(device=output.device, dtype=output.dtype) +``` + +### 5. **Calibration Logging Could Be Missed** +The calibration results are printed to stdout (line 380) and logged. However, if the user has logging disabled or redirected, they might miss this critical information. + +**Suggestion**: Consider returning calibration results from the pipeline or raising a more visible notification. + +### 6. **Test Suite is Skipped** +```python +@unittest.skip("MagCache unit tests are skipped.") +class MagCacheTests(unittest.TestCase): +``` + +**Issue**: All unit tests are skipped, which means the core logic isn't being validated in CI. + +**Action Required**: Remove the skip decorator before merging or add a comment explaining why it's temporarily skipped. + +## Code Quality Observations + +### Strengths: +1. **Well-structured**: Follows existing patterns (ModelHook, StateManager) consistently +2. **Good documentation**: Comprehensive docstrings and inline comments +3. **Calibration mode**: Clever design allowing model-agnostic usage +4. **Error handling**: Validates configuration upfront +5. **Interpolation logic**: Smart handling of different step counts via `nearest_interp()` + +### Areas for Improvement: + +1. **Magic Numbers**: Several hardcoded values could be constants: + ```python + eps = 1e-8 # Line 335 in _perform_calibration_step + expected_atol = 0.1 # Line 2989 in test + ``` + +2. **Code Duplication**: The logic for handling tuple returns appears multiple times. Consider extracting to a helper method. + +3. **Type Hints**: Some methods lack return type hints (e.g., `nearest_interp`) + +4. **Compiler Disable Decorator**: The `@torch.compiler.disable` decorator is used but not explained. Add a comment about why compilation is disabled. + +## Security Considerations + +### Low Risk: +- No external network calls +- No file system access beyond logging +- No execution of arbitrary code +- Tensor operations are standard PyTorch + +### Observations: +1. **Device Transfer**: The `.to(device)` calls are safe but could consume unexpected memory if tensors are large +2. **State Management**: The state is properly isolated and reset between inference runs + +## Suggestions for Improvement + +### 1. Add Configuration Validation +```python +def __post_init__(self): + # Existing checks... + + # Add bounds checking + if not 0.0 <= self.retention_ratio <= 1.0: + raise ValueError(f"retention_ratio must be in [0, 1], got {self.retention_ratio}") + if self.max_skip_steps < 1: + raise ValueError(f"max_skip_steps must be >= 1, got {self.max_skip_steps}") + if self.threshold <= 0: + raise ValueError(f"threshold must be positive, got {self.threshold}") +``` + +### 2. Add Metrics/Statistics +Consider adding optional statistics collection: +- How many blocks were skipped per step +- Average accumulated error +- Total compute savings + +This would help users optimize their thresholds. + +### 3. Improve Documentation Example +The documentation example could show expected speedup or quality metrics to set user expectations. + +### 4. Add Gradient Mode Check +```python +if torch.is_grad_enabled(): + logger.warning("MagCache is designed for inference only. Gradients are enabled but will not flow correctly through cached blocks.") +``` + +### 5. Consider Memory Cleanup +The `previous_residual` is held in state indefinitely. Consider adding explicit cleanup: +```python +def cleanup(self): + if self.previous_residual is not None: + del self.previous_residual + self.previous_residual = None +``` + +## Minor Issues + +1. **Line 26**: Unused import or should be used in logger initialization +2. **Line 332**: Comment says "Fallback to matching tail" but logic is unclear +3. **Documentation**: The TIP about batched CFG could include more detail about why this works + +## Conclusion + +This is a **well-implemented feature** with good design patterns and documentation. The main concerns are: + +1. **Critical**: Fix the missing block hook removal in `disable_cache()` (Line 127) +2. **Important**: Unskip and fix the unit tests +3. **Recommended**: Improve shape mismatch handling with better error messages + +The implementation is production-ready once these issues are addressed. The calibration mode is particularly clever and makes this genuinely model-agnostic. + +**Recommendation**: Request changes for items #1 and #2, then approve once fixed. diff --git a/pr_review/13028.md b/pr_review/13028.md new file mode 100644 index 000000000000..7988498aecf1 --- /dev/null +++ b/pr_review/13028.md @@ -0,0 +1,99 @@ +# PR #13028: [Modular] add explicit workflow support + +**Author:** @yiyixuxu +**Branch:** `modular-workflow` -> `main` +**Files changed:** `modular_pipeline.py`, `modular_pipeline_utils.py`, `qwenimage/modular_blocks_qwenimage.py` +**+298 / -165** + +--- + +## Summary + +This PR adds a `_workflow_map` class attribute to `SequentialPipelineBlocks` that maps named workflows (e.g., `"text2image"`, `"inpainting"`) to their trigger inputs. Users can then call `get_workflow("text2image")` to get the execution blocks for that workflow. The PR also refactors `get_execution_blocks` into `ConditionalPipelineBlocks` and `SequentialPipelineBlocks`, moves `combine_inputs`/`combine_outputs` to module-level functions, and improves docstrings. + +## Main Concern: "Workflow" as a New Concept + +Modular Diffusers already requires users to learn: **Pipelines**, **Blocks** (Sequential, Conditional, Auto, Loop), **Steps**, **Components**, **Inputs/Outputs**, **Trigger Inputs**, **Execution Blocks**, **PipelineState**, and **BlockState**. Adding "workflow" as yet another term increases cognitive overhead. + +The underlying feature is useful — named presets for trigger inputs are genuinely helpful for discoverability. But "workflow" may not be the right label: + +1. **Overloaded term**: "Workflow" is heavily used in the AI/ML ecosystem (ComfyUI workflows, orchestration workflows, CI/CD workflows). Users may expect something more complex than what this is. + +2. **It's really a task/mode, not a workflow**: `"text2image"`, `"inpainting"`, `"image2image"` are *tasks* or *modes*. The rest of diffusers already uses "task" terminology — `AutoPipelineForText2Image`, `AutoPipelineForInpainting`, etc. Calling the same concept "workflow" in Modular Diffusers creates inconsistency. + +3. **It's a thin wrapper**: `get_workflow("text2image")` is just `get_execution_blocks(prompt=True)`. Users still need to understand `get_execution_blocks` and trigger inputs to do anything beyond the predefined workflows. The abstraction doesn't save much complexity. + +**Suggestion**: Consider `_task_map` / `get_task()` / `task_names` to align with existing diffusers terminology, or `_mode_map` / `get_mode()` / `mode_names` for something more neutral. The existing `auto_pipeline.py` already uses "task" internally — `_get_task_class()` maps pipeline class names to task-specific variants (text2image, image2image, inpainting), and the public API follows the `AutoPipelineFor` naming pattern. These are the exact same concepts this PR calls "workflows." Alternatively, this could simply be better documentation on `get_execution_blocks` with named examples, rather than a new API surface. + +## Code Issues + +### Behavioral change: `outputs` -> `intermediate_outputs` in traversal + +`modular_pipeline.py` — In `SequentialPipelineBlocks.get_execution_blocks`, the old `_traverse_trigger_blocks` tracked `block.outputs` to propagate available values to downstream blocks. The new code tracks `block.intermediate_outputs` instead: + +```python +# Old +if hasattr(block, "outputs"): + for out in block.outputs: + active_inputs[out.name] = True + +# New +if hasattr(block, "intermediate_outputs"): + for out in block.intermediate_outputs: + active_inputs[out.name] = True +``` + +`intermediate_outputs` and `outputs` can differ — `intermediate_outputs` includes values passed between blocks in the pipeline state, while `outputs` are the final outputs. This could change which downstream conditional blocks get triggered. If this is intentional, it should be called out explicitly in the PR description since it affects existing behavior. + +### `_workflow_map` on base class, implementations only on `SequentialPipelineBlocks` + +`_workflow_map = None` is defined on `ModularPipelineBlocks` (the base class), but `workflow_names` and `get_workflow()` are only implemented on `SequentialPipelineBlocks`. The base class stubs raise `NotImplementedError`. This is misleading — it suggests workflows *could* be implemented for other block types. If workflows are intentionally only for `SequentialPipelineBlocks`, define `_workflow_map` there and don't add stubs to the base class. + +### `get_execution_blocks` no longer filters None values + +Old code: +```python +active_inputs = {k: v for k, v in kwargs.items() if v is not None} +``` + +New code: +```python +active_inputs = dict(kwargs) +``` + +This is a behavioral change to the public `get_execution_blocks` API. The old code explicitly stripped `None` values so users could write `get_execution_blocks(prompt="a cat", image=None)` and `image` wouldn't trigger anything. The new code passes `None` through. It happens to still work because `select_block` checks `is not None` internally, but callers can no longer rely on the documented filtering behavior. This should be noted. + +### `default_block_name` changed from property to instance attribute + +In `AutoPipelineBlocks`, `default_block_name` was a `@property` that derived the default from `block_trigger_inputs` on every access. It's now set as an instance attribute in `__init__`. This is mostly fine, but the new code also adds a validation that `default_block_name is not None` raises an error before it's set — so subclasses that accidentally set `default_block_name` as a class attribute will now break. This is a stricter contract that should be documented. + +### Typo + +`modular_pipeline.py` — `# currentlyonly ConditionalPipelineBlocks` should be `# currently only`. + +### `_get_trigger_inputs()` called multiple times in `__repr__` + +In `SequentialPipelineBlocks.__repr__`, `self._get_trigger_inputs()` is called 3 times (condition check, trigger inputs display, example input). This recursively traverses all blocks each time. Should be computed once and reused. + +### Duplicate `format_workflow` calls in `__repr__` and `doc` + +Both `SequentialPipelineBlocks.__repr__` and `SequentialPipelineBlocks.doc` build the description + workflow string independently with identical logic: + +```python +description = self.description +if self._workflow_map is not None: + workflow_str = format_workflow(self._workflow_map) + description = f"{self.description}\n\n{workflow_str}" +``` + +This should be extracted into a property or helper. + +### No tests + +The PR description mentions "I will add a test suite for this too!" but there are no tests included. Workflow resolution, edge cases (empty workflow map, missing workflow name, workflows with overlapping triggers), and the `get_execution_blocks` refactoring should all be tested before merge. + +## Refactoring Quality + +The refactoring of `get_execution_blocks` from a monolithic method on `SequentialPipelineBlocks` into separate implementations on `ConditionalPipelineBlocks` and `SequentialPipelineBlocks` is a good separation of concerns. Moving `combine_inputs`/`combine_outputs` to module-level functions is also reasonable since they don't depend on instance state. + +The improved `AutoPipelineBlocks` docstring with the example is a significant documentation improvement. diff --git a/pr_review/13075.md b/pr_review/13075.md new file mode 100644 index 000000000000..23b042b0cbfa --- /dev/null +++ b/pr_review/13075.md @@ -0,0 +1,97 @@ +I'll review this PR that addresses PyTorch version compatibility for distributed operations. + +## Summary of Changes + +The PR refactors the `gather_size_by_comm` function in `_modeling_parallel.py` to handle PyTorch versions prior to 2.6 that don't have the `torch.accelerator` API. The changes replace a single ternary expression with a multi-level conditional that: + +1. First checks if "cpu" is in the backend string +2. Then checks if `torch.accelerator` exists (PyTorch >= 2.6) +3. Falls back to CUDA as a default device + +## Potential Issues or Bugs + +**1. Device Type Inconsistency** +The original code returns a string `"cpu"` but the new code returns `torch.device("cuda")` objects. This inconsistency could cause issues: + +```python +gather_device = "cpu" # str +# vs +gather_device = torch.device("cuda") # torch.device object +``` + +**Recommendation:** Use `torch.device()` consistently: +```python +if "cpu" in comm_backends: + gather_device = torch.device("cpu") +elif hasattr(torch, "accelerator"): + acc = torch.accelerator.current_accelerator() + gather_device = torch.device(acc if acc is not None else "cuda") +else: + gather_device = torch.device("cuda") +``` + +**2. Unclear Accelerator Return Behavior** +The comment states "Fall back to CUDA when no accelerator is returned" but it's unclear when `torch.accelerator.current_accelerator()` would return `None`. This should be verified or documented. + +**3. Missing Type Information** +What type does `torch.accelerator.current_accelerator()` return? If it returns a string like `"cuda"` or `"mps"`, the code should handle it consistently. If it returns a device object, the logic might need adjustment. + +## Code Quality Observations + +**Positive:** +- Clear comments explaining the fallback logic +- Proper use of `hasattr()` for backward compatibility +- Addresses the reported issue #13074 + +**Areas for Improvement:** + +1. **Device type consistency** (mentioned above) + +2. **Consider alternative hardware accelerators:** The fallback to CUDA might not be appropriate for all systems (e.g., MPS on macOS, XPU on Intel). Consider: + ```python + else: + # Fallback for PyTorch < 2.6 + if torch.cuda.is_available(): + gather_device = torch.device("cuda") + else: + gather_device = torch.device("cpu") + ``` + +3. **Code style:** The expanded conditional is more readable but could benefit from extracting into a helper function if this pattern appears elsewhere: + ```python + def _get_gather_device(comm_backends: str) -> torch.device: + """Determine device for distributed gather operations.""" + # ... implementation + ``` + +## Security Considerations + +No significant security issues identified. This is primarily a compatibility fix for internal device selection logic. + +## Suggestions for Improvement + +1. **Add a test case** to verify behavior on PyTorch < 2.6 (if not already covered) + +2. **Document the behavior** more explicitly: + ```python + # Determine gather device based on backend and PyTorch version + # Priority: CPU backend > torch.accelerator (>= 2.6) > CUDA fallback (< 2.6) + ``` + +3. **Consider this more defensive approach:** + ```python + if "cpu" in comm_backends: + gather_device = torch.device("cpu") + elif hasattr(torch, "accelerator"): + acc = torch.accelerator.current_accelerator() + gather_device = torch.device(acc if acc else "cuda") + elif torch.cuda.is_available(): + gather_device = torch.device("cuda") + else: + # Fallback to CPU if no GPU available + gather_device = torch.device("cpu") + ``` + +## Verdict + +The PR addresses the compatibility issue but has a **type inconsistency bug** that should be fixed before merging. The string vs `torch.device` object mismatch could cause runtime errors. Once that's addressed, the change is sound for backward compatibility. diff --git a/pr_review/13116.md b/pr_review/13116.md new file mode 100644 index 000000000000..664550cc45c5 --- /dev/null +++ b/pr_review/13116.md @@ -0,0 +1,66 @@ +# PR #13116: [tests] tests for `modules_to_not_convert` + +**Author:** @sayakpaul +**Branch:** `fix-modules-no-convert-torchao` -> `main` +**Files changed:** `tests/models/testing_utils/quantization.py`, `tests/models/transformers/test_models_transformer_flux.py` + +--- + +## Summary + +This PR fixes the `modules_to_not_convert` tests that were effectively dead code. They existed in the base `QuantizationTesterMixin` but never ran because no test class defined `modules_to_not_convert_for_test`. The PR activates these tests for Flux and fixes several underlying bugs that would have caused them to fail. + +## Key Changes + +1. **BnB config key fix**: `BitsAndBytesConfig` uses `llm_int8_skip_modules`, not `modules_to_not_convert`. The base test was setting the wrong key, so modules were never actually excluded. + +2. **TorchAO `_verify_if_layer_quantized` fix**: Previously only checked `isinstance(module, torch.nn.Linear)`, which is always true for TorchAO (it doesn't replace the module class). Now properly checks weight tensor types (`AffineQuantizedTensor`, `LinearActivationQuantizedTensor`). + +3. **`_is_module_quantized` fix**: Now passes `quant_config_kwargs` to `_verify_if_layer_quantized`. Previously it passed `{}`, which caused BnB to always check for `Int8Params` even on 4-bit models. + +4. **Cleanup**: Removes unused guard blocks (`is_gguf_available`, `is_torchao_available`) that only contained `pass`. + +5. **Activates tests**: Adds `modules_to_not_convert_for_test` returning `["norm_out.linear"]` to BnB, Quanto, TorchAo, and ModelOpt Flux test classes. + +## Issues + +### `to_not_convert_key` parameter pollutes the base class interface + +`quantization.py:271-273` — The new `to_not_convert_key` parameter on `_test_quantization_modules_to_not_convert` exists solely for BnB's naming quirk (`llm_int8_skip_modules` vs `modules_to_not_convert`). Every other backend uses the default. This leaks a BnB-specific detail into the shared base method. + +BnB already has its own `test_bnb_modules_to_not_convert` that could handle the key translation locally — either by building the correct `config_kwargs` with `llm_int8_skip_modules` before calling `_create_quantized_model` directly, or by overriding the test. This keeps the base method clean and isolates BnB's naming quirk in `BitsAndBytesTesterMixin` where it belongs. + +### Code duplication in TorchAO `test_torchao_modules_to_not_convert` + +`quantization.py:915-950` — The TorchAO test inlines ~30 lines from `_test_quantization_modules_to_not_convert` to skip the memory footprint comparison. If the base method is updated in the future, this copy won't get the fix. Consider parameterizing the base method instead: + +```python +def _test_quantization_modules_to_not_convert( + self, config_kwargs, modules_to_not_convert, check_memory_footprint=True, +): + # ... existing module-walking logic ... + + if check_memory_footprint: + # Compare memory footprint with fully quantized model + ... +``` + +Then TorchAO could simply call: +```python +self._test_quantization_modules_to_not_convert( + TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude, + check_memory_footprint=False, +) +``` + +### TorchAO imports inside method body + +`quantization.py:822-823` — The `torchao` imports are placed inside `_verify_if_layer_quantized`. While functional (avoids import errors when torchao isn't installed), these could be placed at module level under the existing `is_torchao_available()` guard for consistency with how `bnb` and `QLinear` imports are handled. Minor style point. + +### `_is_module_quantized` callers not updated + +`quantization.py:368` — The `_test_dequantize` method still calls `self._is_module_quantized(module)` without `quant_config_kwargs`. This happens to work correctly (for BnB, checking `Int8Params` after dequantization correctly returns False; for TorchAO, the weight won't be an `AffineQuantizedTensor`), but it means BnB dequantize for 4-bit models asserts the weight is not `Int8Params` rather than asserting it's not `Params4bit`. Consider updating for correctness. + +### Missing GGUF test coverage + +GGUF's `GGUFTesterMixin` doesn't have a `test_gguf_modules_to_not_convert` method. If GGUF is expected to support `modules_to_not_convert`, a test should be added. If not, a comment explaining why would be helpful. diff --git a/pr_review/pr_12700_flashpack.md b/pr_review/pr_12700_flashpack.md new file mode 100644 index 000000000000..975fbd6ca18c --- /dev/null +++ b/pr_review/pr_12700_flashpack.md @@ -0,0 +1,144 @@ +# PR #12700 — FlashPack Integration Review + +**URL**: https://github.com/huggingface/diffusers/pull/12700 +**State**: OPEN +**Branch**: `flashpack` → `main` + +## Summary + +Adds FlashPack as a new weight serialization format for faster model loading. FlashPack packs model weights into a single contiguous file (`model.flashpack`) that can be loaded efficiently, especially for larger models. The PR integrates it across `ModelMixin` (save/load), `DiffusionPipeline` (save/load/download), and supporting utilities. + +## Files Changed + +- `setup.py` / `dependency_versions_table.py` — add `flashpack` dependency +- `src/diffusers/utils/constants.py` — `FLASHPACK_WEIGHTS_NAME`, `FLASHPACK_FILE_EXTENSION` +- `src/diffusers/utils/import_utils.py` — `is_flashpack_available()` +- `src/diffusers/utils/__init__.py` — re-exports +- `src/diffusers/models/model_loading_utils.py` — `load_flashpack_checkpoint()`, dispatch in `load_state_dict()` +- `src/diffusers/models/modeling_utils.py` — `save_pretrained(use_flashpack=...)`, `from_pretrained(use_flashpack=..., flashpack_kwargs=...)` +- `src/diffusers/pipelines/pipeline_utils.py` — pipeline-level `save_pretrained`, `from_pretrained`, `download` with `use_flashpack` +- `src/diffusers/pipelines/pipeline_loading_utils.py` — `load_sub_model`, `_get_ignore_patterns`, `get_class_obj_and_candidates`, `filter_model_files` + +--- + +## Issues + +### 1. `use_flashpack=True` default in `DiffusionPipeline.download()` + +```python +# pipeline_utils.py, in download() +use_flashpack = kwargs.pop("use_flashpack", True) +``` + +This defaults to `True`, meaning `download()` will always try to download FlashPack files by default. Every other call site defaults to `False`. This looks like a bug — it would change download behavior for all users even if they never asked for FlashPack. Should be `False`. + +### 2. `load_flashpack_checkpoint` is unused in the `from_pretrained` path + +`load_flashpack_checkpoint()` is added to `model_loading_utils.py` and wired into `load_state_dict()`. However, in `ModelMixin.from_pretrained`, when `use_flashpack=True`, the code **early-returns** after calling `flashpack.mixin.assign_from_file()` directly — it never goes through `load_state_dict()`. So `load_flashpack_checkpoint` is dead code in the `from_pretrained` flow. Either: +- Remove it if FlashPack always uses its own assign path, or +- Use it consistently (load state dict → assign to model, like safetensors/pickle) + +### 3. `resolved_model_file` may be undefined when `use_flashpack=True` and file fetch fails + +```python +# modeling_utils.py, from_pretrained +elif use_flashpack: + try: + resolved_model_file = _get_model_file(...) + except IOError as e: + logger.error(...) + if not allow_pickle: + raise + logger.warning("Defaulting to unsafe serialization...") +``` + +If the `IOError` is caught and `allow_pickle` is truthy, `resolved_model_file` is never set but is used later at `flashpack.mixin.assign_from_file(model=model, path=resolved_model_file[0], ...)`. This would crash with `NameError` or `UnboundLocalError`. The fallback logic (copied from the safetensors block) doesn't make sense for FlashPack — there's no pickle fallback for FlashPack. The `except` block should just re-raise unconditionally. + +### 4. `resolved_model_file[0]` assumes a list, but `_get_model_file` returns a string + +```python +flashpack.mixin.assign_from_file( + model=model, + path=resolved_model_file[0], # indexing into a string + ... +) +``` + +`_get_model_file` returns a single file path (string), not a list. `resolved_model_file[0]` would give the first character of the path. Should be just `resolved_model_file`. + +### 5. `device_map` handling assumes `device_map[""]` exists + +```python +flashpack_device = device_map[""] +``` + +`device_map` can be a dict with arbitrary keys (layer names, module names), not just `{"": device}`. This would raise `KeyError` for any non-trivial device map. Should handle the general case or document the constraint. + +### 6. `FlashPack` prefix stripping in `get_class_obj_and_candidates` is unexplained + +```python +if class_name.startswith("FlashPack"): + class_name = class_name.removeprefix("FlashPack") +``` + +This is injected into a general-purpose utility function with no explanation of when/why a class name would have a `FlashPack` prefix. This seems like it handles a specific config format but there's no corresponding code that writes `FlashPack`-prefixed class names. If this is for some external convention, it should be documented. If not needed, remove it. + +### 7. Duplicated availability check pattern + +The `is_flashpack_available()` check + import + error message pattern is repeated 3 times: +- `load_flashpack_checkpoint()` in `model_loading_utils.py` +- `save_pretrained()` in `modeling_utils.py` +- `from_pretrained()` in `modeling_utils.py` + +Each has slightly different wording. Should be consolidated — e.g., a helper or just use a single `require_flashpack()` function, consistent with how other optional deps are handled. + +### 8. `save_pretrained` error message says "load" instead of "save" + +```python +# modeling_utils.py, save_pretrained, use_flashpack=True branch +raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.") +``` + +This is in the **save** path, but the message says "load". Should say "save". + +### 9. No `config.json` saved alongside FlashPack weights in `save_pretrained` + +When `use_flashpack=True` in `ModelMixin.save_pretrained`, the model config is saved normally at the top of the method, but the FlashPack branch calls `flashpack.serialization.pack_to_file()` with `target_dtype=self.dtype`. It's not clear if FlashPack's own `config.json` (mentioned in the benchmark script as `flashpack_config.json`) is the same as diffusers' `config.json`. If they're different files, loading back with `from_pretrained(use_flashpack=True)` might fail to reconstruct the model architecture since `from_config` needs the diffusers config. + +### 10. `output_loading_info` warning placement + +```python +if output_loading_info: + logger.warning("`output_loading_info` is not supported with FlashPack.") + return model, {} +``` + +This returns an empty dict silently. The warning is fine, but returning `{}` instead of a proper `loading_info` structure (with `missing_keys`, `unexpected_keys`, etc.) could break code that destructures the result. + +### 11. No tests included + +The PR has no test files. At minimum there should be: +- Unit tests for `load_flashpack_checkpoint` (mocking `flashpack`) +- Unit tests for save/load roundtrip with `use_flashpack=True` +- Integration test for pipeline save/load + +### 12. FlashPack doesn't support sharding + +The `save_pretrained` FlashPack branch ignores `max_shard_size` entirely and always saves a single file. This is fine for the format but should either: +- Log a warning if `max_shard_size` is explicitly set alongside `use_flashpack=True` +- Document this limitation + +--- + +## Minor Issues + +- The benchmark in the PR description shows FlashPack is actually **slower** for fp16 SD v1.5 (0.95x). The claimed benefit is only for bf16. This should be prominently noted. +- `FLASHPACK_WEIGHTS_NAME = "model.flashpack"` breaks the diffusers naming convention (`diffusion_pytorch_model.*` for other formats). +- The PR modifies `_get_ignore_patterns` but doesn't handle the case where both `use_safetensors` and `use_flashpack` are True. +- `filter_model_files` adds `FLASHPACK_WEIGHTS_NAME` to the known weights list but there are no corresponding tests for this filtering. + +--- + +## Verdict + +The PR needs significant work before it's mergeable. The critical issues are the `use_flashpack=True` default in `download()`, the `resolved_model_file[0]` indexing bug, the dead code path with `load_flashpack_checkpoint`, and the lack of tests. The integration pattern also doesn't feel consistent with how other formats (safetensors, GGUF) are integrated — FlashPack bypasses the standard state dict loading path entirely via its own `assign_from_file`, making it a special case that's harder to maintain. diff --git a/pr_review/teacache_pr_12652_review.md b/pr_review/teacache_pr_12652_review.md new file mode 100644 index 000000000000..1cd76e9637f1 --- /dev/null +++ b/pr_review/teacache_pr_12652_review.md @@ -0,0 +1,286 @@ +# TeaCache PR #12652 Review Notes + +## PR Overview + +- **PR**: https://github.com/huggingface/diffusers/pull/12652 +- **Title**: Implement TeaCache +- **Author**: LawJarp-A (Prajwal A) +- **Status**: Open +- **Changes**: +1335 / -22 lines across 6 files + +### What is TeaCache? + +[TeaCache](https://huggingface.co/papers/2411.19108) (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by **1.5x-2.6x** by reusing transformer block computations when consecutive timestep embeddings are similar. + +### Algorithm + +1. Extract modulated input from first transformer block (after norm1 + timestep embedding) +2. Compute relative L1 distance vs previous timestep +3. Apply model-specific polynomial rescaling: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` +4. Accumulate rescaled distance across timesteps +5. If accumulated < threshold → Reuse cached residual (FAST) +6. If accumulated >= threshold → Full transformer pass (SLOW, update cache) + +--- + +## The Mid-Forward Intercept Problem + +### Why TeaCache is Model-Specific + +TeaCache needs to intercept **within** a model's forward method, not just at module boundaries: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Model Forward │ +│ │ +│ PREPROCESSING (must always run) │ +│ ├── x_embedder(hidden_states) │ +│ ├── time_text_embed(timestep, ...) │ +│ └── context_embedder(encoder_hidden_states) │ +│ │ +│ ═══════════════════════════════════════════════════════════│ +│ DECISION POINT ◄── TeaCache needs to intercept HERE │ +│ └── Extract: transformer_blocks[0].norm1(hs, temb)[0] │ +│ ═══════════════════════════════════════════════════════════│ +│ │ +│ CACHEABLE REGION (can be skipped if cached) │ +│ ├── for block in transformer_blocks: ... │ +│ └── for block in single_transformer_blocks: ... │ +│ │ +│ POSTPROCESSING (must always run) │ +│ ├── norm_out(hidden_states, temb) │ +│ └── proj_out(hidden_states) │ +└─────────────────────────────────────────────────────────────┘ +``` + +PyTorch hooks only intercept at **module boundaries** (before/after `forward()`), not within a forward method. The `for` loop over blocks is Python control flow - there's no hook point to skip it. + +### Workaround: Custom Forward Replacement + +The PR replaces the entire model forward with a custom implementation that has cache logic inserted at the right point. This works but requires maintaining separate forward functions for each model. + +--- + +## Comparison of Caching Approaches + +### TeaCache vs FirstBlockCache vs FasterCache + +| Aspect | TeaCache | FirstBlockCache | FasterCache | +|--------|----------|-----------------|-------------| +| **Hook target** | Model forward | Transformer blocks | Attention layers | +| **Decision signal** | Modulated input (norm1 output) | Block output residual | Iteration count | +| **Where signal is** | Inside first block | Block boundary | Attention output | +| **Model-specific needs** | norm1 structure | Block output format | Attention class type | +| **Model-agnostic?** | ❌ No | ✅ Yes | ✅ Yes | + +### Why FirstBlockCache is Model-Agnostic + +FirstBlockCache uses the **first block's output residual** as its signal: + +```python +# FirstBlockCache: hooks individual blocks +def new_forward(self, module, *args, **kwargs): + original_hidden_states = args[0] + output = self.fn_ref.original_forward(*args, **kwargs) # Run block fully + residual = output - original_hidden_states # Signal from OUTPUT + should_compute = self._compare_residual(residual) + ... +``` + +It doesn't need to understand block internals - just input and output. + +### Why FasterCache is Model-Agnostic + +FasterCache hooks **attention layers** (not blocks) using class type checking: + +```python +_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin) + +for name, submodule in module.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES): + # Hook this attention module +``` + +All transformer models use standardized attention classes. + +--- + +## Model Architecture Analysis + +### Models That Fit TeaCache Pattern + +Models with `norm1(hidden_states, temb)` returning modulated input: + +| Model | norm1 Signature | Modulation Location | Single Residual | +|-------|----------------|---------------------|-----------------| +| FLUX 1 | `norm1(hs, emb=temb) → (tensor, gate)` | Inside norm1 | ✅ | +| FLUX Kontext | `norm1(hs, emb=temb) → (tensor, gate)` | Inside norm1 | ✅ | +| Mochi | `norm1(hs, temb) → (tensor, g, s, g)` | Inside norm1 | ✅ | +| Lumina2 | `norm1(hs, temb) → (tensor, gate)` | Inside norm1 | ✅ | + +### Models That DON'T Fit Pattern + +| Model | norm1 Signature | Modulation Location | Issue | +|-------|----------------|---------------------|-------| +| **FLUX 2** | `norm1(hs) → tensor` | Outside norm1 | Plain LayerNorm | +| **Wan** | `norm1(hs) → tensor` | Outside norm1 | Plain LayerNorm | +| **ZImage** | `attention_norm1(x) → tensor` | Outside norm1 | Plain LayerNorm | +| **CogVideoX** | N/A (uses `emb` directly) | N/A | Dual residual needed | + +### FLUX 1 vs FLUX 2 Architecture Difference + +**FLUX 1** (AdaLayerNorm - modulation inside): +```python +class FluxTransformerBlock: + self.norm1 = AdaLayerNormZero(dim) # Takes temb! + + def forward(self, hidden_states, temb, ...): + norm_hs, gate = self.norm1(hidden_states, emb=temb) # Modulation inside +``` + +**FLUX 2** (Plain LayerNorm - modulation outside): +```python +class Flux2TransformerBlock: + self.norm1 = nn.LayerNorm(dim) # NO temb! + + def forward(self, hidden_states, temb_mod_params_img, ...): + (shift_msa, scale_msa, gate_msa), ... = temb_mod_params_img + norm_hs = self.norm1(hidden_states) # Plain norm + norm_hs = (1 + scale_msa) * norm_hs + shift_msa # Modulation outside +``` + +FLUX 2 follows the Wan/ZImage pattern and would need a separate custom forward. + +--- + +## CogVideoX: The Architectural Outlier + +CogVideoX has two unique requirements that don't fit the pattern: + +### 1. Different Modulated Input Source + +```python +# Other models: extract from norm1 +modulated_inp = block.norm1(hidden_states, temb)[0] + +# CogVideoX: uses timestep embedding directly +modulated_inp = emb # Just the embedding, computed before blocks! +``` + +### 2. Dual Residual Caching + +CogVideoX blocks return and modify TWO tensors: +```python +def forward(self, hidden_states, encoder_hidden_states, temb, ...): + # Both are modified! + return hidden_states, encoder_hidden_states +``` + +Requires caching two residuals: +```python +state.previous_residual = hs_output - hs_input +state.previous_residual_encoder = enc_output - enc_input # Extra! +``` + +--- + +## Recommendations + +### Simplification: FLUX-Only Support + +Given the architectural diversity, recommend supporting only FLUX 1 and FLUX Kontext initially: + +```python +_MODEL_CONFIG = { + "FluxKontext": { + "forward_func": _flux_teacache_forward, + "coefficients": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02], + }, + "Flux": { + "forward_func": _flux_teacache_forward, + "coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], + }, +} +``` + +### What to Remove from PR + +1. **CogVideoX support** - Dual residual architecture doesn't fit +2. **Mochi support** - Can be added later if needed +3. **Lumina2 support** - Can be added later if needed +4. **FLUX 2 support** - Different architecture (plain LayerNorm) + +### Estimated Code Reduction + +| Component | Original (PR) | FLUX-Only | +|-----------|---------------|-----------| +| Forward functions | 4 (~400 lines) | 1 (~100 lines) | +| Model configs | 10 entries | 2 entries | +| State fields | 8 | 5 | +| Utility functions | 6 | 3 | +| **Total teacache.py** | ~900 lines | ~350 lines | + +### Simplified State + +```python +class TeaCacheState(BaseState): + def __init__(self): + self.cnt = 0 + self.num_steps = 0 + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input = None + self.previous_residual = None + # Removed: previous_residual_encoder (CogVideoX) + # Removed: cache_dict (Lumina2) + # Removed: uncond_seq_len (Lumina2) +``` + +--- + +## Why Custom Forwards Are Necessary + +Despite the maintenance burden, custom forwards are the pragmatic approach for TeaCache because: + +1. **Mid-forward intercept required** - Need to access `norm1` output before blocks run +2. **Architectural diversity** - Models differ in where/how modulation happens +3. **Block-level hooks insufficient** - Can't extract modulated input from block hooks +4. **Algorithm requirements** - TeaCache paper specifically uses modulated input as signal + +### Alternative Approaches Considered + +| Approach | Works? | Issue | +|----------|--------|-------| +| Block-level hooks (like FirstBlockCache) | ❌ | Can't access modulated input inside block | +| Attention-level hooks (like FasterCache) | ❌ | Different algorithm, not TeaCache | +| Hook norm1 directly | ⚠️ | norm1 interface varies per model | +| Hybrid (FirstBlockCache signal + TeaCache algorithm) | ⚠️ | Loses "optimal" signal per paper | + +--- + +## PR Code Quality Issues (From Review) + +1. **torch.compile incompatibility** - `.item()` calls in `_compute_rel_l1_distance` create graph breaks +2. **Boundary check bug** - `state.cnt == state.num_steps - 1` when `num_steps=0` evaluates to `-1` +3. **Incomplete Lumina2 state reset** - `cache_dict` and `uncond_seq_len` not reset +4. **Model auto-detection fragility** - Substring matching relies on iteration order + +--- + +## Extension Path + +If support for additional models is needed later: + +1. **Mochi** - Same pattern as FLUX, just add coefficients and reuse `_flux_teacache_forward` or create similar +2. **Lumina2** - Same pattern but needs per-sequence-length caching for CFG +3. **FLUX 2 / Wan / ZImage** - Need separate forwards that extract modulated input differently +4. **CogVideoX** - Needs dual residual support, significant additional complexity + +--- + +## Summary + +- **TeaCache requires custom forwards** due to mid-forward intercept requirement +- **FLUX 1 + FLUX Kontext only** is the recommended scope for initial implementation +- **~60% code reduction** possible by removing unsupported models +- **Clear extension path** for adding models later as needed +- **Maintenance burden** is acceptable given the architectural constraints diff --git a/release_notes/v0.37.0.md b/release_notes/v0.37.0.md new file mode 100644 index 000000000000..4a06621e0154 --- /dev/null +++ b/release_notes/v0.37.0.md @@ -0,0 +1,129 @@ +# Diffusers v0.37.0 Release Notes + +*Release based on 191 commits since v0.36.0* + +--- + +## Highlights + +- **Modular Pipelines overhaul**: Major investment in the modular pipeline system with explicit workflow support, improved loaders, documentation, and modular implementations for Wan, Flux2, Z-Image, Qwen, and Mellon pipelines. +- **New pipelines and models**: Cosmos Predict2.5, LTX 2.0 Video, LongCat-Image, Fibo Edit, Z-Image Omni Base, and more. +- **Distributed inference improvements**: Unified Sequence Parallel attention, Ulysses Anything Attention, and context parallel support in native flash attention. +- **Python 3.8 dropped**: Sunset Python 3.8 and cleaned up explicit `typing` exports. + +--- + +## New Pipelines and Models + +- **Cosmos Predict2.5**: Base inference pipeline, scheduler, and checkpoint conversion; 14B model support (#12852, #12863) +- **Cosmos Transfer2.5**: General transfer pipelines for segmentation, depth, blur, and edge (#13066) +- **LTX 2.0 Video Pipelines**: New video generation pipelines (#12915), distilled checkpoint support (#12934), single-file loading (#12983), LoRA support (#12933), long multi-prompt (#12614) +- **LongCat-Image**: New pipeline with offloading/quantization support and regional compile acceleration (#12828, #12963, #12699, #13019, #13021) +- **Fibo Edit Pipeline**: New editing pipeline (#12930) +- **Z-Image Omni Base**: New implementation (#12857) +- **Z-Image Turbo ControlNet**: ControlNet support for Z-Image Turbo (#12792) +- **Z-Image Inpaint Pipeline**: Inpainting support (#13006) +- **Z-Image ControlNet CFG**: CFG support for Z-Image ControlNet (#13080) +- **Chroma Inpaint Pipeline**: New inpainting pipeline for Chroma (#12848) +- **Flux2 Klein**: New model variant (#12982) +- **Qwen Image Edit 2511**: New editing support (#12839) +- **Qwen Image Layered Support** (#12853) + +## Modular Pipelines + +- Explicit workflow support for modular pipelines (#13028) +- Modular implementations for: Wan (#13063), Flux2 (#12763), Z-Image (#12808), Qwen (#12872), Mellon (#12978, #12924, #13051) +- Improved loader support (#13025) +- Custom block tests (#12557) +- Auto-docstring generation and documentation refactors (#12958) +- Quick start guide (#13029) +- Guard `ModularPipeline.blocks` attribute (#13014) +- Better docstrings and template pipeline card (#13072, #12932) + +## Core Improvements + +- **Device-type device maps with offloading support** (#12811) +- **`disable_mmap` in pipeline `from_pretrained`** (#12854) +- **`apply_lora_scale` helper** to remove boilerplate (#12994) +- **MagCache support**: Caching mechanism for faster inference (#12744) +- **Mambo-G Guidance**: New guider implementation (#12862) +- **Laplace Scheduler for DDPM** (#11320) +- **Custom sigmas in UniPCMultistepScheduler** (#12109) +- **Control-LoRA support** (#10686) +- **Latent Perceptual Loss (LPL) for SDXL** (#11573) +- **MultiControlNet support for SD3 Inpainting** (#11251) +- Remove 8-bit device restriction (#12972) +- Graceful error for unsupported attn-backend / context-parallel combos (#12832) +- Handle progress bar and logging in distributed environments (#12806) +- Remove unneeded autoencoder methods from `AutoencoderMixin` subclasses (#12873) +- Remove k-diffusion support (#13152) +- Flag Flax schedulers as deprecated (#13031) + +## Distributed Inference + +- **Unified Sequence Parallel attention** (#12693) +- **Ulysses Anything Attention** (#12996) +- **Context parallel in native flash attention** (#12829) +- NPU Ulysses attention support (#12919) +- Fix Wan 2.1 I2V context parallel (#12909) +- Fix Qwen-Image context parallel (#12970) + +## LoRA + +- Z-Image LoRA training (#13056) +- Fix non-diffusers LoRA key handling for Flux2 (#13119) +- Fix LoRA loading for Flux2 Klein with adaptive block enumeration (#13030) +- Fix wrong LTX2 LoRA mixin (#13144) + +## Bug Fixes + +- Fix QwenImageEditPlus on NPU (#13017) +- Fix MT5Tokenizer → use `T5Tokenizer` for Transformers v5.0+ compatibility (#12877) +- Fix Wan/WanI2V patchification (#13038) +- Fix LTX-2 inference with `num_videos_per_prompt > 1` and CFG (#13121) +- Fix Flux2 img2img prediction (#12855) +- Fix QwenImage `txt_seq_lens` handling (#12702) +- Fix `prefix_token_len` bug (#12845) +- Fix ftfy imports in Wan and SkyReels-V2 (#12314, #13113) +- Fix `is_fsdp` determination (#12960) +- Fix GLM-Image `get_image_features` API (#13052) +- Fix Wan 2.2 when either transformer isn't present (#13055) +- Fix guider issue (#13147) +- Fix torchao quantizer for new versions (#12901) +- Fix GGUF for unquantized types with unquantize kernels (#12498) +- Make Qwen hidden states contiguous for torchao (#13081) +- Make Flux hidden states contiguous (#13068) +- Fix Kandinsky 5 hardcoded CUDA autocast (#12814) +- Fix `aiter` availability check (#13059) +- Fix attention mask check for unsupported backends (#12892) +- Allow `prompt` and `prior_token_ids` simultaneously in `GlmImagePipeline` (#13092) +- GLM-Image batch support (#13007) +- Cosmos 2.5 Video2World frame extraction fix (#13018) +- ResNet: only use contiguous in training mode (#12977) + +## Testing and CI + +- Refactor model tests (#12822) +- Refactor Wan model tests (#13082) +- Accept `recompile_limit` from user in tests (#13150) +- CodeQL workflow for security analysis (#12917) +- Upgrade GitHub Actions for Node 24 compatibility (#12865, #12866) +- Fix `setuptools` / `pkg_resources` CI bugs (#13129, #13132) +- CUDA 12.9 upgrade (#13045) +- FSDP option for Flux2 (#12860) + +## Documentation + +- Custom code AutoModel guide (#13099) +- Remote inference docs (#12372) +- Improved distributed inference docs (#12810, #12827, #12971) +- Improved caching docs (#12684) +- Numerous scheduler docstring improvements (#12798, #12871, #12928, #12931, #12936, #12992, #13010, #13020, #13023, #13024, #13027, #13044, #13083, #13085, #13122, #13127, #13130) +- Various typo and syntax fixes + +## Breaking Changes + +- **Python 3.8 support removed** (#12524) +- **k-diffusion removed** (#13152) +- **Flax schedulers flagged as deprecated** (#13031) +- ControlNet implementations outside the controlnet module removed (#12152) diff --git a/scripts/compare_test_coverage.py b/scripts/compare_test_coverage.py new file mode 100644 index 000000000000..1a002fc16813 --- /dev/null +++ b/scripts/compare_test_coverage.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Compare test coverage between main and model-test-refactor branches +for the Flux transformer tests. + +Usage: + python scripts/compare_test_coverage.py +""" + +import subprocess + + +TEST_FILE = "tests/models/transformers/test_models_transformer_flux.py" +BRANCHES = ["main", "model-test-refactor"] + + +def run_command(cmd, capture=True): + """Run a shell command and return output.""" + result = subprocess.run(cmd, shell=True, capture_output=capture, text=True) + return result.stdout, result.stderr, result.returncode + + +def get_current_branch(): + """Get the current git branch name.""" + stdout, _, _ = run_command("git branch --show-current") + return stdout.strip() + + +def stash_changes(): + """Stash any uncommitted changes.""" + run_command("git stash") + + +def pop_stash(): + """Pop stashed changes.""" + run_command("git stash pop") + + +def checkout_branch(branch): + """Checkout a git branch.""" + _, stderr, code = run_command(f"git checkout {branch}") + if code != 0: + print(f"Failed to checkout {branch}: {stderr}") + return False + return True + + +def collect_tests(test_file): + """Collect tests from a test file and return test info.""" + cmd = f"python -m pytest {test_file} --collect-only -q 2>/dev/null" + stdout, stderr, code = run_command(cmd) + + tests = [] + for line in stdout.strip().split("\n"): + if "::" in line and not line.startswith("="): + tests.append(line.strip()) + + return tests + + +def run_tests_verbose(test_file): + """Run tests and capture pass/skip/fail status.""" + cmd = f"python -m pytest {test_file} -v --tb=no 2>&1" + stdout, _, _ = run_command(cmd) + + results = {"passed": [], "skipped": [], "failed": [], "errors": []} + + for line in stdout.split("\n"): + if " PASSED" in line: + test_name = line.split(" PASSED")[0].strip() + results["passed"].append(test_name) + elif " SKIPPED" in line: + test_name = line.split(" SKIPPED")[0].strip() + reason = "" + if "SKIPPED" in line and "[" in line: + reason = line.split("[")[-1].rstrip("]") if "[" in line else "" + results["skipped"].append((test_name, reason)) + elif " FAILED" in line: + test_name = line.split(" FAILED")[0].strip() + results["failed"].append(test_name) + elif " ERROR" in line: + test_name = line.split(" ERROR")[0].strip() + results["errors"].append(test_name) + + return results + + +def compare_results(main_results, pr_results): + """Compare test results between branches.""" + print("\n" + "=" * 70) + print("COVERAGE COMPARISON REPORT") + print("=" * 70) + + print("\n## Test Counts") + print(f"{'Category':<20} {'main':<15} {'PR':<15} {'Diff':<10}") + print("-" * 60) + + for category in ["passed", "skipped", "failed", "errors"]: + main_count = len(main_results[category]) + pr_count = len(pr_results[category]) + diff = pr_count - main_count + diff_str = f"+{diff}" if diff > 0 else str(diff) + print(f"{category:<20} {main_count:<15} {pr_count:<15} {diff_str:<10}") + + main_tests = set(main_results["passed"] + [t[0] for t in main_results["skipped"]]) + pr_tests = set(pr_results["passed"] + [t[0] for t in pr_results["skipped"]]) + + missing_in_pr = main_tests - pr_tests + new_in_pr = pr_tests - main_tests + + if missing_in_pr: + print("\n## Tests in main but MISSING in PR:") + for test in sorted(missing_in_pr): + print(f" - {test}") + + if new_in_pr: + print("\n## NEW tests in PR (not in main):") + for test in sorted(new_in_pr): + print(f" + {test}") + + print("\n## Skipped Tests Comparison") + main_skipped = {t[0]: t[1] for t in main_results["skipped"]} + pr_skipped = {t[0]: t[1] for t in pr_results["skipped"]} + + newly_skipped = set(pr_skipped.keys()) - set(main_skipped.keys()) + no_longer_skipped = set(main_skipped.keys()) - set(pr_skipped.keys()) + + if newly_skipped: + print("\nNewly skipped in PR:") + for test in sorted(newly_skipped): + print(f" - {test}: {pr_skipped.get(test, 'unknown reason')}") + + if no_longer_skipped: + print("\nNo longer skipped in PR (now running):") + for test in sorted(no_longer_skipped): + print(f" + {test}") + + if not newly_skipped and not no_longer_skipped: + print("\nNo changes in skipped tests.") + + print("\n" + "=" * 70) + + +def main(): + original_branch = get_current_branch() + print(f"Current branch: {original_branch}") + + results = {} + + print("Stashing uncommitted changes...") + stash_changes() + + try: + for branch in BRANCHES: + print(f"\n--- Analyzing branch: {branch} ---") + + if not checkout_branch(branch): + print(f"Skipping {branch}") + continue + + print(f"Collecting and running tests from {TEST_FILE}...") + results[branch] = run_tests_verbose(TEST_FILE) + + print(f" Passed: {len(results[branch]['passed'])}") + print(f" Skipped: {len(results[branch]['skipped'])}") + print(f" Failed: {len(results[branch]['failed'])}") + + checkout_branch(original_branch) + + if "main" in results and "model-test-refactor" in results: + compare_results(results["main"], results["model-test-refactor"]) + else: + print("Could not compare - missing results from one or both branches") + + finally: + print("\nRestoring stashed changes...") + pop_stash() + + checkout_branch(original_branch) + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5d6c4064ef96..c1ac7f3aab4c 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -14,6 +14,7 @@ import importlib import inspect import os +import sys import traceback import warnings from collections import OrderedDict @@ -28,10 +29,16 @@ from typing_extensions import Self from ..configuration_utils import ConfigMixin, FrozenDict -from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj +from ..pipelines.pipeline_loading_utils import ( + LOADABLE_CLASSES, + _fetch_class_library_tuple, + _unwrap_model, + simple_get_class_obj, +) from ..utils import PushToHubMixin, is_accelerate_available, logging from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ..utils.hub_utils import load_or_create_model_card, populate_model_card +from ..utils.torch_utils import is_compiled_module from .components_manager import ComponentsManager from .modular_pipeline_utils import ( MODULAR_MODEL_CARD_TEMPLATE, @@ -1826,29 +1833,124 @@ def from_pretrained( ) return pipeline - def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs): + def save_pretrained( + self, + save_directory: str | os.PathLike, + safe_serialization: bool = True, + variant: str | None = None, + max_shard_size: int | str | None = None, + push_to_hub: bool = False, + **kwargs, + ): """ - Save the pipeline to a directory. It does not save components, you need to save them separately. + Save the pipeline and all its components to a directory, so that it can be re-loaded using the + [`~ModularPipeline.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): - Path to the directory where the pipeline will be saved. - push_to_hub (`bool`, optional): - Whether to push the pipeline to the huggingface hub. - **kwargs: Additional arguments passed to `save_config()` method - """ + Directory to save the pipeline to. Will be created if it doesn't exist. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `None`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the pipeline to the Hugging Face model hub after saving it. + **kwargs: Additional keyword arguments: + - `overwrite_modular_index` (`bool`, *optional*, defaults to `False`): + When saving a Modular Pipeline, its components in `modular_model_index.json` may reference repos + different from the destination repo. Setting this to `True` updates all component references in + `modular_model_index.json` so they point to the repo specified by `repo_id`. + - `repo_id` (`str`, *optional*): + The repository ID to push the pipeline to. Defaults to the last component of `save_directory`. + - `commit_message` (`str`, *optional*): + Commit message for the push to hub operation. + - `private` (`bool`, *optional*): + Whether the repository should be private. + - `create_pr` (`bool`, *optional*, defaults to `False`): + Whether to create a pull request instead of pushing directly. + - `token` (`str`, *optional*): + The Hugging Face token to use for authentication. + """ + overwrite_modular_index = kwargs.pop("overwrite_modular_index", False) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", None) create_pr = kwargs.pop("create_pr", False) token = kwargs.pop("token", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id - # Generate modular pipeline card content - card_content = generate_modular_model_card_content(self.blocks) + for component_name, component_spec in self._component_specs.items(): + if component_spec.default_creation_method != "from_pretrained": + continue + + component = getattr(self, component_name, None) + if component is None: + continue + + model_cls = component.__class__ + if is_compiled_module(component): + component = _unwrap_model(component) + model_cls = component.__class__ + + save_method_name = None + for library_name, library_classes in LOADABLE_CLASSES.items(): + if library_name in sys.modules: + library = importlib.import_module(library_name) + else: + logger.info( + f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}" + ) + continue + + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + if save_method_name is None: + logger.warning(f"self.{component_name}={component} of type {type(component)} cannot be saved.") + continue + + save_method = getattr(component, save_method_name) + save_method_signature = inspect.signature(save_method) + save_method_accept_safe = "safe_serialization" in save_method_signature.parameters + save_method_accept_variant = "variant" in save_method_signature.parameters + save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters - # Create a new empty model card and eventually tag it + save_kwargs = {} + if save_method_accept_safe: + save_kwargs["safe_serialization"] = safe_serialization + if save_method_accept_variant: + save_kwargs["variant"] = variant + if save_method_accept_max_shard_size and max_shard_size is not None: + save_kwargs["max_shard_size"] = max_shard_size + + component_save_path = os.path.join(save_directory, component_name) + save_method(component_save_path, **save_kwargs) + + if component_name not in self.config: + continue + + has_no_load_id = not hasattr(component, "_diffusers_load_id") or component._diffusers_load_id == "null" + if overwrite_modular_index or has_no_load_id: + library, class_name, component_spec_dict = self.config[component_name] + component_spec_dict["pretrained_model_name_or_path"] = repo_id if push_to_hub else save_directory + component_spec_dict["subfolder"] = component_name + self.register_to_config(**{component_name: (library, class_name, component_spec_dict)}) + + self.save_config(save_directory=save_directory) + + if push_to_hub: + card_content = generate_modular_model_card_content(self.blocks) model_card = load_or_create_model_card( repo_id, token=token, @@ -1857,13 +1959,8 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = is_modular=True, ) model_card = populate_model_card(model_card, tags=card_content["tags"]) - model_card.save(os.path.join(save_directory, "README.md")) - # YiYi TODO: maybe order the json file to make it more readable: configs first, then components - self.save_config(save_directory=save_directory) - - if push_to_hub: self._upload_folder( save_directory, repo_id, @@ -2131,8 +2228,9 @@ def update_components(self, **kwargs): ``` Notes: - - Components with trained weights should be loaded with `AutoModel.from_pretrained()` or - `ComponentSpec.load()` so that loading specs are preserved for serialization. + - Components loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()` will have + loading specs preserved for serialization. Custom or locally loaded components without Hub references will + have their `modular_model_index.json` entries updated automatically during `save_pretrained()`. - ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly. """ @@ -2154,14 +2252,6 @@ def update_components(self, **kwargs): new_component_spec = current_component_spec if hasattr(self, name) and getattr(self, name) is not None: logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)") - elif current_component_spec.default_creation_method == "from_pretrained" and not ( - hasattr(component, "_diffusers_load_id") and component._diffusers_load_id is not None - ): - logger.warning( - f"ModularPipeline.update_components: {name} has no valid _diffusers_load_id. " - f"This will result in empty loading spec, use ComponentSpec.load() for proper specs" - ) - new_component_spec = ComponentSpec(name=name, type_hint=type(component)) else: new_component_spec = ComponentSpec.from_component(name, component) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index cab17c2aed5c..fa81d81920eb 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -311,6 +311,12 @@ def load(self, **kwargs) -> Any: f"`type_hint` is required when loading a single file model but is missing for component: {self.name}" ) + # `torch_dtype` is not an accepted parameter for tokenizers and processors. + # As a result, it gets stored in `init_kwargs`, which are written to the config + # during save. This causes JSON serialization to fail when saving the component. + if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module): + kwargs.pop("torch_dtype", None) + if self.type_hint is None: try: from diffusers import AutoModel @@ -328,6 +334,12 @@ def load(self, **kwargs) -> Any: else getattr(self.type_hint, "from_pretrained") ) + # `torch_dtype` is not an accepted parameter for tokenizers and processors. + # As a result, it gets stored in `init_kwargs`, which are written to the config + # during save. This causes JSON serialization to fail when saving the component. + if not issubclass(self.type_hint, torch.nn.Module): + kwargs.pop("torch_dtype", None) + try: component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs) except Exception as e: diff --git a/test_automodel_meta.py b/test_automodel_meta.py new file mode 100644 index 000000000000..f0dbe7f4a3b9 --- /dev/null +++ b/test_automodel_meta.py @@ -0,0 +1,14 @@ +import torch +from diffusers import AutoModel + +repo = "meituan-longcat/LongCat-Image" +subfolder = "transformer" + +config = AutoModel.load_config(repo, subfolder=subfolder) + +with torch.device("meta"): + model = AutoModel.from_config(config) +print(f"model.config:") +for k, v in dict(model.config).items(): + if not k.startswith("_"): + print(f" {k}: {v}") diff --git a/test_dataclass_config.py b/test_dataclass_config.py new file mode 100644 index 000000000000..ab7eb48eb7bd --- /dev/null +++ b/test_dataclass_config.py @@ -0,0 +1,11 @@ +import dataclasses +from diffusers import AutoModel, LongCatImageTransformer2DModel + +config_dict = AutoModel.load_config( + "meituan-longcat/LongCat-Image", + subfolder="transformer", +) +# import DiT based on _class_name +typed_config = LongCatImageTransformer2DModel._get_dataclass_from_config(config_dict) +for f in dataclasses.fields(typed_config): + print(f"{f.name}: {f.type}") diff --git a/test_pretrained_config.py b/test_pretrained_config.py new file mode 100644 index 000000000000..40b871d4163d --- /dev/null +++ b/test_pretrained_config.py @@ -0,0 +1,29 @@ +import dataclasses +import torch +from diffusers import FluxTransformer2DModel +from diffusers.models import AutoModel + +repo = "black-forest-labs/FLUX.2-dev" +subfolder = "transformer" + +print("=== From load_config (no model instantiation) ===") +config_dict = FluxTransformer2DModel.load_config(repo, subfolder=subfolder) +tc = FluxTransformer2DModel._get_dataclass_from_config(config_dict) +print(f"Type: {type(tc).__name__}") +for k, v in dataclasses.asdict(tc).items(): + print(f" {k}: {v}") + +print() +print("=== From AutoModel.from_config on meta device ===") +with torch.device("meta"): + model = AutoModel.from_config(repo, subfolder=subfolder) +print(f"model.config:") +for k, v in dict(model.config).items(): + if not k.startswith("_"): + print(f" {k}: {v}") + +print() +print("=== Comparison ===") +dc_dict = dataclasses.asdict(tc) +config = {k: v for k, v in dict(model.config).items() if not k.startswith("_")} +print(f"Match: {dc_dict == config}") diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 5aceae77da27..bd96516785d9 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -730,6 +730,82 @@ def test_load_components_skips_invalid_pretrained_path(self): assert not hasattr(pipe, "test_component") or pipe.test_component is None +class TestCustomModelSavePretrained: + def test_save_pretrained_updates_index_for_local_model(self, tmp_path): + """When a component without _diffusers_load_id (custom/local model) is saved, + modular_model_index.json should point to the save directory.""" + import json + + pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + pipe.load_components(torch_dtype=torch.float32) + + pipe.unet._diffusers_load_id = "null" + + save_dir = str(tmp_path / "my-pipeline") + pipe.save_pretrained(save_dir) + + with open(os.path.join(save_dir, "modular_model_index.json")) as f: + index = json.load(f) + + _library, _cls, unet_spec = index["unet"] + assert unet_spec["pretrained_model_name_or_path"] == save_dir + assert unet_spec["subfolder"] == "unet" + + _library, _cls, vae_spec = index["vae"] + assert vae_spec["pretrained_model_name_or_path"] == "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + def test_save_pretrained_roundtrip_with_local_model(self, tmp_path): + """A pipeline with a custom/local model should be saveable and re-loadable with identical outputs.""" + pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + pipe.load_components(torch_dtype=torch.float32) + + pipe.unet._diffusers_load_id = "null" + + original_state_dict = pipe.unet.state_dict() + + save_dir = str(tmp_path / "my-pipeline") + pipe.save_pretrained(save_dir) + + loaded_pipe = ModularPipeline.from_pretrained(save_dir) + loaded_pipe.load_components(torch_dtype=torch.float32) + + assert loaded_pipe.unet is not None + assert loaded_pipe.unet.__class__.__name__ == pipe.unet.__class__.__name__ + + loaded_state_dict = loaded_pipe.unet.state_dict() + assert set(original_state_dict.keys()) == set(loaded_state_dict.keys()) + for key in original_state_dict: + assert torch.equal(original_state_dict[key], loaded_state_dict[key]), f"Mismatch in {key}" + + def test_save_pretrained_overwrite_modular_index(self, tmp_path): + """With overwrite_modular_index=True, all component references should point to the save directory.""" + import json + + pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + pipe.load_components(torch_dtype=torch.float32) + + save_dir = str(tmp_path / "my-pipeline") + pipe.save_pretrained(save_dir, overwrite_modular_index=True) + + with open(os.path.join(save_dir, "modular_model_index.json")) as f: + index = json.load(f) + + for component_name in ["unet", "vae", "text_encoder", "text_encoder_2"]: + if component_name not in index: + continue + _library, _cls, spec = index[component_name] + assert spec["pretrained_model_name_or_path"] == save_dir, ( + f"{component_name} should point to save dir but got {spec['pretrained_model_name_or_path']}" + ) + assert spec["subfolder"] == component_name + + loaded_pipe = ModularPipeline.from_pretrained(save_dir) + loaded_pipe.load_components(torch_dtype=torch.float32) + + assert loaded_pipe.unet is not None + assert loaded_pipe.vae is not None + + class TestModularPipelineInitFallback: """Test that ModularPipeline.__init__ falls back to default_blocks_name when _blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict).""" From 387568ad5d38d449ccac5e66a7bb954403416b82 Mon Sep 17 00:00:00 2001 From: David El Malih Date: Mon, 2 Mar 2026 18:42:55 +0100 Subject: [PATCH 012/215] docs: improve docstring scheduling_ipndm.py (#13198) Improve docstring scheduling ipndm --- src/diffusers/schedulers/scheduling_ipndm.py | 51 +++++++++++++++++--- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py index 0a02311ba9b6..f2e744e0f2af 100644 --- a/src/diffusers/schedulers/scheduling_ipndm.py +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -31,14 +31,18 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. - trained_betas (`np.ndarray`, *optional*): + trained_betas (`np.ndarray` or `List[float]`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. """ order = 1 @register_to_config - def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray | list[float] | None = None): + def __init__( + self, + num_train_timesteps: int = 1000, + trained_betas: np.ndarray | list[float] | None = None, + ): # set `betas`, `alphas`, `timesteps` self.set_timesteps(num_train_timesteps) @@ -56,21 +60,29 @@ def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray | self._begin_index = None @property - def step_index(self): + def step_index(self) -> int | None: """ The index counter for current timestep. It will increase 1 after each scheduler step. + + Returns: + `int` or `None`: + The index counter for current timestep. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> int | None: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The index for the first timestep. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -169,7 +181,7 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`int`): + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. @@ -228,7 +240,30 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens """ return sample - def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): + def _get_prev_sample( + self, + sample: torch.Tensor, + timestep_index: int, + prev_timestep_index: int, + ets: torch.Tensor, + ) -> torch.Tensor: + """ + Predicts the previous sample based on the current sample, timestep indices, and running model outputs. + + Args: + sample (`torch.Tensor`): + The current sample. + timestep_index (`int`): + Index of the current timestep in the schedule. + prev_timestep_index (`int`): + Index of the previous timestep in the schedule. + ets (`torch.Tensor`): + The running sequence of model outputs. + + Returns: + `torch.Tensor`: + The predicted previous sample. + """ alpha = self.alphas[timestep_index] sigma = self.betas[timestep_index] @@ -240,5 +275,5 @@ def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): return prev_sample - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps From 710b56440ec04bd2a090695285a4c3e77b351d53 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 3 Mar 2026 00:35:58 +0530 Subject: [PATCH 013/215] Clean up accidental files (#13202) update --- .claude/CLAUDE.md | 100 --------- _modular_model_index.json | 75 ------- custom_model_automodel_guide.md | 239 --------------------- example.py | 120 ----------- modular_model_index.json | 73 ------- pr_review/12498.md | 56 ----- pr_review/12744.md | 186 ----------------- pr_review/13028.md | 99 --------- pr_review/13075.md | 97 --------- pr_review/13116.md | 66 ------ pr_review/pr_12700_flashpack.md | 144 ------------- pr_review/teacache_pr_12652_review.md | 286 -------------------------- release_notes/v0.37.0.md | 129 ------------ scripts/compare_test_coverage.py | 183 ---------------- test_automodel_meta.py | 14 -- test_dataclass_config.py | 11 - test_pretrained_config.py | 29 --- 17 files changed, 1907 deletions(-) delete mode 100644 .claude/CLAUDE.md delete mode 100644 _modular_model_index.json delete mode 100644 custom_model_automodel_guide.md delete mode 100644 example.py delete mode 100644 modular_model_index.json delete mode 100644 pr_review/12498.md delete mode 100644 pr_review/12744.md delete mode 100644 pr_review/13028.md delete mode 100644 pr_review/13075.md delete mode 100644 pr_review/13116.md delete mode 100644 pr_review/pr_12700_flashpack.md delete mode 100644 pr_review/teacache_pr_12652_review.md delete mode 100644 release_notes/v0.37.0.md delete mode 100644 scripts/compare_test_coverage.py delete mode 100644 test_automodel_meta.py delete mode 100644 test_dataclass_config.py delete mode 100644 test_pretrained_config.py diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md deleted file mode 100644 index ae8010084af7..000000000000 --- a/.claude/CLAUDE.md +++ /dev/null @@ -1,100 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Build, Lint, and Test Commands - -```bash -# Install in development mode -pip install -e ".[dev]" - -# Run full test suite (requires beefy machine) -make test -# Or directly: -python -m pytest -n auto --dist=loadfile -s -v ./tests/ - -# Run a single test file -python -m pytest tests/.py - -# Run slow tests (downloads many GBs of models) -RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/ - -# Format code (ruff + doc-builder) -make style - -# Check code quality without modifying -make quality - -# Fast fixup for modified files only (recommended before commits) -make fixup - -# Fix copied code snippets and dummy objects -make fix-copies - -# Check repository consistency (dummies, inits, repo structure) -make repo-consistency -``` - -## Code Architecture - -Diffusers is built on three core component types that work together: - -### Pipelines (`src/diffusers/pipelines/`) -- End-to-end inference workflows combining models and schedulers -- Base class: `DiffusionPipeline` (in `pipeline_utils.py`) -- Follow **single-file policy**: each pipeline in its own directory -- Loaded via `DiffusionPipeline.from_pretrained()` which reads `model_index.json` -- Components registered via `register_modules()` become pipeline attributes -- ~99 pipeline implementations (Stable Diffusion, SDXL, Flux, etc.) - -### Models (`src/diffusers/models/`) -- Configurable neural network architectures extending PyTorch's Module -- Base classes: `ModelMixin` + `ConfigMixin` (in `modeling_utils.py`) -- **Do NOT follow single-file policy**: use shared building blocks (`attention.py`, `embeddings.py`, `resnet.py`) -- Key subdirectories: - - `autoencoders/`: VAEs for latent space compression - - `unets/`: Diffusion model architectures (UNet2DConditionModel, etc.) - - `transformers/`: Transformer-based models (Flux, SD3, etc.) - - `controlnets/`: ControlNet variants - -### Schedulers (`src/diffusers/schedulers/`) -- Guide denoising process during inference -- Base class: `SchedulerMixin` + `ConfigMixin` (in `scheduling_utils.py`) -- Follow **single-file policy**: one scheduler per file -- Key methods: `set_num_inference_steps()`, `step()`, `timesteps` property -- Easily swappable via `ConfigMixin.from_config()` -- ~55 scheduler algorithms (DDPM, DDIM, Euler, DPM-Solver, etc.) - -### Supporting Systems - -- **Loaders** (`src/diffusers/loaders/`): Mixins for LoRA, IP-Adapter, textual inversion, single-file loading -- **Quantizers** (`src/diffusers/quantizers/`): BitsAndBytes, GGUF, TorchAO, Quanto support -- **Hooks** (`src/diffusers/hooks/`): Runtime optimizations (offloading, layer skipping, caching) -- **Guiders** (`src/diffusers/guiders/`): Guidance algorithms (CFG, PAG, etc.) - -## Configuration System - -All components use `ConfigMixin` for serialization: -- Constructor arguments stored via `register_to_config(**kwargs)` -- Instantiate from config: `Component.from_config(config_dict)` -- Save/load as JSON files - -## Key Design Principles - -1. **Usability over Performance**: Models load at float32/CPU by default -2. **Simple over Easy**: Explicit > implicit; expose complexity rather than hide it -3. **Single-file policy**: Pipelines and schedulers are self-contained; models share building blocks -4. **Copy-paste over abstraction**: Prefer duplicated code over hasty abstractions for contributor-friendliness - -## Code Style - -- Uses `ruff` for linting and formatting (line length: 119) -- Documentation follows [Google style](https://google.github.io/styleguide/pyguide.html) -- Use `# Copied from` mechanism for sharing code between similar files -- Avoid lambda functions and advanced PyTorch operators for readability - -## Testing - -- Tests use `pytest` with `pytest-xdist` for parallelization -- Slow tests gated by `RUN_SLOW=yes` environment variable -- Test dependencies: `pip install -e ".[test]"` diff --git a/_modular_model_index.json b/_modular_model_index.json deleted file mode 100644 index b0eba6916d3d..000000000000 --- a/_modular_model_index.json +++ /dev/null @@ -1,75 +0,0 @@ -{ - "_blocks_class_name": "SequentialPipelineBlocks", - "_class_name": "Flux2ModularPipeline", - "_diffusers_version": "0.36.0.dev0", - "scheduler": [ - "diffusers", - "FlowMatchEulerDiscreteScheduler", - { - "repo": "hf-internal-testing/tiny-flux2", - "revision": null, - "subfolder": "scheduler", - "type_hint": [ - "diffusers", - "FlowMatchEulerDiscreteScheduler" - ], - "variant": null - } - ], - "text_encoder": [ - "transformers", - "Mistral3ForConditionalGeneration", - { - "repo": "hf-internal-testing/tiny-flux2", - "revision": null, - "subfolder": "text_encoder", - "type_hint": [ - "transformers", - "Mistral3ForConditionalGeneration" - ], - "variant": null - } - ], - "tokenizer": [ - "transformers", - "AutoProcessor", - { - "repo": "hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor", - "revision": null, - "subfolder": "", - "type_hint": [ - "transformers", - "AutoProcessor" - ], - "variant": null - } - ], - "transformer": [ - "diffusers", - "Flux2Transformer2DModel", - { - "repo": "hf-internal-testing/tiny-flux2", - "revision": null, - "subfolder": "transformer", - "type_hint": [ - "diffusers", - "Flux2Transformer2DModel" - ], - "variant": null - } - ], - "vae": [ - "diffusers", - "AutoencoderKLFlux2", - { - "repo": "hf-internal-testing/tiny-flux2", - "revision": null, - "subfolder": "vae", - "type_hint": [ - "diffusers", - "AutoencoderKLFlux2" - ], - "variant": null - } - ] -} diff --git a/custom_model_automodel_guide.md b/custom_model_automodel_guide.md deleted file mode 100644 index 66343023e644..000000000000 --- a/custom_model_automodel_guide.md +++ /dev/null @@ -1,239 +0,0 @@ -# Loading Custom Models with `AutoModel` and `trust_remote_code` - -This guide shows how to create a custom model class that lives outside the `diffusers` library and load it via `AutoModel` with `trust_remote_code=True`. - -## How It Works - -When `AutoModel.from_pretrained()` (or `from_config()`) is called with `trust_remote_code=True`, it: - -1. Loads the `config.json` from the model repository. -2. Checks for an `"auto_map"` key in the config that maps `"AutoModel"` to a `"."` reference. -3. Downloads the referenced Python module from the repository. -4. Dynamically imports and instantiates the class from that module. - -This allows anyone to define and share completely custom model architectures without requiring changes to the `diffusers` library itself. - -## Step 1: Define Your Custom Model - -Create a Python file (e.g., `modeling_my_model.py`) that defines your model class. The class must inherit from `ModelMixin` and `ConfigMixin`, and use the `@register_to_config` decorator on `__init__`. - -```python -# modeling_my_model.py - -import torch -from torch import nn -from diffusers import ModelMixin, ConfigMixin -from diffusers.configuration_utils import register_to_config - - -class MyCustomModel(ModelMixin, ConfigMixin): - @register_to_config - def __init__(self, in_channels: int = 3, hidden_dim: int = 64, out_channels: int = 3): - super().__init__() - self.net = nn.Sequential( - nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.SiLU(), - nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), - nn.SiLU(), - nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) -``` - -Key requirements: - -- **`ModelMixin`** provides `save_pretrained()` / `from_pretrained()` for weight serialization. -- **`ConfigMixin`** provides `save_config()` / `from_config()` and the `config.json` machinery. -- **`@register_to_config`** automatically captures all `__init__` parameters into `config.json` so the model can be reconstructed from config alone. - -## Step 2: Save the Model Locally - -```python -from modeling_my_model import MyCustomModel - -model = MyCustomModel(in_channels=3, hidden_dim=128, out_channels=3) -model.save_pretrained("./my-custom-model") -``` - -This creates a directory with: - -``` -my-custom-model/ -├── config.json -└── diffusion_pytorch_model.safetensors -``` - -The generated `config.json` will look like: - -```json -{ - "_class_name": "MyCustomModel", - "_diffusers_version": "0.32.0", - "in_channels": 3, - "hidden_dim": 128, - "out_channels": 3 -} -``` - -## Step 3: Add the `auto_map` and Model File to the Repository - -To make `AutoModel` aware of your custom class, you need to: - -1. **Copy `modeling_my_model.py` into the saved model directory.** -2. **Add an `"auto_map"` entry to `config.json`** that points `AutoModel` to your class. - -The `auto_map` value format is `"."`: - -```json -{ - "_class_name": "MyCustomModel", - "_diffusers_version": "0.32.0", - "in_channels": 3, - "hidden_dim": 128, - "out_channels": 3, - "auto_map": { - "AutoModel": "modeling_my_model.MyCustomModel" - } -} -``` - -Your final directory structure should be: - -``` -my-custom-model/ -├── config.json # with auto_map added -├── diffusion_pytorch_model.safetensors -└── modeling_my_model.py # your custom model code -``` - -## Step 4: Load with `AutoModel` - -### From a Local Directory - -```python -from diffusers import AutoModel - -model = AutoModel.from_pretrained("./my-custom-model", trust_remote_code=True) -print(model) -``` - -### From the Hugging Face Hub - -First, push the model directory to a Hub repository: - -```python -from huggingface_hub import HfApi - -api = HfApi() -api.create_repo("your-username/my-custom-model", exist_ok=True) -api.upload_folder( - folder_path="./my-custom-model", - repo_id="your-username/my-custom-model", -) -``` - -Then load it: - -```python -from diffusers import AutoModel - -model = AutoModel.from_pretrained( - "your-username/my-custom-model", - trust_remote_code=True, -) -``` - -### Initializing from Config (Random Weights) - -```python -from diffusers import AutoModel - -model = AutoModel.from_config("./my-custom-model", trust_remote_code=True) -``` - -## Complete Example - -```python -import torch -from torch import nn -from diffusers import ModelMixin, ConfigMixin, AutoModel -from diffusers.configuration_utils import register_to_config - - -# 1. Define -class MyCustomModel(ModelMixin, ConfigMixin): - @register_to_config - def __init__(self, in_channels: int = 3, hidden_dim: int = 64, out_channels: int = 3): - super().__init__() - self.net = nn.Sequential( - nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.SiLU(), - nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), - nn.SiLU(), - nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) - - -# 2. Save -model = MyCustomModel(in_channels=3, hidden_dim=128, out_channels=3) -model.save_pretrained("./my-custom-model") - -# 3. Manually add auto_map to config.json and copy modeling file -import json, shutil - -config_path = "./my-custom-model/config.json" -with open(config_path) as f: - config = json.load(f) - -config["auto_map"] = {"AutoModel": "modeling_my_model.MyCustomModel"} - -with open(config_path, "w") as f: - json.dump(config, f, indent=2) - -shutil.copy("modeling_my_model.py", "./my-custom-model/modeling_my_model.py") - -# 4. Load via AutoModel -loaded_model = AutoModel.from_pretrained("./my-custom-model", trust_remote_code=True) - -# 5. Verify -x = torch.randn(1, 3, 32, 32) -with torch.no_grad(): - out_original = model(x) - out_loaded = loaded_model(x) - -assert torch.allclose(out_original, out_loaded) -print("Models produce identical outputs!") -``` - -## Using Relative Imports in Custom Code - -If your custom model depends on additional modules, you can use relative imports. For example, if your model uses a custom attention layer defined in a separate file: - -``` -my-custom-model/ -├── config.json -├── diffusion_pytorch_model.safetensors -├── modeling_my_model.py # imports from .my_attention -└── my_attention.py # custom attention implementation -``` - -In `modeling_my_model.py`: - -```python -from .my_attention import MyAttention -``` - -The dynamic module loader will automatically resolve and download all relatively imported files. - -## Security Note - -`trust_remote_code=True` executes arbitrary Python code from the model repository. Only use it with repositories you trust. You can globally disable remote code execution by setting the environment variable: - -```bash -export DIFFUSERS_DISABLE_REMOTE_CODE=1 -``` diff --git a/example.py b/example.py deleted file mode 100644 index bb0a5b430e3a..000000000000 --- a/example.py +++ /dev/null @@ -1,120 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from diffusers import QwenImageTransformer2DModel -from diffusers.utils.torch_utils import randn_tensor - -from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin -from ..testing_utils import ( - AttentionTesterMixin, - ContextParallelTesterMixin, - LoraTesterMixin, - MemoryTesterMixin, - ModelTesterMixin, - TorchCompileTesterMixin, - TrainingTesterMixin, -) - - -enable_full_determinism() - - -class QwenImageTransformerTesterConfig: - model_class = QwenImageTransformer2DModel - pretrained_model_name_or_path = "" - pretrained_model_kwargs = {"subfolder": "transformer"} - - @property - def generator(self): - return torch.Generator("cpu").manual_seed(0) - - def get_init_dict(self) -> dict[str, int | list[int]]: - # __init__ parameters: - # patch_size: int = 2 - # in_channels: int = 64 - # out_channels: Optional[int] = 16 - # num_layers: int = 60 - # attention_head_dim: int = 128 - # num_attention_heads: int = 24 - # joint_attention_dim: int = 3584 - # guidance_embeds: bool = False - # axes_dims_rope: Tuple[int, int, int] = - return {} - - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - # forward() parameters: - # hidden_states: torch.Tensor - # encoder_hidden_states: torch.Tensor - # encoder_hidden_states_mask: torch.Tensor - # timestep: torch.LongTensor - # img_shapes: Optional[List[Tuple[int, int, int]]] - # txt_seq_lens: Optional[List[int]] - # guidance: torch.Tensor - # attention_kwargs: Optional[Dict[str, Any]] - # controlnet_block_samples - # return_dict: bool = True - # TODO: Fill in dummy inputs - return {} - - @property - def input_shape(self) -> tuple[int, ...]: - return (1, 1) - - @property - def output_shape(self) -> tuple[int, ...]: - return (1, 1) - - -class TestQwenImageTransformerModel(QwenImageTransformerTesterConfig, ModelTesterMixin): - pass - - -class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin): - pass - - -class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin): - pass - - -class TestQwenImageTransformerTorchCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin): - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - # TODO: Implement dynamic input generation - return {} - - -class TestQwenImageTransformerLora(QwenImageTransformerTesterConfig, LoraTesterMixin): - pass - - -class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): - pass - - -class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin): - pass - - -class TestQwenImageTransformerLoraHotSwappingForModel(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - # TODO: Implement dynamic input generation - return {} diff --git a/modular_model_index.json b/modular_model_index.json deleted file mode 100644 index 48872b1c98d7..000000000000 --- a/modular_model_index.json +++ /dev/null @@ -1,73 +0,0 @@ -{ - "_blocks_class_name": "SequentialPipelineBlocks", - "_class_name": "Flux2ModularPipeline", - "_diffusers_version": "0.36.0.dev0", - "scheduler": [ - null, - null, - { - "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev", - "revision": null, - "subfolder": "scheduler", - "type_hint": [ - "diffusers", - "FlowMatchEulerDiscreteScheduler" - ], - "variant": null - } - ], - "text_encoder": [ - null, - null, - { - "revision": null, - "subfolder": "text_encoder", - "type_hint": [ - "transformers", - "Mistral3ForConditionalGeneration" - ], - "variant": null - } - ], - "tokenizer": [ - null, - null, - { - "revision": null, - "subfolder": "tokenizer", - "type_hint": [ - "transformers", - "AutoProcessor" - ], - "variant": null - } - ], - "transformer": [ - null, - null, - { - "pretrained_model_name_or_path": "diffusers/FLUX.2-dev-bnb-4bit", - "revision": null, - "subfolder": "transformer", - "type_hint": [ - "diffusers", - "Flux2Transformer2DModel" - ], - "variant": null - } - ], - "vae": [ - null, - null, - { - "pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev", - "revision": null, - "subfolder": "vae", - "type_hint": [ - "diffusers", - "AutoencoderKLFlux2" - ], - "variant": null - } - ] -} diff --git a/pr_review/12498.md b/pr_review/12498.md deleted file mode 100644 index 1f52d1f7acc9..000000000000 --- a/pr_review/12498.md +++ /dev/null @@ -1,56 +0,0 @@ -Now let me look at what `dequantize_gguf_tensor` does to understand the fix better: - -## Code Review: GGUF fix for unquantized types - -### 1. Summary of Changes - -The PR fixes a bug in the `_fused_mul_mat_gguf` function (line 79-105) where unquantized GGUF tensor types (F32, F16, BF16) were incorrectly handled. - -**Before:** When `qweight_type` was an unquantized type, the code directly performed matrix multiplication: `x @ qweight.T` - -**After:** It now calls `dequantize_gguf_tensor(qweight)` first, then performs the matrix multiplication: `x @ weight.T` - -The issue was that even "unquantized" GGUF tensors are stored in an 8-bit tensor format and need to be converted to their proper data type representation before use. - -### 2. Potential Issues or Bugs - -**None identified.** The fix is correct and addresses a real bug: - -- The `dequantize_gguf_tensor` function (lines 509-527) checks if the tensor has a `quant_type` attribute and handles the appropriate conversion -- For BF16 specifically, there's a dedicated `dequantize_blocks_BF16` function (lines 428-429) that properly converts the 8-bit storage format -- The fix aligns with how the native path already works in `forward_native` (lines 593-599), which always calls `dequantize_gguf_tensor` - -### 3. Code Quality Observations - -**Strengths:** -- The fix is minimal and surgical - only changes what's necessary -- Maintains consistency with the `forward_native` path which already uses `dequantize_gguf_tensor` -- The variable naming (`weight` instead of reusing `qweight`) makes it clear a transformation occurred - -**Minor observation:** -- The comment on line 80 "there is no need to call any kernel for fp16/bf16" is now slightly misleading since we DO need to call dequantization logic. Consider updating it to something like: "no need to call specialized GGUF kernel for fp16/bf16, but still need to dequantize from 8-bit storage" - -### 4. Security Considerations - -**No security concerns.** The change: -- Doesn't introduce any external input handling -- Doesn't modify control flow in a way that could bypass security checks -- Only fixes a data type conversion issue - -### 5. Suggestions for Improvement - -1. **Update the comment** on line 80 in `src/diffusers/quantizers/gguf/utils.py:80`: - ```python - # unquantized types still need dequantization from 8-bit storage, but don't need specialized kernels - if qweight_type in UNQUANTIZED_TYPES: - weight = dequantize_gguf_tensor(qweight) - return x @ weight.T - ``` - -2. **Consider adding a test** to prevent regression of this issue. A test should verify that unquantized GGUF tensors produce correct output shapes and values. - -3. **Documentation:** The PR description mentions torch 2.8/2.9 build availability. This might be worth tracking in a GitHub issue if not already done. - -### Verdict - -**Approve with minor comment update suggestion.** The fix correctly addresses a real shape mismatch bug where GGUF's 8-bit storage format wasn't being properly converted for unquantized types. The logic is sound and aligns with the existing native implementation path. diff --git a/pr_review/12744.md b/pr_review/12744.md deleted file mode 100644 index 30e2ac1a71b6..000000000000 --- a/pr_review/12744.md +++ /dev/null @@ -1,186 +0,0 @@ -I'll provide a comprehensive code review of this MagCache PR. - -## Summary of Changes - -This PR implements MagCache (Magnitude-aware Cache), a training-free inference acceleration technique for diffusion transformers. The implementation: - -- Adds a `MagCacheConfig` class for configuration -- Implements `MagCacheHeadHook` and `MagCacheBlockHook` following the existing ModelHook pattern -- Includes calibration mode to compute magnitude ratios for any transformer model -- Provides pre-computed `FLUX_MAG_RATIOS` for Flux models -- Adds comprehensive documentation and tests - -## Potential Issues and Bugs - -### 1. **Critical: Missing Hook Removal in `disable_cache()`** -```python -# In cache_utils.py, line ~127 -elif isinstance(self._cache_config, MagCacheConfig): - registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True) -``` - -**Issue**: The code only removes the leader/head hook but not the block hooks (`_MAG_CACHE_BLOCK_HOOK`). This will leave hooks attached when disabling the cache. - -**Fix**: Add removal of block hooks: -```python -elif isinstance(self._cache_config, MagCacheConfig): - registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True) - registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) -``` - -### 2. **Shape Mismatch Handling Logic Issue** -In `mag_cache.py` lines 224-248, the shape mismatch handling has a potential issue: - -```python -elif ( - output.ndim == 3 - and res.ndim == 3 - and output.shape[0] == res.shape[0] - and output.shape[2] == res.shape[2] -): - diff = output.shape[1] - res.shape[1] - if diff > 0: - output = output.clone() - output[:, diff:, :] = output[:, diff:, :] + res -``` - -**Issue**: This assumes text tokens come first and image tokens come last. This may not be universal across all models (e.g., some models interleave tokens differently). - -**Suggestion**: Add a comment explaining this assumption or add configuration to specify the concatenation strategy. - -### 3. **Residual Calculation Fallback is Unsafe** -In `mag_cache.py` line 343: - -```python -else: - # Fallback for completely mismatched shapes - residual = out_hidden -``` - -**Issue**: This fallback doesn't compute a residual at all—it just uses the output. This will cause incorrect behavior in subsequent steps. - -**Suggestion**: Either raise an error or add a warning that calibration is required for this model architecture. - -### 4. **Device Mismatch Handling is Incomplete** -```python -if res.device != output.device: - res = res.to(output.device) -``` - -**Issue**: This only handles device mismatch for the residual, but doesn't handle dtype mismatches which could occur with mixed precision training. - -**Suggestion**: Add dtype handling: -```python -if res.device != output.device or res.dtype != output.dtype: - res = res.to(device=output.device, dtype=output.dtype) -``` - -### 5. **Calibration Logging Could Be Missed** -The calibration results are printed to stdout (line 380) and logged. However, if the user has logging disabled or redirected, they might miss this critical information. - -**Suggestion**: Consider returning calibration results from the pipeline or raising a more visible notification. - -### 6. **Test Suite is Skipped** -```python -@unittest.skip("MagCache unit tests are skipped.") -class MagCacheTests(unittest.TestCase): -``` - -**Issue**: All unit tests are skipped, which means the core logic isn't being validated in CI. - -**Action Required**: Remove the skip decorator before merging or add a comment explaining why it's temporarily skipped. - -## Code Quality Observations - -### Strengths: -1. **Well-structured**: Follows existing patterns (ModelHook, StateManager) consistently -2. **Good documentation**: Comprehensive docstrings and inline comments -3. **Calibration mode**: Clever design allowing model-agnostic usage -4. **Error handling**: Validates configuration upfront -5. **Interpolation logic**: Smart handling of different step counts via `nearest_interp()` - -### Areas for Improvement: - -1. **Magic Numbers**: Several hardcoded values could be constants: - ```python - eps = 1e-8 # Line 335 in _perform_calibration_step - expected_atol = 0.1 # Line 2989 in test - ``` - -2. **Code Duplication**: The logic for handling tuple returns appears multiple times. Consider extracting to a helper method. - -3. **Type Hints**: Some methods lack return type hints (e.g., `nearest_interp`) - -4. **Compiler Disable Decorator**: The `@torch.compiler.disable` decorator is used but not explained. Add a comment about why compilation is disabled. - -## Security Considerations - -### Low Risk: -- No external network calls -- No file system access beyond logging -- No execution of arbitrary code -- Tensor operations are standard PyTorch - -### Observations: -1. **Device Transfer**: The `.to(device)` calls are safe but could consume unexpected memory if tensors are large -2. **State Management**: The state is properly isolated and reset between inference runs - -## Suggestions for Improvement - -### 1. Add Configuration Validation -```python -def __post_init__(self): - # Existing checks... - - # Add bounds checking - if not 0.0 <= self.retention_ratio <= 1.0: - raise ValueError(f"retention_ratio must be in [0, 1], got {self.retention_ratio}") - if self.max_skip_steps < 1: - raise ValueError(f"max_skip_steps must be >= 1, got {self.max_skip_steps}") - if self.threshold <= 0: - raise ValueError(f"threshold must be positive, got {self.threshold}") -``` - -### 2. Add Metrics/Statistics -Consider adding optional statistics collection: -- How many blocks were skipped per step -- Average accumulated error -- Total compute savings - -This would help users optimize their thresholds. - -### 3. Improve Documentation Example -The documentation example could show expected speedup or quality metrics to set user expectations. - -### 4. Add Gradient Mode Check -```python -if torch.is_grad_enabled(): - logger.warning("MagCache is designed for inference only. Gradients are enabled but will not flow correctly through cached blocks.") -``` - -### 5. Consider Memory Cleanup -The `previous_residual` is held in state indefinitely. Consider adding explicit cleanup: -```python -def cleanup(self): - if self.previous_residual is not None: - del self.previous_residual - self.previous_residual = None -``` - -## Minor Issues - -1. **Line 26**: Unused import or should be used in logger initialization -2. **Line 332**: Comment says "Fallback to matching tail" but logic is unclear -3. **Documentation**: The TIP about batched CFG could include more detail about why this works - -## Conclusion - -This is a **well-implemented feature** with good design patterns and documentation. The main concerns are: - -1. **Critical**: Fix the missing block hook removal in `disable_cache()` (Line 127) -2. **Important**: Unskip and fix the unit tests -3. **Recommended**: Improve shape mismatch handling with better error messages - -The implementation is production-ready once these issues are addressed. The calibration mode is particularly clever and makes this genuinely model-agnostic. - -**Recommendation**: Request changes for items #1 and #2, then approve once fixed. diff --git a/pr_review/13028.md b/pr_review/13028.md deleted file mode 100644 index 7988498aecf1..000000000000 --- a/pr_review/13028.md +++ /dev/null @@ -1,99 +0,0 @@ -# PR #13028: [Modular] add explicit workflow support - -**Author:** @yiyixuxu -**Branch:** `modular-workflow` -> `main` -**Files changed:** `modular_pipeline.py`, `modular_pipeline_utils.py`, `qwenimage/modular_blocks_qwenimage.py` -**+298 / -165** - ---- - -## Summary - -This PR adds a `_workflow_map` class attribute to `SequentialPipelineBlocks` that maps named workflows (e.g., `"text2image"`, `"inpainting"`) to their trigger inputs. Users can then call `get_workflow("text2image")` to get the execution blocks for that workflow. The PR also refactors `get_execution_blocks` into `ConditionalPipelineBlocks` and `SequentialPipelineBlocks`, moves `combine_inputs`/`combine_outputs` to module-level functions, and improves docstrings. - -## Main Concern: "Workflow" as a New Concept - -Modular Diffusers already requires users to learn: **Pipelines**, **Blocks** (Sequential, Conditional, Auto, Loop), **Steps**, **Components**, **Inputs/Outputs**, **Trigger Inputs**, **Execution Blocks**, **PipelineState**, and **BlockState**. Adding "workflow" as yet another term increases cognitive overhead. - -The underlying feature is useful — named presets for trigger inputs are genuinely helpful for discoverability. But "workflow" may not be the right label: - -1. **Overloaded term**: "Workflow" is heavily used in the AI/ML ecosystem (ComfyUI workflows, orchestration workflows, CI/CD workflows). Users may expect something more complex than what this is. - -2. **It's really a task/mode, not a workflow**: `"text2image"`, `"inpainting"`, `"image2image"` are *tasks* or *modes*. The rest of diffusers already uses "task" terminology — `AutoPipelineForText2Image`, `AutoPipelineForInpainting`, etc. Calling the same concept "workflow" in Modular Diffusers creates inconsistency. - -3. **It's a thin wrapper**: `get_workflow("text2image")` is just `get_execution_blocks(prompt=True)`. Users still need to understand `get_execution_blocks` and trigger inputs to do anything beyond the predefined workflows. The abstraction doesn't save much complexity. - -**Suggestion**: Consider `_task_map` / `get_task()` / `task_names` to align with existing diffusers terminology, or `_mode_map` / `get_mode()` / `mode_names` for something more neutral. The existing `auto_pipeline.py` already uses "task" internally — `_get_task_class()` maps pipeline class names to task-specific variants (text2image, image2image, inpainting), and the public API follows the `AutoPipelineFor` naming pattern. These are the exact same concepts this PR calls "workflows." Alternatively, this could simply be better documentation on `get_execution_blocks` with named examples, rather than a new API surface. - -## Code Issues - -### Behavioral change: `outputs` -> `intermediate_outputs` in traversal - -`modular_pipeline.py` — In `SequentialPipelineBlocks.get_execution_blocks`, the old `_traverse_trigger_blocks` tracked `block.outputs` to propagate available values to downstream blocks. The new code tracks `block.intermediate_outputs` instead: - -```python -# Old -if hasattr(block, "outputs"): - for out in block.outputs: - active_inputs[out.name] = True - -# New -if hasattr(block, "intermediate_outputs"): - for out in block.intermediate_outputs: - active_inputs[out.name] = True -``` - -`intermediate_outputs` and `outputs` can differ — `intermediate_outputs` includes values passed between blocks in the pipeline state, while `outputs` are the final outputs. This could change which downstream conditional blocks get triggered. If this is intentional, it should be called out explicitly in the PR description since it affects existing behavior. - -### `_workflow_map` on base class, implementations only on `SequentialPipelineBlocks` - -`_workflow_map = None` is defined on `ModularPipelineBlocks` (the base class), but `workflow_names` and `get_workflow()` are only implemented on `SequentialPipelineBlocks`. The base class stubs raise `NotImplementedError`. This is misleading — it suggests workflows *could* be implemented for other block types. If workflows are intentionally only for `SequentialPipelineBlocks`, define `_workflow_map` there and don't add stubs to the base class. - -### `get_execution_blocks` no longer filters None values - -Old code: -```python -active_inputs = {k: v for k, v in kwargs.items() if v is not None} -``` - -New code: -```python -active_inputs = dict(kwargs) -``` - -This is a behavioral change to the public `get_execution_blocks` API. The old code explicitly stripped `None` values so users could write `get_execution_blocks(prompt="a cat", image=None)` and `image` wouldn't trigger anything. The new code passes `None` through. It happens to still work because `select_block` checks `is not None` internally, but callers can no longer rely on the documented filtering behavior. This should be noted. - -### `default_block_name` changed from property to instance attribute - -In `AutoPipelineBlocks`, `default_block_name` was a `@property` that derived the default from `block_trigger_inputs` on every access. It's now set as an instance attribute in `__init__`. This is mostly fine, but the new code also adds a validation that `default_block_name is not None` raises an error before it's set — so subclasses that accidentally set `default_block_name` as a class attribute will now break. This is a stricter contract that should be documented. - -### Typo - -`modular_pipeline.py` — `# currentlyonly ConditionalPipelineBlocks` should be `# currently only`. - -### `_get_trigger_inputs()` called multiple times in `__repr__` - -In `SequentialPipelineBlocks.__repr__`, `self._get_trigger_inputs()` is called 3 times (condition check, trigger inputs display, example input). This recursively traverses all blocks each time. Should be computed once and reused. - -### Duplicate `format_workflow` calls in `__repr__` and `doc` - -Both `SequentialPipelineBlocks.__repr__` and `SequentialPipelineBlocks.doc` build the description + workflow string independently with identical logic: - -```python -description = self.description -if self._workflow_map is not None: - workflow_str = format_workflow(self._workflow_map) - description = f"{self.description}\n\n{workflow_str}" -``` - -This should be extracted into a property or helper. - -### No tests - -The PR description mentions "I will add a test suite for this too!" but there are no tests included. Workflow resolution, edge cases (empty workflow map, missing workflow name, workflows with overlapping triggers), and the `get_execution_blocks` refactoring should all be tested before merge. - -## Refactoring Quality - -The refactoring of `get_execution_blocks` from a monolithic method on `SequentialPipelineBlocks` into separate implementations on `ConditionalPipelineBlocks` and `SequentialPipelineBlocks` is a good separation of concerns. Moving `combine_inputs`/`combine_outputs` to module-level functions is also reasonable since they don't depend on instance state. - -The improved `AutoPipelineBlocks` docstring with the example is a significant documentation improvement. diff --git a/pr_review/13075.md b/pr_review/13075.md deleted file mode 100644 index 23b042b0cbfa..000000000000 --- a/pr_review/13075.md +++ /dev/null @@ -1,97 +0,0 @@ -I'll review this PR that addresses PyTorch version compatibility for distributed operations. - -## Summary of Changes - -The PR refactors the `gather_size_by_comm` function in `_modeling_parallel.py` to handle PyTorch versions prior to 2.6 that don't have the `torch.accelerator` API. The changes replace a single ternary expression with a multi-level conditional that: - -1. First checks if "cpu" is in the backend string -2. Then checks if `torch.accelerator` exists (PyTorch >= 2.6) -3. Falls back to CUDA as a default device - -## Potential Issues or Bugs - -**1. Device Type Inconsistency** -The original code returns a string `"cpu"` but the new code returns `torch.device("cuda")` objects. This inconsistency could cause issues: - -```python -gather_device = "cpu" # str -# vs -gather_device = torch.device("cuda") # torch.device object -``` - -**Recommendation:** Use `torch.device()` consistently: -```python -if "cpu" in comm_backends: - gather_device = torch.device("cpu") -elif hasattr(torch, "accelerator"): - acc = torch.accelerator.current_accelerator() - gather_device = torch.device(acc if acc is not None else "cuda") -else: - gather_device = torch.device("cuda") -``` - -**2. Unclear Accelerator Return Behavior** -The comment states "Fall back to CUDA when no accelerator is returned" but it's unclear when `torch.accelerator.current_accelerator()` would return `None`. This should be verified or documented. - -**3. Missing Type Information** -What type does `torch.accelerator.current_accelerator()` return? If it returns a string like `"cuda"` or `"mps"`, the code should handle it consistently. If it returns a device object, the logic might need adjustment. - -## Code Quality Observations - -**Positive:** -- Clear comments explaining the fallback logic -- Proper use of `hasattr()` for backward compatibility -- Addresses the reported issue #13074 - -**Areas for Improvement:** - -1. **Device type consistency** (mentioned above) - -2. **Consider alternative hardware accelerators:** The fallback to CUDA might not be appropriate for all systems (e.g., MPS on macOS, XPU on Intel). Consider: - ```python - else: - # Fallback for PyTorch < 2.6 - if torch.cuda.is_available(): - gather_device = torch.device("cuda") - else: - gather_device = torch.device("cpu") - ``` - -3. **Code style:** The expanded conditional is more readable but could benefit from extracting into a helper function if this pattern appears elsewhere: - ```python - def _get_gather_device(comm_backends: str) -> torch.device: - """Determine device for distributed gather operations.""" - # ... implementation - ``` - -## Security Considerations - -No significant security issues identified. This is primarily a compatibility fix for internal device selection logic. - -## Suggestions for Improvement - -1. **Add a test case** to verify behavior on PyTorch < 2.6 (if not already covered) - -2. **Document the behavior** more explicitly: - ```python - # Determine gather device based on backend and PyTorch version - # Priority: CPU backend > torch.accelerator (>= 2.6) > CUDA fallback (< 2.6) - ``` - -3. **Consider this more defensive approach:** - ```python - if "cpu" in comm_backends: - gather_device = torch.device("cpu") - elif hasattr(torch, "accelerator"): - acc = torch.accelerator.current_accelerator() - gather_device = torch.device(acc if acc else "cuda") - elif torch.cuda.is_available(): - gather_device = torch.device("cuda") - else: - # Fallback to CPU if no GPU available - gather_device = torch.device("cpu") - ``` - -## Verdict - -The PR addresses the compatibility issue but has a **type inconsistency bug** that should be fixed before merging. The string vs `torch.device` object mismatch could cause runtime errors. Once that's addressed, the change is sound for backward compatibility. diff --git a/pr_review/13116.md b/pr_review/13116.md deleted file mode 100644 index 664550cc45c5..000000000000 --- a/pr_review/13116.md +++ /dev/null @@ -1,66 +0,0 @@ -# PR #13116: [tests] tests for `modules_to_not_convert` - -**Author:** @sayakpaul -**Branch:** `fix-modules-no-convert-torchao` -> `main` -**Files changed:** `tests/models/testing_utils/quantization.py`, `tests/models/transformers/test_models_transformer_flux.py` - ---- - -## Summary - -This PR fixes the `modules_to_not_convert` tests that were effectively dead code. They existed in the base `QuantizationTesterMixin` but never ran because no test class defined `modules_to_not_convert_for_test`. The PR activates these tests for Flux and fixes several underlying bugs that would have caused them to fail. - -## Key Changes - -1. **BnB config key fix**: `BitsAndBytesConfig` uses `llm_int8_skip_modules`, not `modules_to_not_convert`. The base test was setting the wrong key, so modules were never actually excluded. - -2. **TorchAO `_verify_if_layer_quantized` fix**: Previously only checked `isinstance(module, torch.nn.Linear)`, which is always true for TorchAO (it doesn't replace the module class). Now properly checks weight tensor types (`AffineQuantizedTensor`, `LinearActivationQuantizedTensor`). - -3. **`_is_module_quantized` fix**: Now passes `quant_config_kwargs` to `_verify_if_layer_quantized`. Previously it passed `{}`, which caused BnB to always check for `Int8Params` even on 4-bit models. - -4. **Cleanup**: Removes unused guard blocks (`is_gguf_available`, `is_torchao_available`) that only contained `pass`. - -5. **Activates tests**: Adds `modules_to_not_convert_for_test` returning `["norm_out.linear"]` to BnB, Quanto, TorchAo, and ModelOpt Flux test classes. - -## Issues - -### `to_not_convert_key` parameter pollutes the base class interface - -`quantization.py:271-273` — The new `to_not_convert_key` parameter on `_test_quantization_modules_to_not_convert` exists solely for BnB's naming quirk (`llm_int8_skip_modules` vs `modules_to_not_convert`). Every other backend uses the default. This leaks a BnB-specific detail into the shared base method. - -BnB already has its own `test_bnb_modules_to_not_convert` that could handle the key translation locally — either by building the correct `config_kwargs` with `llm_int8_skip_modules` before calling `_create_quantized_model` directly, or by overriding the test. This keeps the base method clean and isolates BnB's naming quirk in `BitsAndBytesTesterMixin` where it belongs. - -### Code duplication in TorchAO `test_torchao_modules_to_not_convert` - -`quantization.py:915-950` — The TorchAO test inlines ~30 lines from `_test_quantization_modules_to_not_convert` to skip the memory footprint comparison. If the base method is updated in the future, this copy won't get the fix. Consider parameterizing the base method instead: - -```python -def _test_quantization_modules_to_not_convert( - self, config_kwargs, modules_to_not_convert, check_memory_footprint=True, -): - # ... existing module-walking logic ... - - if check_memory_footprint: - # Compare memory footprint with fully quantized model - ... -``` - -Then TorchAO could simply call: -```python -self._test_quantization_modules_to_not_convert( - TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude, - check_memory_footprint=False, -) -``` - -### TorchAO imports inside method body - -`quantization.py:822-823` — The `torchao` imports are placed inside `_verify_if_layer_quantized`. While functional (avoids import errors when torchao isn't installed), these could be placed at module level under the existing `is_torchao_available()` guard for consistency with how `bnb` and `QLinear` imports are handled. Minor style point. - -### `_is_module_quantized` callers not updated - -`quantization.py:368` — The `_test_dequantize` method still calls `self._is_module_quantized(module)` without `quant_config_kwargs`. This happens to work correctly (for BnB, checking `Int8Params` after dequantization correctly returns False; for TorchAO, the weight won't be an `AffineQuantizedTensor`), but it means BnB dequantize for 4-bit models asserts the weight is not `Int8Params` rather than asserting it's not `Params4bit`. Consider updating for correctness. - -### Missing GGUF test coverage - -GGUF's `GGUFTesterMixin` doesn't have a `test_gguf_modules_to_not_convert` method. If GGUF is expected to support `modules_to_not_convert`, a test should be added. If not, a comment explaining why would be helpful. diff --git a/pr_review/pr_12700_flashpack.md b/pr_review/pr_12700_flashpack.md deleted file mode 100644 index 975fbd6ca18c..000000000000 --- a/pr_review/pr_12700_flashpack.md +++ /dev/null @@ -1,144 +0,0 @@ -# PR #12700 — FlashPack Integration Review - -**URL**: https://github.com/huggingface/diffusers/pull/12700 -**State**: OPEN -**Branch**: `flashpack` → `main` - -## Summary - -Adds FlashPack as a new weight serialization format for faster model loading. FlashPack packs model weights into a single contiguous file (`model.flashpack`) that can be loaded efficiently, especially for larger models. The PR integrates it across `ModelMixin` (save/load), `DiffusionPipeline` (save/load/download), and supporting utilities. - -## Files Changed - -- `setup.py` / `dependency_versions_table.py` — add `flashpack` dependency -- `src/diffusers/utils/constants.py` — `FLASHPACK_WEIGHTS_NAME`, `FLASHPACK_FILE_EXTENSION` -- `src/diffusers/utils/import_utils.py` — `is_flashpack_available()` -- `src/diffusers/utils/__init__.py` — re-exports -- `src/diffusers/models/model_loading_utils.py` — `load_flashpack_checkpoint()`, dispatch in `load_state_dict()` -- `src/diffusers/models/modeling_utils.py` — `save_pretrained(use_flashpack=...)`, `from_pretrained(use_flashpack=..., flashpack_kwargs=...)` -- `src/diffusers/pipelines/pipeline_utils.py` — pipeline-level `save_pretrained`, `from_pretrained`, `download` with `use_flashpack` -- `src/diffusers/pipelines/pipeline_loading_utils.py` — `load_sub_model`, `_get_ignore_patterns`, `get_class_obj_and_candidates`, `filter_model_files` - ---- - -## Issues - -### 1. `use_flashpack=True` default in `DiffusionPipeline.download()` - -```python -# pipeline_utils.py, in download() -use_flashpack = kwargs.pop("use_flashpack", True) -``` - -This defaults to `True`, meaning `download()` will always try to download FlashPack files by default. Every other call site defaults to `False`. This looks like a bug — it would change download behavior for all users even if they never asked for FlashPack. Should be `False`. - -### 2. `load_flashpack_checkpoint` is unused in the `from_pretrained` path - -`load_flashpack_checkpoint()` is added to `model_loading_utils.py` and wired into `load_state_dict()`. However, in `ModelMixin.from_pretrained`, when `use_flashpack=True`, the code **early-returns** after calling `flashpack.mixin.assign_from_file()` directly — it never goes through `load_state_dict()`. So `load_flashpack_checkpoint` is dead code in the `from_pretrained` flow. Either: -- Remove it if FlashPack always uses its own assign path, or -- Use it consistently (load state dict → assign to model, like safetensors/pickle) - -### 3. `resolved_model_file` may be undefined when `use_flashpack=True` and file fetch fails - -```python -# modeling_utils.py, from_pretrained -elif use_flashpack: - try: - resolved_model_file = _get_model_file(...) - except IOError as e: - logger.error(...) - if not allow_pickle: - raise - logger.warning("Defaulting to unsafe serialization...") -``` - -If the `IOError` is caught and `allow_pickle` is truthy, `resolved_model_file` is never set but is used later at `flashpack.mixin.assign_from_file(model=model, path=resolved_model_file[0], ...)`. This would crash with `NameError` or `UnboundLocalError`. The fallback logic (copied from the safetensors block) doesn't make sense for FlashPack — there's no pickle fallback for FlashPack. The `except` block should just re-raise unconditionally. - -### 4. `resolved_model_file[0]` assumes a list, but `_get_model_file` returns a string - -```python -flashpack.mixin.assign_from_file( - model=model, - path=resolved_model_file[0], # indexing into a string - ... -) -``` - -`_get_model_file` returns a single file path (string), not a list. `resolved_model_file[0]` would give the first character of the path. Should be just `resolved_model_file`. - -### 5. `device_map` handling assumes `device_map[""]` exists - -```python -flashpack_device = device_map[""] -``` - -`device_map` can be a dict with arbitrary keys (layer names, module names), not just `{"": device}`. This would raise `KeyError` for any non-trivial device map. Should handle the general case or document the constraint. - -### 6. `FlashPack` prefix stripping in `get_class_obj_and_candidates` is unexplained - -```python -if class_name.startswith("FlashPack"): - class_name = class_name.removeprefix("FlashPack") -``` - -This is injected into a general-purpose utility function with no explanation of when/why a class name would have a `FlashPack` prefix. This seems like it handles a specific config format but there's no corresponding code that writes `FlashPack`-prefixed class names. If this is for some external convention, it should be documented. If not needed, remove it. - -### 7. Duplicated availability check pattern - -The `is_flashpack_available()` check + import + error message pattern is repeated 3 times: -- `load_flashpack_checkpoint()` in `model_loading_utils.py` -- `save_pretrained()` in `modeling_utils.py` -- `from_pretrained()` in `modeling_utils.py` - -Each has slightly different wording. Should be consolidated — e.g., a helper or just use a single `require_flashpack()` function, consistent with how other optional deps are handled. - -### 8. `save_pretrained` error message says "load" instead of "save" - -```python -# modeling_utils.py, save_pretrained, use_flashpack=True branch -raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.") -``` - -This is in the **save** path, but the message says "load". Should say "save". - -### 9. No `config.json` saved alongside FlashPack weights in `save_pretrained` - -When `use_flashpack=True` in `ModelMixin.save_pretrained`, the model config is saved normally at the top of the method, but the FlashPack branch calls `flashpack.serialization.pack_to_file()` with `target_dtype=self.dtype`. It's not clear if FlashPack's own `config.json` (mentioned in the benchmark script as `flashpack_config.json`) is the same as diffusers' `config.json`. If they're different files, loading back with `from_pretrained(use_flashpack=True)` might fail to reconstruct the model architecture since `from_config` needs the diffusers config. - -### 10. `output_loading_info` warning placement - -```python -if output_loading_info: - logger.warning("`output_loading_info` is not supported with FlashPack.") - return model, {} -``` - -This returns an empty dict silently. The warning is fine, but returning `{}` instead of a proper `loading_info` structure (with `missing_keys`, `unexpected_keys`, etc.) could break code that destructures the result. - -### 11. No tests included - -The PR has no test files. At minimum there should be: -- Unit tests for `load_flashpack_checkpoint` (mocking `flashpack`) -- Unit tests for save/load roundtrip with `use_flashpack=True` -- Integration test for pipeline save/load - -### 12. FlashPack doesn't support sharding - -The `save_pretrained` FlashPack branch ignores `max_shard_size` entirely and always saves a single file. This is fine for the format but should either: -- Log a warning if `max_shard_size` is explicitly set alongside `use_flashpack=True` -- Document this limitation - ---- - -## Minor Issues - -- The benchmark in the PR description shows FlashPack is actually **slower** for fp16 SD v1.5 (0.95x). The claimed benefit is only for bf16. This should be prominently noted. -- `FLASHPACK_WEIGHTS_NAME = "model.flashpack"` breaks the diffusers naming convention (`diffusion_pytorch_model.*` for other formats). -- The PR modifies `_get_ignore_patterns` but doesn't handle the case where both `use_safetensors` and `use_flashpack` are True. -- `filter_model_files` adds `FLASHPACK_WEIGHTS_NAME` to the known weights list but there are no corresponding tests for this filtering. - ---- - -## Verdict - -The PR needs significant work before it's mergeable. The critical issues are the `use_flashpack=True` default in `download()`, the `resolved_model_file[0]` indexing bug, the dead code path with `load_flashpack_checkpoint`, and the lack of tests. The integration pattern also doesn't feel consistent with how other formats (safetensors, GGUF) are integrated — FlashPack bypasses the standard state dict loading path entirely via its own `assign_from_file`, making it a special case that's harder to maintain. diff --git a/pr_review/teacache_pr_12652_review.md b/pr_review/teacache_pr_12652_review.md deleted file mode 100644 index 1cd76e9637f1..000000000000 --- a/pr_review/teacache_pr_12652_review.md +++ /dev/null @@ -1,286 +0,0 @@ -# TeaCache PR #12652 Review Notes - -## PR Overview - -- **PR**: https://github.com/huggingface/diffusers/pull/12652 -- **Title**: Implement TeaCache -- **Author**: LawJarp-A (Prajwal A) -- **Status**: Open -- **Changes**: +1335 / -22 lines across 6 files - -### What is TeaCache? - -[TeaCache](https://huggingface.co/papers/2411.19108) (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by **1.5x-2.6x** by reusing transformer block computations when consecutive timestep embeddings are similar. - -### Algorithm - -1. Extract modulated input from first transformer block (after norm1 + timestep embedding) -2. Compute relative L1 distance vs previous timestep -3. Apply model-specific polynomial rescaling: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` -4. Accumulate rescaled distance across timesteps -5. If accumulated < threshold → Reuse cached residual (FAST) -6. If accumulated >= threshold → Full transformer pass (SLOW, update cache) - ---- - -## The Mid-Forward Intercept Problem - -### Why TeaCache is Model-Specific - -TeaCache needs to intercept **within** a model's forward method, not just at module boundaries: - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Model Forward │ -│ │ -│ PREPROCESSING (must always run) │ -│ ├── x_embedder(hidden_states) │ -│ ├── time_text_embed(timestep, ...) │ -│ └── context_embedder(encoder_hidden_states) │ -│ │ -│ ═══════════════════════════════════════════════════════════│ -│ DECISION POINT ◄── TeaCache needs to intercept HERE │ -│ └── Extract: transformer_blocks[0].norm1(hs, temb)[0] │ -│ ═══════════════════════════════════════════════════════════│ -│ │ -│ CACHEABLE REGION (can be skipped if cached) │ -│ ├── for block in transformer_blocks: ... │ -│ └── for block in single_transformer_blocks: ... │ -│ │ -│ POSTPROCESSING (must always run) │ -│ ├── norm_out(hidden_states, temb) │ -│ └── proj_out(hidden_states) │ -└─────────────────────────────────────────────────────────────┘ -``` - -PyTorch hooks only intercept at **module boundaries** (before/after `forward()`), not within a forward method. The `for` loop over blocks is Python control flow - there's no hook point to skip it. - -### Workaround: Custom Forward Replacement - -The PR replaces the entire model forward with a custom implementation that has cache logic inserted at the right point. This works but requires maintaining separate forward functions for each model. - ---- - -## Comparison of Caching Approaches - -### TeaCache vs FirstBlockCache vs FasterCache - -| Aspect | TeaCache | FirstBlockCache | FasterCache | -|--------|----------|-----------------|-------------| -| **Hook target** | Model forward | Transformer blocks | Attention layers | -| **Decision signal** | Modulated input (norm1 output) | Block output residual | Iteration count | -| **Where signal is** | Inside first block | Block boundary | Attention output | -| **Model-specific needs** | norm1 structure | Block output format | Attention class type | -| **Model-agnostic?** | ❌ No | ✅ Yes | ✅ Yes | - -### Why FirstBlockCache is Model-Agnostic - -FirstBlockCache uses the **first block's output residual** as its signal: - -```python -# FirstBlockCache: hooks individual blocks -def new_forward(self, module, *args, **kwargs): - original_hidden_states = args[0] - output = self.fn_ref.original_forward(*args, **kwargs) # Run block fully - residual = output - original_hidden_states # Signal from OUTPUT - should_compute = self._compare_residual(residual) - ... -``` - -It doesn't need to understand block internals - just input and output. - -### Why FasterCache is Model-Agnostic - -FasterCache hooks **attention layers** (not blocks) using class type checking: - -```python -_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin) - -for name, submodule in module.named_modules(): - if isinstance(submodule, _ATTENTION_CLASSES): - # Hook this attention module -``` - -All transformer models use standardized attention classes. - ---- - -## Model Architecture Analysis - -### Models That Fit TeaCache Pattern - -Models with `norm1(hidden_states, temb)` returning modulated input: - -| Model | norm1 Signature | Modulation Location | Single Residual | -|-------|----------------|---------------------|-----------------| -| FLUX 1 | `norm1(hs, emb=temb) → (tensor, gate)` | Inside norm1 | ✅ | -| FLUX Kontext | `norm1(hs, emb=temb) → (tensor, gate)` | Inside norm1 | ✅ | -| Mochi | `norm1(hs, temb) → (tensor, g, s, g)` | Inside norm1 | ✅ | -| Lumina2 | `norm1(hs, temb) → (tensor, gate)` | Inside norm1 | ✅ | - -### Models That DON'T Fit Pattern - -| Model | norm1 Signature | Modulation Location | Issue | -|-------|----------------|---------------------|-------| -| **FLUX 2** | `norm1(hs) → tensor` | Outside norm1 | Plain LayerNorm | -| **Wan** | `norm1(hs) → tensor` | Outside norm1 | Plain LayerNorm | -| **ZImage** | `attention_norm1(x) → tensor` | Outside norm1 | Plain LayerNorm | -| **CogVideoX** | N/A (uses `emb` directly) | N/A | Dual residual needed | - -### FLUX 1 vs FLUX 2 Architecture Difference - -**FLUX 1** (AdaLayerNorm - modulation inside): -```python -class FluxTransformerBlock: - self.norm1 = AdaLayerNormZero(dim) # Takes temb! - - def forward(self, hidden_states, temb, ...): - norm_hs, gate = self.norm1(hidden_states, emb=temb) # Modulation inside -``` - -**FLUX 2** (Plain LayerNorm - modulation outside): -```python -class Flux2TransformerBlock: - self.norm1 = nn.LayerNorm(dim) # NO temb! - - def forward(self, hidden_states, temb_mod_params_img, ...): - (shift_msa, scale_msa, gate_msa), ... = temb_mod_params_img - norm_hs = self.norm1(hidden_states) # Plain norm - norm_hs = (1 + scale_msa) * norm_hs + shift_msa # Modulation outside -``` - -FLUX 2 follows the Wan/ZImage pattern and would need a separate custom forward. - ---- - -## CogVideoX: The Architectural Outlier - -CogVideoX has two unique requirements that don't fit the pattern: - -### 1. Different Modulated Input Source - -```python -# Other models: extract from norm1 -modulated_inp = block.norm1(hidden_states, temb)[0] - -# CogVideoX: uses timestep embedding directly -modulated_inp = emb # Just the embedding, computed before blocks! -``` - -### 2. Dual Residual Caching - -CogVideoX blocks return and modify TWO tensors: -```python -def forward(self, hidden_states, encoder_hidden_states, temb, ...): - # Both are modified! - return hidden_states, encoder_hidden_states -``` - -Requires caching two residuals: -```python -state.previous_residual = hs_output - hs_input -state.previous_residual_encoder = enc_output - enc_input # Extra! -``` - ---- - -## Recommendations - -### Simplification: FLUX-Only Support - -Given the architectural diversity, recommend supporting only FLUX 1 and FLUX Kontext initially: - -```python -_MODEL_CONFIG = { - "FluxKontext": { - "forward_func": _flux_teacache_forward, - "coefficients": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02], - }, - "Flux": { - "forward_func": _flux_teacache_forward, - "coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], - }, -} -``` - -### What to Remove from PR - -1. **CogVideoX support** - Dual residual architecture doesn't fit -2. **Mochi support** - Can be added later if needed -3. **Lumina2 support** - Can be added later if needed -4. **FLUX 2 support** - Different architecture (plain LayerNorm) - -### Estimated Code Reduction - -| Component | Original (PR) | FLUX-Only | -|-----------|---------------|-----------| -| Forward functions | 4 (~400 lines) | 1 (~100 lines) | -| Model configs | 10 entries | 2 entries | -| State fields | 8 | 5 | -| Utility functions | 6 | 3 | -| **Total teacache.py** | ~900 lines | ~350 lines | - -### Simplified State - -```python -class TeaCacheState(BaseState): - def __init__(self): - self.cnt = 0 - self.num_steps = 0 - self.accumulated_rel_l1_distance = 0.0 - self.previous_modulated_input = None - self.previous_residual = None - # Removed: previous_residual_encoder (CogVideoX) - # Removed: cache_dict (Lumina2) - # Removed: uncond_seq_len (Lumina2) -``` - ---- - -## Why Custom Forwards Are Necessary - -Despite the maintenance burden, custom forwards are the pragmatic approach for TeaCache because: - -1. **Mid-forward intercept required** - Need to access `norm1` output before blocks run -2. **Architectural diversity** - Models differ in where/how modulation happens -3. **Block-level hooks insufficient** - Can't extract modulated input from block hooks -4. **Algorithm requirements** - TeaCache paper specifically uses modulated input as signal - -### Alternative Approaches Considered - -| Approach | Works? | Issue | -|----------|--------|-------| -| Block-level hooks (like FirstBlockCache) | ❌ | Can't access modulated input inside block | -| Attention-level hooks (like FasterCache) | ❌ | Different algorithm, not TeaCache | -| Hook norm1 directly | ⚠️ | norm1 interface varies per model | -| Hybrid (FirstBlockCache signal + TeaCache algorithm) | ⚠️ | Loses "optimal" signal per paper | - ---- - -## PR Code Quality Issues (From Review) - -1. **torch.compile incompatibility** - `.item()` calls in `_compute_rel_l1_distance` create graph breaks -2. **Boundary check bug** - `state.cnt == state.num_steps - 1` when `num_steps=0` evaluates to `-1` -3. **Incomplete Lumina2 state reset** - `cache_dict` and `uncond_seq_len` not reset -4. **Model auto-detection fragility** - Substring matching relies on iteration order - ---- - -## Extension Path - -If support for additional models is needed later: - -1. **Mochi** - Same pattern as FLUX, just add coefficients and reuse `_flux_teacache_forward` or create similar -2. **Lumina2** - Same pattern but needs per-sequence-length caching for CFG -3. **FLUX 2 / Wan / ZImage** - Need separate forwards that extract modulated input differently -4. **CogVideoX** - Needs dual residual support, significant additional complexity - ---- - -## Summary - -- **TeaCache requires custom forwards** due to mid-forward intercept requirement -- **FLUX 1 + FLUX Kontext only** is the recommended scope for initial implementation -- **~60% code reduction** possible by removing unsupported models -- **Clear extension path** for adding models later as needed -- **Maintenance burden** is acceptable given the architectural constraints diff --git a/release_notes/v0.37.0.md b/release_notes/v0.37.0.md deleted file mode 100644 index 4a06621e0154..000000000000 --- a/release_notes/v0.37.0.md +++ /dev/null @@ -1,129 +0,0 @@ -# Diffusers v0.37.0 Release Notes - -*Release based on 191 commits since v0.36.0* - ---- - -## Highlights - -- **Modular Pipelines overhaul**: Major investment in the modular pipeline system with explicit workflow support, improved loaders, documentation, and modular implementations for Wan, Flux2, Z-Image, Qwen, and Mellon pipelines. -- **New pipelines and models**: Cosmos Predict2.5, LTX 2.0 Video, LongCat-Image, Fibo Edit, Z-Image Omni Base, and more. -- **Distributed inference improvements**: Unified Sequence Parallel attention, Ulysses Anything Attention, and context parallel support in native flash attention. -- **Python 3.8 dropped**: Sunset Python 3.8 and cleaned up explicit `typing` exports. - ---- - -## New Pipelines and Models - -- **Cosmos Predict2.5**: Base inference pipeline, scheduler, and checkpoint conversion; 14B model support (#12852, #12863) -- **Cosmos Transfer2.5**: General transfer pipelines for segmentation, depth, blur, and edge (#13066) -- **LTX 2.0 Video Pipelines**: New video generation pipelines (#12915), distilled checkpoint support (#12934), single-file loading (#12983), LoRA support (#12933), long multi-prompt (#12614) -- **LongCat-Image**: New pipeline with offloading/quantization support and regional compile acceleration (#12828, #12963, #12699, #13019, #13021) -- **Fibo Edit Pipeline**: New editing pipeline (#12930) -- **Z-Image Omni Base**: New implementation (#12857) -- **Z-Image Turbo ControlNet**: ControlNet support for Z-Image Turbo (#12792) -- **Z-Image Inpaint Pipeline**: Inpainting support (#13006) -- **Z-Image ControlNet CFG**: CFG support for Z-Image ControlNet (#13080) -- **Chroma Inpaint Pipeline**: New inpainting pipeline for Chroma (#12848) -- **Flux2 Klein**: New model variant (#12982) -- **Qwen Image Edit 2511**: New editing support (#12839) -- **Qwen Image Layered Support** (#12853) - -## Modular Pipelines - -- Explicit workflow support for modular pipelines (#13028) -- Modular implementations for: Wan (#13063), Flux2 (#12763), Z-Image (#12808), Qwen (#12872), Mellon (#12978, #12924, #13051) -- Improved loader support (#13025) -- Custom block tests (#12557) -- Auto-docstring generation and documentation refactors (#12958) -- Quick start guide (#13029) -- Guard `ModularPipeline.blocks` attribute (#13014) -- Better docstrings and template pipeline card (#13072, #12932) - -## Core Improvements - -- **Device-type device maps with offloading support** (#12811) -- **`disable_mmap` in pipeline `from_pretrained`** (#12854) -- **`apply_lora_scale` helper** to remove boilerplate (#12994) -- **MagCache support**: Caching mechanism for faster inference (#12744) -- **Mambo-G Guidance**: New guider implementation (#12862) -- **Laplace Scheduler for DDPM** (#11320) -- **Custom sigmas in UniPCMultistepScheduler** (#12109) -- **Control-LoRA support** (#10686) -- **Latent Perceptual Loss (LPL) for SDXL** (#11573) -- **MultiControlNet support for SD3 Inpainting** (#11251) -- Remove 8-bit device restriction (#12972) -- Graceful error for unsupported attn-backend / context-parallel combos (#12832) -- Handle progress bar and logging in distributed environments (#12806) -- Remove unneeded autoencoder methods from `AutoencoderMixin` subclasses (#12873) -- Remove k-diffusion support (#13152) -- Flag Flax schedulers as deprecated (#13031) - -## Distributed Inference - -- **Unified Sequence Parallel attention** (#12693) -- **Ulysses Anything Attention** (#12996) -- **Context parallel in native flash attention** (#12829) -- NPU Ulysses attention support (#12919) -- Fix Wan 2.1 I2V context parallel (#12909) -- Fix Qwen-Image context parallel (#12970) - -## LoRA - -- Z-Image LoRA training (#13056) -- Fix non-diffusers LoRA key handling for Flux2 (#13119) -- Fix LoRA loading for Flux2 Klein with adaptive block enumeration (#13030) -- Fix wrong LTX2 LoRA mixin (#13144) - -## Bug Fixes - -- Fix QwenImageEditPlus on NPU (#13017) -- Fix MT5Tokenizer → use `T5Tokenizer` for Transformers v5.0+ compatibility (#12877) -- Fix Wan/WanI2V patchification (#13038) -- Fix LTX-2 inference with `num_videos_per_prompt > 1` and CFG (#13121) -- Fix Flux2 img2img prediction (#12855) -- Fix QwenImage `txt_seq_lens` handling (#12702) -- Fix `prefix_token_len` bug (#12845) -- Fix ftfy imports in Wan and SkyReels-V2 (#12314, #13113) -- Fix `is_fsdp` determination (#12960) -- Fix GLM-Image `get_image_features` API (#13052) -- Fix Wan 2.2 when either transformer isn't present (#13055) -- Fix guider issue (#13147) -- Fix torchao quantizer for new versions (#12901) -- Fix GGUF for unquantized types with unquantize kernels (#12498) -- Make Qwen hidden states contiguous for torchao (#13081) -- Make Flux hidden states contiguous (#13068) -- Fix Kandinsky 5 hardcoded CUDA autocast (#12814) -- Fix `aiter` availability check (#13059) -- Fix attention mask check for unsupported backends (#12892) -- Allow `prompt` and `prior_token_ids` simultaneously in `GlmImagePipeline` (#13092) -- GLM-Image batch support (#13007) -- Cosmos 2.5 Video2World frame extraction fix (#13018) -- ResNet: only use contiguous in training mode (#12977) - -## Testing and CI - -- Refactor model tests (#12822) -- Refactor Wan model tests (#13082) -- Accept `recompile_limit` from user in tests (#13150) -- CodeQL workflow for security analysis (#12917) -- Upgrade GitHub Actions for Node 24 compatibility (#12865, #12866) -- Fix `setuptools` / `pkg_resources` CI bugs (#13129, #13132) -- CUDA 12.9 upgrade (#13045) -- FSDP option for Flux2 (#12860) - -## Documentation - -- Custom code AutoModel guide (#13099) -- Remote inference docs (#12372) -- Improved distributed inference docs (#12810, #12827, #12971) -- Improved caching docs (#12684) -- Numerous scheduler docstring improvements (#12798, #12871, #12928, #12931, #12936, #12992, #13010, #13020, #13023, #13024, #13027, #13044, #13083, #13085, #13122, #13127, #13130) -- Various typo and syntax fixes - -## Breaking Changes - -- **Python 3.8 support removed** (#12524) -- **k-diffusion removed** (#13152) -- **Flax schedulers flagged as deprecated** (#13031) -- ControlNet implementations outside the controlnet module removed (#12152) diff --git a/scripts/compare_test_coverage.py b/scripts/compare_test_coverage.py deleted file mode 100644 index 1a002fc16813..000000000000 --- a/scripts/compare_test_coverage.py +++ /dev/null @@ -1,183 +0,0 @@ -#!/usr/bin/env python3 -""" -Compare test coverage between main and model-test-refactor branches -for the Flux transformer tests. - -Usage: - python scripts/compare_test_coverage.py -""" - -import subprocess - - -TEST_FILE = "tests/models/transformers/test_models_transformer_flux.py" -BRANCHES = ["main", "model-test-refactor"] - - -def run_command(cmd, capture=True): - """Run a shell command and return output.""" - result = subprocess.run(cmd, shell=True, capture_output=capture, text=True) - return result.stdout, result.stderr, result.returncode - - -def get_current_branch(): - """Get the current git branch name.""" - stdout, _, _ = run_command("git branch --show-current") - return stdout.strip() - - -def stash_changes(): - """Stash any uncommitted changes.""" - run_command("git stash") - - -def pop_stash(): - """Pop stashed changes.""" - run_command("git stash pop") - - -def checkout_branch(branch): - """Checkout a git branch.""" - _, stderr, code = run_command(f"git checkout {branch}") - if code != 0: - print(f"Failed to checkout {branch}: {stderr}") - return False - return True - - -def collect_tests(test_file): - """Collect tests from a test file and return test info.""" - cmd = f"python -m pytest {test_file} --collect-only -q 2>/dev/null" - stdout, stderr, code = run_command(cmd) - - tests = [] - for line in stdout.strip().split("\n"): - if "::" in line and not line.startswith("="): - tests.append(line.strip()) - - return tests - - -def run_tests_verbose(test_file): - """Run tests and capture pass/skip/fail status.""" - cmd = f"python -m pytest {test_file} -v --tb=no 2>&1" - stdout, _, _ = run_command(cmd) - - results = {"passed": [], "skipped": [], "failed": [], "errors": []} - - for line in stdout.split("\n"): - if " PASSED" in line: - test_name = line.split(" PASSED")[0].strip() - results["passed"].append(test_name) - elif " SKIPPED" in line: - test_name = line.split(" SKIPPED")[0].strip() - reason = "" - if "SKIPPED" in line and "[" in line: - reason = line.split("[")[-1].rstrip("]") if "[" in line else "" - results["skipped"].append((test_name, reason)) - elif " FAILED" in line: - test_name = line.split(" FAILED")[0].strip() - results["failed"].append(test_name) - elif " ERROR" in line: - test_name = line.split(" ERROR")[0].strip() - results["errors"].append(test_name) - - return results - - -def compare_results(main_results, pr_results): - """Compare test results between branches.""" - print("\n" + "=" * 70) - print("COVERAGE COMPARISON REPORT") - print("=" * 70) - - print("\n## Test Counts") - print(f"{'Category':<20} {'main':<15} {'PR':<15} {'Diff':<10}") - print("-" * 60) - - for category in ["passed", "skipped", "failed", "errors"]: - main_count = len(main_results[category]) - pr_count = len(pr_results[category]) - diff = pr_count - main_count - diff_str = f"+{diff}" if diff > 0 else str(diff) - print(f"{category:<20} {main_count:<15} {pr_count:<15} {diff_str:<10}") - - main_tests = set(main_results["passed"] + [t[0] for t in main_results["skipped"]]) - pr_tests = set(pr_results["passed"] + [t[0] for t in pr_results["skipped"]]) - - missing_in_pr = main_tests - pr_tests - new_in_pr = pr_tests - main_tests - - if missing_in_pr: - print("\n## Tests in main but MISSING in PR:") - for test in sorted(missing_in_pr): - print(f" - {test}") - - if new_in_pr: - print("\n## NEW tests in PR (not in main):") - for test in sorted(new_in_pr): - print(f" + {test}") - - print("\n## Skipped Tests Comparison") - main_skipped = {t[0]: t[1] for t in main_results["skipped"]} - pr_skipped = {t[0]: t[1] for t in pr_results["skipped"]} - - newly_skipped = set(pr_skipped.keys()) - set(main_skipped.keys()) - no_longer_skipped = set(main_skipped.keys()) - set(pr_skipped.keys()) - - if newly_skipped: - print("\nNewly skipped in PR:") - for test in sorted(newly_skipped): - print(f" - {test}: {pr_skipped.get(test, 'unknown reason')}") - - if no_longer_skipped: - print("\nNo longer skipped in PR (now running):") - for test in sorted(no_longer_skipped): - print(f" + {test}") - - if not newly_skipped and not no_longer_skipped: - print("\nNo changes in skipped tests.") - - print("\n" + "=" * 70) - - -def main(): - original_branch = get_current_branch() - print(f"Current branch: {original_branch}") - - results = {} - - print("Stashing uncommitted changes...") - stash_changes() - - try: - for branch in BRANCHES: - print(f"\n--- Analyzing branch: {branch} ---") - - if not checkout_branch(branch): - print(f"Skipping {branch}") - continue - - print(f"Collecting and running tests from {TEST_FILE}...") - results[branch] = run_tests_verbose(TEST_FILE) - - print(f" Passed: {len(results[branch]['passed'])}") - print(f" Skipped: {len(results[branch]['skipped'])}") - print(f" Failed: {len(results[branch]['failed'])}") - - checkout_branch(original_branch) - - if "main" in results and "model-test-refactor" in results: - compare_results(results["main"], results["model-test-refactor"]) - else: - print("Could not compare - missing results from one or both branches") - - finally: - print("\nRestoring stashed changes...") - pop_stash() - - checkout_branch(original_branch) - - -if __name__ == "__main__": - main() diff --git a/test_automodel_meta.py b/test_automodel_meta.py deleted file mode 100644 index f0dbe7f4a3b9..000000000000 --- a/test_automodel_meta.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -from diffusers import AutoModel - -repo = "meituan-longcat/LongCat-Image" -subfolder = "transformer" - -config = AutoModel.load_config(repo, subfolder=subfolder) - -with torch.device("meta"): - model = AutoModel.from_config(config) -print(f"model.config:") -for k, v in dict(model.config).items(): - if not k.startswith("_"): - print(f" {k}: {v}") diff --git a/test_dataclass_config.py b/test_dataclass_config.py deleted file mode 100644 index ab7eb48eb7bd..000000000000 --- a/test_dataclass_config.py +++ /dev/null @@ -1,11 +0,0 @@ -import dataclasses -from diffusers import AutoModel, LongCatImageTransformer2DModel - -config_dict = AutoModel.load_config( - "meituan-longcat/LongCat-Image", - subfolder="transformer", -) -# import DiT based on _class_name -typed_config = LongCatImageTransformer2DModel._get_dataclass_from_config(config_dict) -for f in dataclasses.fields(typed_config): - print(f"{f.name}: {f.type}") diff --git a/test_pretrained_config.py b/test_pretrained_config.py deleted file mode 100644 index 40b871d4163d..000000000000 --- a/test_pretrained_config.py +++ /dev/null @@ -1,29 +0,0 @@ -import dataclasses -import torch -from diffusers import FluxTransformer2DModel -from diffusers.models import AutoModel - -repo = "black-forest-labs/FLUX.2-dev" -subfolder = "transformer" - -print("=== From load_config (no model instantiation) ===") -config_dict = FluxTransformer2DModel.load_config(repo, subfolder=subfolder) -tc = FluxTransformer2DModel._get_dataclass_from_config(config_dict) -print(f"Type: {type(tc).__name__}") -for k, v in dataclasses.asdict(tc).items(): - print(f" {k}: {v}") - -print() -print("=== From AutoModel.from_config on meta device ===") -with torch.device("meta"): - model = AutoModel.from_config(repo, subfolder=subfolder) -print(f"model.config:") -for k, v in dict(model.config).items(): - if not k.startswith("_"): - print(f" {k}: {v}") - -print() -print("=== Comparison ===") -dc_dict = dataclasses.asdict(tc) -config = {k: v for k, v in dict(model.config).items() if not k.startswith("_")} -print(f"Match: {dc_dict == config}") From 2fc5d56674f49eda8ccb7edc083bf8924ada9bec Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 2 Mar 2026 20:50:07 -1000 Subject: [PATCH 014/215] [modular]Update model card to include workflow (#13195) * up * up * update * remove test --------- Co-authored-by: yiyi@huggingface.co Co-authored-by: yiyi@huggingface.co --- .../modular_pipelines/modular_pipeline.py | 2 + .../modular_pipeline_utils.py | 199 +++++++++++------- src/diffusers/utils/hub_utils.py | 12 ++ .../test_modular_pipelines_common.py | 52 ++--- 4 files changed, 150 insertions(+), 115 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c1ac7f3aab4c..2daf62c5023b 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1883,6 +1883,7 @@ def save_pretrained( private = kwargs.pop("private", None) create_pr = kwargs.pop("create_pr", False) token = kwargs.pop("token", None) + update_model_card = kwargs.pop("update_model_card", False) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id for component_name, component_spec in self._component_specs.items(): @@ -1957,6 +1958,7 @@ def save_pretrained( is_pipeline=True, model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content), is_modular=True, + update_model_card=update_model_card, ) model_card = populate_model_card(model_card, tags=card_content["tags"]) model_card.save(os.path.join(save_directory, "README.md")) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index fa81d81920eb..68bb1fe2fd0c 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -50,11 +50,7 @@ {components_description} {configs_section} -## Input/Output Specification - -### Inputs {inputs_description} - -### Outputs {outputs_description} +{io_specification_section} """ @@ -811,6 +807,46 @@ def format_output_params(output_params, indent_level=4, max_line_length=115): return format_params(output_params, "Outputs", indent_level, max_line_length) +def format_params_markdown(params, header="Inputs"): + """Format a list of InputParam or OutputParam objects as a markdown bullet-point list. + + Suitable for model cards rendered on Hugging Face Hub. + + Args: + params: list of InputParam or OutputParam objects to format + header: Header text (e.g. "Inputs" or "Outputs") + + Returns: + A formatted markdown string, or empty string if params is empty. + """ + if not params: + return "" + + def get_type_str(type_hint): + if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union: + type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)] + return " | ".join(type_strs) + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + lines = [f"**{header}:**\n"] if header else [] + for param in params: + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name + param_str = f"- `{name}` (`{type_str}`" + + if hasattr(param, "required") and not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to `{param.default}`" + param_str += ")" + + desc = param.description if param.description else "No description provided" + param_str += f": {desc}" + lines.append(param_str) + + return "\n".join(lines) + + def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): """Format a list of ComponentSpec objects into a readable string representation. @@ -1067,8 +1103,7 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]: - blocks_description: Detailed architecture of blocks - components_description: List of required components - configs_section: Configuration parameters section - - inputs_description: Input parameters specification - - outputs_description: Output parameters specification + - io_specification_section: Input/Output specification (per-workflow or unified) - trigger_inputs_section: Conditional execution information - tags: List of relevant tags for the model card """ @@ -1087,15 +1122,6 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]: if block_desc: blocks_desc_parts.append(f" - {block_desc}") - # add sub-blocks if any - if hasattr(block, "sub_blocks") and block.sub_blocks: - for sub_name, sub_block in block.sub_blocks.items(): - sub_class = sub_block.__class__.__name__ - sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else "" - blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`") - if sub_desc: - blocks_desc_parts.append(f" - {sub_desc}") - blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined." components = getattr(blocks, "expected_components", []) @@ -1121,63 +1147,76 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]: if configs_description: configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}" - inputs = blocks.inputs - outputs = blocks.outputs - - # format inputs as markdown list - inputs_parts = [] - required_inputs = [inp for inp in inputs if inp.required] - optional_inputs = [inp for inp in inputs if not inp.required] - - if required_inputs: - inputs_parts.append("**Required:**\n") - for inp in required_inputs: - if hasattr(inp.type_hint, "__name__"): - type_str = inp.type_hint.__name__ - elif inp.type_hint is not None: - type_str = str(inp.type_hint).replace("typing.", "") - else: - type_str = "Any" - desc = inp.description or "No description provided" - inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}") + # Branch on whether workflows are defined + has_workflows = getattr(blocks, "_workflow_map", None) is not None - if optional_inputs: - if required_inputs: - inputs_parts.append("") - inputs_parts.append("**Optional:**\n") - for inp in optional_inputs: - if hasattr(inp.type_hint, "__name__"): - type_str = inp.type_hint.__name__ - elif inp.type_hint is not None: - type_str = str(inp.type_hint).replace("typing.", "") - else: - type_str = "Any" - desc = inp.description or "No description provided" - default_str = f", default: `{inp.default}`" if inp.default is not None else "" - inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}") - - inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined." - - # format outputs as markdown list - outputs_parts = [] - for out in outputs: - if hasattr(out.type_hint, "__name__"): - type_str = out.type_hint.__name__ - elif out.type_hint is not None: - type_str = str(out.type_hint).replace("typing.", "") - else: - type_str = "Any" - desc = out.description or "No description provided" - outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}") - - outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs." - - trigger_inputs_section = "" - if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs: - trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None]) - if trigger_inputs_list: - trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list) - trigger_inputs_section = f""" + if has_workflows: + workflow_map = blocks._workflow_map + parts = [] + + # If blocks overrides outputs (e.g. to return just "images" instead of all intermediates), + # use that as the shared output for all workflows + blocks_outputs = blocks.outputs + blocks_intermediate = getattr(blocks, "intermediate_outputs", None) + shared_outputs = ( + blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None + ) + + parts.append("## Workflow Input Specification\n") + + # Per-workflow details: show trigger inputs with full param descriptions + for wf_name, trigger_inputs in workflow_map.items(): + trigger_input_names = set(trigger_inputs.keys()) + try: + workflow_blocks = blocks.get_workflow(wf_name) + except Exception: + parts.append(f"
\n{wf_name}\n") + parts.append("*Could not resolve workflow blocks.*\n") + parts.append("
\n") + continue + + wf_inputs = workflow_blocks.inputs + # Show only trigger inputs with full parameter descriptions + trigger_params = [p for p in wf_inputs if p.name in trigger_input_names] + + parts.append(f"
\n{wf_name}\n") + + inputs_str = format_params_markdown(trigger_params, header=None) + parts.append(inputs_str if inputs_str else "No additional inputs required.") + parts.append("") + + parts.append("
\n") + + # Common Inputs & Outputs section (like non-workflow pipelines) + all_inputs = blocks.inputs + all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs + + inputs_str = format_params_markdown(all_inputs, "Inputs") + outputs_str = format_params_markdown(all_outputs, "Outputs") + inputs_description = inputs_str if inputs_str else "No specific inputs defined." + outputs_description = outputs_str if outputs_str else "Standard pipeline outputs." + + parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}") + + io_specification_section = "\n".join(parts) + # Suppress trigger_inputs_section when workflows are shown (it's redundant) + trigger_inputs_section = "" + else: + # Unified I/O section (original behavior) + inputs = blocks.inputs + outputs = blocks.outputs + inputs_str = format_params_markdown(inputs, "Inputs") + outputs_str = format_params_markdown(outputs, "Outputs") + inputs_description = inputs_str if inputs_str else "No specific inputs defined." + outputs_description = outputs_str if outputs_str else "Standard pipeline outputs." + io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}" + + trigger_inputs_section = "" + if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs: + trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None]) + if trigger_inputs_list: + trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list) + trigger_inputs_section = f""" ### Conditional Execution This pipeline contains blocks that are selected at runtime based on inputs: @@ -1190,7 +1229,18 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]: if hasattr(blocks, "model_name") and blocks.model_name: tags.append(blocks.model_name) - if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs: + if has_workflows: + # Derive tags from workflow names + workflow_names = set(blocks._workflow_map.keys()) + if any("inpainting" in wf for wf in workflow_names): + tags.append("inpainting") + if any("image2image" in wf for wf in workflow_names): + tags.append("image-to-image") + if any("controlnet" in wf for wf in workflow_names): + tags.append("controlnet") + if any("text2image" in wf for wf in workflow_names): + tags.append("text-to-image") + elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs: triggers = blocks.trigger_inputs if any(t in triggers for t in ["mask", "mask_image"]): tags.append("inpainting") @@ -1218,8 +1268,7 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]: "blocks_description": blocks_description, "components_description": components_description, "configs_section": configs_section, - "inputs_description": inputs_description, - "outputs_description": outputs_description, + "io_specification_section": io_specification_section, "trigger_inputs_section": trigger_inputs_section, "tags": tags, } diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ad1ce988870c..b5eb9ab2e17f 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -107,6 +107,7 @@ def load_or_create_model_card( widget: list[dict] | None = None, inference: bool | None = None, is_modular: bool = False, + update_model_card: bool = False, ) -> ModelCard: """ Loads or creates a model card. @@ -133,6 +134,9 @@ def load_or_create_model_card( `load_or_create_model_card` from a training script. is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline. When True, uses model_description as-is without additional template formatting. + update_model_card: (`bool`, optional): When True, regenerates the model card content even if one + already exists on the remote repo. Existing card metadata (tags, license, etc.) is preserved. Only + supported for modular pipelines (i.e., `is_modular=True`). """ if not is_jinja_available(): raise ValueError( @@ -141,9 +145,17 @@ def load_or_create_model_card( " To install it, please run `pip install Jinja2`." ) + if update_model_card and not is_modular: + raise ValueError("`update_model_card=True` is only supported for modular pipelines (`is_modular=True`).") + try: # Check if the model card is present on the remote repo model_card = ModelCard.load(repo_id_or_path, token=token) + # For modular pipelines, regenerate card content when requested (preserve existing metadata) + if update_model_card and is_modular and model_description is not None: + existing_data = model_card.data + model_card = ModelCard(model_description) + model_card.data = existing_data except (EntryNotFoundError, RepositoryNotFoundError): # Otherwise create a model card from template if from_training: diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index bd96516785d9..486b2c3f4166 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -483,8 +483,7 @@ def test_basic_model_card_content_structure(self): "blocks_description", "components_description", "configs_section", - "inputs_description", - "outputs_description", + "io_specification_section", "trigger_inputs_section", "tags", ] @@ -581,18 +580,19 @@ def test_inputs_description_required_and_optional(self): blocks = self.create_mock_blocks(inputs=inputs) content = generate_modular_model_card_content(blocks) - assert "**Required:**" in content["inputs_description"] - assert "**Optional:**" in content["inputs_description"] - assert "prompt" in content["inputs_description"] - assert "num_steps" in content["inputs_description"] - assert "default: `50`" in content["inputs_description"] + io_section = content["io_specification_section"] + assert "**Inputs:**" in io_section + assert "prompt" in io_section + assert "num_steps" in io_section + assert "*optional*" in io_section + assert "defaults to `50`" in io_section def test_inputs_description_empty(self): """Test handling of pipelines without specific inputs.""" blocks = self.create_mock_blocks(inputs=[]) content = generate_modular_model_card_content(blocks) - assert "No specific inputs defined" in content["inputs_description"] + assert "No specific inputs defined" in content["io_specification_section"] def test_outputs_description_formatting(self): """Test that outputs are correctly formatted.""" @@ -602,15 +602,16 @@ def test_outputs_description_formatting(self): blocks = self.create_mock_blocks(outputs=outputs) content = generate_modular_model_card_content(blocks) - assert "images" in content["outputs_description"] - assert "Generated images" in content["outputs_description"] + io_section = content["io_specification_section"] + assert "images" in io_section + assert "Generated images" in io_section def test_outputs_description_empty(self): """Test handling of pipelines without specific outputs.""" blocks = self.create_mock_blocks(outputs=[]) content = generate_modular_model_card_content(blocks) - assert "Standard pipeline outputs" in content["outputs_description"] + assert "Standard pipeline outputs" in content["io_specification_section"] def test_trigger_inputs_section_with_triggers(self): """Test that trigger inputs section is generated when present.""" @@ -628,35 +629,6 @@ def test_trigger_inputs_section_empty(self): assert content["trigger_inputs_section"] == "" - def test_blocks_description_with_sub_blocks(self): - """Test that blocks with sub-blocks are correctly described.""" - - class MockBlockWithSubBlocks: - def __init__(self): - self.__class__.__name__ = "ParentBlock" - self.description = "Parent block" - self.sub_blocks = { - "child1": self.create_child_block("ChildBlock1", "Child 1 description"), - "child2": self.create_child_block("ChildBlock2", "Child 2 description"), - } - - def create_child_block(self, name, desc): - class ChildBlock: - def __init__(self): - self.__class__.__name__ = name - self.description = desc - - return ChildBlock() - - blocks = self.create_mock_blocks() - blocks.sub_blocks["parent"] = MockBlockWithSubBlocks() - - content = generate_modular_model_card_content(blocks) - - assert "parent" in content["blocks_description"] - assert "child1" in content["blocks_description"] - assert "child2" in content["blocks_description"] - def test_model_description_includes_block_count(self): """Test that model description includes the number of blocks.""" blocks = self.create_mock_blocks(num_blocks=5) From 561549b06213e7f4651efcbe8988ea02ed65875d Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 3 Mar 2026 02:36:36 -1000 Subject: [PATCH 015/215] [modular] not pass trust_remote_code to external repos (#13204) * add * update warn * add a test * updaqte * update_component with custom model * add more tests * Apply suggestion from @DN6 Co-authored-by: Dhruv Nair * up --------- Co-authored-by: yiyi@huggingface.co Co-authored-by: Dhruv Nair --- .../modular_pipelines/modular_pipeline.py | 55 ++++++- .../test_modular_pipelines_common.py | 42 +++++ .../test_modular_pipelines_custom_blocks.py | 150 ++++++++++++++++++ 3 files changed, 239 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 2daf62c5023b..8d662080124c 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1707,6 +1707,8 @@ def __init__( _blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None ) + self._pretrained_model_name_or_path = pretrained_model_name_or_path + @property def default_call_parameters(self) -> dict[str, Any]: """ @@ -2254,6 +2256,11 @@ def update_components(self, **kwargs): new_component_spec = current_component_spec if hasattr(self, name) and getattr(self, name) is not None: logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)") + elif ( + current_component_spec.default_creation_method == "from_pretrained" + and getattr(component, "_diffusers_load_id", None) is None + ): + new_component_spec = ComponentSpec(name=name, type_hint=type(component)) else: new_component_spec = ComponentSpec.from_component(name, component) @@ -2325,17 +2332,49 @@ def load_components(self, names: list[str] | str | None = None, **kwargs): elif "default" in value: # check if the default is specified component_load_kwargs[key] = value["default"] + # Only pass trust_remote_code to components from the same repo as the pipeline. + # When a user passes trust_remote_code=True, they intend to trust code from the + # pipeline's repo, not from external repos referenced in modular_model_index.json. + trust_remote_code_stripped = False + if ( + "trust_remote_code" in component_load_kwargs + and self._pretrained_model_name_or_path is not None + and spec.pretrained_model_name_or_path != self._pretrained_model_name_or_path + ): + component_load_kwargs.pop("trust_remote_code") + trust_remote_code_stripped = True + + if not spec.pretrained_model_name_or_path: + logger.info(f"Skipping component `{name}`: no pretrained model path specified.") + continue + try: components_to_register[name] = spec.load(**component_load_kwargs) except Exception: - logger.warning( - f"\nFailed to create component {name}:\n" - f"- Component spec: {spec}\n" - f"- load() called with kwargs: {component_load_kwargs}\n" - "If this component is not required for your workflow you can safely ignore this message.\n\n" - "Traceback:\n" - f"{traceback.format_exc()}" - ) + tb = traceback.format_exc() + if trust_remote_code_stripped and "trust_remote_code" in tb: + warning_msg = ( + f"Failed to load component `{name}` from external repository " + f"`{spec.pretrained_model_name_or_path}`.\n\n" + f"`trust_remote_code=True` was not forwarded to `{name}` because it comes from " + f"a different repository than the pipeline (`{self._pretrained_model_name_or_path}`). " + f"For safety, `trust_remote_code` is only forwarded to components from the same " + f"repository as the pipeline.\n\n" + f"You need to load this component manually with `trust_remote_code=True` and pass it " + f"to the pipeline via `pipe.update_components()`. For example, if it is a custom model:\n\n" + f' {name} = AutoModel.from_pretrained("{spec.pretrained_model_name_or_path}", trust_remote_code=True)\n' + f" pipe.update_components({name}={name})\n" + ) + else: + warning_msg = ( + f"Failed to create component {name}:\n" + f"- Component spec: {spec}\n" + f"- load() called with kwargs: {component_load_kwargs}\n" + "If this component is not required for your workflow you can safely ignore this message.\n\n" + "Traceback:\n" + f"{tb}" + ) + logger.warning(warning_msg) # Register all components at once self.register_components(**components_to_register) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 486b2c3f4166..589698ffc73b 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -687,6 +687,18 @@ def test_load_components_selective_loading(self): assert pipe.unet is not None assert getattr(pipe, "vae", None) is None + def test_load_components_selective_loading_incremental(self): + """Loading a subset of components should not affect already-loaded components.""" + pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + + pipe.load_components(names="unet", torch_dtype=torch.float32) + pipe.load_components(names="text_encoder", torch_dtype=torch.float32) + + assert hasattr(pipe, "unet") + assert pipe.unet is not None + assert hasattr(pipe, "text_encoder") + assert pipe.text_encoder is not None + def test_load_components_skips_invalid_pretrained_path(self): pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") @@ -749,6 +761,36 @@ def test_save_pretrained_roundtrip_with_local_model(self, tmp_path): for key in original_state_dict: assert torch.equal(original_state_dict[key], loaded_state_dict[key]), f"Mismatch in {key}" + def test_save_pretrained_updates_index_for_model_with_no_load_id(self, tmp_path): + """testing the workflow of update the pipeline with a custom model and save the pipeline, + the modular_model_index.json should point to the save directory.""" + import json + + from diffusers import UNet2DConditionModel + + pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + pipe.load_components(torch_dtype=torch.float32) + + unet = UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet" + ) + assert not hasattr(unet, "_diffusers_load_id") + + pipe.update_components(unet=unet) + + save_dir = str(tmp_path / "my-pipeline") + pipe.save_pretrained(save_dir) + + with open(os.path.join(save_dir, "modular_model_index.json")) as f: + index = json.load(f) + + _library, _cls, unet_spec = index["unet"] + assert unet_spec["pretrained_model_name_or_path"] == save_dir + assert unet_spec["subfolder"] == "unet" + + _library, _cls, vae_spec = index["vae"] + assert vae_spec["pretrained_model_name_or_path"] == "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + def test_save_pretrained_overwrite_modular_index(self, tmp_path): """With overwrite_modular_index=True, all component references should point to the save directory.""" import json diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 9c5fd5be326d..766ca0c16f86 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -192,6 +192,156 @@ def test_custom_block_supported_components(self): assert len(pipe.components) == 1 assert pipe.component_names[0] == "transformer" + def test_trust_remote_code_not_propagated_to_external_repo(self): + """When a modular pipeline repo references a component from an external repo that has custom + code (auto_map in config), calling load_components(trust_remote_code=True) should NOT + propagate trust_remote_code to that external component. The external component should fail + to load.""" + + from diffusers import ModularPipeline + + CUSTOM_MODEL_CODE = ( + "import torch\n" + "from diffusers import ModelMixin, ConfigMixin\n" + "from diffusers.configuration_utils import register_to_config\n" + "\n" + "class CustomModel(ModelMixin, ConfigMixin):\n" + " @register_to_config\n" + " def __init__(self, hidden_size=8):\n" + " super().__init__()\n" + " self.linear = torch.nn.Linear(hidden_size, hidden_size)\n" + "\n" + " def forward(self, x):\n" + " return self.linear(x)\n" + ) + + with tempfile.TemporaryDirectory() as external_repo_dir, tempfile.TemporaryDirectory() as pipeline_repo_dir: + # Step 1: Create an external model repo with custom code (requires trust_remote_code) + with open(os.path.join(external_repo_dir, "modeling.py"), "w") as f: + f.write(CUSTOM_MODEL_CODE) + + config = { + "_class_name": "CustomModel", + "_diffusers_version": "0.0.0", + "auto_map": {"AutoModel": "modeling.CustomModel"}, + "hidden_size": 8, + } + with open(os.path.join(external_repo_dir, "config.json"), "w") as f: + json.dump(config, f) + + torch.save({}, os.path.join(external_repo_dir, "diffusion_pytorch_model.bin")) + + # Step 2: Create a custom block that references the external repo. + # Define both the class (for direct use) and its code string (for block.py). + class ExternalRefBlock(ModularPipelineBlocks): + @property + def expected_components(self): + return [ + ComponentSpec( + "custom_model", + AutoModel, + pretrained_model_name_or_path=external_repo_dir, + ) + ] + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("prompt", type_hint=str, required=True)] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("output", type_hint=str)] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.output = "test" + self.set_block_state(state, block_state) + return components, state + + EXTERNAL_REF_BLOCK_CODE_STR = ( + "from typing import List\n" + "from diffusers import AutoModel\n" + "from diffusers.modular_pipelines import (\n" + " ComponentSpec,\n" + " InputParam,\n" + " ModularPipelineBlocks,\n" + " OutputParam,\n" + " PipelineState,\n" + ")\n" + "\n" + "class ExternalRefBlock(ModularPipelineBlocks):\n" + " @property\n" + " def expected_components(self):\n" + " return [\n" + " ComponentSpec(\n" + ' "custom_model",\n' + " AutoModel,\n" + f' pretrained_model_name_or_path="{external_repo_dir}",\n' + " )\n" + " ]\n" + "\n" + " @property\n" + " def inputs(self) -> List[InputParam]:\n" + ' return [InputParam("prompt", type_hint=str, required=True)]\n' + "\n" + " @property\n" + " def intermediate_inputs(self) -> List[InputParam]:\n" + " return []\n" + "\n" + " @property\n" + " def intermediate_outputs(self) -> List[OutputParam]:\n" + ' return [OutputParam("output", type_hint=str)]\n' + "\n" + " def __call__(self, components, state: PipelineState) -> PipelineState:\n" + " block_state = self.get_block_state(state)\n" + ' block_state.output = "test"\n' + " self.set_block_state(state, block_state)\n" + " return components, state\n" + ) + + # Save the block config, write block.py, then load back via from_pretrained + block = ExternalRefBlock() + block.save_pretrained(pipeline_repo_dir) + + # auto_map will reference the module name derived from ExternalRefBlock.__module__, + # which is "test_modular_pipelines_custom_blocks". Write the code file with that name. + code_path = os.path.join(pipeline_repo_dir, "test_modular_pipelines_custom_blocks.py") + with open(code_path, "w") as f: + f.write(EXTERNAL_REF_BLOCK_CODE_STR) + + block = ModularPipelineBlocks.from_pretrained(pipeline_repo_dir, trust_remote_code=True) + pipe = block.init_pipeline() + pipe.save_pretrained(pipeline_repo_dir) + + # Step 3: Load the pipeline from the saved directory. + loaded_pipe = ModularPipeline.from_pretrained(pipeline_repo_dir, trust_remote_code=True) + + assert loaded_pipe._pretrained_model_name_or_path == pipeline_repo_dir + assert loaded_pipe._component_specs["custom_model"].pretrained_model_name_or_path == external_repo_dir + assert getattr(loaded_pipe, "custom_model", None) is None + + # Step 4a: load_components WITHOUT trust_remote_code. + # It should still fail + loaded_pipe.load_components() + assert getattr(loaded_pipe, "custom_model", None) is None + + # Step 4b: load_components with trust_remote_code=True. + # trust_remote_code should be stripped for the external component, so it fails. + # The warning should contain guidance about manually loading with trust_remote_code. + loaded_pipe.load_components(trust_remote_code=True) + assert getattr(loaded_pipe, "custom_model", None) is None + + # Step 4c: Manually load with AutoModel and update_components — this should work. + from diffusers import AutoModel + + custom_model = AutoModel.from_pretrained(external_repo_dir, trust_remote_code=True) + loaded_pipe.update_components(custom_model=custom_model) + assert getattr(loaded_pipe, "custom_model", None) is not None + def test_custom_block_loads_from_hub(self): repo_id = "hf-internal-testing/tiny-modular-diffusers-block" block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) From 5321c68befca006764299658464b79348b867bce Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 4 Mar 2026 12:19:08 +0530 Subject: [PATCH 016/215] [Modular] implement requirements validation for custom blocks (#12196) * feat: implement requirements validation for custom blocks. * up * unify. * up * add tests * Apply suggestions from code review Co-authored-by: Dhruv Nair * reviewer feedback. * [docs] validation for custom blocks (#13156) validation * move to tmp_path fixture. * propagate to conditional and loopsequential blocks. * up * remove collected tests --------- Co-authored-by: Dhruv Nair Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../en/modular_diffusers/custom_blocks.md | 47 ++++++- src/diffusers/commands/custom_blocks.py | 2 - .../modular_pipelines/modular_pipeline.py | 38 +++++- .../modular_pipeline_utils.py | 85 ++++++++++++ .../test_modular_pipelines_common.py | 124 +++++++++++++++++- 5 files changed, 291 insertions(+), 5 deletions(-) diff --git a/docs/source/en/modular_diffusers/custom_blocks.md b/docs/source/en/modular_diffusers/custom_blocks.md index b412e0e58abc..66e1de172b34 100644 --- a/docs/source/en/modular_diffusers/custom_blocks.md +++ b/docs/source/en/modular_diffusers/custom_blocks.md @@ -332,4 +332,49 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks. - \ No newline at end of file + + +## Dependencies + +Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible. + +Set a `_requirements` attribute in your block class, mapping package names to version specifiers. + +```py +from diffusers.modular_pipelines import PipelineBlock + +class MyCustomBlock(PipelineBlock): + _requirements = { + "transformers": ">=4.44.0", + "sentencepiece": ">=0.2.0" + } +``` + +When there are blocks with different requirements, Diffusers merges their requirements. + +```py +from diffusers.modular_pipelines import SequentialPipelineBlocks + +class BlockA(PipelineBlock): + _requirements = {"transformers": ">=4.44.0"} + # ... + +class BlockB(PipelineBlock): + _requirements = {"sentencepiece": ">=0.2.0"} + # ... + +pipe = SequentialPipelineBlocks.from_blocks_dict({ + "block_a": BlockA, + "block_b": BlockB, +}) +``` + +When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning. + +```md +# missing package +xyz-package was specified in the requirements but wasn't found in the current environment. + +# version mismatch +xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected. +``` diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py index 43d9ea88577a..953240c5a2c3 100644 --- a/src/diffusers/commands/custom_blocks.py +++ b/src/diffusers/commands/custom_blocks.py @@ -89,8 +89,6 @@ def run(self): # automap = self._create_automap(parent_class=parent_class, child_class=child_class) # with open(CONFIG, "w") as f: # json.dump(automap, f) - with open("requirements.txt", "w") as f: - f.write("") def _choose_block(self, candidates, chosen=None): for cls, base in candidates: diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 8d662080124c..a563d2aa99eb 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -47,6 +47,7 @@ InputParam, InsertableDict, OutputParam, + _validate_requirements, combine_inputs, combine_outputs, format_components, @@ -297,6 +298,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): config_name = "modular_config.json" model_name = None + _requirements: dict[str, str] | None = None _workflow_map = None @classmethod @@ -411,6 +413,9 @@ def from_pretrained( "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file." ) + if "requirements" in config and config["requirements"] is not None: + _ = _validate_requirements(config["requirements"]) + class_ref = config["auto_map"][cls.__name__] module_file, class_name = class_ref.split(".") module_file = module_file + ".py" @@ -435,8 +440,13 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs): module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "") parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0] auto_map = {f"{parent_module}": f"{module}.{cls_name}"} - self.register_to_config(auto_map=auto_map) + + # resolve requirements + requirements = _validate_requirements(getattr(self, "_requirements", None)) + if requirements: + self.register_to_config(requirements=requirements) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) config = dict(self.config) self._internal_dict = FrozenDict(config) @@ -658,6 +668,15 @@ def outputs(self) -> list[str]: combined_outputs = combine_outputs(*named_outputs) return combined_outputs + @property + # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements + def _requirements(self) -> dict[str, str]: + requirements = {} + for block_name, block in self.sub_blocks.items(): + if getattr(block, "_requirements", None): + requirements[block_name] = block._requirements + return requirements + # used for `__repr__` def _get_trigger_inputs(self) -> set: """ @@ -1247,6 +1266,14 @@ def doc(self): expected_configs=self.expected_configs, ) + @property + def _requirements(self) -> dict[str, str]: + requirements = {} + for block_name, block in self.sub_blocks.items(): + if getattr(block, "_requirements", None): + requirements[block_name] = block._requirements + return requirements + class LoopSequentialPipelineBlocks(ModularPipelineBlocks): """ @@ -1385,6 +1412,15 @@ def intermediate_outputs(self) -> list[str]: def outputs(self) -> list[str]: return next(reversed(self.sub_blocks.values())).intermediate_outputs + @property + # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements + def _requirements(self) -> dict[str, str]: + requirements = {} + for block_name, block in self.sub_blocks.items(): + if getattr(block, "_requirements", None): + requirements[block_name] = block._requirements + return requirements + def __init__(self): sub_blocks = InsertableDict() for block_name, block in zip(self.block_names, self.block_classes): diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 68bb1fe2fd0c..fa82f17a9108 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -22,10 +22,12 @@ import PIL.Image import torch +from packaging.specifiers import InvalidSpecifier, SpecifierSet from ..configuration_utils import ConfigMixin, FrozenDict from ..loaders.single_file_utils import _is_single_file_path_or_url from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging +from ..utils.import_utils import _is_package_available if is_torch_available(): @@ -1020,6 +1022,89 @@ def make_doc_string( return output +def _validate_requirements(reqs): + if reqs is None: + normalized_reqs = {} + else: + if not isinstance(reqs, dict): + raise ValueError( + "Requirements must be provided as a dictionary mapping package names to version specifiers." + ) + normalized_reqs = _normalize_requirements(reqs) + + if not normalized_reqs: + return {} + + final: dict[str, str] = {} + for req, specified_ver in normalized_reqs.items(): + req_available, req_actual_ver = _is_package_available(req) + if not req_available: + logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.") + + if specified_ver: + try: + specifier = SpecifierSet(specified_ver) + except InvalidSpecifier as err: + raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err + + if req_actual_ver == "N/A": + logger.warning( + f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected." + ) + elif not specifier.contains(req_actual_ver, prereleases=True): + logger.warning( + f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected." + ) + + final[req] = specified_ver + + return final + + +def _normalize_requirements(reqs): + if not reqs: + return {} + + normalized: "OrderedDict[str, str]" = OrderedDict() + + def _accumulate(mapping: dict[str, Any]): + for pkg, spec in mapping.items(): + if isinstance(spec, dict): + # This is recursive because blocks are composable. This way, we can merge requirements + # from multiple blocks. + _accumulate(spec) + continue + + pkg_name = str(pkg).strip() + if not pkg_name: + raise ValueError("Requirement package name cannot be empty.") + + spec_str = "" if spec is None else str(spec).strip() + if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")): + spec_str = f"=={spec_str}" + + existing_spec = normalized.get(pkg_name) + if existing_spec is not None: + if not existing_spec and spec_str: + normalized[pkg_name] = spec_str + elif existing_spec and spec_str and existing_spec != spec_str: + try: + combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str]))) + except InvalidSpecifier: + logger.warning( + f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'." + ) + else: + normalized[pkg_name] = str(combined_spec) + continue + + normalized[pkg_name] = spec_str + + _accumulate(reqs) + + return normalized + + def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]: """ Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 589698ffc73b..c1a402a2fd8f 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -10,6 +10,11 @@ import diffusers from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks from diffusers.guiders import ClassifierFreeGuidance +from diffusers.modular_pipelines import ( + ConditionalPipelineBlocks, + LoopSequentialPipelineBlocks, + SequentialPipelineBlocks, +) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, ConfigSpec, @@ -19,7 +24,13 @@ ) from diffusers.utils import logging -from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device +from ..testing_utils import ( + CaptureLogger, + backend_empty_cache, + numpy_cosine_similarity_distance, + require_accelerator, + torch_device, +) class ModularPipelineTesterMixin: @@ -429,6 +440,117 @@ def test_guider_cfg(self, expected_max_diff=1e-2): assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference" +class TestCustomBlockRequirements: + def get_dummy_block_pipe(self): + class DummyBlockOne: + # keep two arbitrary deps so that we can test warnings. + _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} + + class DummyBlockTwo: + # keep two dependencies that will be available during testing. + _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} + + pipe = SequentialPipelineBlocks.from_blocks_dict( + {"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo} + ) + return pipe + + def get_dummy_conditional_block_pipe(self): + class DummyBlockOne: + _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} + + class DummyBlockTwo: + _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} + + class DummyConditionalBlocks(ConditionalPipelineBlocks): + block_classes = [DummyBlockOne, DummyBlockTwo] + block_names = ["block_one", "block_two"] + block_trigger_inputs = [] + + def select_block(self, **kwargs): + return "block_one" + + return DummyConditionalBlocks() + + def get_dummy_loop_block_pipe(self): + class DummyBlockOne: + _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} + + class DummyBlockTwo: + _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} + + return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo}) + + def test_sequential_block_requirements_save_load(self, tmp_path): + pipe = self.get_dummy_block_pipe() + pipe.save_pretrained(tmp_path) + + config_path = tmp_path / "modular_config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + assert "requirements" in config + requirements = config["requirements"] + + expected_requirements = { + "xyz": ">=0.8.0", + "abc": ">=10.0.0", + "transformers": ">=4.44.0", + "diffusers": ">=0.2.0", + } + assert expected_requirements == requirements + + def test_sequential_block_requirements_warnings(self, tmp_path): + pipe = self.get_dummy_block_pipe() + + logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils") + logger.setLevel(30) + + with CaptureLogger(logger) as cap_logger: + pipe.save_pretrained(tmp_path) + + template = "{req} was specified in the requirements but wasn't found in the current environment" + msg_xyz = template.format(req="xyz") + msg_abc = template.format(req="abc") + assert msg_xyz in str(cap_logger.out) + assert msg_abc in str(cap_logger.out) + + def test_conditional_block_requirements_save_load(self, tmp_path): + pipe = self.get_dummy_conditional_block_pipe() + pipe.save_pretrained(tmp_path) + + config_path = tmp_path / "modular_config.json" + with open(config_path, "r") as f: + config = json.load(f) + + assert "requirements" in config + expected_requirements = { + "xyz": ">=0.8.0", + "abc": ">=10.0.0", + "transformers": ">=4.44.0", + "diffusers": ">=0.2.0", + } + assert expected_requirements == config["requirements"] + + def test_loop_block_requirements_save_load(self, tmp_path): + pipe = self.get_dummy_loop_block_pipe() + pipe.save_pretrained(tmp_path) + + config_path = tmp_path / "modular_config.json" + with open(config_path, "r") as f: + config = json.load(f) + + assert "requirements" in config + expected_requirements = { + "xyz": ">=0.8.0", + "abc": ">=10.0.0", + "transformers": ">=4.44.0", + "diffusers": ">=0.2.0", + } + assert expected_requirements == config["requirements"] + + class TestModularModelCardContent: def create_mock_block(self, name="TestBlock", description="Test block description"): class MockBlock: From 7bffed85664998d99b32b125d2f4aaecae2fcc47 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 4 Mar 2026 17:39:01 +0800 Subject: [PATCH 017/215] cogvideo example: Distribute VAE video encoding across processes in CogVideoX LoRA training (#13207) * Distribute VAE video encoding across processes in CogVideoX LoRA training Signed-off-by: jiqing-feng * Apply style fixes --------- Signed-off-by: jiqing-feng Co-authored-by: github-actions[bot] --- examples/cogvideo/train_cogvideox_lora.py | 41 +++++++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 4e22d3f8727d..e08143f98a5c 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -1232,22 +1232,49 @@ def load_model_hook(models, input_dir): id_token=args.id_token, ) - def encode_video(video, bar): - bar.update(1) + def encode_video(video): video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] latent_dist = vae.encode(video).latent_dist return latent_dist + # Distribute video encoding across processes: each process only encodes its own shard + num_videos = len(train_dataset.instance_videos) + num_procs = accelerator.num_processes + local_rank = accelerator.process_index + local_count = len(range(local_rank, num_videos, num_procs)) + progress_encode_bar = tqdm( - range(0, len(train_dataset.instance_videos)), - desc="Loading Encode videos", + range(local_count), + desc="Encoding videos", + disable=not accelerator.is_local_main_process, ) - train_dataset.instance_videos = [ - encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos - ] + + encoded_videos = [None] * num_videos + for i, video in enumerate(train_dataset.instance_videos): + if i % num_procs == local_rank: + encoded_videos[i] = encode_video(video) + progress_encode_bar.update(1) progress_encode_bar.close() + # Broadcast encoded latent distributions so every process has the full set + if num_procs > 1: + import torch.distributed as dist + + from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution + + ref_params = next(v for v in encoded_videos if v is not None).parameters + for i in range(num_videos): + src = i % num_procs + if encoded_videos[i] is not None: + params = encoded_videos[i].parameters.contiguous() + else: + params = torch.empty_like(ref_params) + dist.broadcast(params, src=src) + encoded_videos[i] = DiagonalGaussianDistribution(params) + + train_dataset.instance_videos = encoded_videos + def collate_fn(examples): videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] prompts = [example["instance_prompt"] for example in examples] From 1cdfccbec74f72feab4148e483b397e8fb36c70b Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:04:13 +0800 Subject: [PATCH 018/215] Fix group-offloading bug (#13211) * Implement synchronous onload for offloaded parameters Add fallback synchronous onload for conditionally-executed modules. * add test for new code path about group-offloading * Update tests/hooks/test_group_offloading.py Co-authored-by: Sayak Paul * use unittest.skipIf and update the comment --------- Co-authored-by: Sayak Paul --- src/diffusers/hooks/group_offloading.py | 11 +++ tests/hooks/test_group_offloading.py | 124 ++++++++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 15c9fa44f5ff..891ac28455af 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -307,6 +307,17 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): if self.group.onload_leader == module: if self.group.onload_self: self.group.onload_() + else: + # onload_self=False means this group relies on prefetching from a previous group. + # However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios), + # the prefetch chain may not cover them if they were absent during the first forward pass + # when the execution order was traced. In that case, their weights remain on offload_device, + # so we fall back to a synchronous onload here. + params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters) + if params and params[0].device == self.group.offload_device: + self.group.onload_() + if self.group.stream is not None: + self.group.stream.synchronize() should_onload_next_group = self.next_group is not None and not self.next_group.onload_self if should_onload_next_group: diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 236094109d07..108a7247bcc6 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -566,3 +566,127 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non "layers_per_block": 1, } return init_dict + + +# Model with conditionally-executed modules, simulating Helios patch_short/patch_mid/patch_long behavior. +# These modules are only called when optional inputs are provided, which means the lazy prefetch +# execution order tracer may not see them on the first forward pass. This can cause a device mismatch +# on subsequent calls when the modules ARE invoked but their weights were never onloaded. +# See: https://github.com/huggingface/diffusers/pull/13211 +class DummyModelWithConditionalModules(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + # These modules are only invoked when optional_input is not None. + # Output dimension matches hidden_features so they can be added after linear_1. + self.optional_proj_1 = torch.nn.Linear(in_features, hidden_features) + self.optional_proj_2 = torch.nn.Linear(in_features, hidden_features) + + def forward(self, x: torch.Tensor, optional_input: torch.Tensor | None = None) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + if optional_input is not None: + # Add optional projections after linear_1 so dimensions match (both hidden_features) + x = x + self.optional_proj_1(optional_input) + x = x + self.optional_proj_2(optional_input) + for block in self.blocks: + x = block(x) + x = self.linear_2(x) + return x + + +class ConditionalModuleGroupOffloadTests(GroupOffloadTests): + """Tests for conditionally-executed modules under group offloading with streams. + + Regression tests for the case where a module is not executed during the first forward pass + (when the lazy prefetch execution order is traced), but IS executed on subsequent passes. + Without the fix, the weights of such modules remain on CPU while the input is on GPU, + causing a RuntimeError about tensor device mismatch. + """ + + def get_model(self): + torch.manual_seed(0) + return DummyModelWithConditionalModules( + in_features=self.in_features, + hidden_features=self.hidden_features, + out_features=self.out_features, + num_layers=self.num_layers, + ) + + @parameterized.expand([("leaf_level",), ("block_level",)]) + @unittest.skipIf( + torch.device(torch_device).type not in ["cuda", "xpu"], + "Test requires a CUDA or XPU device.", + ) + def test_conditional_modules_with_stream(self, offload_type: str): + """Regression test: conditionally-executed modules must not cause device mismatch when using streams. + + The model contains two optional Linear layers (optional_proj_1, optional_proj_2) that are only + executed when `optional_input` is provided. This simulates modules like patch_short/patch_mid/ + patch_long in HeliosTransformer3DModel, which are only called when history latents are present. + + When using streams, `LazyPrefetchGroupOffloadingHook` traces the execution order on the first + forward pass and sets up a prefetch chain so each module pre-loads the next one's weights. + Modules not executed during this tracing pass are excluded from the prefetch chain. + + The bug: if a module was absent from the first (tracing) pass, its `onload_self` flag gets set + to False (meaning "someone else will onload me"). But since it's not in the prefetch chain, + nobody ever does — so its weights remain on CPU. When the module is eventually called in a + subsequent pass, the input is on GPU but the weights are on CPU, causing a RuntimeError. + + We therefore must invoke the model multiple times: + 1. First pass WITHOUT optional_input: triggers the lazy prefetch tracing. optional_proj_1/2 + are absent, so they are excluded from the prefetch chain. + 2. Second pass WITH optional_input: the regression case. Without the fix, this raises a + RuntimeError because optional_proj_1/2 weights are still on CPU. + 3. Third pass WITHOUT optional_input: verifies the model remains stable after having seen + both code paths. + """ + + model = self.get_model() + model_ref = self.get_model() + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload( + torch_device, + offload_type=offload_type, + num_blocks_per_group=1, + use_stream=True, + ) + + x = torch.randn(4, self.in_features).to(torch_device) + optional_input = torch.randn(4, self.in_features).to(torch_device) + + with torch.no_grad(): + # First forward pass WITHOUT optional_input — this is when the lazy prefetch + # execution order is traced. optional_proj_1/2 are NOT in the traced order. + out_ref_no_opt = model_ref(x, optional_input=None) + out_no_opt = model(x, optional_input=None) + self.assertTrue( + torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), + f"[{offload_type}] Outputs do not match on first pass (no optional_input).", + ) + + # Second forward pass WITH optional_input — optional_proj_1/2 ARE now called. + out_ref_with_opt = model_ref(x, optional_input=optional_input) + out_with_opt = model(x, optional_input=optional_input) + self.assertTrue( + torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), + f"[{offload_type}] Outputs do not match on second pass (with optional_input).", + ) + + # Third pass again without optional_input — verify stable behavior. + out_ref_no_opt2 = model_ref(x, optional_input=None) + out_no_opt2 = model(x, optional_input=None) + self.assertTrue( + torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), + f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).", + ) From d0c9e60d844e3a84137f68e1f0894e0d342d276b Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Wed, 4 Mar 2026 08:01:43 -0800 Subject: [PATCH 019/215] Add Helios-14B Video Generation Pipelines (#13208) * [1/N] add helios * fix test * make fix-copies * change script path * fix cus script * update docs * fix documented check * update links for docs and examples * change default config * small refactor * add test * Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu * remove register_buffer for _scale_cache * fix non-cuda devices error * remove "handle the case when timestep is 2D" * refactor HeliosMultiTermMemoryPatch and process_input_hidden_states * Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix calculate_shift * Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * rewritten `einops` in pure `torch` * fix: pass patch_size to apply_schedule_shift instead of hardcoding * remove the logics of 'vae_decode_type' * move some validation into check_inputs() * rename helios scheduler & merge all into one step() * add some details to doc * move dmd step() logics from pipeline to scheduler * change to Python 3.9+ style type * fix NoneType error * refactor DMD scheduler's set_timestep * change rope related vars name * fix stage2 sample * fix dmd sample * Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu * remove redundant & refactor norm_out * Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: YiYi Xu * change "is_keep_x0" to "keep_first_frame" * use a more intuitive name * refactor dynamic_time_shifting * remove use_dynamic_shifting args * remove usage of UniPCMultistepScheduler * separate stage2 sample to HeliosPyramidPipeline * Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu * Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu * fix transformer * use a more intuitive name * update example script * fix requirements * remove redudant attention mask * fix * optimize pipelines * make style . * update TYPE_CHECKING * change to use torch.split Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * derive memory patch sizes from patch_size multiples * remove some hardcoding * move some checks into check_inputs * refactor sample_block_noise * optimize encoding chunks logits for v2v * use num_history_latent_frames = sum(history_sizes) * Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * remove redudant optimized_scale * Update src/diffusers/pipelines/helios/pipeline_helios_pyramid.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * use more descriptive name * optimize history_latents * remove not used "num_inference_steps" * removed redudant "pyramid_num_stages" * add "is_cfg_zero_star" and "is_distilled" to HeliosPyramidPipeline * remove redudant * change example scripts name * change example scripts name * correct docs * update example * update docs * Update tests/models/transformers/test_models_transformer_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update tests/models/transformers/test_models_transformer_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * separate HeliosDMDScheduler * fix numerical stability issue: * Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * remove redudant * small refactor * remove use_interpolate_prompt logits * simplified model test * fallbackt to BaseModelTesterConfig * remove _maybe_expand_t2v_lora_for_i2v * fix HeliosLoraLoaderMixin * update docs * use randn_tensor for test * fix doc typo * optimize code * mark torch.compile xfail * change paper name * Make get_dummy_inputs deterministic using self.generator * Set less strict threshold for test_save_load_float16 test for Helios pipeline * make style and make quality * Preparation for merging * add torch.Generator * Fix HeliosPipelineOutput doc path * Fix Helios related (optimize docs & remove redudant) (#13210) * fix docs * remove redudant * remove redudant * fix group offload * Removed fixes for group offload --------- Co-authored-by: yuanshenghai Co-authored-by: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Co-authored-by: YiYi Xu Co-authored-by: SHYuanBest Co-authored-by: Sayak Paul --- docs/source/en/_toctree.yml | 11 +- docs/source/en/api/loaders/lora.md | 5 + .../en/api/models/helios_transformer3d.md | 35 + docs/source/en/api/pipelines/helios.md | 465 +++++++ docs/source/en/api/schedulers/helios.md | 20 + docs/source/en/api/schedulers/helios_dmd.md | 20 + docs/source/en/using-diffusers/consisid.md | 2 +- docs/source/en/using-diffusers/helios.md | 133 ++ docs/source/zh/_toctree.yml | 2 + docs/source/zh/community_projects.md | 8 + docs/source/zh/using-diffusers/helios.md | 134 +++ src/diffusers/__init__.py | 10 + src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 201 ++++ src/diffusers/loaders/peft.py | 1 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_helios.py | 814 +++++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/auto_pipeline.py | 3 + src/diffusers/pipelines/helios/__init__.py | 48 + .../pipelines/helios/pipeline_helios.py | 916 ++++++++++++++ .../helios/pipeline_helios_pyramid.py | 1065 +++++++++++++++++ .../pipelines/helios/pipeline_output.py | 20 + src/diffusers/schedulers/__init__.py | 4 + src/diffusers/schedulers/scheduling_helios.py | 867 ++++++++++++++ .../schedulers/scheduling_helios_dmd.py | 331 +++++ src/diffusers/utils/dummy_pt_objects.py | 45 + .../dummy_torch_and_transformers_objects.py | 30 + tests/lora/test_lora_layers_helios.py | 120 ++ .../test_models_transformer_helios.py | 168 +++ tests/pipelines/helios/__init__.py | 0 tests/pipelines/helios/test_helios.py | 172 +++ 33 files changed, 5655 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/api/models/helios_transformer3d.md create mode 100644 docs/source/en/api/pipelines/helios.md create mode 100644 docs/source/en/api/schedulers/helios.md create mode 100644 docs/source/en/api/schedulers/helios_dmd.md create mode 100644 docs/source/en/using-diffusers/helios.md create mode 100644 docs/source/zh/using-diffusers/helios.md create mode 100644 src/diffusers/models/transformers/transformer_helios.py create mode 100644 src/diffusers/pipelines/helios/__init__.py create mode 100644 src/diffusers/pipelines/helios/pipeline_helios.py create mode 100644 src/diffusers/pipelines/helios/pipeline_helios_pyramid.py create mode 100644 src/diffusers/pipelines/helios/pipeline_output.py create mode 100644 src/diffusers/schedulers/scheduling_helios.py create mode 100644 src/diffusers/schedulers/scheduling_helios_dmd.py create mode 100644 tests/lora/test_lora_layers_helios.py create mode 100644 tests/models/transformers/test_models_transformer_helios.py create mode 100644 tests/pipelines/helios/__init__.py create mode 100644 tests/pipelines/helios/test_helios.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 098660ec3f39..ea06f35a0343 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -194,6 +194,8 @@ title: Model accelerators and hardware - isExpanded: false sections: + - local: using-diffusers/helios + title: Helios - local: using-diffusers/consisid title: ConsisID - local: using-diffusers/sdxl @@ -350,6 +352,8 @@ title: FluxTransformer2DModel - local: api/models/glm_image_transformer2d title: GlmImageTransformer2DModel + - local: api/models/helios_transformer3d + title: HeliosTransformer3DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d @@ -625,7 +629,6 @@ title: Image-to-image - local: api/pipelines/stable_diffusion/inpaint title: Inpainting - - local: api/pipelines/stable_diffusion/latent_upscale title: Latent upscaler - local: api/pipelines/stable_diffusion/ldm3d_diffusion @@ -674,6 +677,8 @@ title: ConsisID - local: api/pipelines/framepack title: Framepack + - local: api/pipelines/helios + title: Helios - local: api/pipelines/hunyuan_video title: HunyuanVideo - local: api/pipelines/hunyuan_video15 @@ -745,6 +750,10 @@ title: FlowMatchEulerDiscreteScheduler - local: api/schedulers/flow_match_heun_discrete title: FlowMatchHeunDiscreteScheduler + - local: api/schedulers/helios_dmd + title: HeliosDMDScheduler + - local: api/schedulers/helios + title: HeliosScheduler - local: api/schedulers/heun title: HeunDiscreteScheduler - local: api/schedulers/ipndm diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index bbae6a9020af..db1ea884558f 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -23,6 +23,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow). - [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video). - [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana). +- [`HeliosLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/helios). - [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video). - [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2). - [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan). @@ -86,6 +87,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin +## HeliosLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.HeliosLoraLoaderMixin + ## HunyuanVideoLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin diff --git a/docs/source/en/api/models/helios_transformer3d.md b/docs/source/en/api/models/helios_transformer3d.md new file mode 100644 index 000000000000..5aa2826c32ec --- /dev/null +++ b/docs/source/en/api/models/helios_transformer3d.md @@ -0,0 +1,35 @@ + + +# HeliosTransformer3DModel + +A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc. + +The model can be loaded with the following code snippet. + +```python +from diffusers import HeliosTransformer3DModel + +# Best Quality +transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="transformer", torch_dtype=torch.bfloat16) +# Intermediate Weight +transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="transformer", torch_dtype=torch.bfloat16) +# Best Efficiency +transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## HeliosTransformer3DModel + +[[autodoc]] HeliosTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md new file mode 100644 index 000000000000..81559b24c071 --- /dev/null +++ b/docs/source/en/api/pipelines/helios.md @@ -0,0 +1,465 @@ + + +
+
+ + LoRA + +
+
+ +# Helios + +[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan. + +* We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality. We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page). + +The following Helios models are supported in Diffusers: + +- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler. +- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler. +- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosDMDScheduler. + +> [!TIP] +> Click on the Helios models in the right sidebar for more examples of video generation. + +### Optimizing Memory and Inference Speed + +The example below demonstrates how to generate a video from text optimized for memory or inference speed. + + + + +Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. + +The Helios model below requires ~19GB of VRAM. + +```py +import torch +from diffusers import AutoModel, HeliosPipeline +from diffusers.hooks.group_offloading import apply_group_offloading +from diffusers.utils import export_to_video + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32) + +# group-offloading +pipeline = HeliosPipeline.from_pretrained( + "BestWishYsh/Helios-Base", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.enable_group_offload( + onload_device=torch.device("cuda"), + offload_device=torch.device("cpu"), + offload_type="block_level", + num_blocks_per_group=1, + use_stream=True, + record_stream=True, +) + +prompt = """ +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. +""" +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + num_inference_steps=50, + guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_base_t2v_output.mp4", fps=24) +``` + + + + +[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. + +```py +import torch +from diffusers import AutoModel, HeliosPipeline +from diffusers.utils import export_to_video + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32) + +pipeline = HeliosPipeline.from_pretrained( + "BestWishYsh/Helios-Base", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +# attention backend +# pipeline.transformer.set_attention_backend("flash") +pipeline.transformer.set_attention_backend("_flash_3_hub") # For Hopper GPUs + +# torch.compile +torch.backends.cudnn.benchmark = True +pipeline.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +pipeline.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False) + +prompt = """ +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. +""" +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + num_inference_steps=50, + guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_base_t2v_output.mp4", fps=24) +``` + + + + + +### Generation with Helios-Base + +The example below demonstrates how to use Helios-Base to generate video based on text, image or video. + + + + +```python +import torch +from diffusers import AutoModel, HeliosPipeline +from diffusers.utils import export_to_video, load_video, load_image + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32) + +pipeline = HeliosPipeline.from_pretrained( + "BestWishYsh/Helios-Base", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +# For Text-to-Video +prompt = """ +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + num_inference_steps=50, + guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_base_t2v_output.mp4", fps=24) + +# For Image-to-Video +prompt = """ +A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, +illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, +casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes +apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and +relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and +respect for nature’s might. +""" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + image=load_image(image_path).resize((640, 384)), + num_frames=99, + num_inference_steps=50, + guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_base_i2v_output.mp4", fps=24) + +# For Video-to-Video +prompt = """ +A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees +under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, +emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to +the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. +A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. +""" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + video=load_video(video_path), + num_frames=99, + num_inference_steps=50, + guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_base_v2v_output.mp4", fps=24) +``` + + + + + +### Generation with Helios-Mid + +The example below demonstrates how to use Helios-Mid to generate video based on text, image or video. + + + + +```python +import torch +from diffusers import AutoModel, HeliosPyramidPipeline +from diffusers.utils import export_to_video, load_video, load_image + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="vae", torch_dtype=torch.float32) + +pipeline = HeliosPyramidPipeline.from_pretrained( + "BestWishYsh/Helios-Mid", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +# For Text-to-Video +prompt = """ +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + pyramid_num_inference_steps_list=[20, 20, 20], + guidance_scale=5.0, + use_zero_init=True, + zero_steps=1, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_pyramid_t2v_output.mp4", fps=24) + +# For Image-to-Video +prompt = """ +A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, +illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, +casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes +apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and +relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and +respect for nature’s might. +""" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + image=load_image(image_path).resize((640, 384)), + num_frames=99, + pyramid_num_inference_steps_list=[20, 20, 20], + guidance_scale=5.0, + use_zero_init=True, + zero_steps=1, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_pyramid_i2v_output.mp4", fps=24) + +# For Video-to-Video +prompt = """ +A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees +under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, +emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to +the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. +A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. +""" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + video=load_video(video_path), + num_frames=99, + pyramid_num_inference_steps_list=[20, 20, 20], + guidance_scale=5.0, + use_zero_init=True, + zero_steps=1, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_pyramid_v2v_output.mp4", fps=24) +``` + + + + + +### Generation with Helios-Distilled + +The example below demonstrates how to use Helios-Distilled to generate video based on text, image or video. + + + + +```python +import torch +from diffusers import AutoModel, HeliosPyramidPipeline +from diffusers.utils import export_to_video, load_video, load_image + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="vae", torch_dtype=torch.float32) + +pipeline = HeliosPyramidPipeline.from_pretrained( + "BestWishYsh/Helios-Distilled", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +# For Text-to-Video +prompt = """ +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=240, + pyramid_num_inference_steps_list=[2, 2, 2], + guidance_scale=1.0, + is_amplify_first_chunk=True, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_distilled_t2v_output.mp4", fps=24) + +# For Image-to-Video +prompt = """ +A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, +illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, +casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes +apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and +relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and +respect for nature’s might. +""" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + image=load_image(image_path).resize((640, 384)), + num_frames=240, + pyramid_num_inference_steps_list=[2, 2, 2], + guidance_scale=1.0, + is_amplify_first_chunk=True, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_distilled_i2v_output.mp4", fps=24) + +# For Video-to-Video +prompt = """ +A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees +under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, +emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to +the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. +A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. +""" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + video=load_video(video_path), + num_frames=240, + pyramid_num_inference_steps_list=[2, 2, 2], + guidance_scale=1.0, + is_amplify_first_chunk=True, + generator=torch.Generator("cuda").manual_seed(42), +).frames[0] +export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24) +``` + + + + + +## HeliosPipeline + +[[autodoc]] HeliosPipeline + + - all + - __call__ + +## HeliosPyramidPipeline + +[[autodoc]] HeliosPyramidPipeline + + - all + - __call__ + +## HeliosPipelineOutput + +[[autodoc]] pipelines.helios.pipeline_output.HeliosPipelineOutput diff --git a/docs/source/en/api/schedulers/helios.md b/docs/source/en/api/schedulers/helios.md new file mode 100644 index 000000000000..14c2be60bc89 --- /dev/null +++ b/docs/source/en/api/schedulers/helios.md @@ -0,0 +1,20 @@ + + +# HeliosScheduler + +`HeliosScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers). + +## HeliosScheduler +[[autodoc]] HeliosScheduler + +scheduling_helios diff --git a/docs/source/en/api/schedulers/helios_dmd.md b/docs/source/en/api/schedulers/helios_dmd.md new file mode 100644 index 000000000000..4f075e8a7dfc --- /dev/null +++ b/docs/source/en/api/schedulers/helios_dmd.md @@ -0,0 +1,20 @@ + + +# HeliosDMDScheduler + +`HeliosDMDScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers). + +## HeliosDMDScheduler +[[autodoc]] HeliosDMDScheduler + +scheduling_helios_dmd diff --git a/docs/source/en/using-diffusers/consisid.md b/docs/source/en/using-diffusers/consisid.md index b6b04ddaf57e..96ece5b20c3a 100644 --- a/docs/source/en/using-diffusers/consisid.md +++ b/docs/source/en/using-diffusers/consisid.md @@ -60,7 +60,7 @@ export_to_video(video.frames[0], "output.mp4", fps=8) Face Image Video - DescriptionDescription diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md new file mode 100644 index 000000000000..8106f1c568f8 --- /dev/null +++ b/docs/source/en/using-diffusers/helios.md @@ -0,0 +1,133 @@ + +# Helios + +[Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are: + +- Without commonly used anti-drifting strategies (eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), Helios generates minute-scale videos with high quality and strong coherence. +- Without standard acceleration techniques (eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), Helios achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU. +- Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models. + +This guide will walk you through using Helios for use cases. + +## Load Model Checkpoints + +Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. + +```python +import torch +from diffusers import HeliosPipeline, HeliosPyramidPipeline +from huggingface_hub import snapshot_download + +# For Best Quality +snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base") +pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Intermediate Weight +snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid") +pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# For Best Efficiency +snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled") +pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16) +pipe.to("cuda") +``` + +## Text-to-Video Showcases + + + + + + + + + + + + + + +
PromptGenerated Video
A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression. + + +
A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle. + + +
+ +## Image-to-Video Showcases + + + + + + + + + + + + + + + + + +
ImagePromptGenerated Video
A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees. + + +
A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective. + + +
+ +## Interactive-Video Showcases + + + + + + + + + + + + + + +
PromptGenerated Video
The prompt can be found here + +
The prompt can be found here + +
+ +## Resources + +Learn more about Helios with the following resources. +- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features. +- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details. diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 337d010fc74d..ab9eaf6ec7fb 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -132,6 +132,8 @@ sections: - local: using-diffusers/consisid title: ConsisID + - local: using-diffusers/helios + title: Helios - title: Resources isExpanded: false diff --git a/docs/source/zh/community_projects.md b/docs/source/zh/community_projects.md index 0440142452f1..ffa45f1e9bb0 100644 --- a/docs/source/zh/community_projects.md +++ b/docs/source/zh/community_projects.md @@ -26,6 +26,14 @@ http://www.apache.org/licenses/LICENSE-2.0 项目名称 描述 + + helios + Helios:比1.3B更低开销、更快且更强的14B的实时长视频生成模型 + + + consisid + ConsisID:零样本身份保持的文本到视频生成模型 + dream-textures Stable Diffusion内置到Blender diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md new file mode 100644 index 000000000000..5c4faed2ca2a --- /dev/null +++ b/docs/source/zh/using-diffusers/helios.md @@ -0,0 +1,134 @@ + +# Helios + +[Helios](https://github.com/PKU-YuanGroup/Helios) 是首个能够在单张 NVIDIA H100 GPU 上以 19.5 FPS 运行的 14B 视频生成模型。它在支持分钟级视频生成的同时,拥有媲美强大基线模型的生成质量,并在统一架构下原生集成了文生视频(T2V)、图生视频(I2V)和视频生视频(V2V)任务。Helios 的主要特性包括: + +- 无需常用的防漂移策略(例如:自强制/self-forcing、误差库/error-banks、关键帧采样或逆采样),我们的模型即可生成高质量且高度连贯的分钟级视频。 +- 无需标准的加速技术(例如:KV 缓存、因果掩码、稀疏/线性注意力机制、TinyVAE、渐进式噪声调度、隐藏状态缓存或量化),作为一款 14B 规模的视频生成模型,我们在单张 H100 GPU 上的端到端推理速度便达到了 19.5 FPS。 +- 引入了多项优化方案,在降低显存消耗的同时,显著提升了训练与推理的吞吐量。这些改进使得我们无需借助并行或分片(sharding)等基础设施,即可使用与图像模型相当的批大小(batch sizes)来训练 14B 的视频生成模型。 + +本指南将引导您完成 Helios 在不同场景下的使用。 + +## Load Model Checkpoints + +模型权重可以存储在Hub上或本地的单独子文件夹中,在这种情况下,您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。 + +```python +import torch +from diffusers import HeliosPipeline, HeliosPyramidPipeline +from huggingface_hub import snapshot_download + +# For Best Quality +snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base") +pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Intermediate Weight +snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid") +pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# For Best Efficiency +snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled") +pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16) +pipe.to("cuda") +``` + +## Text-to-Video Showcases + + + + + + + + + + + + + + +
PromptGenerated Video
A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression. + + +
A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle. + + +
+ +## Image-to-Video Showcases + + + + + + + + + + + + + + + + + +
ImagePromptGenerated Video
A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees. + + +
A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective. + + +
+ +## Interactive-Video Showcases + + + + + + + + + + + + + + +
PromptGenerated Video
The prompt can be found here + +
The prompt can be found here + +
+ +## Resources + +通过以下资源了解有关 Helios 的更多信息: + +- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能; +- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。 diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1fc0914fe09e..1458164191df 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -227,6 +227,7 @@ "FluxMultiControlNetModel", "FluxTransformer2DModel", "GlmImageTransformer2DModel", + "HeliosTransformer3DModel", "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", @@ -359,6 +360,8 @@ "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", + "HeliosDMDScheduler", + "HeliosScheduler", "HeunDiscreteScheduler", "IPNDMScheduler", "KarrasVeScheduler", @@ -515,6 +518,8 @@ "FluxPipeline", "FluxPriorReduxPipeline", "GlmImagePipeline", + "HeliosPipeline", + "HeliosPyramidPipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -994,6 +999,7 @@ FluxMultiControlNetModel, FluxTransformer2DModel, GlmImageTransformer2DModel, + HeliosTransformer3DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, @@ -1122,6 +1128,8 @@ FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FlowMatchLCMScheduler, + HeliosDMDScheduler, + HeliosScheduler, HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, @@ -1257,6 +1265,8 @@ FluxPipeline, FluxPriorReduxPipeline, GlmImagePipeline, + HeliosPipeline, + HeliosPyramidPipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index bdd4dbbcd4b5..ed0d2a07336f 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -78,6 +78,7 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", + "HeliosLoraLoaderMixin", "KandinskyLoraLoaderMixin", "HiDreamImageLoraLoaderMixin", "SkyReelsV2LoraLoaderMixin", @@ -118,6 +119,7 @@ def text_encoder_attn_modules(text_encoder): CogView4LoraLoaderMixin, Flux2LoraLoaderMixin, FluxLoraLoaderMixin, + HeliosLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, KandinskyLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3423a88d3368..5d10f596f2e6 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3440,6 +3440,207 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class HeliosLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) + elif any(k.startswith("lora_unet_") for k in state_dict): + state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: str | os.PathLike, + transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: dict | None = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: list[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`]. diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 80fb6a72869a..a96542c2a50c 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -51,6 +51,7 @@ "FluxTransformer2DModel": lambda model_cls, weights: weights, "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, "ConsisIDTransformer3DModel": lambda model_cls, weights: weights, + "HeliosTransformer3DModel": lambda model_cls, weights: weights, "MochiTransformer3DModel": lambda model_cls, weights: weights, "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 96953afa4f4a..8b8d9c52659e 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -100,6 +100,7 @@ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] + _import_structure["transformers.transformer_helios"] = ["HeliosTransformer3DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] @@ -212,6 +213,7 @@ Flux2Transformer2DModel, FluxTransformer2DModel, GlmImageTransformer2DModel, + HeliosTransformer3DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d9d1b27a1e40..45157ee91808 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -28,6 +28,7 @@ from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel from .transformer_glm_image import GlmImageTransformer2DModel + from .transformer_helios import HeliosTransformer3DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py new file mode 100644 index 000000000000..9f3ef047d98d --- /dev/null +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -0,0 +1,814 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import lru_cache +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import apply_lora_scale, logging +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def pad_for_3d_conv(x, kernel_size): + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") + + +def center_down_sample_3d(x, kernel_size): + return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) + + +def apply_rotary_emb_transposed( + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, +): + x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) + out = torch.empty_like(hidden_states) + out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2] + out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2] + return out.type_as(hidden_states) + + +def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if not attn.is_cross_attention: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +class HeliosOutputNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False): + super().__init__() + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, original_context_length: int): + temb = temb[:, -original_context_length:, :] + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift, scale = shift.squeeze(2).to(hidden_states.device), scale.squeeze(2).to(hidden_states.device) + hidden_states = hidden_states[:, -original_context_length:, :] + hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + return hidden_states + + +class HeliosAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "HeliosAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + original_context_length: int = None, + ) -> torch.Tensor: + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + query = apply_rotary_emb_transposed(query, rotary_emb) + key = apply_rotary_emb_transposed(key, rotary_emb) + + if not attn.is_cross_attention and attn.is_amplify_history: + history_seq_len = hidden_states.shape[1] - original_context_length + + if history_seq_len > 0: + scale_key = 1.0 + torch.sigmoid(attn.history_key_scale) * (attn.max_scale - 1.0) + if attn.history_scale_mode == "per_head": + scale_key = scale_key.view(1, 1, -1, 1) + key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class HeliosAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = HeliosAttnProcessor + _available_processors = [HeliosAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: int | None = None, + cross_attention_dim_head: int | None = None, + processor=None, + is_cross_attention=None, + is_amplify_history=False, + history_scale_mode="per_head", # [scalar, per_head] + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(dropout), + ] + ) + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + if is_cross_attention is not None: + self.is_cross_attention = is_cross_attention + else: + self.is_cross_attention = cross_attention_dim_head is not None + + self.set_processor(processor) + + self.is_amplify_history = is_amplify_history + if is_amplify_history: + if history_scale_mode == "scalar": + self.history_key_scale = nn.Parameter(torch.ones(1)) + elif history_scale_mode == "per_head": + self.history_key_scale = nn.Parameter(torch.ones(heads)) + else: + raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}") + self.history_scale_mode = history_scale_mode + self.max_scale = 10.0 + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if not self.is_cross_attention: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + original_context_length: int = None, + **kwargs, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states, + attention_mask, + rotary_emb, + original_context_length, + **kwargs, + ) + + +class HeliosTimeTextEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + is_return_encoder_hidden_states: bool = True, + ): + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + if encoder_hidden_states is not None and is_return_encoder_hidden_states: + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + + return temb, timestep_proj, encoder_hidden_states + + +class HeliosRotaryPosEmbed(nn.Module): + def __init__(self, rope_dim, theta): + super().__init__() + self.DT, self.DY, self.DX = rope_dim + self.theta = theta + self.register_buffer("freqs_base_t", self._get_freqs_base(self.DT), persistent=False) + self.register_buffer("freqs_base_y", self._get_freqs_base(self.DY), persistent=False) + self.register_buffer("freqs_base_x", self._get_freqs_base(self.DX), persistent=False) + + def _get_freqs_base(self, dim): + return 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + + @torch.no_grad() + def get_frequency_batched(self, freqs_base, pos): + freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos) + freqs = freqs.repeat_interleave(2, dim=0) + return freqs.cos(), freqs.sin() + + @torch.no_grad() + @lru_cache(maxsize=32) + def _get_spatial_meshgrid(self, height, width, device_str): + device = torch.device(device_str) + grid_y_coords = torch.arange(height, device=device, dtype=torch.float32) + grid_x_coords = torch.arange(width, device=device, dtype=torch.float32) + grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij") + return grid_y, grid_x + + @torch.no_grad() + def forward(self, frame_indices, height, width, device): + batch_size = frame_indices.shape[0] + num_frames = frame_indices.shape[1] + + frame_indices = frame_indices.to(device=device, dtype=torch.float32) + grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device)) + + grid_t = frame_indices[:, :, None, None].expand(batch_size, num_frames, height, width) + grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1) + grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1) + + freqs_cos_t, freqs_sin_t = self.get_frequency_batched(self.freqs_base_t, grid_t) + freqs_cos_y, freqs_sin_y = self.get_frequency_batched(self.freqs_base_y, grid_y_batch) + freqs_cos_x, freqs_sin_x = self.get_frequency_batched(self.freqs_base_x, grid_x_batch) + + result = torch.cat([freqs_cos_t, freqs_cos_y, freqs_cos_x, freqs_sin_t, freqs_sin_y, freqs_sin_x], dim=0) + + return result.permute(1, 0, 2, 3, 4) + + +@maybe_allow_in_graph +class HeliosTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + guidance_cross_attn: bool = False, + is_amplify_history: bool = False, + history_scale_mode: str = "per_head", # [scalar, per_head] + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = HeliosAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + processor=HeliosAttnProcessor(), + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + ) + + # 2. Cross-attention + self.attn2 = HeliosAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + processor=HeliosAttnProcessor(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + # 4. Guidance cross-attention + self.guidance_cross_attn = guidance_cross_attn + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + original_context_length: int = None, + ) -> torch.Tensor: + if temb.ndim == 4: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1( + norm_hidden_states, + None, + None, + rotary_emb, + original_context_length, + ) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + if self.guidance_cross_attn: + history_seq_len = hidden_states.shape[1] - original_context_length + + history_hidden_states, hidden_states = torch.split( + hidden_states, [history_seq_len, original_context_length], dim=1 + ) + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states, + None, + None, + original_context_length, + ) + hidden_states = hidden_states + attn_output + hidden_states = torch.cat([history_hidden_states, hidden_states], dim=1) + else: + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states, + None, + None, + original_context_length, + ) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class HeliosTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + r""" + A Transformer model for video-like data used in the Helios model. + + Args: + patch_size (`tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = [ + "patch_embedding", + "patch_short", + "patch_mid", + "patch_long", + "condition_embedder", + "norm", + ] + _no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"] + _keep_in_fp32_modules = [ + "time_embedder", + "scale_shift_table", + "norm1", + "norm2", + "norm3", + "history_key_scale", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["HeliosTransformerBlock"] + _cp_plan = { + "blocks.0": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*": { + "temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False), + "rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + patch_size: tuple[int, ...] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: str | None = "rms_norm_across_heads", + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + rope_dim: tuple[int, ...] = (44, 42, 42), + rope_theta: float = 10000.0, + guidance_cross_attn: bool = True, + zero_history_timestep: bool = True, + has_multi_term_memory_patch: bool = True, + is_amplify_history: bool = False, + history_scale_mode: str = "per_head", # [scalar, per_head] + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = HeliosRotaryPosEmbed(rope_dim=rope_dim, theta=rope_theta) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Initial Multi Term Memory Patch + self.zero_history_timestep = zero_history_timestep + if has_multi_term_memory_patch: + self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.patch_mid = nn.Conv3d( + in_channels, + inner_dim, + kernel_size=tuple(2 * p for p in patch_size), + stride=tuple(2 * p for p in patch_size), + ) + self.patch_long = nn.Conv3d( + in_channels, + inner_dim, + kernel_size=tuple(4 * p for p in patch_size), + stride=tuple(4 * p for p in patch_size), + ) + + # 3. Condition embeddings + self.condition_embedder = HeliosTimeTextEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + ) + + # 4. Transformer blocks + self.blocks = nn.ModuleList( + [ + HeliosTransformerBlock( + inner_dim, + ffn_dim, + num_attention_heads, + qk_norm, + cross_attn_norm, + eps, + added_kv_proj_dim, + guidance_cross_attn=guidance_cross_attn, + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + ) + for _ in range(num_layers) + ] + ) + + # 5. Output norm & projection + self.norm_out = HeliosOutputNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + + self.gradient_checkpointing = False + + @apply_lora_scale("attention_kwargs") + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + # ------------ Stage 1 ------------ + indices_hidden_states=None, + indices_latents_history_short=None, + indices_latents_history_mid=None, + indices_latents_history_long=None, + latents_history_short=None, + latents_history_mid=None, + latents_history_long=None, + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor | dict[str, torch.Tensor]: + # 1. Input + batch_size = hidden_states.shape[0] + p_t, p_h, p_w = self.config.patch_size + + # 2. Process noisy latents + hidden_states = self.patch_embedding(hidden_states) + _, _, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape + + if indices_hidden_states is None: + indices_hidden_states = torch.arange(0, post_patch_num_frames).unsqueeze(0).expand(batch_size, -1) + + hidden_states = hidden_states.flatten(2).transpose(1, 2) + rotary_emb = self.rope( + frame_indices=indices_hidden_states, + height=post_patch_height, + width=post_patch_width, + device=hidden_states.device, + ) + rotary_emb = rotary_emb.flatten(2).transpose(1, 2) + original_context_length = hidden_states.shape[1] + + # 3. Process short history latents + if latents_history_short is not None and indices_latents_history_short is not None: + latents_history_short = self.patch_short(latents_history_short) + _, _, _, H1, W1 = latents_history_short.shape + latents_history_short = latents_history_short.flatten(2).transpose(1, 2) + + rotary_emb_history_short = self.rope( + frame_indices=indices_latents_history_short, + height=H1, + width=W1, + device=latents_history_short.device, + ) + rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_short, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1) + + # 4. Process mid history latents + if latents_history_mid is not None and indices_latents_history_mid is not None: + latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4)) + latents_history_mid = self.patch_mid(latents_history_mid) + latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2) + + rotary_emb_history_mid = self.rope( + frame_indices=indices_latents_history_mid, + height=H1, + width=W1, + device=latents_history_mid.device, + ) + rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2)) + rotary_emb_history_mid = center_down_sample_3d(rotary_emb_history_mid, (2, 2, 2)) + rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1) + + # 5. Process long history latents + if latents_history_long is not None and indices_latents_history_long is not None: + latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8)) + latents_history_long = self.patch_long(latents_history_long) + latents_history_long = latents_history_long.flatten(2).transpose(1, 2) + + rotary_emb_history_long = self.rope( + frame_indices=indices_latents_history_long, + height=H1, + width=W1, + device=latents_history_long.device, + ) + rotary_emb_history_long = pad_for_3d_conv(rotary_emb_history_long, (4, 4, 4)) + rotary_emb_history_long = center_down_sample_3d(rotary_emb_history_long, (4, 4, 4)) + rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_long, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1) + + history_context_length = hidden_states.shape[1] - original_context_length + + if indices_hidden_states is not None and self.zero_history_timestep: + timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device) + temb_t0, timestep_proj_t0, _ = self.condition_embedder( + timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False + ) + temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1) + timestep_proj_t0 = ( + timestep_proj_t0.unflatten(-1, (6, -1)) + .view(1, 6, 1, -1) + .expand(batch_size, -1, history_context_length, -1) + ) + + temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states) + timestep_proj = timestep_proj.unflatten(-1, (6, -1)) + + if indices_hidden_states is not None and not self.zero_history_timestep: + main_repeat_size = hidden_states.shape[1] + else: + main_repeat_size = original_context_length + temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1) + timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(batch_size, 6, main_repeat_size, -1) + + if indices_hidden_states is not None and self.zero_history_timestep: + temb = torch.cat([temb_t0, temb], dim=1) + timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2) + + if timestep_proj.ndim == 4: + timestep_proj = timestep_proj.permute(0, 2, 1, 3) + + # 6. Transformer blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + rotary_emb = rotary_emb.contiguous() + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + original_context_length, + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + original_context_length, + ) + + # 7. Normalization + hidden_states = self.norm_out(hidden_states, temb, original_context_length) + hidden_states = self.proj_out(hidden_states) + + # 8. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 638598051d64..08cb28a6237a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -237,6 +237,7 @@ "EasyAnimateInpaintPipeline", "EasyAnimateControlPipeline", ] + _import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"] _import_structure["hidream_image"] = ["HiDreamImagePipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["hunyuan_video"] = [ @@ -667,6 +668,7 @@ ) from .flux2 import Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline + from .helios import HeliosPipeline, HeliosPyramidPipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline from .hunyuan_video import ( diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 7247ca0d161c..72151dc40a53 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -54,6 +54,7 @@ ) from .flux2 import Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline +from .helios import HeliosPipeline, HeliosPyramidPipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -174,6 +175,8 @@ ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), ("glm_image", GlmImagePipeline), + ("helios", HeliosPipeline), + ("helios-pyramid", HeliosPyramidPipeline), ("cogview4-control", CogView4ControlPipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), diff --git a/src/diffusers/pipelines/helios/__init__.py b/src/diffusers/pipelines/helios/__init__.py new file mode 100644 index 000000000000..ae08f5997279 --- /dev/null +++ b/src/diffusers/pipelines/helios/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_helios"] = ["HeliosPipeline"] + _import_structure["pipeline_helios_pyramid"] = ["HeliosPyramidPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_helios import HeliosPipeline + from .pipeline_helios_pyramid import HeliosPyramidPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py new file mode 100644 index 000000000000..87a8600badab --- /dev/null +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -0,0 +1,916 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import numpy as np +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HeliosLoraLoaderMixin +from ...models import AutoencoderKLWan, HeliosTransformer3DModel +from ...schedulers import HeliosScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HeliosPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLWan, HeliosPipeline + + >>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled + >>> model_id = "BestWishYsh/Helios-Base" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = HeliosPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=384, + ... width=640, + ... num_frames=132, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): + r""" + Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`HeliosTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`HeliosScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: HeliosScheduler, + transformer: HeliosTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, text_inputs.attention_mask.bool() + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, _ = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, _ = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + video=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if image is not None and video is not None: + raise ValueError("image and video cannot be provided simultaneously") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 384, + width: int = 640, + num_frames: int = 33, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + fake_latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + if latents is None: + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = (latents - latents_mean) * latents_std + if fake_latents is None: + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype) + fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator) + fake_latents_full = (fake_latents_full - latents_mean) * latents_std + fake_latents = fake_latents_full[:, :, -1:, :, :] + return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype) + + def prepare_video_latents( + self, + video: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + video = video.to(device=device, dtype=self.vae.dtype) + if latents is None: + num_frames = video.shape[2] + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames " + f"(got {num_frames} frames). " + f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}" + ) + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + + first_frame = video[:, :, 0:1, :, :] + first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator) + first_frame_latent = (first_frame_latent - latents_mean) * latents_std + + latents_chunks = [] + for i in range(num_chunks): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[:, :, chunk_start:chunk_end, :, :] + chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator) + chunk_latents = (chunk_latents - latents_mean) * latents_std + latents_chunks.append(chunk_latents) + latents = torch.cat(latents_chunks, dim=2) + return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 384, + width: int = 640, + num_frames: int = 132, + num_inference_steps: int = 50, + sigmas: list[float] = None, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + # ------------ I2V ------------ + image: PipelineImageInput | None = None, + image_latents: torch.Tensor | None = None, + fake_image_latents: torch.Tensor | None = None, + add_noise_to_image_latents: bool = True, + image_noise_sigma_min: float = 0.111, + image_noise_sigma_max: float = 0.135, + # ------------ V2V ------------ + video: PipelineImageInput | None = None, + video_latents: torch.Tensor | None = None, + add_noise_to_video_latents: bool = True, + video_noise_sigma_min: float = 0.111, + video_noise_sigma_max: float = 0.135, + # ------------ Stage 1 ------------ + history_sizes: list = [16, 2, 1], + num_latent_frames_per_chunk: int = 9, + keep_first_frame: bool = True, + is_skip_first_chunk: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `384`): + The height in pixels of the generated image. + width (`int`, defaults to `640`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `132`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~HeliosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + history_sizes = sorted(history_sizes, reverse=True) # From big to small + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + image, + video, + ) + + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + vae_dtype = self.vae.dtype + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image or video + if image is not None: + image = self.video_processor.preprocess(image, height=height, width=width) + image_latents, fake_image_latents = self.prepare_image_latents( + image, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=image_latents, + fake_latents=fake_image_latents, + ) + + if image_latents is not None and add_noise_to_image_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + + if video is not None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + image_latents, video_latents = self.prepare_video_latents( + video, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=video_latents, + ) + + if video_latents is not None and add_noise_to_video_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + + noisy_latents_chunks = [] + num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk + for i in range(num_latent_chunks): + chunk_start = i * num_latent_frames_per_chunk + chunk_end = chunk_start + num_latent_frames_per_chunk + latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :] + + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand(chunk_frames, device=device, generator=generator) + * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1) + + noisy_chunk = ( + frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device) + + (1 - frame_sigmas) * latent_chunk + ) + noisy_latents_chunks.append(noisy_chunk) + video_latents = torch.cat(noisy_latents_chunks, dim=2) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + num_history_latent_frames = sum(history_sizes) + history_video = None + total_generated_latent_frames = 0 + + if not keep_first_frame: + history_sizes[-1] = history_sizes[-1] + 1 + history_latents = torch.zeros( + batch_size, + num_channels_latents, + num_history_latent_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2) + total_generated_latent_frames += 1 + if video_latents is not None: + history_frames = history_latents.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2) + else: + history_latents = video_latents + total_generated_latent_frames += video_latents.shape[2] + + if keep_first_frame: + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + else: + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + + # 6. Denoising loop + patch_size = self.transformer.config.patch_size + image_seq_len = ( + num_latent_frames_per_chunk + * (height // self.vae_scale_factor_spatial) + * (width // self.vae_scale_factor_spatial) + // (patch_size[0] * patch_size[1] * patch_size[2]) + ) + sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + for k in range(num_latent_chunk): + is_first_chunk = k == 0 + is_second_chunk = k == 1 + if keep_first_frame: + latents_history_long, latents_history_mid, latents_history_1x = history_latents[ + :, :, -num_history_latent_frames: + ].split(history_sizes, dim=2) + if image_latents is None and is_first_chunk: + latents_prefix = torch.zeros( + ( + batch_size, + num_channels_latents, + 1, + latents_history_1x.shape[-2], + latents_history_1x.shape[-1], + ), + device=device, + dtype=latents_history_1x.dtype, + ) + else: + latents_prefix = image_latents + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + latents_history_long, latents_history_mid, latents_history_short = history_latents[ + :, :, -num_history_latent_frames: + ].split(history_sizes, dim=2) + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + window_num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu) + timesteps = self.scheduler.timesteps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]) + + latent_model_input = latents.to(transformer_dtype) + latents_history_short = latents_history_short.to(transformer_dtype) + latents_history_mid = latents_history_mid.to(transformer_dtype) + latents_history_long = latents_history_long.to(transformer_dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + return_dict=False, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if keep_first_frame and ( + (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk) + ): + image_latents = latents[:, :, 0:1, :, :] + + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + current_latents = ( + real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std + + latents_mean + ) + current_video = self.vae.decode(current_latents, return_dict=False)[0] + + if history_video is None: + history_video = current_video + else: + history_video = torch.cat([history_video, current_video], dim=2) + + self._current_timestep = None + + if output_type != "latent": + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = real_history_latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HeliosPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py new file mode 100644 index 000000000000..40c1d65825ff --- /dev/null +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -0,0 +1,1065 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import math +from typing import Any, Callable + +import regex as re +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HeliosLoraLoaderMixin +from ...models import AutoencoderKLWan, HeliosTransformer3DModel +from ...schedulers import HeliosDMDScheduler, HeliosScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HeliosPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLWan, HeliosPyramidPipeline + + >>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled + >>> model_id = "BestWishYsh/Helios-Base" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = HeliosPyramidPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=384, + ... width=640, + ... num_frames=132, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +def optimized_scale(positive_flat, negative_flat): + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + return st_star + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class HeliosPyramidPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): + r""" + Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`HeliosTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`HeliosScheduler`, `HeliosDMDScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: HeliosScheduler | HeliosDMDScheduler, + transformer: HeliosTransformer3DModel, + is_cfg_zero_star: bool = False, + is_distilled: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.register_to_config(is_cfg_zero_star=is_cfg_zero_star) + self.register_to_config(is_distilled=is_distilled) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.helios.pipeline_helios.HeliosPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, text_inputs.attention_mask.bool() + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, _ = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, _ = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + video=None, + guidance_scale=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if image is not None and video is not None: + raise ValueError("image and video cannot be provided simultaneously") + + if guidance_scale > 1.0 and self.config.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 384, + width: int = 640, + num_frames: int = 33, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + fake_latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + if latents is None: + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = (latents - latents_mean) * latents_std + if fake_latents is None: + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype) + fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator) + fake_latents_full = (fake_latents_full - latents_mean) * latents_std + fake_latents = fake_latents_full[:, :, -1:, :, :] + return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype) + + def prepare_video_latents( + self, + video: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + video = video.to(device=device, dtype=self.vae.dtype) + if latents is None: + num_frames = video.shape[2] + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames " + f"(got {num_frames} frames). " + f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}" + ) + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + + first_frame = video[:, :, 0:1, :, :] + first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator) + first_frame_latent = (first_frame_latent - latents_mean) * latents_std + + latents_chunks = [] + for i in range(num_chunks): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[:, :, chunk_start:chunk_end, :, :] + chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator) + chunk_latents = (chunk_latents - latents_mean) * latents_std + latents_chunks.append(chunk_latents) + latents = torch.cat(latents_chunks, dim=2) + return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) + + def sample_block_noise( + self, + batch_size, + channel, + num_frames, + height, + width, + patch_size: tuple[int, ...] = (1, 2, 2), + device: torch.device | None = None, + ): + gamma = self.scheduler.config.gamma + _, ph, pw = patch_size + block_size = ph * pw + + cov = ( + torch.eye(block_size, device=device) * (1 + gamma) + - torch.ones(block_size, block_size, device=device) * gamma + ) + cov += torch.eye(block_size, device=device) * 1e-6 + dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov) + block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + + noise = dist.sample((block_number,)) # [block number, block_size] + noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) + noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + return noise + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 384, + width: int = 640, + num_frames: int = 132, + sigmas: list[float] = None, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + # ------------ I2V ------------ + image: PipelineImageInput | None = None, + image_latents: torch.Tensor | None = None, + fake_image_latents: torch.Tensor | None = None, + add_noise_to_image_latents: bool = True, + image_noise_sigma_min: float = 0.111, + image_noise_sigma_max: float = 0.135, + # ------------ V2V ------------ + video: PipelineImageInput | None = None, + video_latents: torch.Tensor | None = None, + add_noise_to_video_latents: bool = True, + video_noise_sigma_min: float = 0.111, + video_noise_sigma_max: float = 0.135, + # ------------ Stage 1 ------------ + history_sizes: list = [16, 2, 1], + num_latent_frames_per_chunk: int = 9, + keep_first_frame: bool = True, + is_skip_first_chunk: bool = False, + # ------------ Stage 2 ------------ + pyramid_num_inference_steps_list: list = [10, 10, 10], + # ------------ CFG Zero ------------ + use_zero_init: bool | None = True, + zero_steps: int | None = 1, + # ------------ DMD ------------ + is_amplify_first_chunk: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `384`): + The height in pixels of the generated image. + width (`int`, defaults to `640`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `132`): + The number of frames in the generated video. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~HeliosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + history_sizes = sorted(history_sizes, reverse=True) # From big to small + pyramid_num_stages = len(pyramid_num_inference_steps_list) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + image, + video, + guidance_scale, + ) + + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + vae_dtype = self.vae.dtype + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image or video + if image is not None: + image = self.video_processor.preprocess(image, height=height, width=width) + image_latents, fake_image_latents = self.prepare_image_latents( + image, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=image_latents, + fake_latents=fake_image_latents, + ) + + if image_latents is not None and add_noise_to_image_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + + if video is not None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + image_latents, video_latents = self.prepare_video_latents( + video, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=video_latents, + ) + + if video_latents is not None and add_noise_to_video_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + + noisy_latents_chunks = [] + num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk + for i in range(num_latent_chunks): + chunk_start = i * num_latent_frames_per_chunk + chunk_end = chunk_start + num_latent_frames_per_chunk + latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :] + + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand(chunk_frames, device=device, generator=generator) + * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1) + + noisy_chunk = ( + frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device) + + (1 - frame_sigmas) * latent_chunk + ) + noisy_latents_chunks.append(noisy_chunk) + video_latents = torch.cat(noisy_latents_chunks, dim=2) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + num_history_latent_frames = sum(history_sizes) + history_video = None + total_generated_latent_frames = 0 + + if not keep_first_frame: + history_sizes[-1] = history_sizes[-1] + 1 + history_latents = torch.zeros( + batch_size, + num_channels_latents, + num_history_latent_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2) + total_generated_latent_frames += 1 + if video_latents is not None: + history_frames = history_latents.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2) + else: + history_latents = video_latents + total_generated_latent_frames += video_latents.shape[2] + + if keep_first_frame: + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + else: + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + + # 6. Denoising loop + for k in range(num_latent_chunk): + is_first_chunk = k == 0 + is_second_chunk = k == 1 + if keep_first_frame: + latents_history_long, latents_history_mid, latents_history_1x = history_latents[ + :, :, -num_history_latent_frames: + ].split(history_sizes, dim=2) + if image_latents is None and is_first_chunk: + latents_prefix = torch.zeros( + ( + batch_size, + num_channels_latents, + 1, + latents_history_1x.shape[-2], + latents_history_1x.shape[-1], + ), + device=device, + dtype=latents_history_1x.dtype, + ) + else: + latents_prefix = image_latents + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + latents_history_long, latents_history_mid, latents_history_short = history_latents[ + :, :, -num_history_latent_frames: + ].split(history_sizes, dim=2) + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + window_num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + num_inference_steps = ( + sum(pyramid_num_inference_steps_list) * 2 + if is_amplify_first_chunk and self.config.is_distilled and is_first_chunk + else sum(pyramid_num_inference_steps_list) + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + _, _, _, pyramid_height, pyramid_width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width + ) + for _ in range(pyramid_num_stages - 1): + pyramid_height //= 2 + pyramid_width //= 2 + latents = ( + F.interpolate( + latents, + size=(pyramid_height, pyramid_width), + mode="bilinear", + ) + * 2 + ) + latents = latents.reshape( + batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width + ).permute(0, 2, 1, 3, 4) + + start_point_list = None + if self.config.is_distilled: + start_point_list = [latents] + + for stage_idx in range(pyramid_num_stages): + patch_size = self.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.set_timesteps( + pyramid_num_inference_steps_list[stage_idx], + stage_idx, + device=device, + mu=mu, + is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk, + ) + timesteps = self.scheduler.timesteps + num_warmup_steps = 0 + self._num_timesteps = len(timesteps) + + if stage_idx > 0: + pyramid_height *= 2 + pyramid_width *= 2 + num_frames = latents.shape[2] + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_latent_frames_per_chunk, + num_channels_latents, + pyramid_height // 2, + pyramid_width // 2, + ) + latents = F.interpolate(latents, size=(pyramid_height, pyramid_width), mode="nearest") + latents = latents.reshape( + batch_size, + num_latent_frames_per_chunk, + num_channels_latents, + pyramid_height, + pyramid_width, + ).permute(0, 2, 1, 3, 4) + # Fix the stage + ori_sigma = 1 - self.scheduler.ori_start_sigmas[stage_idx] # the original coeff of signal + gamma = self.scheduler.config.gamma + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + + batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape + noise = self.sample_block_noise( + batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device + ) + noise = noise.to(device=device, dtype=transformer_dtype) + latents = alpha * latents + beta * noise # To fix the block artifact + + if self.config.is_distilled: + start_point_list.append(latents) + + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(torch.int64) + + latent_model_input = latents.to(transformer_dtype) + latents_history_short = latents_history_short.to(transformer_dtype) + latents_history_mid = latents_history_mid.to(transformer_dtype) + latents_history_long = latents_history_long.to(transformer_dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.config.is_cfg_zero_star: + noise_pred_text = noise_pred + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_uncond.view(batch_size, -1) + + alpha = optimized_scale(positive_flat, negative_flat) + alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) + alpha = alpha.to(noise_pred_text.dtype) + + if (stage_idx == 0 and i <= zero_steps) and use_zero_init: + noise_pred = noise_pred_text * 0.0 + else: + noise_pred = noise_uncond * alpha + guidance_scale * ( + noise_pred_text - noise_uncond * alpha + ) + else: + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + extra_kwargs = ( + { + "cur_sampling_step": i, + "dmd_noisy_tensor": start_point_list[stage_idx] + if start_point_list is not None + else None, + "dmd_sigmas": self.scheduler.sigmas, + "dmd_timesteps": self.scheduler.timesteps, + "all_timesteps": timesteps, + } + if self.config.is_distilled + else {} + ) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + return_dict=False, + **extra_kwargs, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if keep_first_frame and ( + (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk) + ): + image_latents = latents[:, :, 0:1, :, :] + + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + current_latents = ( + real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std + + latents_mean + ) + current_video = self.vae.decode(current_latents, return_dict=False)[0] + + if history_video is None: + history_video = current_video + else: + history_video = torch.cat([history_video, current_video], dim=2) + + self._current_timestep = None + + if output_type != "latent": + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = real_history_latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HeliosPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/helios/pipeline_output.py b/src/diffusers/pipelines/helios/pipeline_output.py new file mode 100644 index 000000000000..08546289ef4c --- /dev/null +++ b/src/diffusers/pipelines/helios/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HeliosPipelineOutput(BaseOutput): + r""" + Output class for Helios pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 4199e75bf331..c7101d1b0401 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -61,6 +61,8 @@ _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] + _import_structure["scheduling_helios"] = ["HeliosScheduler"] + _import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] @@ -164,6 +166,8 @@ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_flow_match_lcm import FlowMatchLCMScheduler + from .scheduling_helios import HeliosScheduler + from .scheduling_helios_dmd import HeliosDMDScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py new file mode 100644 index 000000000000..ed35245c9db3 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -0,0 +1,867 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..schedulers.scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, deprecate + + +@dataclass +class HeliosSchedulerOutput(BaseOutput): + prev_sample: torch.FloatTensor + model_outputs: torch.FloatTensor | None = None + last_sample: torch.FloatTensor | None = None + this_order: int | None = None + + +class HeliosScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, # Following Stable diffusion 3, + stages: int = 3, + stage_range: list = [0, 1 / 3, 2 / 3, 1], + gamma: float = 1 / 3, + # For UniPC + thresholding: bool = False, + prediction_type: str = "flow_prediction", + solver_order: int = 2, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: list[int] = [], + solver_p: SchedulerMixin = None, + use_flow_sigmas: bool = True, + scheduler_type: str = "unipc", # ["euler", "unipc"] + use_dynamic_shifting: bool = False, + time_shift_type: Literal["exponential", "linear"] = "exponential", + ): + self.timestep_ratios = {} # The timestep ratio for each stage + self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage) + self.sigmas_per_stage = {} # always uniform [1000, 0] + self.start_sigmas = {} # for start point / upsample renoise + self.end_sigmas = {} # for end point + self.ori_start_sigmas = {} + + # self.init_sigmas() + self.init_sigmas_for_each_stage() + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + self.gamma = gamma + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + def init_sigmas(self): + """ + initialize the global timesteps and sigmas + """ + num_train_timesteps = self.config.num_train_timesteps + shift = self.config.shift + + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1) + sigmas = 1.0 - alphas + sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy() + sigmas = torch.from_numpy(sigmas) + timesteps = (sigmas * num_train_timesteps).clone() + + self._step_index = None + self._begin_index = None + self.timesteps = timesteps + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def init_sigmas_for_each_stage(self): + """ + Init the timesteps for each stage + """ + self.init_sigmas() + + stage_distance = [] + stages = self.config.stages + training_steps = self.config.num_train_timesteps + stage_range = self.config.stage_range + + # Init the start and end point of each stage + for i_s in range(stages): + # To decide the start and ends point + start_indice = int(stage_range[i_s] * training_steps) + start_indice = max(start_indice, 0) + end_indice = int(stage_range[i_s + 1] * training_steps) + end_indice = min(end_indice, training_steps) + start_sigma = self.sigmas[start_indice].item() + end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0 + self.ori_start_sigmas[i_s] = start_sigma + + if i_s != 0: + ori_sigma = 1 - start_sigma + gamma = self.config.gamma + corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma + # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma + start_sigma = 1 - corrected_sigma + + stage_distance.append(start_sigma - end_sigma) + self.start_sigmas[i_s] = start_sigma + self.end_sigmas[i_s] = end_sigma + + # Determine the ratio of each stage according to flow length + tot_distance = sum(stage_distance) + for i_s in range(stages): + if i_s == 0: + start_ratio = 0.0 + else: + start_ratio = sum(stage_distance[:i_s]) / tot_distance + if i_s == stages - 1: + end_ratio = 0.9999999999999999 + else: + end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance + + self.timestep_ratios[i_s] = (start_ratio, end_ratio) + + # Determine the timesteps and sigmas for each stage + for i_s in range(stages): + timestep_ratio = self.timestep_ratios[i_s] + # timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)] + timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999) + timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)] + timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1) + self.timesteps_per_stage[i_s] = ( + timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1]) + ) + stage_sigmas = np.linspace(0.999, 0, training_steps + 1) + self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1]) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int, + stage_index: int | None = None, + device: str | torch.device = None, + sigmas: bool | None = None, + mu: bool | None = None, + is_amplify_first_chunk: bool = False, + ): + """ + Setting the timesteps and sigmas for each stage + """ + if self.config.scheduler_type == "dmd": + if is_amplify_first_chunk: + num_inference_steps = num_inference_steps * 2 + 1 + else: + num_inference_steps = num_inference_steps + 1 + + self.num_inference_steps = num_inference_steps + self.init_sigmas() + + if self.config.stages == 1: + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype( + np.float32 + ) + if self.config.shift != 1.0: + assert not self.config.use_dynamic_shifting + sigmas = self.time_shift(self.config.shift, 1.0, sigmas) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = torch.from_numpy(sigmas) + else: + stage_timesteps = self.timesteps_per_stage[stage_index] + timesteps = np.linspace( + stage_timesteps[0].item(), + stage_timesteps[-1].item(), + num_inference_steps, + ) + + stage_sigmas = self.sigmas_per_stage[stage_index] + ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps) + sigmas = torch.from_numpy(ratios) + + self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device) + + self._step_index = None + self.reset_scheduler_history() + + if self.config.scheduler_type == "dmd": + self.timesteps = self.timesteps[:-1] + self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) + + if self.config.use_dynamic_shifting: + assert self.config.shift == 1.0 + self.sigmas = self.time_shift(mu, 1.0, self.sigmas) + if self.config.stages == 1: + self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps + else: + self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * ( + self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min() + ) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + """ + Apply time shifting to the sigmas. + + Args: + mu (`float`): + The mu parameter for the time shift. + sigma (`float`): + The sigma parameter for the time shift. + t (`torch.Tensor`): + The input timesteps. + + Returns: + `torch.Tensor`: + The time-shifted timesteps. + """ + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + + # ---------------------------------- Euler ---------------------------------- + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step_euler( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor = None, + sample: torch.FloatTensor = None, + generator: torch.Generator | None = None, + sigma: torch.FloatTensor | None = None, + sigma_next: torch.FloatTensor | None = None, + return_dict: bool = True, + ) -> HeliosSchedulerOutput | tuple: + assert (sigma is None) == (sigma_next is None), "sigma and sigma_next must both be None or both be not None" + + if sigma is None and sigma_next is None: + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._step_index = 0 + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if sigma is None and sigma_next is None: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return HeliosSchedulerOutput(prev_sample=prev_sample) + + # ---------------------------------- UniPC ---------------------------------- + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = torch.clamp(sigma, min=1e-8) + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + sigma: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + flag = False + if sigma is None: + flag = True + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + if flag: + sigma_t = self.sigmas[self.step_index] + else: + sigma_t = sigma + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + sigma: torch.Tensor = None, + sigma_next: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError("missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + if sigma_next is None and sigma is None: + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + else: + sigma_t, sigma_s0 = sigma_next, sigma + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + sigma_before: torch.Tensor = None, + sigma: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError("missing `last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError("missing `this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError("missing `order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + if sigma_before is None and sigma is None: + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + else: + sigma_t, sigma_s0 = sigma, sigma_before + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step_unipc( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor = None, + sample: torch.Tensor = None, + return_dict: bool = True, + model_outputs: list = None, + timestep_list: list = None, + sigma_before: torch.Tensor = None, + sigma: torch.Tensor = None, + sigma_next: torch.Tensor = None, + cus_step_index: int = None, + cus_lower_order_num: int = None, + cus_this_order: int = None, + cus_last_sample: torch.Tensor = None, + ) -> HeliosSchedulerOutput | tuple: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if cus_step_index is None: + if self.step_index is None: + self._step_index = 0 + else: + self._step_index = cus_step_index + + if cus_lower_order_num is not None: + self.lower_order_nums = cus_lower_order_num + + if cus_this_order is not None: + self.this_order = cus_this_order + + if cus_last_sample is not None: + self.last_sample = cus_last_sample + + use_corrector = ( + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + # Convert model output using the proper conversion method + model_output_convert = self.convert_model_output(model_output, sample=sample, sigma=sigma) + + if model_outputs is not None and timestep_list is not None: + self.model_outputs = model_outputs[:-1] + self.timestep_list = timestep_list[:-1] + + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + sigma_before=sigma_before, + sigma=sigma, + ) + + if model_outputs is not None and timestep_list is not None: + model_outputs[-1] = model_output_convert + self.model_outputs = model_outputs[1:] + self.timestep_list = timestep_list[1:] + else: + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + sigma=sigma, + sigma_next=sigma_next, + ) + + if cus_lower_order_num is None: + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + if cus_step_index is None: + self._step_index += 1 + + if not return_dict: + return (prev_sample, model_outputs, self.last_sample, self.this_order) + + return HeliosSchedulerOutput( + prev_sample=prev_sample, + model_outputs=model_outputs, + last_sample=self.last_sample, + this_order=self.this_order, + ) + + # ---------------------------------- Merge ---------------------------------- + def step( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor = None, + sample: torch.FloatTensor = None, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> HeliosSchedulerOutput | tuple: + if self.config.scheduler_type == "euler": + return self.step_euler( + model_output=model_output, + timestep=timestep, + sample=sample, + generator=generator, + return_dict=return_dict, + ) + elif self.config.scheduler_type == "unipc": + return self.step_unipc( + model_output=model_output, + timestep=timestep, + sample=sample, + return_dict=return_dict, + ) + else: + raise NotImplementedError + + def reset_scheduler_history(self): + self.model_outputs = [None] * self.config.solver_order + self.timestep_list = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.disable_corrector = self.config.disable_corrector + self.solver_p = self.config.solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py new file mode 100644 index 000000000000..1f4afa0e3128 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -0,0 +1,331 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..schedulers.scheduling_utils import SchedulerMixin +from ..utils import BaseOutput + + +@dataclass +class HeliosDMDSchedulerOutput(BaseOutput): + prev_sample: torch.FloatTensor + model_outputs: torch.FloatTensor | None = None + last_sample: torch.FloatTensor | None = None + this_order: int | None = None + + +class HeliosDMDScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, # Following Stable diffusion 3, + stages: int = 3, + stage_range: list = [0, 1 / 3, 2 / 3, 1], + gamma: float = 1 / 3, + prediction_type: str = "flow_prediction", + use_flow_sigmas: bool = True, + use_dynamic_shifting: bool = False, + time_shift_type: Literal["exponential", "linear"] = "linear", + ): + self.timestep_ratios = {} # The timestep ratio for each stage + self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage) + self.sigmas_per_stage = {} # always uniform [1000, 0] + self.start_sigmas = {} # for start point / upsample renoise + self.end_sigmas = {} # for end point + self.ori_start_sigmas = {} + + # self.init_sigmas() + self.init_sigmas_for_each_stage() + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + self.gamma = gamma + + self.last_sample = None + self._step_index = None + self._begin_index = None + + def init_sigmas(self): + """ + initialize the global timesteps and sigmas + """ + num_train_timesteps = self.config.num_train_timesteps + shift = self.config.shift + + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1) + sigmas = 1.0 - alphas + sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy() + sigmas = torch.from_numpy(sigmas) + timesteps = (sigmas * num_train_timesteps).clone() + + self._step_index = None + self._begin_index = None + self.timesteps = timesteps + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def init_sigmas_for_each_stage(self): + """ + Init the timesteps for each stage + """ + self.init_sigmas() + + stage_distance = [] + stages = self.config.stages + training_steps = self.config.num_train_timesteps + stage_range = self.config.stage_range + + # Init the start and end point of each stage + for i_s in range(stages): + # To decide the start and ends point + start_indice = int(stage_range[i_s] * training_steps) + start_indice = max(start_indice, 0) + end_indice = int(stage_range[i_s + 1] * training_steps) + end_indice = min(end_indice, training_steps) + start_sigma = self.sigmas[start_indice].item() + end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0 + self.ori_start_sigmas[i_s] = start_sigma + + if i_s != 0: + ori_sigma = 1 - start_sigma + gamma = self.config.gamma + corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma + # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma + start_sigma = 1 - corrected_sigma + + stage_distance.append(start_sigma - end_sigma) + self.start_sigmas[i_s] = start_sigma + self.end_sigmas[i_s] = end_sigma + + # Determine the ratio of each stage according to flow length + tot_distance = sum(stage_distance) + for i_s in range(stages): + if i_s == 0: + start_ratio = 0.0 + else: + start_ratio = sum(stage_distance[:i_s]) / tot_distance + if i_s == stages - 1: + end_ratio = 0.9999999999999999 + else: + end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance + + self.timestep_ratios[i_s] = (start_ratio, end_ratio) + + # Determine the timesteps and sigmas for each stage + for i_s in range(stages): + timestep_ratio = self.timestep_ratios[i_s] + # timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)] + timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999) + timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)] + timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1) + self.timesteps_per_stage[i_s] = ( + timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1]) + ) + stage_sigmas = np.linspace(0.999, 0, training_steps + 1) + self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1]) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int, + stage_index: int | None = None, + device: str | torch.device = None, + sigmas: bool | None = None, + mu: bool | None = None, + is_amplify_first_chunk: bool = False, + ): + """ + Setting the timesteps and sigmas for each stage + """ + if is_amplify_first_chunk: + num_inference_steps = num_inference_steps * 2 + 1 + else: + num_inference_steps = num_inference_steps + 1 + + self.num_inference_steps = num_inference_steps + self.init_sigmas() + + if self.config.stages == 1: + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype( + np.float32 + ) + if self.config.shift != 1.0: + assert not self.config.use_dynamic_shifting + sigmas = self.time_shift(self.config.shift, 1.0, sigmas) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = torch.from_numpy(sigmas) + else: + stage_timesteps = self.timesteps_per_stage[stage_index] + timesteps = np.linspace( + stage_timesteps[0].item(), + stage_timesteps[-1].item(), + num_inference_steps, + ) + + stage_sigmas = self.sigmas_per_stage[stage_index] + ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps) + sigmas = torch.from_numpy(ratios) + + self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device) + + self._step_index = None + self.reset_scheduler_history() + + self.timesteps = self.timesteps[:-1] + self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) + + if self.config.use_dynamic_shifting: + assert self.config.shift == 1.0 + self.sigmas = self.time_shift(mu, 1.0, self.sigmas) + if self.config.stages == 1: + self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps + else: + self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * ( + self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min() + ) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + """ + Apply time shifting to the sigmas. + + Args: + mu (`float`): + The mu parameter for the time shift. + sigma (`float`): + The sigma parameter for the time shift. + t (`torch.Tensor`): + The input timesteps. + + Returns: + `torch.Tensor`: + The time-shifted timesteps. + """ + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + + # ---------------------------------- For DMD ---------------------------------- + def add_noise(self, original_samples, noise, timestep, sigmas, timesteps): + sigmas = sigmas.to(noise.device) + timesteps = timesteps.to(noise.device) + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps): + # use higher precision for calculations + original_dtype = flow_pred.dtype + device = flow_pred.device + flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps)) + + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + return x0_pred.to(original_dtype) + + def step( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor = None, + sample: torch.FloatTensor = None, + generator: torch.Generator | None = None, + return_dict: bool = True, + cur_sampling_step: int = 0, + dmd_noisy_tensor: torch.FloatTensor | None = None, + dmd_sigmas: torch.FloatTensor | None = None, + dmd_timesteps: torch.FloatTensor | None = None, + all_timesteps: torch.FloatTensor | None = None, + ) -> HeliosDMDSchedulerOutput | tuple: + pred_image_or_video = self.convert_flow_pred_to_x0( + flow_pred=model_output, + xt=sample, + timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + if cur_sampling_step < len(all_timesteps) - 1: + prev_sample = self.add_noise( + pred_image_or_video, + dmd_noisy_tensor, + torch.full( + (model_output.shape[0],), + all_timesteps[cur_sampling_step + 1], + dtype=torch.long, + device=model_output.device, + ), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + else: + prev_sample = pred_image_or_video + + if not return_dict: + return (prev_sample,) + + return HeliosDMDSchedulerOutput(prev_sample=prev_sample) + + def reset_scheduler_history(self): + self._step_index = None + self._begin_index = None + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 4e402921aa5f..3a4aecd24f90 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1031,6 +1031,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HeliosTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HiDreamImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -2743,6 +2758,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HeliosDMDScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HeliosScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HeunDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8758c549ca77..b86b5d2c6f4d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1352,6 +1352,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HeliosPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HeliosPyramidPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HiDreamImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/lora/test_lora_layers_helios.py b/tests/lora/test_lora_layers_helios.py new file mode 100644 index 000000000000..fbcc3b808eee --- /dev/null +++ b/tests/lora/test_lora_layers_helios.py @@ -0,0 +1,120 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, HeliosPipeline, HeliosTransformer3DModel + +from ..testing_utils import floats_tensor, require_peft_backend, skip_mps + + +sys.path.append(".") + +from .utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +@skip_mps +class HeliosLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = HeliosPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = {} + + transformer_kwargs = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 16, + "out_channels": 16, + "text_dim": 32, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_dim": (4, 4, 4), + "has_multi_term_memory_patch": True, + "guidance_cross_attn": True, + "zero_history_timestep": True, + "is_amplify_history": False, + } + transformer_cls = HeliosTransformer3DModel + vae_kwargs = { + "base_dim": 3, + "z_dim": 16, + "dim_mult": [1, 1, 1, 1], + "num_res_blocks": 1, + "temperal_downsample": [False, True, True], + } + vae_cls = AutoencoderKLWan + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + text_encoder_target_modules = ["q", "k", "v", "o"] + + supports_text_encoder_loras = False + + @property + def output_shape(self): + return (1, 33, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 9 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 + sizes = (4, 4) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "", + "num_frames": num_frames, + "num_inference_steps": 1, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + @unittest.skip("Not supported in Helios.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Helios.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Helios.") + def test_modify_padding_mode(self): + pass diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py new file mode 100644 index 000000000000..c365c258e596 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -0,0 +1,168 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from diffusers import HeliosTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class HeliosTransformer3DTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return HeliosTransformer3DModel + + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-helios-base-transformer" + + @property + def output_shape(self) -> tuple[int, ...]: + return (4, 2, 16, 16) + + @property + def input_shape(self) -> tuple[int, ...]: + return (4, 2, 16, 16) + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]: + return { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_dim": (4, 4, 4), + "has_multi_term_memory_patch": True, + "guidance_cross_attn": True, + "zero_history_timestep": True, + "is_amplify_history": False, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = randn_tensor( + (batch_size, num_channels, num_frames, height, width), + generator=self.generator, + device=torch_device, + ) + timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, text_encoder_embedding_dim), + generator=self.generator, + device=torch_device, + ) + indices_hidden_states = torch.ones((batch_size, num_frames)).to(torch_device) + indices_latents_history_short = torch.ones((batch_size, num_frames - 1)).to(torch_device) + indices_latents_history_mid = torch.ones((batch_size, num_frames - 1)).to(torch_device) + indices_latents_history_long = torch.ones((batch_size, (num_frames - 1) * 4)).to(torch_device) + latents_history_short = randn_tensor( + (batch_size, num_channels, num_frames - 1, height, width), + generator=self.generator, + device=torch_device, + ) + latents_history_mid = randn_tensor( + (batch_size, num_channels, num_frames - 1, height, width), + generator=self.generator, + device=torch_device, + ) + latents_history_long = randn_tensor( + (batch_size, num_channels, (num_frames - 1) * 4, height, width), + generator=self.generator, + device=torch_device, + ) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "indices_hidden_states": indices_hidden_states, + "indices_latents_history_short": indices_latents_history_short, + "indices_latents_history_mid": indices_latents_history_mid, + "indices_latents_history_long": indices_latents_history_long, + "latents_history_short": latents_history_short, + "latents_history_mid": latents_history_mid, + "latents_history_long": latents_history_long, + } + + +class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin): + """Core model tests for Helios Transformer 3D.""" + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # Skip: fp16/bf16 require very high atol to pass, providing little signal. + # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. + pytest.skip("Tolerance requirements too high for meaningful test") + + +class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin): + """Memory optimization tests for Helios Transformer 3D.""" + + +class TestHeliosTransformer3DTraining(HeliosTransformer3DTesterConfig, TrainingTesterMixin): + """Training tests for Helios Transformer 3D.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HeliosTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestHeliosTransformer3DAttention(HeliosTransformer3DTesterConfig, AttentionTesterMixin): + """Attention processor tests for Helios Transformer 3D.""" + + +class TestHeliosTransformer3DCompile(HeliosTransformer3DTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for Helios Transformer 3D.""" + + @pytest.mark.xfail( + reason="Helios DiT does not compile when deterministic algorithms are used due to https://github.com/pytorch/pytorch/issues/170079" + ) + def test_torch_compile_recompilation_and_graph_break(self): + super().test_torch_compile_recompilation_and_graph_break() diff --git a/tests/pipelines/helios/__init__.py b/tests/pipelines/helios/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/helios/test_helios.py b/tests/pipelines/helios/test_helios.py new file mode 100644 index 000000000000..b8ee99085036 --- /dev/null +++ b/tests/pipelines/helios/test_helios.py @@ -0,0 +1,172 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch +from transformers import AutoConfig, AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, HeliosPipeline, HeliosScheduler, HeliosTransformer3DModel + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HeliosPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = HeliosScheduler(stage_range=[0, 1], stages=1, use_dynamic_shifting=True) + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel(config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = HeliosTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_dim=(4, 4, 4), + has_multi_term_memory_patch=True, + guidance_cross_attn=True, + zero_history_timestep=True, + is_amplify_history=False, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (33, 3, 16, 16)) + + # fmt: off + expected_slice = torch.tensor([0.4529, 0.4527, 0.4499, 0.4542, 0.4528, 0.4524, 0.4531, 0.4534, 0.5328, + 0.5340, 0.5012, 0.5135, 0.5322, 0.5203, 0.5144, 0.5101]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + + # Override to set a more lenient max diff threshold. + def test_save_load_float16(self): + super().test_save_load_float16(expected_max_diff=0.03) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Optional components not applicable for Helios") + def test_save_load_optional_components(self): + pass + + +@slow +@require_torch_accelerator +class HeliosPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + @unittest.skip("TODO: test needs to be implemented") + def test_helios(self): + pass From 21ea2b708a9a656b80179b4b4aaae9726e6f2e55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Wed, 4 Mar 2026 17:11:55 -0300 Subject: [PATCH 020/215] [Z-Image] Fix more `do_classifier_free_guidance` thresholds (#13212) fix --- src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py | 2 +- .../pipelines/z_image/pipeline_z_image_controlnet_inpaint.py | 2 +- src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py | 2 +- src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py | 2 +- src/diffusers/pipelines/z_image/pipeline_z_image_omni.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 288aee039ed8..fe09c6f073f9 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -365,7 +365,7 @@ def guidance_scale(self): @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1 + return self._guidance_scale > 0 @property def joint_attention_kwargs(self): diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py index d278da912b5a..0404b9dbabc1 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py @@ -372,7 +372,7 @@ def guidance_scale(self): @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1 + return self._guidance_scale > 0 @property def joint_attention_kwargs(self): diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py b/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py index 94b98a5f8580..ee57f51dd957 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py @@ -347,7 +347,7 @@ def guidance_scale(self): @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1 + return self._guidance_scale > 0 @property def joint_attention_kwargs(self): diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py index 2f842182edc6..e740a48e65ec 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py @@ -462,7 +462,7 @@ def guidance_scale(self): @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1 + return self._guidance_scale > 0 @property def joint_attention_kwargs(self): diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py index a0c3ec03ef80..6d04202162f9 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -339,7 +339,7 @@ def guidance_scale(self): @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1 + return self._guidance_scale > 0 @property def joint_attention_kwargs(self): From 9e40bafd3150cc39ee78e631cf21a97056685610 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 5 Mar 2026 08:24:20 +0530 Subject: [PATCH 021/215] [lora] fix zimage lora conversion to support for more lora. (#13209) fix zimage lora conversion to support for more lora. --- .../loaders/lora_conversion_utils.py | 94 +++++++++++++++---- 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 8b0f95b905e4..0895d5223e13 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2519,6 +2519,13 @@ def normalize_out_key(k: str) -> str: if has_default: state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()} + # Normalize ZImage-specific dot-separated module names to underscore form so they + # match the diffusers model parameter names (context_refiner, noise_refiner). + state_dict = { + k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v + for k, v in state_dict.items() + } + converted_state_dict = {} all_keys = list(state_dict.keys()) down_key = ".lora_down.weight" @@ -2529,19 +2536,18 @@ def normalize_out_key(k: str) -> str: has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys) has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys) - if has_non_diffusers_lora_id: - - def get_alpha_scales(down_weight, alpha_key): - rank = down_weight.shape[0] - alpha = state_dict.pop(alpha_key).item() - scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here - scale_down = scale - scale_up = 1.0 - while scale_down * 2 < scale_up: - scale_down *= 2 - scale_up /= 2 - return scale_down, scale_up + def get_alpha_scales(down_weight, alpha_key): + rank = down_weight.shape[0] + alpha = state_dict.pop(alpha_key).item() + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + return scale_down, scale_up + if has_non_diffusers_lora_id: for k in all_keys: if k.endswith(down_key): diffusers_down_key = k.replace(down_key, ".lora_A.weight") @@ -2554,13 +2560,69 @@ def get_alpha_scales(down_weight, alpha_key): converted_state_dict[diffusers_down_key] = down_weight * scale_down converted_state_dict[diffusers_up_key] = up_weight * scale_up - # Already in diffusers format (lora_A/lora_B), just pop + # Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop. elif has_diffusers_lora_id: for k in all_keys: - if a_key in k or b_key in k: - converted_state_dict[k] = state_dict.pop(k) - elif ".alpha" in k: + if k.endswith(a_key): + diffusers_up_key = k.replace(a_key, b_key) + alpha_key = k.replace(a_key, ".alpha") + + down_weight = state_dict.pop(k) + up_weight = state_dict.pop(diffusers_up_key) + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[k] = down_weight * scale_down + converted_state_dict[diffusers_up_key] = up_weight * scale_up + + # Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight". + # Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in + # lora weight names and also include redundant keys: + # - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv + # - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out + # - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight" + lora_dot_down_key = ".lora.down.weight" + lora_dot_up_key = ".lora.up.weight" + has_lora_dot_format = any(lora_dot_down_key in k for k in state_dict) + + if has_lora_dot_format: + dot_keys = list(state_dict.keys()) + for k in dot_keys: + if lora_dot_down_key not in k: + continue + if k not in state_dict: + continue # already popped by a prior iteration + + base = k[: -len(lora_dot_down_key)] + + # Skip combined "qkv" projection — individual to.q/k/v keys are also present. + if base.endswith(".qkv"): + state_dict.pop(k) + state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) + state_dict.pop(base + ".alpha", None) + continue + + # Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection. + if re.search(r"\.out$", base) and ".to_out" not in base: state_dict.pop(k) + state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) + continue + + # Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key. + norm_k = re.sub( + r"\.to\.([qkv])" + re.escape(lora_dot_down_key) + r"$", + r".to_\1" + lora_dot_down_key, + k, + ) + norm_base = norm_k[: -len(lora_dot_down_key)] + alpha_key = norm_base + ".alpha" + + diffusers_down = norm_k.replace(lora_dot_down_key, ".lora_A.weight") + diffusers_up = norm_k.replace(lora_dot_down_key, ".lora_B.weight") + + down_weight = state_dict.pop(k) + up_weight = state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key)) + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[diffusers_down] = down_weight * scale_down + converted_state_dict[diffusers_up] = up_weight * scale_up if len(state_dict) > 0: raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}") From a483a7784ce2aab3cc4fe07e9e22571c48bb90fa Mon Sep 17 00:00:00 2001 From: Christopher Date: Thu, 5 Mar 2026 05:35:53 +0100 Subject: [PATCH 022/215] adding lora support to z-image controlnet pipelines (#13200) adding lora to z-image controlnet pipelines --- .../pipelines/z_image/pipeline_z_image_controlnet.py | 4 ++-- .../pipelines/z_image/pipeline_z_image_controlnet_inpaint.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index fe09c6f073f9..1e49737bb5b0 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -19,7 +19,7 @@ from transformers import AutoTokenizer, PreTrainedModel from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets import ZImageControlNetModel from ...models.transformers import ZImageTransformer2DModel @@ -185,7 +185,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin): +class ZImageControlNetPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): model_cpu_offload_seq = "text_encoder->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py index 0404b9dbabc1..09f9b2395458 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py @@ -20,7 +20,7 @@ from transformers import AutoTokenizer, PreTrainedModel from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets import ZImageControlNetModel from ...models.transformers import ZImageTransformer2DModel @@ -185,7 +185,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin): +class ZImageControlNetInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): model_cpu_offload_seq = "text_encoder->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] From 88205353391bbc69644000a92d325c1dcd082fa2 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Thu, 5 Mar 2026 00:42:55 -0800 Subject: [PATCH 023/215] Add LTX2 Condition Pipeline (#13058) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * LTX2 condition pipeline initial commit * Fix pipeline import error * Implement LTX-2-style general image conditioning * Blend denoising output and clean latents in sample space instead of velocity space * make style and make quality * make fix-copies * Rename LTX2VideoCondition image to frames * Update LTX2ConditionPipeline example * Remove support for image and video in __call__ * Put latent_idx_from_index logic inline * Improve comment on using the conditioning mask in denoising loop * Apply suggestions from code review Co-authored-by: Álvaro Somoza * make fix-copies * Migrate to Python 3.9+ style type annotations without explicit typing imports * Forward kwargs from preprocess/postprocess_video to preprocess/postprocess resp. * Center crop LTX-2 conditions following original code * Duplicate video and audio position ids if using CFG * make style and make quality * Remove unused index_type arg to preprocess_conditions * Add # Copied from for _normalize_latents * Fix _normalize_latents # Copied from statement * Add LTX-2 condition pipeline docs * Remove TODOs * Support only unpacked latents (5D for video, 4D for audio) * Remove # Copied from for prepare_audio_latents --------- Co-authored-by: Sayak Paul Co-authored-by: Álvaro Somoza --- docs/source/en/api/pipelines/ltx2.md | 179 ++ src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 9 +- src/diffusers/pipelines/ltx2/__init__.py | 2 + .../pipelines/ltx2/pipeline_ltx2_condition.py | 1474 +++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + src/diffusers/video_processor.py | 17 +- 7 files changed, 1690 insertions(+), 8 deletions(-) create mode 100644 src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index c77efa09f594..85b0f9691891 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -193,6 +193,179 @@ encode_video( ) ``` +## Condition Pipeline Generation + +You can use `LTX2ConditionPipeline` to specify image and/or video conditions at arbitrary latent indices. For example, we can specify both a first-frame and last-frame condition to perform first-last-frame-to-video (FLF2V) generation: + +```py +import torch +from diffusers import LTX2ConditionPipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition +from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.utils import load_image + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "rootonchair/LTX-2-19b-distilled" + +pipe = LTX2ConditionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) +pipe.enable_sequential_cpu_offload(device=device) +pipe.vae.enable_tiling() + +prompt = ( + "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are " + "delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright " + "sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, " + "low-angle perspective." +) + +first_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png", +) +last_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png", +) +first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0) +last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0) +conditions = [first_cond, last_cond] + +frame_rate = 24.0 +video_latent, audio_latent = pipe( + conditions=conditions, + prompt=prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=8, + sigmas=DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + generator=generator, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + model_path, + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + width=width * 2, + height=height * 2, + num_inference_steps=3, + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + generator=generator, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_distilled_flf2v.mp4", +) +``` + +You can use both image and video conditions: + +```py +import torch +from diffusers import LTX2ConditionPipeline +from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.utils import load_image, load_video + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "rootonchair/LTX-2-19b-distilled" + +pipe = LTX2ConditionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) +pipe.enable_sequential_cpu_offload(device=device) +pipe.vae.enable_tiling() + +prompt = ( + "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is " + "divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features " + "dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered " + "clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, " + "with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The " + "landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the " + "solitude and beauty of a winter drive through a mountainous region." +) +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) + +cond_video = load_video( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" +) +cond_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" +) +video_cond = LTX2VideoCondition(frames=cond_video, index=0, strength=1.0) +image_cond = LTX2VideoCondition(frames=cond_image, index=8, strength=1.0) +conditions = [video_cond, image_cond] + +frame_rate = 24.0 +video, audio = pipe( + conditions=conditions, + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + guidance_scale=4.0, + generator=generator, + output_type="np", + return_dict=False, +) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_cond_video.mp4", +) +``` + +Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static. + ## LTX2Pipeline [[autodoc]] LTX2Pipeline @@ -205,6 +378,12 @@ encode_video( - all - __call__ +## LTX2ConditionPipeline + +[[autodoc]] LTX2ConditionPipeline + - all + - __call__ + ## LTX2LatentUpsamplePipeline [[autodoc]] LTX2LatentUpsamplePipeline diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1458164191df..ec0347750816 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -571,6 +571,7 @@ "LEditsPPPipelineStableDiffusionXL", "LongCatImageEditPipeline", "LongCatImagePipeline", + "LTX2ConditionPipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", "LTX2Pipeline", @@ -1318,6 +1319,7 @@ LEditsPPPipelineStableDiffusionXL, LongCatImageEditPipeline, LongCatImagePipeline, + LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 08cb28a6237a..8007035338b0 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -292,7 +292,12 @@ "LTXLatentUpsamplePipeline", "LTXI2VLongMultiPromptPipeline", ] - _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"] + _import_structure["ltx2"] = [ + "LTX2Pipeline", + "LTX2ConditionPipeline", + "LTX2ImageToVideoPipeline", + "LTX2LatentUpsamplePipeline", + ] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -731,7 +736,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, ) - from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline + from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 115e83e827a4..d6a408d5c546 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -25,6 +25,7 @@ _import_structure["connectors"] = ["LTX2TextConnectors"] _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + _import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] _import_structure["vocoder"] = ["LTX2Vocoder"] @@ -40,6 +41,7 @@ from .connectors import LTX2TextConnectors from .latent_upsampler import LTX2LatentUpsamplerModel from .pipeline_ltx2 import LTX2Pipeline + from .pipeline_ltx2_condition import LTX2ConditionPipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline from .vocoder import LTX2Vocoder diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py new file mode 100644 index 000000000000..4c451330f439 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -0,0 +1,1474 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ConditionPipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ConditionPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> first_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" + ... ) + >>> last_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" + ... ) + >>> first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0) + >>> last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0) + >>> conditions = [first_cond, last_cond] + >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static" + + >>> frame_rate = 24.0 + >>> video = pipe( + ... conditions=conditions, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +@dataclass +class LTX2VideoCondition: + """ + Defines a single frame-conditioning item for LTX-2 Video - a single frame or a sequence of frames. + + Attributes: + frames (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + The image (or video) to condition the video on. Accepts any type that can be handled by + VideoProcessor.preprocess_video. + index (`int`, defaults to `0`): + The index at which the image or video will conditionally affect the video generation. + strength (`float`, defaults to `1.0`): + The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. + """ + + frames: PIL.Image.Image | list[PIL.Image.Image] | np.ndarray | torch.Tensor + index: int = 0 + strength: float = 1.0 + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for video generation which allows image conditions to be inserted at arbitary parts of the video. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + latents=None, + audio_latents=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if latents is not None and latents.ndim != 5: + raise ValueError( + f"Only unpacked (5D) video latents of shape `[batch_size, latent_channels, latent_frames," + f" latent_height, latent_width] are supported, but got {latents.ndim} dims. If you have packed (3D)" + f" latents, please unpack them (e.g. using the `_unpack_latents` method)." + ) + if audio_latents is not None and audio_latents.ndim != 4: + raise ValueError( + f"Only unpacked (4D) audio latents of shape `[batch_size, num_channels, audio_length, mel_bins] are" + f" supported, but got {latents.ndim} dims. If you have packed (3D) latents, please unpack them (e.g." + f" using the `_unpack_audio_latents` method)." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + # Copied from diffusers.pipelines.ltx.pipeline_ltx_condition.LTXConditionPipeline.trim_conditioning_sequence + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int) -> int: + """ + Trim a conditioning sequence to the allowed number of frames. + + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + def preprocess_conditions( + self, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + device: torch.device | None = None, + ) -> tuple[list[torch.Tensor], list[float], list[int]]: + """ + Preprocesses the condition images/videos to torch tensors. + + Args: + conditions (`LTX2VideoCondition` or `List[LTX2VideoCondition]`, *optional*, defaults to `None`): + A list of image/video condition instances. + height (`int`, *optional*, defaults to `512`): + The desired height in pixels. + width (`int`, *optional*, defaults to `768`): + The desired width in pixels. + num_frames (`int`, *optional*, defaults to `121`): + The desired number of frames in the generated video. + device (`torch.device`, *optional*, defaults to `None`): + The device on which to put the preprocessed image/video tensors. + + Returns: + `Tuple[List[torch.Tensor], List[float], List[int]]`: + Returns a 3-tuple of lists of length `len(conditions)` as follows: + 1. The first list is a list of preprocessed video tensors of shape [batch_size=1, num_channels, + num_frames, height, width]. + 2. The second list is a list of conditioning strengths. + 3. The third list is a list of indices in latent space to insert the corresponding condition. + """ + conditioning_frames, conditioning_strengths, conditioning_indices = [], [], [] + + if conditions is None: + conditions = [] + if isinstance(conditions, LTX2VideoCondition): + conditions = [conditions] + + frame_scale_factor = self.vae_temporal_compression_ratio + latent_num_frames = (num_frames - 1) // frame_scale_factor + 1 + for i, condition in enumerate(conditions): + if isinstance(condition.frames, PIL.Image.Image): + # Single image, convert to List[PIL.Image.Image] + video_like_cond = [condition.frames] + elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3: + # Image-like ndarray of shape (H, W, C), insert frame dim in first axis + video_like_cond = np.expand_dims(condition.frames, axis=0) + elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3: + # Image-like tensor of shape (C, H, W), insert frame dim in first dim + video_like_cond = condition.frames.unsqueeze(0) + else: + # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of + # shape (F, H, W, C) and (F, C, H, W), respectively. + video_like_cond = condition.frames + condition_pixels = self.video_processor.preprocess_video( + video_like_cond, height, width, resize_mode="crop" + ) + + # Interpret the index as a latent index, following the original LTX-2 code. + latent_start_idx = condition.index + # Support negative latent indices (e.g. -1 for the last latent index) + if latent_start_idx < 0: + # latent_start_idx will be positive because latent_num_frames is positive + latent_start_idx = latent_start_idx % latent_num_frames + if latent_start_idx >= latent_num_frames: + logger.warning( + f"The starting latent index {latent_start_idx} of condition {i} is too big for the specified number" + f" of latent frames {latent_num_frames}. This condition will be skipped." + ) + continue + + cond_num_frames = condition_pixels.size(2) + start_idx = max((latent_start_idx - 1) * frame_scale_factor + 1, 0) + truncated_cond_frames = self.trim_conditioning_sequence(start_idx, cond_num_frames, num_frames) + condition_pixels = condition_pixels[:, :, :truncated_cond_frames] + + conditioning_frames.append(condition_pixels.to(dtype=self.vae.dtype, device=device)) + conditioning_strengths.append(condition.strength) + conditioning_indices.append(latent_start_idx) + + return conditioning_frames, conditioning_strengths, conditioning_indices + + def apply_visual_conditioning( + self, + latents: torch.Tensor, + conditioning_mask: torch.Tensor, + condition_latents: list[torch.Tensor], + condition_strengths: list[float], + condition_indices: list[int], + latent_height: int, + latent_width: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies visual conditioning frames to an initial latent. + + Args: + latents (`torch.Tensor`): + Initial packed (patchified) latents of shape [batch_size, patch_seq_len, hidden_dim]. + conditioning_mask (`torch.Tensor`, *optional*): + Initial packed (patchified) conditioning mask of shape [batch_size, patch_seq_len, 1] with values in + [0, 1] where 0 means that the denoising model output will be fully used and 1 means that the condition + will be fully used (with intermediate values specifying a blend of the denoised and latent values). + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: + Returns a 3-tuple of tensors where: + 1. The first element is the packed video latents (with unchanged shape [batch_size, patch_seq_len, + hidden_dim]) with the conditions applied + 2. The second element is the packed conditioning mask with conditioning strengths applied + 3. The third element holds the clean conditioning latents. + """ + # Latents-like tensor which holds the clean conditioning latents + clean_latents = torch.zeros_like(latents) + for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices): + num_cond_tokens = cond.size(1) + start_token_idx = latent_idx * latent_height * latent_width + end_token_idx = start_token_idx + num_cond_tokens + + # Overwrite the portion of latents starting with start_token_idx with the condition + latents[:, start_token_idx:end_token_idx] = cond + conditioning_mask[:, start_token_idx:end_token_idx] = strength + clean_latents[:, start_token_idx:end_token_idx] = cond + + return latents, conditioning_mask, clean_latents + + def prepare_latents( + self, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 1.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) + mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) + + if latents is not None: + # Latents are expected to be unpacked (5D) with shape [B, F, C, H, W] + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + else: + # NOTE: we set the initial latents to zeros rather a sample from the standard Gaussian prior because we + # will sample from the prior later once we have calculated the conditioning mask + latents = torch.zeros(shape, device=device, dtype=dtype) + + conditioning_mask = latents.new_zeros(mask_shape) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) # [B, seq_len, 1] + + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape[:2]: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape[:2] + (num_channels_latents,)}." + ) + + if isinstance(generator, list): + logger.warning( + f"{self.__class__.__name__} does not support using a list of generators. The first generator in the" + f" list will be used for all (pseudo-)random operations." + ) + generator = generator[0] + + condition_frames, condition_strengths, condition_indices = self.preprocess_conditions( + conditions, height, width, num_frames, device=device + ) + condition_latents = [] + for condition_tensor in condition_frames: + condition_latent = retrieve_latents( + self.vae.encode(condition_tensor), generator=generator, sample_mode="argmax" + ) + condition_latent = self._normalize_latents( + condition_latent, self.vae.latents_mean, self.vae.latents_std + ).to(device=device, dtype=dtype) + condition_latent = self._pack_latents( + condition_latent, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + condition_latents.append(condition_latent) + + # NOTE: following the I2V pipeline, we return a conditioning mask. The original LTX 2 code uses a denoising + # mask, which is the inverse of the conditioning mask (`denoise_mask = 1 - conditioning_mask`) + latents, conditioning_mask, clean_latents = self.apply_visual_conditioning( + latents, + conditioning_mask, + condition_latents, + condition_strengths, + condition_indices, + latent_height=latent_height, + latent_width=latent_width, + ) + + # Sample from the standard Gaussian prior (or an intermediate Gaussian distribution if noise_scale < 1.0). + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + scaled_mask = (1.0 - conditioning_mask) * noise_scale + # Add noise to the `latents` so that it is at the noise level specified by `noise_scale`. + latents = noise * scaled_mask + latents * (1 - scaled_mask) + + return latents, conditioning_mask, clean_latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + # latents expected to be unpacked (4D) with shape [B, C, L, M] + latents = self._pack_audio_latents(latents) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[float] | None = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + conditions (`List[LTXVideoCondition], *optional*`): + The list of frame-conditioning items for the video generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `None`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the + sigma schedule. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + latents=latents, + audio_latents=audio_latents, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if conditions is not None and not isinstance(conditions, list): + conditions = [conditions] + + # Infer noise scale: first (largest) sigma value if using custom sigmas, else 1.0 + if noise_scale is None: + noise_scale = sigmas[0] if sigmas is not None else 1.0 + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask, clean_latents = self.prepare_latents( + conditions, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_num_frames, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + # Duplicate the positional ids as well if using CFG + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1)) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG + bsz = noise_pred_video.size(0) + sigma = self.scheduler.sigmas[i] + # Convert the noise_pred_video velocity model prediction into a sample (x0) prediction + denoised_sample = latents - noise_pred_video * sigma + # Apply the (packed) conditioning mask to the denoised (x0) sample and clean conditioning. The + # conditioning mask contains conditioning strengths from 0 (always use denoised sample) to 1 (always + # use conditions), with intermediate values specifying how strongly to follow the conditions. + denoised_sample_cond = ( + denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] + ).to(noise_pred_video.dtype) + # Convert the denoised (x0) sample back to a velocity for the scheduler + denoised_latents_cond = ((latents - denoised_sample_cond) / sigma).to(noise_pred_video.dtype) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(denoised_latents_cond, t, latents, return_dict=False)[0] + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b86b5d2c6f4d..157b04ef266a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2147,6 +2147,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTX2ConditionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTX2ImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 34427686394d..0c51b4b38f23 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -25,9 +25,9 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" - def preprocess_video(self, video, height: int | None = None, width: int | None = None) -> torch.Tensor: + def preprocess_video(self, video, height: int | None = None, width: int | None = None, **kwargs) -> torch.Tensor: r""" - Preprocesses input video(s). + Preprocesses input video(s). Keyword arguments will be forwarded to `VaeImageProcessor.preprocess`. Args: video (`list[PIL.Image]`, `list[list[PIL.Image]]`, `torch.Tensor`, `np.array`, `list[torch.Tensor]`, `list[np.array]`): @@ -49,6 +49,10 @@ def preprocess_video(self, video, height: int | None = None, width: int | None = width (`int`, *optional*`, defaults to `None`): The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get the default width. + + Returns: + `torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`: + A 5D tensor holding the batched channels-first video(s). """ if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: warnings.warn( @@ -79,7 +83,7 @@ def preprocess_video(self, video, height: int | None = None, width: int | None = "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) - video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + video = torch.stack([self.preprocess(img, height=height, width=width, **kwargs) for img in video], dim=0) # move the number of channels before the number of frames. video = video.permute(0, 2, 1, 3, 4) @@ -87,10 +91,11 @@ def preprocess_video(self, video, height: int | None = None, width: int | None = return video def postprocess_video( - self, video: torch.Tensor, output_type: str = "np" + self, video: torch.Tensor, output_type: str = "np", **kwargs ) -> np.ndarray | torch.Tensor | list[PIL.Image.Image]: r""" - Converts a video tensor to a list of frames for export. + Converts a video tensor to a list of frames for export. Keyword arguments will be forwarded to + `VaeImageProcessor.postprocess`. Args: video (`torch.Tensor`): The video as a tensor. @@ -100,7 +105,7 @@ def postprocess_video( outputs = [] for batch_idx in range(batch_size): batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = self.postprocess(batch_vid, output_type) + batch_output = self.postprocess(batch_vid, output_type, **kwargs) outputs.append(batch_output) if output_type == "np": From 607d147d4a04ff23ed7d124c4276b83b3228bc3d Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 5 Mar 2026 21:28:13 +0800 Subject: [PATCH 024/215] Fix Helios paper link in documentation (#13213) * Fix Helios paper link in documentation Updated the link to the Helios paper for accuracy. * Fix reference link in HeliosTransformer3DModel documentation Updated the reference link for the Helios Transformer model paper. * Update Helios research paper link in documentation * Update Helios research paper link in documentation --- docs/source/en/api/models/helios_transformer3d.md | 2 +- docs/source/en/api/pipelines/helios.md | 2 +- docs/source/en/using-diffusers/helios.md | 2 +- docs/source/zh/using-diffusers/helios.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/models/helios_transformer3d.md b/docs/source/en/api/models/helios_transformer3d.md index 5aa2826c32ec..302b91d6c829 100644 --- a/docs/source/en/api/models/helios_transformer3d.md +++ b/docs/source/en/api/models/helios_transformer3d.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> # HeliosTransformer3DModel -A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc. +A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) by Peking University & ByteDance & etc. The model can be loaded with the following code snippet. diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 81559b24c071..54a08240001c 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -22,7 +22,7 @@ # Helios -[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan. +[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan. * We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality. We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page). diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md index 8106f1c568f8..ced7c6298f23 100644 --- a/docs/source/en/using-diffusers/helios.md +++ b/docs/source/en/using-diffusers/helios.md @@ -130,4 +130,4 @@ pipe.to("cuda") Learn more about Helios with the following resources. - Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features. -- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details. +- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) for more details. diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md index 5c4faed2ca2a..5f7f067eb781 100644 --- a/docs/source/zh/using-diffusers/helios.md +++ b/docs/source/zh/using-diffusers/helios.md @@ -131,4 +131,4 @@ pipe.to("cuda") 通过以下资源了解有关 Helios 的更多信息: - [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能; -- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。 +- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379)。 From e83d98d50059d84d5e73b1ada85ef08daa85e90e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 5 Mar 2026 19:23:07 +0530 Subject: [PATCH 025/215] [attention backends] change to updated repo and version. (#13161) * change to updated repo and version. * fix version and force updated kernels. * propagate version. --- src/diffusers/models/attention_dispatch.py | 23 +++++++++++++++------- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 16 +++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 1f3d7e072ab1..5b1f831ed060 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -38,6 +38,7 @@ is_flash_attn_available, is_flash_attn_version, is_kernels_available, + is_kernels_version, is_sageattention_available, is_sageattention_version, is_torch_npu_available, @@ -318,6 +319,7 @@ class _HubKernelConfig: repo_id: str function_attr: str revision: str | None = None + version: int | None = None kernel_fn: Callable | None = None wrapped_forward_attr: str | None = None wrapped_backward_attr: str | None = None @@ -327,31 +329,34 @@ class _HubKernelConfig: # Registry for hub-based attention kernels _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = { - # TODO: temporary revision for now. Remove when merged upstream into `main`. AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", - revision="fake-ops-return-probs", wrapped_forward_attr="flash_attn_interface._flash_attn_forward", wrapped_backward_attr="flash_attn_interface._flash_attn_backward", + version=1, ), AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn3", function_attr="flash_attn_varlen_func", - # revision="fake-ops-return-probs", + version=1, ), AttentionBackendName.FLASH_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", - revision=None, wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", + version=1, ), AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( - repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None + repo_id="kernels-community/flash-attn2", + function_attr="flash_attn_varlen_func", + version=1, ), AttentionBackendName.SAGE_HUB: _HubKernelConfig( - repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None + repo_id="kernels-community/sage-attention", + function_attr="sageattn", + version=1, ), } @@ -521,6 +526,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None raise RuntimeError( f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) + if not is_kernels_version(">=", "0.12"): + raise RuntimeError( + f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`." + ) elif backend == AttentionBackendName.AITER: if not _CAN_USE_AITER_ATTN: @@ -694,7 +703,7 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: try: from kernels import get_kernel - kernel_module = get_kernel(config.repo_id, revision=config.revision) + kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version) if needs_kernel: config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index dd50405c74b2..23d7ac7c6c2d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -86,6 +86,7 @@ is_inflect_available, is_invisible_watermark_available, is_kernels_available, + is_kernels_version, is_kornia_available, is_librosa_available, is_matplotlib_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 8fb481946ebf..551fa358a28d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str): return compare_versions(parse(_transformers_version), operation, version) +@cache +def is_kernels_version(operation: str, version: str): + """ + Compares the current Kernels version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _kernels_available: + return False + return compare_versions(parse(_kernels_version), operation, version) + + @cache def is_hf_hub_version(operation: str, version: str): """ From 6eab5de5fd2920564a7df9bf2afa1ca05743929a Mon Sep 17 00:00:00 2001 From: Ando <2974376016@qq.com> Date: Thu, 5 Mar 2026 22:47:14 +0800 Subject: [PATCH 026/215] feat: implement rae autoencoder. (#13046) * feat: implement three RAE encoders(dinov2, siglip2, mae) * feat: finish first version of autoencoder_rae * fix formatting * make fix-copies * initial doc * fix latent_mean / latent_var init types to accept config-friendly inputs * use mean and std convention * cleanup * add rae to diffusers script * use imports * use attention * remove unneeded class * example traiing script * input and ground truth sizes have to be the same * fix argument * move loss to training script * cleanup * simplify mixins * fix training script * fix entrypoint for instantiating the AutoencoderRAE * added encoder_image_size config * undo last change * fixes from pretrained weights * cleanups * address reviews * fix train script to use pretrained * fix conversion script review * latebt normalization buffers are now always registered with no-op defaults * Update examples/research_projects/autoencoder_rae/README.md Co-authored-by: Sayak Paul * Update src/diffusers/models/autoencoders/autoencoder_rae.py Co-authored-by: Sayak Paul * use image url * Encoder is frozen * fix slow test * remove config * use ModelTesterMixin and AutoencoderTesterMixin * make quality * strip final layernorm when converting * _strip_final_layernorm_affine for training script * fix test * add dispatch forward and update conversion script * update training script * error out as soon as possible and add comments * Update src/diffusers/models/autoencoders/autoencoder_rae.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * use buffer * inline * Update src/diffusers/models/autoencoders/autoencoder_rae.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * remove optional * _noising takes a generator * Update src/diffusers/models/autoencoders/autoencoder_rae.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix api * rename * remove unittest * use randn_tensor * fix device map on multigpu * check if the key is missing in the original state dict and only then add to the allow_missing set * remove initialize_weights --------- Co-authored-by: wangyuqi Co-authored-by: Kashif Rasul Co-authored-by: Sayak Paul Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/models/autoencoder_rae.md | 89 +++ .../autoencoder_rae/README.md | 66 ++ .../autoencoder_rae/train_autoencoder_rae.py | 405 ++++++++++ scripts/convert_rae_to_diffusers.py | 406 +++++++++++ src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../models/autoencoders/autoencoder_rae.py | 689 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + .../test_models_autoencoder_rae.py | 300 ++++++++ 11 files changed, 1977 insertions(+) create mode 100644 docs/source/en/api/models/autoencoder_rae.md create mode 100644 examples/research_projects/autoencoder_rae/README.md create mode 100644 examples/research_projects/autoencoder_rae/train_autoencoder_rae.py create mode 100644 scripts/convert_rae_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_rae.py create mode 100644 tests/models/autoencoders/test_models_autoencoder_rae.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ea06f35a0343..e0b7af4898b2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -460,6 +460,8 @@ title: AutoencoderKLQwenImage - local: api/models/autoencoder_kl_wan title: AutoencoderKLWan + - local: api/models/autoencoder_rae + title: AutoencoderRAE - local: api/models/consistency_decoder_vae title: ConsistencyDecoderVAE - local: api/models/autoencoder_oobleck diff --git a/docs/source/en/api/models/autoencoder_rae.md b/docs/source/en/api/models/autoencoder_rae.md new file mode 100644 index 000000000000..a8c00dd4fde2 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_rae.md @@ -0,0 +1,89 @@ + + +# AutoencoderRAE + +The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx. + +RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation). + +The following RAE models are released and supported in Diffusers: + +| Model | Encoder | Latent shape (224px input) | +|:------|:--------|:---------------------------| +| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 | +| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 | +| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 | +| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 | +| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 | +| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 | + +## Loading a pretrained model + +```python +from diffusers import AutoencoderRAE + +model = AutoencoderRAE.from_pretrained( + "nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08" +).to("cuda").eval() +``` + +## Encoding and decoding a real image + +```python +import torch +from diffusers import AutoencoderRAE +from diffusers.utils import load_image +from torchvision.transforms.functional import to_tensor, to_pil_image + +model = AutoencoderRAE.from_pretrained( + "nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08" +).to("cuda").eval() + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png") +image = image.convert("RGB").resize((224, 224)) +x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1] + +with torch.no_grad(): + latents = model.encode(x).latent # (1, 768, 16, 16) + recon = model.decode(latents).sample # (1, 3, 256, 256) + +recon_image = to_pil_image(recon[0].clamp(0, 1).cpu()) +recon_image.save("recon.png") +``` + +## Latent normalization + +Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively. + +```python +model = AutoencoderRAE.from_pretrained( + "nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08" +).to("cuda").eval() + +# Latent normalization is handled automatically inside encode/decode +# when the checkpoint config includes latents_mean/latents_std. +with torch.no_grad(): + latents = model.encode(x).latent # normalized latents + recon = model.decode(latents).sample +``` + +## AutoencoderRAE + +[[autodoc]] AutoencoderRAE + - encode + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/examples/research_projects/autoencoder_rae/README.md b/examples/research_projects/autoencoder_rae/README.md new file mode 100644 index 000000000000..9ade979090d9 --- /dev/null +++ b/examples/research_projects/autoencoder_rae/README.md @@ -0,0 +1,66 @@ +# Training AutoencoderRAE + +This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen. + +It follows the same high-level training recipe as the official RAE stage-1 setup: +- frozen encoder +- train decoder +- pixel reconstruction loss +- optional encoder feature consistency loss + +## Quickstart + +### Resume or finetune from pretrained weights + +```bash +accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \ + --pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \ + --train_data_dir /path/to/imagenet_like_folder \ + --output_dir /tmp/autoencoder-rae \ + --resolution 256 \ + --train_batch_size 8 \ + --learning_rate 1e-4 \ + --num_train_epochs 10 \ + --report_to wandb \ + --reconstruction_loss_type l1 \ + --use_encoder_loss \ + --encoder_loss_weight 0.1 +``` + +### Train from scratch with a pretrained encoder +The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base. + +```bash +accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \ + --train_data_dir /path/to/imagenet_like_folder \ + --output_dir /tmp/autoencoder-rae \ + --resolution 256 \ + --encoder_type dinov2 \ + --encoder_name_or_path facebook/dinov2-with-registers-base \ + --encoder_input_size 224 \ + --patch_size 16 \ + --image_size 256 \ + --decoder_hidden_size 1152 \ + --decoder_num_hidden_layers 28 \ + --decoder_num_attention_heads 16 \ + --decoder_intermediate_size 4096 \ + --train_batch_size 8 \ + --learning_rate 1e-4 \ + --num_train_epochs 10 \ + --report_to wandb \ + --reconstruction_loss_type l1 \ + --use_encoder_loss \ + --encoder_loss_weight 0.1 +``` + +Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`. + +Dataset format is expected to be `ImageFolder`-compatible: + +```text +train_data_dir/ + class_a/ + img_0001.jpg + class_b/ + img_0002.jpg +``` diff --git a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py new file mode 100644 index 000000000000..ea02c674bc0c --- /dev/null +++ b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +from pathlib import Path + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import ImageFolder +from tqdm.auto import tqdm + +from diffusers import AutoencoderRAE +from diffusers.optimization import get_scheduler + + +logger = get_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train a stage-1 Representation Autoencoder (RAE) decoder.") + parser.add_argument( + "--train_data_dir", + type=str, + required=True, + help="Path to an ImageFolder-style dataset root.", + ) + parser.add_argument( + "--output_dir", type=str, default="autoencoder-rae", help="Directory to save checkpoints/model." + ) + parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.") + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument("--resolution", type=int, default=256) + parser.add_argument("--center_crop", action="store_true") + parser.add_argument("--random_flip", action="store_true") + + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--dataloader_num_workers", type=int, default=4) + parser.add_argument("--num_train_epochs", type=int, default=10) + parser.add_argument("--max_train_steps", type=int, default=None) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--max_grad_norm", type=float, default=1.0) + + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--adam_beta1", type=float, default=0.9) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2) + parser.add_argument("--adam_epsilon", type=float, default=1e-8) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + + parser.add_argument("--checkpointing_steps", type=int, default=1000) + parser.add_argument("--validation_steps", type=int, default=500) + + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to a pretrained AutoencoderRAE model (or HF Hub id) to resume training from.", + ) + parser.add_argument( + "--encoder_name_or_path", + type=str, + default=None, + help=( + "HF Hub id or local path of the pretrained encoder (e.g. 'facebook/dinov2-with-registers-base'). " + "When --pretrained_model_name_or_path is not set, the encoder weights are loaded from this path " + "into a freshly constructed AutoencoderRAE. Ignored when --pretrained_model_name_or_path is set." + ), + ) + + parser.add_argument("--encoder_type", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2") + parser.add_argument("--encoder_hidden_size", type=int, default=768) + parser.add_argument("--encoder_patch_size", type=int, default=14) + parser.add_argument("--encoder_num_hidden_layers", type=int, default=12) + parser.add_argument("--encoder_input_size", type=int, default=224) + parser.add_argument("--patch_size", type=int, default=16) + parser.add_argument("--image_size", type=int, default=256) + parser.add_argument("--num_channels", type=int, default=3) + + parser.add_argument("--decoder_hidden_size", type=int, default=1152) + parser.add_argument("--decoder_num_hidden_layers", type=int, default=28) + parser.add_argument("--decoder_num_attention_heads", type=int, default=16) + parser.add_argument("--decoder_intermediate_size", type=int, default=4096) + + parser.add_argument("--noise_tau", type=float, default=0.0) + parser.add_argument("--scaling_factor", type=float, default=1.0) + parser.add_argument("--reshape_to_2d", action=argparse.BooleanOptionalAction, default=True) + + parser.add_argument( + "--reconstruction_loss_type", + type=str, + choices=["l1", "mse"], + default="l1", + help="Pixel reconstruction loss.", + ) + parser.add_argument( + "--encoder_loss_weight", + type=float, + default=0.0, + help="Weight for encoder feature consistency loss in the training loop.", + ) + parser.add_argument( + "--use_encoder_loss", + action="store_true", + help="Enable encoder feature consistency loss term in the training loop.", + ) + parser.add_argument("--report_to", type=str, default="tensorboard") + + return parser.parse_args() + + +def build_transforms(args): + image_transforms = [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ] + if args.random_flip: + image_transforms.append(transforms.RandomHorizontalFlip()) + image_transforms.append(transforms.ToTensor()) + return transforms.Compose(image_transforms) + + +def compute_losses( + model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float +): + decoded = model(pixel_values).sample + + if decoded.shape[-2:] != pixel_values.shape[-2:]: + raise ValueError( + "Training requires matching reconstruction and target sizes, got " + f"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}." + ) + + if reconstruction_loss_type == "l1": + reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float()) + else: + reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float()) + + encoder_loss = torch.zeros_like(reconstruction_loss) + if use_encoder_loss and encoder_loss_weight > 0: + base_model = model.module if hasattr(model, "module") else model + target_encoder_input = base_model._resize_and_normalize(pixel_values) + reconstructed_encoder_input = base_model._resize_and_normalize(decoded) + + encoder_forward_kwargs = {"model": base_model.encoder} + if base_model.config.encoder_type == "mae": + encoder_forward_kwargs["patch_size"] = base_model.config.encoder_patch_size + with torch.no_grad(): + target_tokens = base_model._encoder_forward_fn(images=target_encoder_input, **encoder_forward_kwargs) + reconstructed_tokens = base_model._encoder_forward_fn( + images=reconstructed_encoder_input, **encoder_forward_kwargs + ) + encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float()) + + loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss + return decoded, loss, reconstruction_loss, encoder_loss + + +def _strip_final_layernorm_affine(state_dict, prefix=""): + """Remove final layernorm weight/bias so the model keeps its default init (identity).""" + keys_to_strip = {f"{prefix}weight", f"{prefix}bias"} + return {k: v for k, v in state_dict.items() if k not in keys_to_strip} + + +def _load_pretrained_encoder_weights(model, encoder_type, encoder_name_or_path): + """Load pretrained HF transformers encoder weights into the model's encoder.""" + if encoder_type == "dinov2": + from transformers import Dinov2WithRegistersModel + + hf_encoder = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path) + state_dict = hf_encoder.state_dict() + state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.") + elif encoder_type == "siglip2": + from transformers import SiglipModel + + hf_encoder = SiglipModel.from_pretrained(encoder_name_or_path).vision_model + state_dict = {f"vision_model.{k}": v for k, v in hf_encoder.state_dict().items()} + state_dict = _strip_final_layernorm_affine(state_dict, prefix="vision_model.post_layernorm.") + elif encoder_type == "mae": + from transformers import ViTMAEForPreTraining + + hf_encoder = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit + state_dict = hf_encoder.state_dict() + state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.") + else: + raise ValueError(f"Unknown encoder_type: {encoder_type}") + + model.encoder.load_state_dict(state_dict, strict=False) + + +def main(): + args = parse_args() + if args.resolution != args.image_size: + raise ValueError( + f"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) " + "for stage-1 reconstruction loss." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + project_config=accelerator_project_config, + log_with=args.report_to, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if args.seed is not None: + set_seed(args.seed) + + if accelerator.is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args)) + + def collate_fn(examples): + pixel_values = torch.stack([example[0] for example in examples]).float() + return {"pixel_values": pixel_values} + + train_dataloader = DataLoader( + dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + pin_memory=True, + drop_last=True, + ) + + if args.pretrained_model_name_or_path is not None: + model = AutoencoderRAE.from_pretrained(args.pretrained_model_name_or_path) + logger.info(f"Loaded pretrained AutoencoderRAE from {args.pretrained_model_name_or_path}") + else: + model = AutoencoderRAE( + encoder_type=args.encoder_type, + encoder_hidden_size=args.encoder_hidden_size, + encoder_patch_size=args.encoder_patch_size, + encoder_num_hidden_layers=args.encoder_num_hidden_layers, + decoder_hidden_size=args.decoder_hidden_size, + decoder_num_hidden_layers=args.decoder_num_hidden_layers, + decoder_num_attention_heads=args.decoder_num_attention_heads, + decoder_intermediate_size=args.decoder_intermediate_size, + patch_size=args.patch_size, + encoder_input_size=args.encoder_input_size, + image_size=args.image_size, + num_channels=args.num_channels, + noise_tau=args.noise_tau, + reshape_to_2d=args.reshape_to_2d, + use_encoder_loss=args.use_encoder_loss, + scaling_factor=args.scaling_factor, + ) + if args.encoder_name_or_path is not None: + _load_pretrained_encoder_weights(model, args.encoder_type, args.encoder_name_or_path) + logger.info(f"Loaded pretrained encoder weights from {args.encoder_name_or_path}") + model.encoder.requires_grad_(False) + model.decoder.requires_grad_(True) + model.train() + + optimizer = torch.optim.AdamW( + (p for p in model.parameters() if p.requires_grad), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + if overrode_max_train_steps: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if accelerator.is_main_process: + accelerator.init_trackers("train_autoencoder_rae", config=vars(args)) + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(model): + pixel_values = batch["pixel_values"] + + _, loss, reconstruction_loss, encoder_loss = compute_losses( + model, + pixel_values, + reconstruction_loss_type=args.reconstruction_loss_type, + use_encoder_loss=args.use_encoder_loss, + encoder_loss_weight=args.encoder_loss_weight, + ) + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = { + "loss": loss.detach().item(), + "reconstruction_loss": reconstruction_loss.detach().item(), + "encoder_loss": encoder_loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step % args.validation_steps == 0: + with torch.no_grad(): + _, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses( + model, + pixel_values, + reconstruction_loss_type=args.reconstruction_loss_type, + use_encoder_loss=args.use_encoder_loss, + encoder_loss_weight=args.encoder_loss_weight, + ) + accelerator.log( + { + "val/loss": val_loss.detach().item(), + "val/reconstruction_loss": val_reconstruction_loss.detach().item(), + "val/encoder_loss": val_encoder_loss.detach().item(), + }, + step=global_step, + ) + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(save_path) + logger.info(f"Saved checkpoint to {save_path}") + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.output_dir) + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/convert_rae_to_diffusers.py b/scripts/convert_rae_to_diffusers.py new file mode 100644 index 000000000000..0502e49ef30c --- /dev/null +++ b/scripts/convert_rae_to_diffusers.py @@ -0,0 +1,406 @@ +import argparse +from pathlib import Path +from typing import Any + +import torch +from huggingface_hub import HfApi, hf_hub_download + +from diffusers import AutoencoderRAE + + +DECODER_CONFIGS = { + "ViTB": { + "decoder_hidden_size": 768, + "decoder_intermediate_size": 3072, + "decoder_num_attention_heads": 12, + "decoder_num_hidden_layers": 12, + }, + "ViTL": { + "decoder_hidden_size": 1024, + "decoder_intermediate_size": 4096, + "decoder_num_attention_heads": 16, + "decoder_num_hidden_layers": 24, + }, + "ViTXL": { + "decoder_hidden_size": 1152, + "decoder_intermediate_size": 4096, + "decoder_num_attention_heads": 16, + "decoder_num_hidden_layers": 28, + }, +} + +ENCODER_DEFAULT_NAME_OR_PATH = { + "dinov2": "facebook/dinov2-with-registers-base", + "siglip2": "google/siglip2-base-patch16-256", + "mae": "facebook/vit-mae-base", +} + +ENCODER_HIDDEN_SIZE = { + "dinov2": 768, + "siglip2": 768, + "mae": 768, +} + +ENCODER_PATCH_SIZE = { + "dinov2": 14, + "siglip2": 16, + "mae": 16, +} + +DEFAULT_DECODER_SUBDIR = { + "dinov2": "decoders/dinov2/wReg_base", + "mae": "decoders/mae/base_p16", + "siglip2": "decoders/siglip2/base_p16_i256", +} + +DEFAULT_STATS_SUBDIR = { + "dinov2": "stats/dinov2/wReg_base", + "mae": "stats/mae/base_p16", + "siglip2": "stats/siglip2/base_p16_i256", +} + +DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt") +STATS_FILE_CANDIDATES = ("stat.pt",) + + +def dataset_case_candidates(name: str) -> tuple[str, ...]: + return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k") + + +class RepoAccessor: + def __init__(self, repo_or_path: str, cache_dir: str | None = None): + self.repo_or_path = repo_or_path + self.cache_dir = cache_dir + self.local_root: Path | None = None + self.repo_id: str | None = None + self.repo_files: set[str] | None = None + + root = Path(repo_or_path) + if root.exists() and root.is_dir(): + self.local_root = root + else: + self.repo_id = repo_or_path + self.repo_files = set(HfApi().list_repo_files(repo_or_path)) + + def exists(self, relative_path: str) -> bool: + relative_path = relative_path.replace("\\", "/") + if self.local_root is not None: + return (self.local_root / relative_path).is_file() + return relative_path in self.repo_files + + def fetch(self, relative_path: str) -> Path: + relative_path = relative_path.replace("\\", "/") + if self.local_root is not None: + return self.local_root / relative_path + downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir) + return Path(downloaded) + + +def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]: + state_dict = maybe_wrapped + for k in ("model", "module", "state_dict"): + if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict): + state_dict = state_dict[k] + + out = dict(state_dict) + if len(out) > 0 and all(key.startswith("module.") for key in out): + out = {key[len("module.") :]: value for key, value in out.items()} + if len(out) > 0 and all(key.startswith("decoder.") for key in out): + out = {key[len("decoder.") :]: value for key, value in out.items()} + return out + + +def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder. + + Example mappings: + - `...attention.attention.query.*` -> `...attention.to_q.*` + - `...attention.attention.key.*` -> `...attention.to_k.*` + - `...attention.attention.value.*` -> `...attention.to_v.*` + - `...attention.output.dense.*` -> `...attention.to_out.0.*` + """ + remapped: dict[str, Any] = {} + for key, value in state_dict.items(): + new_key = key + new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.") + new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.") + new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.") + new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.") + remapped[new_key] = value + return remapped + + +def resolve_decoder_file( + accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None +) -> str: + if decoder_checkpoint is not None: + if accessor.exists(decoder_checkpoint): + return decoder_checkpoint + raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}") + + base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}" + for name in DECODER_FILE_CANDIDATES: + candidate = f"{base}/{name}" + if accessor.exists(candidate): + return candidate + + raise FileNotFoundError( + f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}" + ) + + +def resolve_stats_file( + accessor: RepoAccessor, + encoder_type: str, + dataset_name: str, + stats_checkpoint: str | None, +) -> str | None: + if stats_checkpoint is not None: + if accessor.exists(stats_checkpoint): + return stats_checkpoint + raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}") + + base = DEFAULT_STATS_SUBDIR[encoder_type] + for dataset in dataset_case_candidates(dataset_name): + for name in STATS_FILE_CANDIDATES: + candidate = f"{base}/{dataset}/{name}" + if accessor.exists(candidate): + return candidate + + return None + + +def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]: + if not isinstance(stats_obj, dict): + return None, None + + if "latents_mean" in stats_obj or "latents_std" in stats_obj: + return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None) + + mean = stats_obj.get("mean", None) + var = stats_obj.get("var", None) + if mean is None and var is None: + return None, None + + latents_std = None + if var is not None: + if isinstance(var, torch.Tensor): + latents_std = torch.sqrt(var + 1e-5) + else: + latents_std = torch.sqrt(torch.tensor(var) + 1e-5) + return mean, latents_std + + +def _strip_final_layernorm_affine(state_dict: dict[str, Any], prefix: str = "") -> dict[str, Any]: + """Remove final layernorm weight/bias from encoder state dict. + + RAE uses non-affine layernorm (weight=1, bias=0 is the default identity). + Stripping these keys means the model keeps its default init values, which + is functionally equivalent to setting elementwise_affine=False. + """ + keys_to_strip = {f"{prefix}weight", f"{prefix}bias"} + return {k: v for k, v in state_dict.items() if k not in keys_to_strip} + + +def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]: + """Download the HF encoder and extract the state dict for the inner model.""" + if encoder_type == "dinov2": + from transformers import Dinov2WithRegistersModel + + hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path) + sd = hf_model.state_dict() + return _strip_final_layernorm_affine(sd, prefix="layernorm.") + elif encoder_type == "siglip2": + from transformers import SiglipModel + + # SiglipModel.vision_model is a SiglipVisionTransformer. + # Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it + # under .vision_model, so we add the prefix to match the diffusers key layout. + hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model + sd = {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()} + return _strip_final_layernorm_affine(sd, prefix="vision_model.post_layernorm.") + elif encoder_type == "mae": + from transformers import ViTMAEForPreTraining + + hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit + sd = hf_model.state_dict() + return _strip_final_layernorm_affine(sd, prefix="layernorm.") + else: + raise ValueError(f"Unknown encoder_type: {encoder_type}") + + +def convert(args: argparse.Namespace) -> None: + accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir) + encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type] + + decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint) + stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint) + + print(f"Using decoder checkpoint: {decoder_relpath}") + if stats_relpath is not None: + print(f"Using stats checkpoint: {stats_relpath}") + else: + print("No stats checkpoint found; conversion will proceed without latent stats.") + + if args.dry_run: + return + + decoder_path = accessor.fetch(decoder_relpath) + decoder_obj = torch.load(decoder_path, map_location="cpu") + decoder_state_dict = unwrap_state_dict(decoder_obj) + decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict) + + latents_mean, latents_std = None, None + if stats_relpath is not None: + stats_path = accessor.fetch(stats_relpath) + stats_obj = torch.load(stats_path, map_location="cpu") + latents_mean, latents_std = extract_latent_stats(stats_obj) + + decoder_cfg = DECODER_CONFIGS[args.decoder_config_name] + + # Read encoder normalization stats from the HF image processor (only place that downloads encoder info) + from transformers import AutoConfig, AutoImageProcessor + + proc = AutoImageProcessor.from_pretrained(encoder_name_or_path) + encoder_norm_mean = list(proc.image_mean) + encoder_norm_std = list(proc.image_std) + + # Read encoder hidden size and patch size from HF config + encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type] + encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type] + try: + hf_config = AutoConfig.from_pretrained(encoder_name_or_path) + # For models like SigLIP that nest vision config + if hasattr(hf_config, "vision_config"): + hf_config = hf_config.vision_config + encoder_hidden_size = hf_config.hidden_size + encoder_patch_size = hf_config.patch_size + except Exception: + pass + + # Load the actual encoder weights from HF to include in the saved model + encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path) + + # Build model on meta device to avoid double init overhead + with torch.device("meta"): + model = AutoencoderRAE( + encoder_type=args.encoder_type, + encoder_hidden_size=encoder_hidden_size, + encoder_patch_size=encoder_patch_size, + encoder_input_size=args.encoder_input_size, + patch_size=args.patch_size, + image_size=args.image_size, + num_channels=args.num_channels, + encoder_norm_mean=encoder_norm_mean, + encoder_norm_std=encoder_norm_std, + decoder_hidden_size=decoder_cfg["decoder_hidden_size"], + decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"], + decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"], + decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"], + latents_mean=latents_mean, + latents_std=latents_std, + scaling_factor=args.scaling_factor, + ) + + # Assemble full state dict and load with assign=True + full_state_dict = {} + + # Encoder weights (prefixed with "encoder.") + for k, v in encoder_state_dict.items(): + full_state_dict[f"encoder.{k}"] = v + + # Decoder weights (prefixed with "decoder.") + for k, v in decoder_state_dict.items(): + full_state_dict[f"decoder.{k}"] = v + + # Buffers from config + full_state_dict["encoder_mean"] = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1) + full_state_dict["encoder_std"] = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1) + if latents_mean is not None: + latents_mean_t = latents_mean if isinstance(latents_mean, torch.Tensor) else torch.tensor(latents_mean) + full_state_dict["_latents_mean"] = latents_mean_t + else: + full_state_dict["_latents_mean"] = torch.zeros(1) + if latents_std is not None: + latents_std_t = latents_std if isinstance(latents_std, torch.Tensor) else torch.tensor(latents_std) + full_state_dict["_latents_std"] = latents_std_t + else: + full_state_dict["_latents_std"] = torch.ones(1) + + model.load_state_dict(full_state_dict, strict=False, assign=True) + + # Verify no critical keys are missing + model_keys = {name for name, _ in model.named_parameters()} + model_keys |= {name for name, _ in model.named_buffers()} + loaded_keys = set(full_state_dict.keys()) + missing = model_keys - loaded_keys + # decoder_pos_embed is initialized in-model. trainable_cls_token is only + # allowed to be missing if it was absent in the source decoder checkpoint. + allowed_missing = {"decoder.decoder_pos_embed"} + if "trainable_cls_token" not in decoder_state_dict: + allowed_missing.add("decoder.trainable_cls_token") + if missing - allowed_missing: + print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}") + + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + model.save_pretrained(output_path) + + if args.verify_load: + print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...") + loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False) + if not isinstance(loaded_model, AutoencoderRAE): + raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.") + print("Verification passed.") + + print(f"Saved converted AutoencoderRAE to: {output_path}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format") + parser.add_argument( + "--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path" + ) + parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model") + + parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True) + parser.add_argument( + "--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override" + ) + + parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name") + parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name") + + parser.add_argument( + "--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path" + ) + parser.add_argument( + "--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path" + ) + + parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL") + parser.add_argument("--encoder_input_size", type=int, default=224) + parser.add_argument("--patch_size", type=int, default=16) + parser.add_argument("--image_size", type=int, default=None) + parser.add_argument("--num_channels", type=int, default=3) + parser.add_argument("--scaling_factor", type=float, default=1.0) + + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files") + parser.add_argument( + "--verify_load", + action="store_true", + help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + convert(args) + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ec0347750816..f1285aa9daa8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -202,6 +202,7 @@ "AutoencoderKLTemporalDecoder", "AutoencoderKLWan", "AutoencoderOobleck", + "AutoencoderRAE", "AutoencoderTiny", "AutoModel", "BriaFiboTransformer2DModel", @@ -975,6 +976,7 @@ AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, + AutoencoderRAE, AutoencoderTiny, AutoModel, BriaFiboTransformer2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 8b8d9c52659e..b5b9805d4c96 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -49,6 +49,7 @@ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] + _import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] @@ -168,6 +169,7 @@ AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, + AutoencoderRAE, AutoencoderTiny, ConsistencyDecoderVAE, VQModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 8e7a9c81d2ad..23665ee0532e 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -18,6 +18,7 @@ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan from .autoencoder_oobleck import AutoencoderOobleck +from .autoencoder_rae import AutoencoderRAE from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE from .vq_model import VQModel diff --git a/src/diffusers/models/autoencoders/autoencoder_rae.py b/src/diffusers/models/autoencoders/autoencoder_rae.py new file mode 100644 index 000000000000..58ea66f8d18d --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_rae.py @@ -0,0 +1,689 @@ +# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import sqrt +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ...utils.accelerate_utils import apply_forward_hook +from ...utils.import_utils import is_transformers_available +from ...utils.torch_utils import randn_tensor + + +if is_transformers_available(): + from transformers import ( + Dinov2WithRegistersConfig, + Dinov2WithRegistersModel, + SiglipVisionConfig, + SiglipVisionModel, + ViTMAEConfig, + ViTMAEModel, + ) + +from ..activations import get_activation +from ..attention import AttentionMixin +from ..attention_processor import Attention +from ..embeddings import get_2d_sincos_pos_embed +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput + + +logger = logging.get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Per-encoder forward functions +# --------------------------------------------------------------------------- +# Each function takes the raw transformers model + images and returns patch +# tokens of shape (B, N, C), stripping CLS / register tokens as needed. + + +def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor: + outputs = model(images, output_hidden_states=True) + unused_token_num = 5 # 1 CLS + 4 register tokens + return outputs.last_hidden_state[:, unused_token_num:] + + +def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor: + outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True) + return outputs.last_hidden_state + + +def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor: + h, w = images.shape[2], images.shape[3] + patch_num = int(h * w // patch_size**2) + if patch_num * patch_size**2 != h * w: + raise ValueError("Image size should be divisible by patch size.") + noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype) + outputs = model(images, noise, interpolate_pos_encoding=True) + return outputs.last_hidden_state[:, 1:] # remove cls token + + +# --------------------------------------------------------------------------- +# Encoder construction helpers +# --------------------------------------------------------------------------- + + +def _build_encoder( + encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64 +) -> nn.Module: + """Build a frozen encoder from config (no pretrained download).""" + num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64 + + if encoder_type == "dinov2": + config = Dinov2WithRegistersConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=518, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + ) + model = Dinov2WithRegistersModel(config) + # RAE strips the final layernorm affine params (identity LN). Remove them from + # the architecture so `from_pretrained` doesn't leave them on the meta device. + model.layernorm.weight = None + model.layernorm.bias = None + elif encoder_type == "siglip2": + config = SiglipVisionConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=256, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + ) + model = SiglipVisionModel(config) + # See dinov2 comment above. + model.vision_model.post_layernorm.weight = None + model.vision_model.post_layernorm.bias = None + elif encoder_type == "mae": + config = ViTMAEConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=224, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + mask_ratio=0.0, + ) + model = ViTMAEModel(config) + # See dinov2 comment above. + model.layernorm.weight = None + model.layernorm.bias = None + else: + raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae") + + model.requires_grad_(False) + return model + + +_ENCODER_FORWARD_FNS = { + "dinov2": _dinov2_encoder_forward, + "siglip2": _siglip2_encoder_forward, + "mae": _mae_encoder_forward, +} + + +@dataclass +class RAEDecoderOutput(BaseOutput): + """ + Output of `RAEDecoder`. + + Args: + logits (`torch.Tensor`): + Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`. + """ + + logits: torch.Tensor + + +class ViTMAEIntermediate(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"): + super().__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = get_activation(hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class ViTMAEOutput(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0): + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class ViTMAELayer(nn.Module): + """ + This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block). + """ + + def __init__( + self, + *, + hidden_size: int, + num_attention_heads: int, + intermediate_size: int, + qkv_bias: bool = True, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + hidden_act: str = "gelu", + ): + super().__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" + ) + self.attention = Attention( + query_dim=hidden_size, + heads=num_attention_heads, + dim_head=hidden_size // num_attention_heads, + dropout=attention_probs_dropout_prob, + bias=qkv_bias, + ) + self.intermediate = ViTMAEIntermediate( + hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act + ) + self.output = ViTMAEOutput( + hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob + ) + self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + attention_output = self.attention(self.layernorm_before(hidden_states)) + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, hidden_states) + return layer_output + + +class RAEDecoder(nn.Module): + """ + Decoder implementation ported from RAE-main to keep checkpoint compatibility. + + Key attributes (must match checkpoint keys): + - decoder_embed + - decoder_pos_embed + - decoder_layers + - decoder_norm + - decoder_pred + - trainable_cls_token + """ + + def __init__( + self, + hidden_size: int = 768, + decoder_hidden_size: int = 512, + decoder_num_hidden_layers: int = 8, + decoder_num_attention_heads: int = 16, + decoder_intermediate_size: int = 2048, + num_patches: int = 256, + patch_size: int = 16, + num_channels: int = 3, + image_size: int = 256, + qkv_bias: bool = True, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + hidden_act: str = "gelu", + ): + super().__init__() + self.decoder_hidden_size = decoder_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.image_size = image_size + self.num_patches = num_patches + + self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True) + grid_size = int(num_patches**0.5) + pos_embed = get_2d_sincos_pos_embed( + decoder_hidden_size, grid_size, cls_token=True, extra_tokens=1, output_type="pt" + ) + self.register_buffer("decoder_pos_embed", pos_embed.unsqueeze(0).float(), persistent=False) + + self.decoder_layers = nn.ModuleList( + [ + ViTMAELayer( + hidden_size=decoder_hidden_size, + num_attention_heads=decoder_num_attention_heads, + intermediate_size=decoder_intermediate_size, + qkv_bias=qkv_bias, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + hidden_act=hidden_act, + ) + for _ in range(decoder_num_hidden_layers) + ] + ) + + self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps) + self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True) + self.gradient_checkpointing = False + + self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: + embeddings_positions = embeddings.shape[1] - 1 + num_positions = self.decoder_pos_embed.shape[1] - 1 + + class_pos_embed = self.decoder_pos_embed[:, 0, :] + patch_pos_embed = self.decoder_pos_embed[:, 1:, :] + dim = self.decoder_pos_embed.shape[-1] + + patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2) + patch_pos_embed = F.interpolate( + patch_pos_embed, + scale_factor=(1, embeddings_positions / num_positions), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor: + b, l, c = x.shape + if l == self.num_patches: + return x + h = w = int(l**0.5) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5)) + x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) + x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c) + return x + + def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None): + patch_size, num_channels = self.patch_size, self.num_channels + original_image_size = ( + original_image_size if original_image_size is not None else (self.image_size, self.image_size) + ) + original_height, original_width = original_image_size + num_patches_h = original_height // patch_size + num_patches_w = original_width // patch_size + if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]: + raise ValueError( + f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}" + ) + + batch_size = patchified_pixel_values.shape[0] + patchified_pixel_values = patchified_pixel_values.reshape( + batch_size, + num_patches_h, + num_patches_w, + patch_size, + patch_size, + num_channels, + ) + patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) + pixel_values = patchified_pixel_values.reshape( + batch_size, + num_channels, + num_patches_h * patch_size, + num_patches_w * patch_size, + ) + return pixel_values + + def forward( + self, + hidden_states: torch.Tensor, + *, + interpolate_pos_encoding: bool = False, + drop_cls_token: bool = False, + return_dict: bool = True, + ) -> RAEDecoderOutput | tuple[torch.Tensor]: + x = self.decoder_embed(hidden_states) + if drop_cls_token: + x_ = x[:, 1:, :] + x_ = self.interpolate_latent(x_) + else: + x_ = self.interpolate_latent(x) + + cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1) + x = torch.cat([cls_token, x_], dim=1) + + if interpolate_pos_encoding: + if not drop_cls_token: + raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True") + decoder_pos_embed = self.interpolate_pos_encoding(x) + else: + decoder_pos_embed = self.decoder_pos_embed + + hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype) + + for layer_module in self.decoder_layers: + hidden_states = layer_module(hidden_states) + + hidden_states = self.decoder_norm(hidden_states) + logits = self.decoder_pred(hidden_states) + logits = logits[:, 1:, :] + + if not return_dict: + return (logits,) + return RAEDecoderOutput(logits=logits) + + +class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): + r""" + Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. + + This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct + images from learned representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). + + Args: + encoder_type (`str`, *optional*, defaults to `"dinov2"`): + Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. + encoder_hidden_size (`int`, *optional*, defaults to `768`): + Hidden size of the encoder model. + encoder_patch_size (`int`, *optional*, defaults to `14`): + Patch size of the encoder model. + encoder_num_hidden_layers (`int`, *optional*, defaults to `12`): + Number of hidden layers in the encoder model. + patch_size (`int`, *optional*, defaults to `16`): + Decoder patch size (used for unpatchify and decoder head). + encoder_input_size (`int`, *optional*, defaults to `224`): + Input size expected by the encoder. + image_size (`int`, *optional*): + Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like + RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // + encoder_patch_size) ** 2`. + num_channels (`int`, *optional*, defaults to `3`): + Number of input/output channels. + encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + Channel-wise mean for encoder input normalization (ImageNet defaults). + encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + Channel-wise std for encoder input normalization (ImageNet defaults). + latents_mean (`list` or `tuple`, *optional*): + Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable + lists. + latents_std (`list` or `tuple`, *optional*): + Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to + config-serializable lists. + noise_tau (`float`, *optional*, defaults to `0.0`): + Noise level for training (adds noise to latents during training). + reshape_to_2d (`bool`, *optional*, defaults to `True`): + Whether to reshape latents to 2D (B, C, H, W) format. + use_encoder_loss (`bool`, *optional*, defaults to `False`): + Whether to use encoder hidden states in the loss (for advanced training). + """ + + # NOTE: gradient checkpointing is not wired up for this model yet. + _supports_gradient_checkpointing = False + _no_split_modules = ["ViTMAELayer"] + _keys_to_ignore_on_load_unexpected = ["decoder.decoder_pos_embed"] + + @register_to_config + def __init__( + self, + encoder_type: str = "dinov2", + encoder_hidden_size: int = 768, + encoder_patch_size: int = 14, + encoder_num_hidden_layers: int = 12, + decoder_hidden_size: int = 512, + decoder_num_hidden_layers: int = 8, + decoder_num_attention_heads: int = 16, + decoder_intermediate_size: int = 2048, + patch_size: int = 16, + encoder_input_size: int = 224, + image_size: int | None = None, + num_channels: int = 3, + encoder_norm_mean: list | None = None, + encoder_norm_std: list | None = None, + latents_mean: list | tuple | torch.Tensor | None = None, + latents_std: list | tuple | torch.Tensor | None = None, + noise_tau: float = 0.0, + reshape_to_2d: bool = True, + use_encoder_loss: bool = False, + scaling_factor: float = 1.0, + ): + super().__init__() + + if encoder_type not in _ENCODER_FORWARD_FNS: + raise ValueError( + f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}" + ) + + def _to_config_compatible(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().cpu().tolist() + if isinstance(value, tuple): + return [_to_config_compatible(v) for v in value] + if isinstance(value, list): + return [_to_config_compatible(v) for v in value] + return value + + def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None: + if value is None: + return None + if isinstance(value, torch.Tensor): + return value.detach().clone() + return torch.tensor(value, dtype=torch.float32) + + latents_std_tensor = _as_optional_tensor(latents_std) + + # Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors. + self.register_to_config( + latents_mean=_to_config_compatible(latents_mean), + latents_std=_to_config_compatible(latents_std), + ) + + self.encoder_input_size = encoder_input_size + self.noise_tau = float(noise_tau) + self.reshape_to_2d = bool(reshape_to_2d) + self.use_encoder_loss = bool(use_encoder_loss) + + # Validate early, before building the (potentially large) encoder/decoder. + encoder_patch_size = int(encoder_patch_size) + if self.encoder_input_size % encoder_patch_size != 0: + raise ValueError( + f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}." + ) + decoder_patch_size = int(patch_size) + if decoder_patch_size <= 0: + raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).") + + # Frozen representation encoder (built from config, no downloads) + self.encoder: nn.Module = _build_encoder( + encoder_type=encoder_type, + hidden_size=encoder_hidden_size, + patch_size=encoder_patch_size, + num_hidden_layers=encoder_num_hidden_layers, + ) + self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type] + num_patches = (self.encoder_input_size // encoder_patch_size) ** 2 + + grid = int(sqrt(num_patches)) + if grid * grid != num_patches: + raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.") + + derived_image_size = decoder_patch_size * grid + if image_size is None: + image_size = derived_image_size + else: + image_size = int(image_size) + if image_size != derived_image_size: + raise ValueError( + f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} " + f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." + ) + + # Encoder input normalization stats (ImageNet defaults) + if encoder_norm_mean is None: + encoder_norm_mean = [0.485, 0.456, 0.406] + if encoder_norm_std is None: + encoder_norm_std = [0.229, 0.224, 0.225] + encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1) + encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1) + + self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True) + self.register_buffer("encoder_std", encoder_std_tensor, persistent=True) + + # Latent normalization buffers (defaults are no-ops; actual values come from checkpoint) + latents_mean_tensor = _as_optional_tensor(latents_mean) + if latents_mean_tensor is None: + latents_mean_tensor = torch.zeros(1) + self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True) + + if latents_std_tensor is None: + latents_std_tensor = torch.ones(1) + self.register_buffer("_latents_std", latents_std_tensor, persistent=True) + + # ViT-MAE style decoder + self.decoder = RAEDecoder( + hidden_size=int(encoder_hidden_size), + decoder_hidden_size=int(decoder_hidden_size), + decoder_num_hidden_layers=int(decoder_num_hidden_layers), + decoder_num_attention_heads=int(decoder_num_attention_heads), + decoder_intermediate_size=int(decoder_intermediate_size), + num_patches=int(num_patches), + patch_size=int(decoder_patch_size), + num_channels=int(num_channels), + image_size=int(image_size), + ) + self.num_patches = int(num_patches) + self.decoder_patch_size = int(decoder_patch_size) + self.decoder_image_size = int(image_size) + + # Slicing support (batch dimension) similar to other diffusers autoencoders + self.use_slicing = False + + def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: + # Per-sample random sigma in [0, noise_tau] + noise_sigma = self.noise_tau * torch.rand( + (x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator + ) + return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype) + + def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: + _, _, h, w = x.shape + if h != self.encoder_input_size or w != self.encoder_input_size: + x = F.interpolate( + x, size=(self.encoder_input_size, self.encoder_input_size), mode="bicubic", align_corners=False + ) + mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) + std = self.encoder_std.to(device=x.device, dtype=x.dtype) + return (x - mean) / std + + def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor: + mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) + std = self.encoder_std.to(device=x.device, dtype=x.dtype) + return x * std + mean + + def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor: + latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) + latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) + return (z - latents_mean) / (latents_std + 1e-5) + + def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: + latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) + latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) + return z * (latents_std + 1e-5) + latents_mean + + def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: + x = self._resize_and_normalize(x) + + if self.config.encoder_type == "mae": + tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size) + else: + tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C) + + if self.training and self.noise_tau > 0: + tokens = self._noising(tokens, generator=generator) + + if self.reshape_to_2d: + b, n, c = tokens.shape + side = int(sqrt(n)) + if side * side != n: + raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.") + z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w) + else: + z = tokens + + z = self._normalize_latents(z) + + # Follow diffusers convention: optionally scale latents for diffusion + if self.config.scaling_factor != 1.0: + z = z * self.config.scaling_factor + + return z + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None + ) -> EncoderOutput | tuple[torch.Tensor]: + if self.use_slicing and x.shape[0] > 1: + latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0) + else: + latents = self._encode(x, generator=generator) + + if not return_dict: + return (latents,) + return EncoderOutput(latent=latents) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + # Undo scaling factor if applied at encode time + if self.config.scaling_factor != 1.0: + z = z / self.config.scaling_factor + + z = self._denormalize_latents(z) + + if self.reshape_to_2d: + b, c, h, w = z.shape + tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C) + else: + tokens = z + + logits = self.decoder(tokens, return_dict=True).logits + x_rec = self.decoder.unpatchify(logits) + x_rec = self._denormalize_image(x_rec) + return x_rec.to(device=z.device) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def forward( + self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None + ) -> DecoderOutput | tuple[torch.Tensor]: + latents = self.encode(sample, return_dict=False, generator=generator)[0] + decoded = self.decode(latents, return_dict=False)[0] + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3a4aecd24f90..5244273cb596 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -656,6 +656,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderRAE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderTiny(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/autoencoders/test_models_autoencoder_rae.py b/tests/models/autoencoders/test_models_autoencoder_rae.py new file mode 100644 index 000000000000..cc8869737bcc --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_rae.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import pytest +import torch +import torch.nn.functional as F +from torchvision.transforms.functional import to_tensor + +import diffusers.models.autoencoders.autoencoder_rae as _rae_module +from diffusers.models.autoencoders.autoencoder_rae import ( + _ENCODER_FORWARD_FNS, + AutoencoderRAE, + _build_encoder, +) +from diffusers.utils import load_image + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + slow, + torch_all_close, + torch_device, +) +from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +# --------------------------------------------------------------------------- +# Tiny test encoder for fast unit tests (no transformers dependency) +# --------------------------------------------------------------------------- + + +class _TinyTestEncoderModule(torch.nn.Module): + """Minimal encoder that mimics the patch-token interface without any HF model.""" + + def __init__(self, hidden_size: int = 16, patch_size: int = 8, **kwargs): + super().__init__() + self.patch_size = patch_size + self.hidden_size = hidden_size + + def forward(self, images: torch.Tensor) -> torch.Tensor: + pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size) + tokens = pooled.flatten(2).transpose(1, 2).contiguous() + return tokens.repeat(1, 1, self.hidden_size) + + +def _tiny_test_encoder_forward(model, images): + return model(images) + + +def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers): + return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size) + + +# Monkey-patch the dispatch tables so "tiny_test" is recognised by AutoencoderRAE +_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward +_original_build_encoder = _build_encoder + + +def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers): + if encoder_type == "tiny_test": + return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers) + return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers) + + +_rae_module._build_encoder = _patched_build_encoder + + +# --------------------------------------------------------------------------- +# Test config +# --------------------------------------------------------------------------- + + +class AutoencoderRAETesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return AutoencoderRAE + + @property + def output_shape(self): + return (3, 16, 16) + + def get_init_dict(self): + return { + "encoder_type": "tiny_test", + "encoder_hidden_size": 16, + "encoder_patch_size": 8, + "encoder_input_size": 32, + "patch_size": 4, + "image_size": 16, + "decoder_hidden_size": 32, + "decoder_num_hidden_layers": 1, + "decoder_num_attention_heads": 4, + "decoder_intermediate_size": 64, + "num_channels": 3, + "encoder_norm_mean": [0.5, 0.5, 0.5], + "encoder_norm_std": [0.5, 0.5, 0.5], + "noise_tau": 0.0, + "reshape_to_2d": True, + "scaling_factor": 1.0, + } + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_dummy_inputs(self): + return {"sample": torch.randn(2, 3, 32, 32, generator=self.generator, device="cpu").to(torch_device)} + + # Bridge for AutoencoderTesterMixin which still uses the old interface + def prepare_init_args_and_inputs_for_common(self): + return self.get_init_dict(), self.get_dummy_inputs() + + def _make_model(self, **overrides) -> AutoencoderRAE: + config = self.get_init_dict() + config.update(overrides) + return AutoencoderRAE(**config).to(torch_device) + + +class TestAutoEncoderRAE(AutoencoderRAETesterConfig, ModelTesterMixin): + """Core model tests for AutoencoderRAE.""" + + @pytest.mark.skip(reason="AutoencoderRAE does not support torch dynamo yet") + def test_from_save_pretrained_dynamo(self): ... + + def test_fast_encode_decode_and_forward_shapes(self): + model = self._make_model().eval() + x = torch.rand(2, 3, 32, 32, device=torch_device) + + with torch.no_grad(): + z = model.encode(x).latent + decoded = model.decode(z).sample + recon = model(x).sample + + assert z.shape == (2, 16, 4, 4) + assert decoded.shape == (2, 3, 16, 16) + assert recon.shape == (2, 3, 16, 16) + assert torch.isfinite(recon).all().item() + + def test_fast_scaling_factor_encode_and_decode_consistency(self): + torch.manual_seed(0) + model_base = self._make_model(scaling_factor=1.0).eval() + torch.manual_seed(0) + model_scaled = self._make_model(scaling_factor=2.0).eval() + + x = torch.rand(2, 3, 32, 32, device=torch_device) + with torch.no_grad(): + z_base = model_base.encode(x).latent + z_scaled = model_scaled.encode(x).latent + recon_base = model_base.decode(z_base).sample + recon_scaled = model_scaled.decode(z_scaled).sample + + assert torch.allclose(z_scaled, z_base * 2.0, atol=1e-5, rtol=1e-4) + assert torch.allclose(recon_scaled, recon_base, atol=1e-5, rtol=1e-4) + + def test_fast_latents_normalization_matches_formula(self): + latents_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32) + latents_std = torch.full((1, 16, 1, 1), 2.0, dtype=torch.float32) + + model_raw = self._make_model().eval() + model_norm = self._make_model(latents_mean=latents_mean, latents_std=latents_std).eval() + x = torch.rand(1, 3, 32, 32, device=torch_device) + + with torch.no_grad(): + z_raw = model_raw.encode(x).latent + z_norm = model_norm.encode(x).latent + + expected = (z_raw - latents_mean.to(z_raw.device, z_raw.dtype)) / ( + latents_std.to(z_raw.device, z_raw.dtype) + 1e-5 + ) + assert torch.allclose(z_norm, expected, atol=1e-5, rtol=1e-4) + + def test_fast_slicing_matches_non_slicing(self): + model = self._make_model().eval() + x = torch.rand(3, 3, 32, 32, device=torch_device) + + with torch.no_grad(): + model.use_slicing = False + z_no_slice = model.encode(x).latent + out_no_slice = model.decode(z_no_slice).sample + + model.use_slicing = True + z_slice = model.encode(x).latent + out_slice = model.decode(z_slice).sample + + assert torch.allclose(z_slice, z_no_slice, atol=1e-6, rtol=1e-5) + assert torch.allclose(out_slice, out_no_slice, atol=1e-6, rtol=1e-5) + + def test_fast_noise_tau_applies_only_in_train(self): + model = self._make_model(noise_tau=0.5).to(torch_device) + x = torch.rand(2, 3, 32, 32, device=torch_device) + + model.train() + torch.manual_seed(0) + z_train_1 = model.encode(x).latent + torch.manual_seed(1) + z_train_2 = model.encode(x).latent + + model.eval() + torch.manual_seed(0) + z_eval_1 = model.encode(x).latent + torch.manual_seed(1) + z_eval_2 = model.encode(x).latent + + assert z_train_1.shape == z_eval_1.shape + assert not torch.allclose(z_train_1, z_train_2) + assert torch.allclose(z_eval_1, z_eval_2, atol=1e-6, rtol=1e-5) + + +class TestAutoEncoderRAESlicingTiling(AutoencoderRAETesterConfig, AutoencoderTesterMixin): + """Slicing and tiling tests for AutoencoderRAE.""" + + +@slow +@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.") +class AutoencoderRAEEncoderIntegrationTests: + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_dinov2_encoder_forward_shape(self): + encoder = _build_encoder("dinov2", hidden_size=768, patch_size=14, num_hidden_layers=12).to(torch_device) + x = torch.rand(1, 3, 224, 224, device=torch_device) + y = _ENCODER_FORWARD_FNS["dinov2"](encoder, x) + + assert y.ndim == 3 + assert y.shape[0] == 1 + assert y.shape[1] == 256 # (224/14)^2 - 5 (CLS + 4 register) = 251? Actually dinov2 has 256 patches + assert y.shape[2] == 768 + + def test_siglip2_encoder_forward_shape(self): + encoder = _build_encoder("siglip2", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device) + x = torch.rand(1, 3, 224, 224, device=torch_device) + y = _ENCODER_FORWARD_FNS["siglip2"](encoder, x) + + assert y.ndim == 3 + assert y.shape[0] == 1 + assert y.shape[1] == 196 # (224/16)^2 + assert y.shape[2] == 768 + + def test_mae_encoder_forward_shape(self): + encoder = _build_encoder("mae", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device) + x = torch.rand(1, 3, 224, 224, device=torch_device) + y = _ENCODER_FORWARD_FNS["mae"](encoder, x, patch_size=16) + + assert y.ndim == 3 + assert y.shape[0] == 1 + assert y.shape[1] == 196 # (224/16)^2 + assert y.shape[2] == 768 + + +@slow +@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.") +class AutoencoderRAEIntegrationTests: + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_autoencoder_rae_from_pretrained_dinov2(self): + model = AutoencoderRAE.from_pretrained("nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08").to(torch_device) + model.eval() + + image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" + ) + image = image.convert("RGB").resize((224, 224)) + x = to_tensor(image).unsqueeze(0).to(torch_device) + + with torch.no_grad(): + latents = model.encode(x).latent + assert latents.shape == (1, 768, 16, 16) + + recon = model.decode(latents).sample + assert recon.shape == (1, 3, 256, 256) + assert torch.isfinite(recon).all().item() + + # fmt: off + expected_latent_slice = torch.tensor([0.7617, 0.8824, -0.4891]) + expected_recon_slice = torch.tensor([0.1263, 0.1355, 0.1435]) + # fmt: on + + assert torch_all_close(latents[0, :3, 0, 0].float().cpu(), expected_latent_slice, atol=1e-3) + assert torch_all_close(recon[0, 0, 0, :3].float().cpu(), expected_recon_slice, atol=1e-3) From 93d001731cc2f18fd12e96e28c134538b7eaee84 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:08:44 +0800 Subject: [PATCH 027/215] Convert tensors to float in Helios's optimized_scale function (#13214) Convert tensors to float in optimized_scale function --- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 40c1d65825ff..0e08b2c6e958 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -76,6 +76,8 @@ def optimized_scale(positive_flat, negative_flat): + positive_flat = positive_flat.float() + negative_flat = negative_flat.float() # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) # Squared norm of uncondition From 21bfa092ef2f278379a93dfeec1717bfe8845728 Mon Sep 17 00:00:00 2001 From: tcaimm <93749364+tcaimm@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:17:51 +0800 Subject: [PATCH 028/215] Fix wrapped transformer config access in Flux2 Klein training (#13219) --- examples/dreambooth/train_dreambooth_lora_flux2_klein.py | 2 +- .../dreambooth/train_dreambooth_lora_flux2_klein_img2img.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 278c25900a3a..30f4f4e5d219 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -1715,7 +1715,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input) # handle guidance - if transformer.config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.full([1], args.guidance_scale, device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 28cbaf8f72e7..7edf8c0f194d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -1682,7 +1682,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) # handle guidance - if transformer.config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.full([1], args.guidance_scale, device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: From d239b6f0ebe9710c093bfda6f42255d70b965f2c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 7 Mar 2026 11:03:47 +0530 Subject: [PATCH 029/215] post release 0.37.0 (#13215) * post release 0.37.0 * style --- .../train_dreambooth_lora_flux_advanced.py | 2 +- .../train_dreambooth_lora_sd15_advanced.py | 2 +- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- examples/cogvideo/train_cogvideox_image_to_video_lora.py | 2 +- examples/cogvideo/train_cogvideox_lora.py | 2 +- examples/cogview4-control/train_control_cogview4.py | 2 +- examples/community/marigold_depth_estimation.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sd_wds.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sdxl_wds.py | 2 +- examples/consistency_distillation/train_lcm_distill_sd_wds.py | 2 +- examples/consistency_distillation/train_lcm_distill_sdxl_wds.py | 2 +- examples/controlnet/train_controlnet.py | 2 +- examples/controlnet/train_controlnet_flax.py | 2 +- examples/controlnet/train_controlnet_flux.py | 2 +- examples/controlnet/train_controlnet_sd3.py | 2 +- examples/controlnet/train_controlnet_sdxl.py | 2 +- examples/custom_diffusion/train_custom_diffusion.py | 2 +- examples/dreambooth/train_dreambooth.py | 2 +- examples/dreambooth/train_dreambooth_flax.py | 2 +- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux2.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux2_klein.py | 2 +- .../dreambooth/train_dreambooth_lora_flux2_klein_img2img.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux_kontext.py | 2 +- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- examples/dreambooth/train_dreambooth_lora_lumina2.py | 2 +- examples/dreambooth/train_dreambooth_lora_qwen_image.py | 2 +- examples/dreambooth/train_dreambooth_lora_sana.py | 2 +- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- examples/dreambooth/train_dreambooth_lora_z_image.py | 2 +- examples/dreambooth/train_dreambooth_sd3.py | 2 +- examples/flux-control/train_control_flux.py | 2 +- examples/flux-control/train_control_lora_flux.py | 2 +- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py | 2 +- .../kandinsky2_2/text_to_image/train_text_to_image_decoder.py | 2 +- .../text_to_image/train_text_to_image_lora_decoder.py | 2 +- .../text_to_image/train_text_to_image_lora_prior.py | 2 +- .../kandinsky2_2/text_to_image/train_text_to_image_prior.py | 2 +- examples/t2i_adapter/train_t2i_adapter_sdxl.py | 2 +- examples/text_to_image/train_text_to_image.py | 2 +- examples/text_to_image/train_text_to_image_flax.py | 2 +- examples/text_to_image/train_text_to_image_lora.py | 2 +- examples/text_to_image/train_text_to_image_lora_sdxl.py | 2 +- examples/text_to_image/train_text_to_image_sdxl.py | 2 +- examples/textual_inversion/textual_inversion.py | 2 +- examples/textual_inversion/textual_inversion_flax.py | 2 +- examples/textual_inversion/textual_inversion_sdxl.py | 2 +- examples/unconditional_image_generation/train_unconditional.py | 2 +- examples/vqgan/train_vqgan.py | 2 +- setup.py | 2 +- src/diffusers/__init__.py | 2 +- 57 files changed, 57 insertions(+), 57 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 05f2b1ee17f3..608ab3ef3135 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -94,7 +94,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 8fba00afc39e..a47e4dd96dcb 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -88,7 +88,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 8fb749d328c9..dcaa5a38fc37 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -95,7 +95,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index c59986d2fde7..17a9dd47d3ba 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index e08143f98a5c..984ed697d7c7 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -52,7 +52,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index 6f06ed749635..d381a7902723 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -60,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py index cd4473264e41..e1026cbafb06 100644 --- a/examples/community/marigold_depth_estimation.py +++ b/examples/community/marigold_depth_estimation.py @@ -43,7 +43,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") class MarigoldDepthOutput(BaseOutput): diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 26a3ecc87935..38885d4bdf11 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -74,7 +74,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index ef50e8eb2da4..4dd7cbb60ce1 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -67,7 +67,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index a3302d7147b9..f4eb70e61e0f 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -80,7 +80,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 79bc706bcca3..ef1c57bb9e18 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -73,7 +73,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index d6b2dd895766..6f6fcdfa286a 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -79,7 +79,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 198501da725e..690325e24eb8 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 588f6b1f4ca0..4d60598104ba 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 5d54e34eaa06..70355870e9e8 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -66,7 +66,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 1d130a38c97e..66f2bc2eadce 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -63,7 +63,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index b853a32c4483..62757c7f6eb2 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -62,7 +62,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 5922b7443c10..2ce451917709 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -64,7 +64,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 2e66e1f724e7..d3a2b32aaef5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -64,7 +64,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index e68d9df5e424..0580fb4b96b0 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -35,7 +35,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") # Cache compiled models across invocations of this script. cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 468f6fce3ecb..c7e0c290fa8e 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -80,7 +80,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 2d15684f9107..b6baccc4bc99 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -75,7 +75,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 8ae2ddd9796b..e0e7d2e40e56 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -92,7 +92,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 317ed2c2b2e1..24d098add017 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -104,7 +104,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 16a3863c881d..e18909e6dfd7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -104,7 +104,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 30f4f4e5d219..268d0148e446 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -104,7 +104,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 7edf8c0f194d..0205f2e9e65f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -104,7 +104,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 1a6757810a80..dee65761e92b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -92,7 +92,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 3abc7afcad2c..bd2fb8db2d21 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -75,7 +75,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py index a13c579718c7..48eba4c5041d 100644 --- a/examples/dreambooth/train_dreambooth_lora_lumina2.py +++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py @@ -73,7 +73,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 33a1054effaf..a1e2fa0f6052 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -93,7 +93,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 0afc31cf8a9a..3b295163b73d 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -91,7 +91,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index d6770c805d25..4f49ef4bd801 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -73,7 +73,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 51bac5d59667..502ce1a3f1ec 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -80,7 +80,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index c77953f16410..623ae4d2aca3 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -104,7 +104,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index e43e3178202a..98e7d2d66cbc 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -64,7 +64,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 1e3be74464be..c5f93fa2e987 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -55,7 +55,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 3185f1b2ea6a..f5d3c822b3ef 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -58,7 +58,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 1bfe7aed30cb..55297b334cb9 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -58,7 +58,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 9a5b23a8e623..5df0e22fe1cc 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -60,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index 158c3a6f0994..9b6cb0523d67 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -53,7 +53,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 30094f54827f..869b81ff5d33 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -46,7 +46,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index 9c0a4c38504e..8600269dd0fe 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -46,7 +46,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index caa8d96ef3ec..6cce862f95a5 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -52,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index 32128ebbd4df..eb393418c5d7 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 90dd06d33c5e..7b76594a8dd0 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -57,7 +57,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index e474445d9afe..4fe710089981 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -49,7 +49,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 310a50ac4e9a..55c2c42d74c0 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 88f5c3cede6e..e211ad95ff43 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -68,7 +68,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 4eafa8f28a19..95749d4dcde4 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -55,7 +55,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 0d8c25349fca..1aaa701d8ceb 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -82,7 +82,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 7fb394a1bd15..66a5da1fcd8f 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -56,7 +56,7 @@ # ------------------------------------------------------------------------------ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 3f482341ca4a..3e9151034eaa 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -77,7 +77,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index ed7d2db43700..649fc8c2facd 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -29,7 +29,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index d9ad2774e897..4684c9ce61c6 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -50,7 +50,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.37.0.dev0") +check_min_version("0.38.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/setup.py b/setup.py index 45bffd376351..d42da57920a0 100644 --- a/setup.py +++ b/setup.py @@ -276,7 +276,7 @@ def run(self): setup( name="diffusers", - version="0.37.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.38.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="State-of-the-art diffusion in PyTorch and JAX.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f1285aa9daa8..1f368e2afcbd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.37.0.dev0" +__version__ = "0.38.0.dev0" from typing import TYPE_CHECKING From 16e3dea6745e243258eeb2ba3d56ebb5ed89b4dd Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 9 Mar 2026 06:27:10 +0800 Subject: [PATCH 030/215] Fix Helios Context Parallelism (#13223) * fix Helios Context Parallelism * refacotr * make style and quality --- .../models/transformers/transformer_helios.py | 17 ++++++++---- .../helios/pipeline_helios_pyramid.py | 26 ++++++++++++++++--- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 9f3ef047d98d..6d81f8a13af7 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -556,14 +556,21 @@ class HeliosTransformer3DModel( _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["HeliosTransformerBlock"] _cp_plan = { - "blocks.0": { + # Input split at attn level and ffn level. + "blocks.*.attn1": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, - "blocks.*": { - "temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False), "rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3), + "blocks.*.attn2": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*.ffn": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + # Output gather at attn level and ffn level. + **{f"blocks.{i}.attn1": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + **{f"blocks.{i}.attn2": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + **{f"blocks.{i}.ffn": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, } @register_to_config diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 0e08b2c6e958..d8f317a9a6f1 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -449,7 +449,14 @@ def sample_block_noise( width, patch_size: tuple[int, ...] = (1, 2, 2), device: torch.device | None = None, + generator: torch.Generator | None = None, ): + # NOTE: A generator must be provided to ensure correct and reproducible results. + # Creating a default generator here is a fallback only — without a fixed seed, + # the output will be non-deterministic and may produce incorrect results in CP context. + if generator is None: + generator = torch.Generator(device=device) + gamma = self.scheduler.config.gamma _, ph, pw = patch_size block_size = ph * pw @@ -458,13 +465,17 @@ def sample_block_noise( torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma ) - cov += torch.eye(block_size, device=device) * 1e-6 - dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov) + cov += torch.eye(block_size, device=device) * 1e-8 + cov = cov.float() # Upcast to fp32 for numerical stability — cholesky is unreliable in fp16/bf16. + + L = torch.linalg.cholesky(cov) block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + z = torch.randn(block_number, block_size, device=device, generator=generator) + noise = z @ L.T - noise = dist.sample((block_number,)) # [block number, block_size] noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + return noise @property @@ -918,7 +929,14 @@ def __call__( batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape noise = self.sample_block_noise( - batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device + batch_size, + channel, + num_frames, + pyramid_height, + pyramid_width, + patch_size, + device, + generator, ) noise = noise.to(device=device, dtype=transformer_dtype) latents = alpha * latents + beta * noise # To fix the block artifact From 3a470a68108029ddf09902579c287b89d1eab69d Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 9 Mar 2026 07:54:16 +0800 Subject: [PATCH 031/215] Optimize Helios docs (#13222) optimize helios docs --- docs/source/en/api/pipelines/helios.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 54a08240001c..b85e1dca56b0 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -44,7 +44,7 @@ The example below demonstrates how to generate a video from text optimized for m Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. -The Helios model below requires ~19GB of VRAM. +The Helios model below requires ~6GB of VRAM. ```py import torch @@ -63,8 +63,7 @@ pipeline = HeliosPipeline.from_pretrained( pipeline.enable_group_offload( onload_device=torch.device("cuda"), offload_device=torch.device("cpu"), - offload_type="block_level", - num_blocks_per_group=1, + offload_type="leaf_level", use_stream=True, record_stream=True, ) @@ -97,7 +96,7 @@ export_to_video(output, "helios_base_t2v_output.mp4", fps=24) -[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. +[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Context Parallelism](../../training/distributed_inference#context-parallelism) splits the input sequence across multiple devices to enable processing of long contexts in parallel, reducing memory pressure and latency. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. ```py import torch From 9c0724a9b3ab72bb2ceb899a8f22cf0024ea86d8 Mon Sep 17 00:00:00 2001 From: annitang1997 Date: Mon, 9 Mar 2026 11:23:49 +0800 Subject: [PATCH 032/215] Add VidTok AutoEncoders (#11261) * add_autoencoder_vidtok * format standardization * remove small functions * making the code style more diffusers-like * Apply style fixes * Add dummy objects for AutoencoderVidTok * Fix AutoencoderVidTok avg_pool3d BFloat16 CPU compatibility * skip test_layerwise_casting_training test * Apply style fixes --------- Co-authored-by: annitang1997 Co-authored-by: github-actions[bot] Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../models/autoencoders/autoencoder_vidtok.py | 1488 +++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + .../test_models_autoencoder_vidtok.py | 163 ++ 6 files changed, 1671 insertions(+) create mode 100644 src/diffusers/models/autoencoders/autoencoder_vidtok.py create mode 100644 tests/models/autoencoders/test_models_autoencoder_vidtok.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1f368e2afcbd..d6d557a4c224 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -204,6 +204,7 @@ "AutoencoderOobleck", "AutoencoderRAE", "AutoencoderTiny", + "AutoencoderVidTok", "AutoModel", "BriaFiboTransformer2DModel", "BriaTransformer2DModel", @@ -978,6 +979,7 @@ AutoencoderOobleck, AutoencoderRAE, AutoencoderTiny, + AutoencoderVidTok, AutoModel, BriaFiboTransformer2DModel, BriaTransformer2DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b5b9805d4c96..e4bc95fdf884 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -51,6 +51,7 @@ _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] + _import_structure["autoencoders.autoencoder_vidtok"] = ["AutoencoderVidTok"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["cache_utils"] = ["CacheMixin"] @@ -171,6 +172,7 @@ AutoencoderOobleck, AutoencoderRAE, AutoencoderTiny, + AutoencoderVidTok, ConsistencyDecoderVAE, VQModel, ) diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 23665ee0532e..b6a673f7f7a7 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -20,5 +20,6 @@ from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_rae import AutoencoderRAE from .autoencoder_tiny import AutoencoderTiny +from .autoencoder_vidtok import AutoencoderVidTok from .consistency_decoder_vae import ConsistencyDecoderVAE from .vq_model import VQModel diff --git a/src/diffusers/models/autoencoders/autoencoder_vidtok.py b/src/diffusers/models/autoencoders/autoencoder_vidtok.py new file mode 100644 index 000000000000..4f05afb8a21d --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_vidtok.py @@ -0,0 +1,1488 @@ +# Copyright 2025 The VidTok team, MSRA & Shanghai Jiao Tong University and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FSQRegularizer(nn.Module): + r""" + Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 Code adapted from + https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/finite_scalar_quantization.py + + Args: + levels (`List[int]`): + A list of quantization levels. + dim (`int`, *optional*, defaults to `None`): + The dimension of latent codes. + num_codebooks (`int`, defaults to 1): + The number of codebooks. + keep_num_codebooks_dim (`bool`, *optional*, defaults to `None`): + Whether to keep the number of codebook dim. + """ + + def __init__( + self, + levels: List[int], + dim: Optional[int] = None, + num_codebooks: int = 1, + keep_num_codebooks_dim: Optional[bool] = None, + ): + super().__init__() + + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) + self.register_buffer("_basis", _basis, persistent=False) + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + if keep_num_codebooks_dim is None: + keep_num_codebooks_dim = num_codebooks > 1 + self.keep_num_codebooks_dim = keep_num_codebooks_dim + self.dim = len(_levels) * num_codebooks if dim is None else dim + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + self.global_codebook_usage = torch.zeros([2**self.codebook_dim, self.num_codebooks], dtype=torch.long) + + def quantize(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + r"""Quantizes z, returns quantized zhat, same shape as z.""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + z = (z + shift).tanh() * half_l - offset + zhat = z.round() + quantized = z + (zhat - z).detach() + half_width = self._levels // 2 + return quantized / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + r"""Converts a `code` to an index in the codebook.""" + half_width = self._levels // 2 + zhat = (zhat * half_width) + half_width + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor, project_out: bool = True) -> torch.Tensor: + r"""Inverse of `codes_to_indices`.""" + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + indices = indices.unsqueeze(-1) + codes_non_centered = (indices // self._basis) % self._levels + half_width = self._levels // 2 + codes = (codes_non_centered - half_width) / half_width + if self.keep_num_codebooks_dim: + codes = codes.reshape(*codes.shape[:-2], -1) + if project_out: + codes = self.project_out(codes) + if is_img_or_video: + codes = codes.permute(0, -1, *range(1, codes.dim() - 1)) + return codes + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + einstein notation b - batch n - sequence (or flattened spatial dimensions) d - feature dimension c - number of + codebook dim + """ + is_img_or_video = z.ndim >= 4 + + if is_img_or_video: + if z.ndim == 5: + b, d, t, h, w = z.shape + is_video = True + else: + b, d, h, w = z.shape + is_video = False + z = z.reshape(b, d, -1).permute(0, 2, 1) + + z = self.project_in(z) + b, n, _ = z.shape + z = z.reshape(b, n, self.num_codebooks, -1) + + orig_dtype = z.dtype + z = z.float() + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + codes = codes.type(orig_dtype) + + codes = codes.reshape(b, n, -1) + out = self.project_out(codes) + + # reconstitute image or video dimensions + if is_img_or_video: + if is_video: + out = out.reshape(b, t, h, w, d).permute(0, 4, 1, 2, 3) + indices = indices.reshape(b, t, h, w, 1) + else: + out = out.reshape(b, h, w, d).permute(0, 3, 1, 2) + indices = indices.reshape(b, h, w, 1) + + if not self.keep_num_codebooks_dim: + indices = indices.squeeze(-1) + + return out, indices + + +class VidTokDownsample2D(nn.Module): + r"""A 2D downsampling layer used in VidTok Model.""" + + def __init__(self, in_channels: int): + super().__init__() + + self.in_channels = in_channels + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class VidTokUpsample2D(nn.Module): + r"""A 2D upsampling layer used in VidTok Model.""" + + def __init__(self, in_channels: int): + super().__init__() + + self.in_channels = in_channels + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.interpolate(x.to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype) + x = self.conv(x) + return x + + +class VidTokLayerNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() == 5: + x = x.permute(0, 2, 3, 4, 1) + x = self.norm(x) + x = x.permute(0, 4, 1, 2, 3) + elif x.dim() == 4: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = x.permute(0, 3, 1, 2) + else: + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + return x + + +class VidTokCausalConv1d(nn.Module): + r"""A 1D causal convolution layer that pads the input tensor to ensure causality in VidTok Model.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + padding: int = 0, + ): + super().__init__() + + self.time_pad = dilation * (kernel_size - 1) + (1 - stride) + + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation) + + self.is_first_chunk = True + self.causal_cache = None + self.cache_offset = 0 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_first_chunk: + first_frame_pad = x[:, :, :1].repeat((1, 1, self.time_pad)) + else: + first_frame_pad = self.causal_cache + if self.time_pad != 0: + first_frame_pad = first_frame_pad[:, :, -self.time_pad :] + else: + first_frame_pad = first_frame_pad[:, :, 0:0] + x = torch.concatenate((first_frame_pad, x), dim=2) + if self.cache_offset == 0: + self.causal_cache = x.clone() + else: + self.causal_cache = x[:, :, : -self.cache_offset].clone() + return self.conv(x) + + +class VidTokCausalConv3d(nn.Module): + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in VidTok Model.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + pad_mode: str = "constant", + ): + super().__init__() + self.pad_mode = pad_mode + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(dilation, int): + dilation = (dilation,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + time_pad = dilation[0] * (time_kernel_size - 1) + (1 - stride[0]) + height_pad = dilation[1] * (height_kernel_size - 1) + (1 - stride[1]) + width_pad = dilation[2] * (width_kernel_size - 1) + (1 - stride[2]) + + self.time_pad = time_pad + self.spatial_padding = ( + width_pad // 2, + width_pad - width_pad // 2, + height_pad // 2, + height_pad - height_pad // 2, + 0, + 0, + ) + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation) + + self.is_first_chunk = True + self.causal_cache = None + self.cache_offset = 0 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_first_chunk: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_pad, 1, 1)) + else: + first_frame_pad = self.causal_cache + if self.time_pad != 0: + first_frame_pad = first_frame_pad[:, :, -self.time_pad :] + else: + first_frame_pad = first_frame_pad[:, :, 0:0] + x = torch.concatenate((first_frame_pad, x), dim=2) + if self.cache_offset == 0: + self.causal_cache = x.clone() + else: + self.causal_cache = x[:, :, : -self.cache_offset].clone() + x = F.pad(x, self.spatial_padding, mode=self.pad_mode) + return self.conv(x) + + +class VidTokDownsample3D(nn.Module): + r"""A 3D downsampling layer used in VidTok Model.""" + + def __init__(self, in_channels: int, out_channels: int, mix_factor: float = 2.0, is_causal: bool = True): + super().__init__() + self.is_causal = is_causal + self.kernel_size = (3, 3, 3) + self.avg_pool = nn.AvgPool3d((3, 1, 1), stride=(2, 1, 1)) + make_conv_cls = VidTokCausalConv3d if self.is_causal else nn.Conv3d + self.conv = make_conv_cls(in_channels, out_channels, 3, stride=(2, 1, 1), padding=(0, 1, 1)) + self.mix_factor = nn.Parameter(torch.Tensor([mix_factor])) + if self.is_causal: + self.is_first_chunk = True + self.causal_cache = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + alpha = torch.sigmoid(self.mix_factor) + if self.is_causal: + pad = (0, 0, 0, 0, 1, 0) + if self.is_first_chunk: + x_pad = torch.nn.functional.pad(x, pad, mode="replicate") + else: + x_pad = torch.concatenate((self.causal_cache, x), dim=2) + self.causal_cache = x_pad[:, :, -1:].clone() + if x_pad.device.type == "cpu" and x_pad.dtype == torch.bfloat16: + # PyTorch's avg_pool3d lacks CPU support for BFloat16. + # To avoid errors, we cast to float32, perform the pooling, + # and then cast back to BFloat16 to maintain the expected dtype. + x1 = self.avg_pool(x_pad.float()).to(torch.bfloat16) + else: + x1 = self.avg_pool(x_pad) + else: + pad = (0, 0, 0, 0, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + if x.device.type == "cpu" and x.dtype == torch.bfloat16: + # PyTorch's avg_pool3d lacks CPU support for BFloat16. + # To avoid errors, we cast to float32, perform the pooling, + # and then cast back to BFloat16 to maintain the expected dtype. + x1 = self.avg_pool(x.float()).to(torch.bfloat16) + else: + x1 = self.avg_pool(x) + x2 = self.conv(x) + return alpha * x1 + (1 - alpha) * x2 + + +class VidTokUpsample3D(nn.Module): + r"""A 3D upsampling layer used in VidTok Model.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + mix_factor: float = 2.0, + num_temp_upsample: int = 1, + is_causal: bool = True, + ): + super().__init__() + make_conv_cls = VidTokCausalConv3d if is_causal else nn.Conv3d + self.conv = make_conv_cls(in_channels, out_channels, 3, padding=1) + self.mix_factor = nn.Parameter(torch.Tensor([mix_factor])) + + self.is_causal = is_causal + if self.is_causal: + self.enable_cached = True + self.interpolation_mode = "trilinear" + self.is_first_chunk = True + self.causal_cache = None + self.num_temp_upsample = num_temp_upsample + else: + self.enable_cached = False + self.interpolation_mode = "nearest" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + alpha = torch.sigmoid(self.mix_factor) + if not self.is_causal: + xlst = [ + F.interpolate( + sx.unsqueeze(0).to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode + ).to(x.dtype) + for sx in x + ] + x = torch.cat(xlst, dim=0) + else: + if not self.enable_cached: + x = F.interpolate(x.to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode).to( + x.dtype + ) + elif not self.is_first_chunk: + x = torch.cat([self.causal_cache, x], dim=2) + self.causal_cache = x[:, :, -2 * self.num_temp_upsample : -self.num_temp_upsample].clone() + x = F.interpolate(x.to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode).to( + x.dtype + ) + x = x[:, :, 2 * self.num_temp_upsample :] + else: + self.causal_cache = x[:, :, -self.num_temp_upsample :].clone() + x, _x = x[:, :, : self.num_temp_upsample], x[:, :, self.num_temp_upsample :] + x = F.interpolate(x.to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode).to( + x.dtype + ) + if _x.shape[-3] > 0: + _x = F.interpolate( + _x.to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode + ).to(_x.dtype) + x = torch.concat([x, _x], dim=2) + x_ = self.conv(x) + return alpha * x + (1 - alpha) * x_ + + +class VidTokAttnBlock(nn.Module): + r"""A 3D self-attention block used in VidTok Model.""" + + def __init__(self, in_channels: int, is_causal: bool = True): + super().__init__() + make_conv_cls = VidTokCausalConv3d if is_causal else nn.Conv3d + self.norm = VidTokLayerNorm(dim=in_channels, eps=1e-6) + self.q = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def attention(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""Implement self-attention.""" + hidden_states = self.norm(hidden_states) + q = self.q(hidden_states) + k = self.k(hidden_states) + v = self.v(hidden_states) + b, c, t, h, w = q.shape + q, k, v = [x.permute(0, 2, 3, 4, 1).reshape(b, t, -1, c).contiguous() for x in [q, k, v]] + hidden_states = F.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + return hidden_states.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = x + hidden_states = self.attention(hidden_states) + hidden_states = self.proj_out(hidden_states) + return x + hidden_states + + +class VidTokResnetBlock(nn.Module): + r"""A versatile ResNet block used in VidTok Model.""" + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + btype: str = "3d", + is_causal: bool = True, + ): + super().__init__() + assert btype in ["1d", "2d", "3d"], f"Invalid btype: {btype}" + if btype == "2d": + make_conv_cls = nn.Conv2d + elif btype == "1d": + make_conv_cls = VidTokCausalConv1d if is_causal else nn.Conv1d + else: + make_conv_cls = VidTokCausalConv3d if is_causal else nn.Conv3d + + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.nonlinearity = nn.SiLU() + + self.norm1 = VidTokLayerNorm(dim=in_channels, eps=1e-6) + self.conv1 = make_conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + self.norm2 = VidTokLayerNorm(dim=out_channels, eps=1e-6) + self.dropout = nn.Dropout(dropout) + self.conv2 = make_conv_cls(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = make_conv_cls(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor]) -> torch.Tensor: + hidden_states = x + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + hidden_states + + +class VidTokEncoder3D(nn.Module): + r""" + The `VidTokEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`): + The number of input channels. + ch (`int`): + The number of the basic channel. + ch_mult (`List[int]`, defaults to `[1, 2, 4, 8]`): + The multiple of the basic channel for each block. + num_res_blocks (`int`, defaults to 2): + The number of resblocks. + dropout (`float`, defaults to 0.0): + Dropout rate. + z_channels (`int`, defaults to 4): + The number of latent channels. + double_z (`bool`, defaults to `True`): + Whether or not to double the z_channels. + spatial_ds (`List`, *optional*, defaults to `None`): + Spatial downsample layers. + tempo_ds (`List`, *optional*, defaults to `None`): + Temporal downsample layers. + is_causal (`bool`, defaults to `True`): + Whether it is a causal module. + """ + + def __init__( + self, + in_channels: int, + ch: int, + ch_mult: List[int] = [1, 2, 4, 8], + num_res_blocks: int = 2, + dropout: float = 0.0, + z_channels: int = 4, + double_z: bool = True, + spatial_ds: Optional[List] = None, + tempo_ds: Optional[List] = None, + is_causal: bool = True, + ): + super().__init__() + self.is_causal = is_causal + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + self.nonlinearity = nn.SiLU() + + make_conv_cls = VidTokCausalConv3d if self.is_causal else nn.Conv3d + + self.conv_in = make_conv_cls(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.spatial_ds = list(range(0, self.num_resolutions - 1)) if spatial_ds is None else spatial_ds + self.tempo_ds = [self.num_resolutions - 2, self.num_resolutions - 3] if tempo_ds is None else tempo_ds + self.down = nn.ModuleList() + self.down_temporal = nn.ModuleList() + for i_level in range(self.num_resolutions): + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + block = nn.ModuleList() + attn = nn.ModuleList() + block_temporal = nn.ModuleList() + attn_temporal = nn.ModuleList() + + for i_block in range(self.num_res_blocks): + block.append( + VidTokResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + btype="2d", + ) + ) + block_temporal.append( + VidTokResnetBlock( + in_channels=block_out, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + btype="1d", + is_causal=self.is_causal, + ) + ) + block_in = block_out + + down = nn.Module() + down.block = block + down.attn = attn + + down_temporal = nn.Module() + down_temporal.block = block_temporal + down_temporal.attn = attn_temporal + + if i_level in self.spatial_ds: + down.downsample = VidTokDownsample2D(block_in) + if i_level in self.tempo_ds: + down_temporal.downsample = VidTokDownsample3D(block_in, block_in, is_causal=self.is_causal) + + self.down.append(down) + self.down_temporal.append(down_temporal) + + # middle + self.mid = nn.Module() + self.mid.block_1 = VidTokResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + btype="3d", + is_causal=self.is_causal, + ) + self.mid.attn_1 = VidTokAttnBlock(block_in, is_causal=self.is_causal) + self.mid.block_2 = VidTokResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + btype="3d", + is_causal=self.is_causal, + ) + + # end + self.norm_out = VidTokLayerNorm(dim=block_in, eps=1e-6) + self.conv_out = make_conv_cls( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + temb = None + B, _, T, H, W = x.shape + hs = [self.conv_in(x)] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + hidden_states = hs[-1].permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W) + hidden_states = self._gradient_checkpointing_func( + self.down[i_level].block[i_block], hidden_states, temb + ) + hidden_states = ( + hidden_states.reshape(B, T, -1, H, W).permute(0, 3, 4, 2, 1).reshape(B * H * W, -1, T) + ) + hidden_states = self._gradient_checkpointing_func( + self.down_temporal[i_level].block[i_block], hidden_states, temb + ) + hidden_states = hidden_states.reshape(B, H, W, -1, T).permute(0, 3, 4, 1, 2) + hs.append(hidden_states) + + if i_level in self.spatial_ds: + # spatial downsample + hidden_states = hs[-1].permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W) + hidden_states = self._gradient_checkpointing_func(self.down[i_level].downsample, hidden_states) + hidden_states = hidden_states.reshape(B, T, -1, *hidden_states.shape[-2:]).permute(0, 2, 1, 3, 4) + if i_level in self.tempo_ds: + # temporal downsample + hidden_states = self._gradient_checkpointing_func( + self.down_temporal[i_level].downsample, hidden_states + ) + hs.append(hidden_states) + B, _, T, H, W = hidden_states.shape + # middle + hidden_states = hs[-1] + hidden_states = self._gradient_checkpointing_func(self.mid.block_1, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(self.mid.attn_1, hidden_states) + hidden_states = self._gradient_checkpointing_func(self.mid.block_2, hidden_states, temb) + + else: + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + hidden_states = hs[-1].permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W) + hidden_states = self.down[i_level].block[i_block](hidden_states, temb) + hidden_states = ( + hidden_states.reshape(B, T, -1, H, W).permute(0, 3, 4, 2, 1).reshape(B * H * W, -1, T) + ) + hidden_states = self.down_temporal[i_level].block[i_block](hidden_states, temb) + hidden_states = hidden_states.reshape(B, H, W, -1, T).permute(0, 3, 4, 1, 2) + hs.append(hidden_states) + + if i_level in self.spatial_ds: + # spatial downsample + hidden_states = hs[-1].permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W) + hidden_states = self.down[i_level].downsample(hidden_states) + hidden_states = hidden_states.reshape(B, T, -1, *hidden_states.shape[-2:]).permute(0, 2, 1, 3, 4) + if i_level in self.tempo_ds: + # temporal downsample + hidden_states = self.down_temporal[i_level].downsample(hidden_states) + hs.append(hidden_states) + B, _, T, H, W = hidden_states.shape + # middle + hidden_states = hs[-1] + hidden_states = self.mid.block_1(hidden_states, temb) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class VidTokDecoder3D(nn.Module): + r""" + The `VidTokDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + video. + + Args: + ch (`int`): + The number of the basic channel. + ch_mult (`List[int]`, defaults to `[1, 2, 4, 8]`): + The multiple of the basic channel for each block. + num_res_blocks (`int`, defaults to 2): + The number of resblocks. + dropout (`float`, defaults to 0.0): + Dropout rate. + z_channels (`int`, defaults to 4): + The number of latent channels. + out_channels (`int`, defaults to 3): + The number of output channels. + spatial_us (`List`, *optional*, defaults to `None`): + Spatial upsample layers. + tempo_us (`List`, *optional*, defaults to `None`): + Temporal upsample layers. + is_causal (`bool`, defaults to `True`): + Whether it is a causal module. + """ + + def __init__( + self, + ch: int, + ch_mult: List[int] = [1, 2, 4, 8], + num_res_blocks: int = 2, + dropout: float = 0.0, + z_channels: int = 4, + out_channels: int = 3, + spatial_us: Optional[List] = None, + tempo_us: Optional[List] = None, + is_causal: bool = True, + ): + super().__init__() + + self.is_causal = is_causal + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.nonlinearity = nn.SiLU() + + block_in = ch * ch_mult[self.num_resolutions - 1] + + make_conv_cls = VidTokCausalConv3d if self.is_causal else nn.Conv3d + + self.conv_in = make_conv_cls(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = VidTokResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + btype="3d", + is_causal=self.is_causal, + ) + self.mid.attn_1 = VidTokAttnBlock(block_in, is_causal=self.is_causal) + self.mid.block_2 = VidTokResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + btype="3d", + is_causal=self.is_causal, + ) + + # upsampling + self.spatial_us = list(range(1, self.num_resolutions)) if spatial_us is None else spatial_us + self.tempo_us = [1, 2] if tempo_us is None else tempo_us + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + VidTokResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + btype="2d", + ) + ) + block_in = block_out + + up = nn.Module() + up.block = block + up.attn = attn + if i_level in self.spatial_us: + up.upsample = VidTokUpsample2D(block_in) + self.up.insert(0, up) + + num_temp_upsample = 1 + self.up_temporal = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + VidTokResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + btype="1d", + is_causal=self.is_causal, + ) + ) + block_in = block_out + up_temporal = nn.Module() + up_temporal.block = block + up_temporal.attn = attn + if i_level in self.tempo_us: + up_temporal.upsample = VidTokUpsample3D( + block_in, block_in, num_temp_upsample=num_temp_upsample, is_causal=self.is_causal + ) + num_temp_upsample *= 2 + + self.up_temporal.insert(0, up_temporal) + + # end + self.norm_out = VidTokLayerNorm(dim=block_in, eps=1e-6) + self.conv_out = make_conv_cls(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + self.gradient_checkpointing = False + + def forward(self, z: torch.Tensor) -> torch.Tensor: + temb = None + B, _, T, H, W = z.shape + hidden_states = self.conv_in(z) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + hidden_states = self._gradient_checkpointing_func(self.mid.block_1, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(self.mid.attn_1, hidden_states) + hidden_states = self._gradient_checkpointing_func(self.mid.block_2, hidden_states, temb) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W) + hidden_states = self._gradient_checkpointing_func( + self.up[i_level].block[i_block], hidden_states, temb + ) + hidden_states = ( + hidden_states.reshape(B, T, -1, H, W).permute(0, 3, 4, 2, 1).reshape(B * H * W, -1, T) + ) + hidden_states = self._gradient_checkpointing_func( + self.up_temporal[i_level].block[i_block], hidden_states, temb + ) + hidden_states = hidden_states.reshape(B, H, W, -1, T).permute(0, 3, 4, 1, 2) + + if i_level in self.spatial_us: + # spatial upsample + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W) + hidden_states = self._gradient_checkpointing_func(self.up[i_level].upsample, hidden_states) + hidden_states = hidden_states.reshape(B, T, -1, *hidden_states.shape[-2:]).permute(0, 2, 1, 3, 4) + if i_level in self.tempo_us: + # temporal upsample + hidden_states = self._gradient_checkpointing_func( + self.up_temporal[i_level].upsample, hidden_states + ) + B, _, T, H, W = hidden_states.shape + + else: + # middle + hidden_states = self.mid.block_1(hidden_states, temb) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W) + hidden_states = self.up[i_level].block[i_block](hidden_states, temb) + hidden_states = ( + hidden_states.reshape(B, T, -1, H, W).permute(0, 3, 4, 2, 1).reshape(B * H * W, -1, T) + ) + hidden_states = self.up_temporal[i_level].block[i_block](hidden_states, temb) + hidden_states = hidden_states.reshape(B, H, W, -1, T).permute(0, 3, 4, 1, 2) + + if i_level in self.spatial_us: + # spatial upsample + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W) + hidden_states = self.up[i_level].upsample(hidden_states) + hidden_states = hidden_states.reshape(B, T, -1, *hidden_states.shape[-2:]).permute(0, 2, 1, 3, 4) + if i_level in self.tempo_us: + # temporal upsample + hidden_states = self.up_temporal[i_level].upsample(hidden_states) + B, _, T, H, W = hidden_states.shape + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + out = self.conv_out(hidden_states) + return out + + +class AutoencoderVidTok(ModelMixin, ConfigMixin): + r""" + A VAE model for encoding videos into latents and decoding latent representations into videos, supporting both + continuous and discrete latent representations. Used in [VidTok](https://github.com/microsoft/VidTok). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to 3): + The number of input channels. + out_channels (`int`, defaults to 3): + The number of output channels. + ch (`int`, defaults to 128): + The number of the basic channel. + ch_mult (`List[int]`, defaults to `[1, 2, 4, 4]`): + The multiple of the basic channel for each block. + z_channels (`int`, defaults to 4): + The number of latent channels. + double_z (`bool`, defaults to `True`): + Whether or not to double the z_channels. + num_res_blocks (`int`, defaults to 2): + The number of resblocks. + spatial_ds (`List`, *optional*, defaults to `None`): + Spatial downsample layers. + spatial_us (`List`, *optional*, defaults to `None`): + Spatial upsample layers. + tempo_ds (`List`, *optional*, defaults to `None`): + Temporal downsample layers. + tempo_us (`List`, *optional*, defaults to `None`): + Temporal upsample layers. + dropout (`float`, defaults to 0.0): + Dropout rate. + regularizer (`str`, defaults to `"kl"`): + The regularizer type - "kl" for continuous cases and "fsq" for discrete cases. + codebook_size (`int`, defaults to 262144): + The codebook size used only in discrete cases. + is_causal (`bool`, defaults to `True`): + Whether it is a causal module. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + ch: int = 128, + ch_mult: List[int] = [1, 2, 4, 4], + z_channels: int = 4, + double_z: bool = True, + num_res_blocks: int = 2, + spatial_ds: Optional[List] = None, + spatial_us: Optional[List] = None, + tempo_ds: Optional[List] = None, + tempo_us: Optional[List] = None, + dropout: float = 0.0, + regularizer: str = "kl", + codebook_size: int = 262144, + is_causal: bool = True, + ): + super().__init__() + self.is_causal = is_causal + + self.encoder = VidTokEncoder3D( + in_channels=in_channels, + ch=ch, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + dropout=dropout, + z_channels=z_channels, + double_z=double_z, + spatial_ds=spatial_ds, + tempo_ds=tempo_ds, + is_causal=self.is_causal, + ) + self.decoder = VidTokDecoder3D( + ch=ch, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + dropout=dropout, + z_channels=z_channels, + out_channels=out_channels, + spatial_us=spatial_us, + tempo_us=tempo_us, + is_causal=self.is_causal, + ) + self.temporal_compression_ratio = 2 ** len(self.encoder.tempo_ds) + + self.regularizer = regularizer + if self.regularizer not in ["kl", "fsq"]: + raise ValueError(f"Invalid regularizer: {self.regularizer}. Only `kl` and `fsq` are supported.") + + if self.regularizer == "fsq": + if z_channels != int(math.log(codebook_size, 8)): + raise ValueError( + f"When using the `fsq` regularizer, `z_channels` must be {int(math.log(codebook_size, 8))}, the" + f" log base 8 of the `codebook_size` {codebook_size}, but got {z_channels}." + ) + if double_z: + raise ValueError("When using the `fsq` regularizer, `double_z` must be `False`.") + + self.regularization = FSQRegularizer(levels=[8] * z_channels) + + self.use_slicing = False + self.use_tiling = False + + # Decode more latent frames at once + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = self.num_sample_frames_batch_size // self.temporal_compression_ratio + + # We make the minimum height and width of sample for tiling half that of the generally supported + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** len(self.encoder.spatial_ds))) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** len(self.encoder.spatial_ds))) + self.tile_overlap_factor_height = 0.0 # 1 / 8 + self.tile_overlap_factor_width = 0.0 # 1 / 8 + + @staticmethod + def _pad_at_dim( + t: torch.Tensor, pad: Tuple[int], dim: int = -1, pad_mode: str = "constant", value: float = 0.0 + ) -> torch.Tensor: + r"""Pad function. Supported pad_mode: `constant`, `replicate`, `reflect`.""" + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = (0, 0) * dims_from_right + if pad_mode == "constant": + return F.pad(t, (*zeros, *pad), value=value) + return F.pad(t, (*zeros, *pad), mode=pad_mode) + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_overlap_factor_height: Optional[float] = None, + tile_overlap_factor_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*, defaults to `None`): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*, defaults to `None`): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`float`, *optional*, defaults to `None`): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`float`, *optional*, defaults to `None`): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** len(self.encoder.spatial_ds))) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** len(self.encoder.spatial_ds))) + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + self._empty_causal_cached(self.encoder) + self._set_first_chunk(True) + + if self.use_tiling: + return self.tiled_encode(x) + return self.encoder(x) + + @apply_forward_hook + def encode(self, x: torch.Tensor) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor, torch.Tensor]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `AutoencoderKLOutput` or `Tuple[torch.Tensor]`: + The latent representations of the encoded videos. If the regularizer is `kl`, an `AutoencoderKLOutput` + is returned, otherwise a tuple of `torch.Tensor` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + z = torch.cat(encoded_slices) + else: + z = self._encode(x) + + if self.regularizer == "kl": + posterior = DiagonalGaussianDistribution(z) + return AutoencoderKLOutput(latent_dist=posterior) + else: + quant_z, indices = self.regularization(z) + return quant_z, indices + + def _decode(self, z: torch.Tensor, decode_from_indices: bool = False) -> torch.Tensor: + self._empty_causal_cached(self.decoder) + self._set_first_chunk(True) + if not self.is_causal and z.shape[-3] % self.num_latent_frames_batch_size != 0: + assert z.shape[-3] >= self.num_latent_frames_batch_size, ( + f"Too short latent frames. At least {self.num_latent_frames_batch_size} frames." + ) + z = z[..., : (z.shape[-3] // self.num_latent_frames_batch_size * self.num_latent_frames_batch_size), :, :] + if decode_from_indices: + z = self.tile_indices_to_latent(z) if self.use_tiling else self.indices_to_latent(z) + dec = self.tiled_decode(z) if self.use_tiling else self.decoder(z) + return dec + + @apply_forward_hook + def decode(self, z: torch.Tensor, decode_from_indices: bool = False) -> torch.Tensor: + r""" + Decode a batch of images from latents. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + decode_from_indices (`bool`): If decode from indices or decode from latent code. + Returns: + `torch.Tensor`: The decoded images. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice, decode_from_indices=decode_from_indices) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, decode_from_indices=decode_from_indices) + if self.is_causal: + decoded = decoded[:, :, self.temporal_compression_ratio - 1 :, :, :] + return decoded + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def build_chunk_start_end(self, t, decoder_mode=False): + if self.is_causal: + start_end = [[0, self.temporal_compression_ratio]] if not decoder_mode else [[0, 1]] + start = start_end[0][-1] + else: + start_end, start = [], 0 + end = start + while True: + if start >= t: + break + end = min( + t, end + (self.num_latent_frames_batch_size if decoder_mode else self.num_sample_frames_batch_size) + ) + start_end.append([start, end]) + start = end + if len(start_end) > (2 if self.is_causal else 1): + if start_end[-1][1] - start_end[-1][0] < ( + self.num_latent_frames_batch_size if decoder_mode else self.num_sample_frames_batch_size + ): + start_end[-2] = [start_end[-2][0], start_end[-1][1]] + start_end = start_end[:-1] + return start_end + + def _set_first_chunk(self, is_first_chunk=True): + for module in self.modules(): + if hasattr(module, "is_first_chunk"): + module.is_first_chunk = is_first_chunk + + def _empty_causal_cached(self, parent): + for name, module in parent.named_modules(): + if hasattr(module, "causal_cache"): + module.causal_cache = None + + def _set_cache_offset(self, modules, cache_offset=0): + for module in modules: + for submodule in module.modules(): + if hasattr(submodule, "cache_offset"): + submodule.cache_offset = cache_offset + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r""" + Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: The latent representation of the encoded videos. + """ + num_frames, height, width = x.shape[-3:] + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + start_end = self.build_chunk_start_end(num_frames) + time = [] + for idx, (start_frame, end_frame) in enumerate(start_end): + self._set_first_chunk(idx == 0) + tile = x[ + :, + :, + start_frame:end_frame, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + enc = torch.cat(result_rows, dim=3) + return enc + + def indices_to_latent(self, token_indices: torch.Tensor) -> torch.Tensor: + r""" + Transform indices to latent code. + + Args: + token_indices (`torch.Tensor`): Token indices. + + Returns: + `torch.Tensor`: Latent code corresponding to the input token indices. + """ + b, t, h, w = token_indices.shape + token_indices = token_indices.unsqueeze(-1).reshape(b, -1, 1) + codes = self.regularization.indices_to_codes(token_indices) + codes = codes.permute(0, 2, 3, 1).reshape(b, codes.shape[2], -1) + z = self.regularization.project_out(codes) + return z.reshape(b, t, h, w, -1).permute(0, 4, 1, 2, 3) + + def tile_indices_to_latent(self, token_indices: torch.Tensor) -> torch.Tensor: + r""" + Transform indices to latent code with tiling inference. + + Args: + token_indices (`torch.Tensor`): Token indices. + + Returns: + `torch.Tensor`: Latent code corresponding to the input token indices. + """ + num_frames = token_indices.shape[1] + start_end = self.build_chunk_start_end(num_frames, decoder_mode=True) + result_z = [] + for start, end in start_end: + chunk_z = self.indices_to_latent(token_indices[:, start:end, :, :]) + result_z.append(chunk_z.clone()) + return torch.cat(result_z, dim=2) + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + + Returns: + `torch.Tensor`: Reconstructed batch of videos. + """ + num_frames, height, width = z.shape[-3:] + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + if self.is_causal: + assert self.temporal_compression_ratio in [ + 2, + 4, + 8, + ], "Only support 2x, 4x or 8x temporal downsampling now." + if self.temporal_compression_ratio == 4: + self._set_cache_offset([self.decoder], 1) + self._set_cache_offset([self.decoder.up_temporal[2].upsample, self.decoder.up_temporal[1]], 2) + self._set_cache_offset( + [self.decoder.up_temporal[1].upsample, self.decoder.up_temporal[0], self.decoder.conv_out], + 4, + ) + elif self.temporal_compression_ratio == 2: + self._set_cache_offset([self.decoder], 1) + self._set_cache_offset( + [ + self.decoder.up_temporal[2].upsample, + self.decoder.up_temporal[1], + self.decoder.up_temporal[0], + self.decoder.conv_out, + ], + 2, + ) + else: + self._set_cache_offset([self.decoder], 1) + self._set_cache_offset([self.decoder.up_temporal[3].upsample, self.decoder.up_temporal[2]], 2) + self._set_cache_offset([self.decoder.up_temporal[2].upsample, self.decoder.up_temporal[1]], 4) + self._set_cache_offset( + [self.decoder.up_temporal[1].upsample, self.decoder.up_temporal[0], self.decoder.conv_out], + 8, + ) + + start_end = self.build_chunk_start_end(num_frames, decoder_mode=True) + time = [] + for idx, (start_frame, end_frame) in enumerate(start_end): + self._set_first_chunk(idx == 0) + tile = z[ + :, + :, + start_frame : (end_frame + 1 if self.is_causal and end_frame + 1 <= num_frames else end_frame), + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + tile = self.decoder(tile) + if self.is_causal and end_frame + 1 <= num_frames: + tile = tile[:, :, : -self.temporal_compression_ratio] + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = True, + encoder_mode: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, DecoderOutput]: + x = sample + res = 1 if self.is_causal else 0 + if self.is_causal: + if x.shape[2] % self.temporal_compression_ratio != res: + time_padding = self.temporal_compression_ratio - x.shape[2] % self.temporal_compression_ratio + res + x = self._pad_at_dim(x, (0, time_padding), dim=2, pad_mode="replicate") + else: + time_padding = 0 + else: + if x.shape[2] % self.num_sample_frames_batch_size != res: + if not encoder_mode: + time_padding = ( + self.num_sample_frames_batch_size - x.shape[2] % self.num_sample_frames_batch_size + res + ) + x = self._pad_at_dim(x, (0, time_padding), dim=2, pad_mode="replicate") + else: + assert x.shape[2] >= self.num_sample_frames_batch_size, ( + f"Too short video. At least {self.num_sample_frames_batch_size} frames." + ) + x = x[:, :, : x.shape[2] // self.num_sample_frames_batch_size * self.num_sample_frames_batch_size] + else: + time_padding = 0 + + if self.is_causal: + x = self._pad_at_dim(x, (self.temporal_compression_ratio - 1, 0), dim=2, pad_mode="replicate") + + if self.regularizer == "kl": + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + if encoder_mode: + return z + else: + z, indices = self.encode(x) + if encoder_mode: + return z, indices + + dec = self.decode(z) + if time_padding != 0: + dec = dec[:, :, :-time_padding, :, :] + + if not return_dict: + return dec + return DecoderOutput(sample=dec) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5244273cb596..3425cc8d2b61 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -686,6 +686,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderVidTok(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/autoencoders/test_models_autoencoder_vidtok.py b/tests/models/autoencoders/test_models_autoencoder_vidtok.py new file mode 100644 index 000000000000..70932f2b55aa --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_vidtok.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderVidTok +from diffusers.utils.testing_utils import ( + floats_tensor, + torch_device, +) + +from ...testing_utils import IS_GITHUB_ACTIONS +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +class AutoencoderVidTokTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderVidTok + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_vidtok_config(self): + return { + "is_causal": False, + "in_channels": 3, + "out_channels": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4, 4], + "z_channels": 6, + "double_z": False, + "num_res_blocks": 2, + "regularizer": "fsq", + "codebook_size": 262144, + } + + @property + def dummy_input(self): + batch_size = 4 + num_frames = 16 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 16, 32, 32) + + @property + def output_shape(self): + return (3, 16, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_vidtok_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_enable_disable_tiling(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + torch.manual_seed(0) + output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_tiling() + output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), + 0.5, + "VAE tiling should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_tiling() + output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + "Without tiling outputs should match with the outputs when tiling is manually disabled.", + ) + + def test_enable_disable_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_slicing() + output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), + 0.5, + "VAE slicing should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_slicing() + output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + "Without slicing outputs should match with the outputs when slicing is manually disabled.", + ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "VidTokEncoder3D", + "VidTokDecoder3D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + def test_forward_with_norm_groups(self): + r"""VidTok uses layernorm instead of groupnorm.""" + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") + def test_layerwise_casting_training(self): + super().test_layerwise_casting_training() From 9abea3969e8b10857a184a4a2eca6856ad931f47 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 9 Mar 2026 15:17:59 +0530 Subject: [PATCH 033/215] [tests] Use `tmp_path` fixture modular tests (#13194) * add a test to check modular index consistency * check for compulsory keys. * use fixture for tmp_path in modular tests. * remove unneeded test. * fix code quality. * up * up --- .../flux/test_modular_pipeline_flux.py | 29 ++++------ .../test_modular_pipelines_common.py | 57 +++++++++---------- .../test_modular_pipelines_custom_blocks.py | 27 +++++---- 3 files changed, 52 insertions(+), 61 deletions(-) diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py index 9a6b4b9b6fb4..05fe16e372ec 100644 --- a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py +++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py @@ -14,7 +14,6 @@ # limitations under the License. import random -import tempfile import numpy as np import PIL @@ -129,18 +128,16 @@ def get_dummy_inputs(self, seed=0): return inputs - def test_save_from_pretrained(self): + def test_save_from_pretrained(self, tmp_path): pipes = [] base_pipe = self.get_pipeline().to(torch_device) pipes.append(base_pipe) - with tempfile.TemporaryDirectory() as tmpdirname: - base_pipe.save_pretrained(tmpdirname) - - pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) - pipe.load_components(torch_dtype=torch.float32) - pipe.to(torch_device) - pipe.image_processor = VaeImageProcessor(vae_scale_factor=2) + base_pipe.save_pretrained(str(tmp_path)) + pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipe.image_processor = VaeImageProcessor(vae_scale_factor=2) pipes.append(pipe) @@ -212,18 +209,16 @@ def get_dummy_inputs(self, seed=0): return inputs - def test_save_from_pretrained(self): + def test_save_from_pretrained(self, tmp_path): pipes = [] base_pipe = self.get_pipeline().to(torch_device) pipes.append(base_pipe) - with tempfile.TemporaryDirectory() as tmpdirname: - base_pipe.save_pretrained(tmpdirname) - - pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) - pipe.load_components(torch_dtype=torch.float32) - pipe.to(torch_device) - pipe.image_processor = VaeImageProcessor(vae_scale_factor=2) + base_pipe.save_pretrained(str(tmp_path)) + pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipe.image_processor = VaeImageProcessor(vae_scale_factor=2) pipes.append(pipe) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index c1a402a2fd8f..d897ed793376 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -1,7 +1,6 @@ import gc import json import os -import tempfile from typing import Callable import pytest @@ -341,16 +340,15 @@ def test_components_auto_cpu_offload_inference_consistent(self): assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 - def test_save_from_pretrained(self): + def test_save_from_pretrained(self, tmp_path): pipes = [] base_pipe = self.get_pipeline().to(torch_device) pipes.append(base_pipe) - with tempfile.TemporaryDirectory() as tmpdirname: - base_pipe.save_pretrained(tmpdirname) - pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device) - pipe.load_components(torch_dtype=torch.float32) - pipe.to(torch_device) + base_pipe.save_pretrained(str(tmp_path)) + pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) pipes.append(pipe) @@ -362,32 +360,31 @@ def test_save_from_pretrained(self): assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 - def test_modular_index_consistency(self): + def test_modular_index_consistency(self, tmp_path): pipe = self.get_pipeline() components_spec = pipe._component_specs components = sorted(components_spec.keys()) - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir) - index_file = os.path.join(tmpdir, "modular_model_index.json") - assert os.path.exists(index_file) + pipe.save_pretrained(str(tmp_path)) + index_file = tmp_path / "modular_model_index.json" + assert index_file.exists() - with open(index_file) as f: - index_contents = json.load(f) + with open(index_file) as f: + index_contents = json.load(f) - compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"} - for k in compulsory_keys: - assert k in index_contents + compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"} + for k in compulsory_keys: + assert k in index_contents - to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"} - for component in components: - spec = components_spec[component] - for attr in to_check_attrs: - if getattr(spec, "pretrained_model_name_or_path", None) is not None: - for attr in to_check_attrs: - assert component in index_contents, f"{component} should be present in index but isn't." - attr_value_from_index = index_contents[component][2][attr] - assert getattr(spec, attr) == attr_value_from_index + to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"} + for component in components: + spec = components_spec[component] + for attr in to_check_attrs: + if getattr(spec, "pretrained_model_name_or_path", None) is not None: + for attr in to_check_attrs: + assert component in index_contents, f"{component} should be present in index but isn't." + attr_value_from_index = index_contents[component][2][attr] + assert getattr(spec, attr) == attr_value_from_index def test_workflow_map(self): blocks = self.pipeline_blocks_class() @@ -483,7 +480,7 @@ class DummyBlockTwo: def test_sequential_block_requirements_save_load(self, tmp_path): pipe = self.get_dummy_block_pipe() - pipe.save_pretrained(tmp_path) + pipe.save_pretrained(str(tmp_path)) config_path = tmp_path / "modular_config.json" @@ -508,7 +505,7 @@ def test_sequential_block_requirements_warnings(self, tmp_path): logger.setLevel(30) with CaptureLogger(logger) as cap_logger: - pipe.save_pretrained(tmp_path) + pipe.save_pretrained(str(tmp_path)) template = "{req} was specified in the requirements but wasn't found in the current environment" msg_xyz = template.format(req="xyz") @@ -518,7 +515,7 @@ def test_sequential_block_requirements_warnings(self, tmp_path): def test_conditional_block_requirements_save_load(self, tmp_path): pipe = self.get_dummy_conditional_block_pipe() - pipe.save_pretrained(tmp_path) + pipe.save_pretrained(str(tmp_path)) config_path = tmp_path / "modular_config.json" with open(config_path, "r") as f: @@ -535,7 +532,7 @@ def test_conditional_block_requirements_save_load(self, tmp_path): def test_loop_block_requirements_save_load(self, tmp_path): pipe = self.get_dummy_loop_block_pipe() - pipe.save_pretrained(tmp_path) + pipe.save_pretrained(str(tmp_path)) config_path = tmp_path / "modular_config.json" with open(config_path, "r") as f: diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 766ca0c16f86..7c6e97a36eb7 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -153,25 +153,24 @@ def test_custom_block_output(self): output_prompt = output.values["output_prompt"] assert output_prompt.startswith("Modular diffusers + ") - def test_custom_block_saving_loading(self): + def test_custom_block_saving_loading(self, tmp_path): custom_block = DummyCustomBlockSimple() - with tempfile.TemporaryDirectory() as tmpdir: - custom_block.save_pretrained(tmpdir) - assert any("modular_config.json" in k for k in os.listdir(tmpdir)) + custom_block.save_pretrained(tmp_path) + assert any("modular_config.json" in k for k in os.listdir(tmp_path)) - with open(os.path.join(tmpdir, "modular_config.json"), "r") as f: - config = json.load(f) - auto_map = config["auto_map"] - assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"} + with open(os.path.join(tmp_path, "modular_config.json"), "r") as f: + config = json.load(f) + auto_map = config["auto_map"] + assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"} - # For now, the Python script that implements the custom block has to be manually pushed to the Hub. - # This is why, we have to separately save the Python script here. - code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py") - with open(code_path, "w") as f: - f.write(CODE_STR) + # For now, the Python script that implements the custom block has to be manually pushed to the Hub. + # This is why, we have to separately save the Python script here. + code_path = os.path.join(tmp_path, "test_modular_pipelines_custom_blocks.py") + with open(code_path, "w") as f: + f.write(CODE_STR) - loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True) + loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmp_path, trust_remote_code=True) pipe = loaded_custom_block.init_pipeline() prompt = "Diffusers is nice" From 134f0bdc4d4d83ad8bfebe2feafd7ed233fcc593 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 9 Mar 2026 17:51:42 +0530 Subject: [PATCH 034/215] [CI] Add Workflow permissions to PR tests (#13233) [CI] Add workflow permissions to PR tests Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- .github/workflows/pr_tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 3e69dfd05eae..02dee7d541b7 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -16,6 +16,9 @@ on: branches: - ci-* +permissions: + contents: read + concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true From 0d920d4b831f8c23cdd916c9c746438fbb6f0914 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 10 Mar 2026 04:30:56 +0800 Subject: [PATCH 035/215] fix: allow pass cpu generator for helios (#13228) * allow pass cpu generator for helios * allow pass cpu generator for helios * allow pass cpu generator for helios * patch --- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index d8f317a9a6f1..1791da11b490 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -456,6 +456,8 @@ def sample_block_noise( # the output will be non-deterministic and may produce incorrect results in CP context. if generator is None: generator = torch.Generator(device=device) + elif isinstance(generator, list): + generator = generator[0] gamma = self.scheduler.config.gamma _, ph, pw = patch_size @@ -470,7 +472,8 @@ def sample_block_noise( L = torch.linalg.cholesky(cov) block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) - z = torch.randn(block_number, block_size, device=device, generator=generator) + z = torch.randn(block_number, block_size, generator=generator, device=generator.device) + z = z.to(device=device) noise = z @ L.T noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) From c39f53eca00866860137d19a6f643aefee1b142f Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 9 Mar 2026 10:37:56 -1000 Subject: [PATCH 036/215] [modular] helios (#13216) * add helios modular * upup * revert change in guider * up * fix for real * fix batch test * Apply suggestion from @yiyixuxu --------- Co-authored-by: yiyi@huggingface.co --- src/diffusers/__init__.py | 12 + src/diffusers/modular_pipelines/__init__.py | 16 + .../modular_pipelines/helios/__init__.py | 59 + .../helios/before_denoise.py | 836 +++++++++++++ .../modular_pipelines/helios/decoders.py | 110 ++ .../modular_pipelines/helios/denoise.py | 1069 +++++++++++++++++ .../modular_pipelines/helios/encoders.py | 392 ++++++ .../helios/modular_blocks_helios.py | 542 +++++++++ .../helios/modular_blocks_helios_pyramid.py | 520 ++++++++ ...modular_blocks_helios_pyramid_distilled.py | 530 ++++++++ .../helios/modular_pipeline.py | 87 ++ .../modular_pipelines/modular_pipeline.py | 12 + .../dummy_torch_and_transformers_objects.py | 90 ++ tests/modular_pipelines/helios/__init__.py | 0 .../helios/test_modular_pipeline_helios.py | 166 +++ 15 files changed, 4441 insertions(+) create mode 100644 src/diffusers/modular_pipelines/helios/__init__.py create mode 100644 src/diffusers/modular_pipelines/helios/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/helios/decoders.py create mode 100644 src/diffusers/modular_pipelines/helios/denoise.py create mode 100644 src/diffusers/modular_pipelines/helios/encoders.py create mode 100644 src/diffusers/modular_pipelines/helios/modular_blocks_helios.py create mode 100644 src/diffusers/modular_pipelines/helios/modular_blocks_helios_pyramid.py create mode 100644 src/diffusers/modular_pipelines/helios/modular_blocks_helios_pyramid_distilled.py create mode 100644 src/diffusers/modular_pipelines/helios/modular_pipeline.py create mode 100644 tests/modular_pipelines/helios/__init__.py create mode 100644 tests/modular_pipelines/helios/test_modular_pipeline_helios.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d6d557a4c224..546fbe57be9e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -434,6 +434,12 @@ "FluxKontextAutoBlocks", "FluxKontextModularPipeline", "FluxModularPipeline", + "HeliosAutoBlocks", + "HeliosModularPipeline", + "HeliosPyramidAutoBlocks", + "HeliosPyramidDistilledAutoBlocks", + "HeliosPyramidDistilledModularPipeline", + "HeliosPyramidModularPipeline", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", "QwenImageEditModularPipeline", @@ -1188,6 +1194,12 @@ FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline, + HeliosAutoBlocks, + HeliosModularPipeline, + HeliosPyramidAutoBlocks, + HeliosPyramidDistilledAutoBlocks, + HeliosPyramidDistilledModularPipeline, + HeliosPyramidModularPipeline, QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index c9bebd8644f7..fd9bd691ca87 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -56,6 +56,14 @@ "WanImage2VideoModularPipeline", "Wan22Image2VideoModularPipeline", ] + _import_structure["helios"] = [ + "HeliosAutoBlocks", + "HeliosModularPipeline", + "HeliosPyramidAutoBlocks", + "HeliosPyramidDistilledAutoBlocks", + "HeliosPyramidDistilledModularPipeline", + "HeliosPyramidModularPipeline", + ] _import_structure["flux"] = [ "FluxAutoBlocks", "FluxModularPipeline", @@ -103,6 +111,14 @@ Flux2KleinModularPipeline, Flux2ModularPipeline, ) + from .helios import ( + HeliosAutoBlocks, + HeliosModularPipeline, + HeliosPyramidAutoBlocks, + HeliosPyramidDistilledAutoBlocks, + HeliosPyramidDistilledModularPipeline, + HeliosPyramidModularPipeline, + ) from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/helios/__init__.py b/src/diffusers/modular_pipelines/helios/__init__.py new file mode 100644 index 000000000000..26551399a3e8 --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/__init__.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_helios"] = ["HeliosAutoBlocks"] + _import_structure["modular_blocks_helios_pyramid"] = ["HeliosPyramidAutoBlocks"] + _import_structure["modular_blocks_helios_pyramid_distilled"] = ["HeliosPyramidDistilledAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "HeliosModularPipeline", + "HeliosPyramidDistilledModularPipeline", + "HeliosPyramidModularPipeline", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_helios import HeliosAutoBlocks + from .modular_blocks_helios_pyramid import HeliosPyramidAutoBlocks + from .modular_blocks_helios_pyramid_distilled import HeliosPyramidDistilledAutoBlocks + from .modular_pipeline import ( + HeliosModularPipeline, + HeliosPyramidDistilledModularPipeline, + HeliosPyramidModularPipeline, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/helios/before_denoise.py b/src/diffusers/modular_pipelines/helios/before_denoise.py new file mode 100644 index 000000000000..6d317fa737f4 --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/before_denoise.py @@ -0,0 +1,836 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +from ...models import HeliosTransformer3DModel +from ...schedulers import HeliosScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import HeliosModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class HeliosTextInputStep(ModularPipelineBlocks): + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_videos_per_prompt." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "num_videos_per_prompt", + default=1, + type_hint=int, + description="Number of videos to generate per prompt.", + ), + InputParam.template("prompt_embeds"), + InputParam.template("negative_prompt_embeds"), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds.dtype`)", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_videos_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + + return components, state + + +# Copied from diffusers.modular_pipelines.wan.before_denoise.repeat_tensor_to_batch_size +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_videos_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times + - If batch size equals batch_size: repeat each element num_videos_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_videos_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_videos_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_videos_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +# Copied from diffusers.modular_pipelines.wan.before_denoise.calculate_dimension_from_latents +def calculate_dimension_from_latents( + latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int +) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by + multiplying the latent num_frames/height/width by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension. + Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension) + vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + if latents.ndim != 5: + raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}") + + _, _, num_latent_frames, latent_height, latent_width = latents.shape + + num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1 + height = latent_height * vae_scale_factor_spatial + width = latent_width * vae_scale_factor_spatial + + return num_frames, height, width + + +class HeliosAdditionalInputsStep(ModularPipelineBlocks): + """Configurable step that standardizes inputs for the denoising step. + + This step handles: + 1. For encoded image latents: Computes height/width from latents and expands batch size + 2. For additional_batch_inputs: Expands batch dimensions to match final batch size + """ + + model_name = "helios" + + def __init__( + self, + image_latent_inputs: list[InputParam] | None = None, + additional_batch_inputs: list[InputParam] | None = None, + ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Computes height/width from latents and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="num_videos_per_prompt", default=1), + InputParam(name="batch_size", required=True), + ] + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + outputs = [ + OutputParam("height", type_hint=int), + OutputParam("width", type_hint=int), + ] + + for input_param in self._image_latent_inputs: + outputs.append(OutputParam(input_param.name, type_hint=torch.Tensor)) + + for input_param in self._additional_batch_inputs: + outputs.append(OutputParam(input_param.name, type_hint=torch.Tensor)) + + return outputs + + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + for input_param in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, input_param.name) + if image_latent_tensor is None: + continue + + # Calculate height/width from latents + _, height, width = calculate_dimension_from_latents( + image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial + ) + block_state.height = height + block_state.width = width + + # Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=input_param.name, + input_tensor=image_latent_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_param.name, image_latent_tensor) + + for input_param in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_param.name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_param.name, + input_tensor=input_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_param.name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class HeliosAddNoiseToImageLatentsStep(ModularPipelineBlocks): + """Adds noise to image_latents and fake_image_latents for I2V conditioning. + + Applies single-sigma noise to image_latents (using image_noise_sigma range) and single-sigma noise to + fake_image_latents (using video_noise_sigma range). + """ + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Adds noise to image_latents and fake_image_latents for I2V conditioning. " + "Uses random sigma from configured ranges for each." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image_latents"), + InputParam( + "fake_image_latents", + required=True, + type_hint=torch.Tensor, + description="Fake image latents used as history seed for I2V generation.", + ), + InputParam( + "image_noise_sigma_min", + default=0.111, + type_hint=float, + description="Minimum sigma for image latent noise.", + ), + InputParam( + "image_noise_sigma_max", + default=0.135, + type_hint=float, + description="Maximum sigma for image latent noise.", + ), + InputParam( + "video_noise_sigma_min", + default=0.111, + type_hint=float, + description="Minimum sigma for video/fake-image latent noise.", + ), + InputParam( + "video_noise_sigma_max", + default=0.135, + type_hint=float, + description="Maximum sigma for video/fake-image latent noise.", + ), + InputParam.template("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("image_latents"), + OutputParam("fake_image_latents", type_hint=torch.Tensor, description="Noisy fake image latents"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + image_latents = block_state.image_latents + fake_image_latents = block_state.fake_image_latents + + # Add noise to image_latents + image_noise_sigma = ( + torch.rand(1, device=device, generator=block_state.generator) + * (block_state.image_noise_sigma_max - block_state.image_noise_sigma_min) + + block_state.image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=block_state.generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + + # Add noise to fake_image_latents + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=block_state.generator) + * (block_state.video_noise_sigma_max - block_state.video_noise_sigma_min) + + block_state.video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma + * randn_tensor(fake_image_latents.shape, generator=block_state.generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + + block_state.image_latents = image_latents.to(device=device, dtype=torch.float32) + block_state.fake_image_latents = fake_image_latents.to(device=device, dtype=torch.float32) + + self.set_block_state(state, block_state) + return components, state + + +class HeliosAddNoiseToVideoLatentsStep(ModularPipelineBlocks): + """Adds noise to image_latents and video_latents for V2V conditioning. + + Applies single-sigma noise to image_latents (using image_noise_sigma range) and per-frame noise to video_latents in + chunks (using video_noise_sigma range). + """ + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Adds noise to image_latents and video_latents for V2V conditioning. " + "Uses single-sigma noise for image_latents and per-frame noise for video chunks." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image_latents"), + InputParam( + "video_latents", + required=True, + type_hint=torch.Tensor, + description="Encoded video latents for V2V generation.", + ), + InputParam( + "num_latent_frames_per_chunk", + default=9, + type_hint=int, + description="Number of latent frames per temporal chunk.", + ), + InputParam( + "image_noise_sigma_min", + default=0.111, + type_hint=float, + description="Minimum sigma for image latent noise.", + ), + InputParam( + "image_noise_sigma_max", + default=0.135, + type_hint=float, + description="Maximum sigma for image latent noise.", + ), + InputParam( + "video_noise_sigma_min", + default=0.111, + type_hint=float, + description="Minimum sigma for video latent noise.", + ), + InputParam( + "video_noise_sigma_max", + default=0.135, + type_hint=float, + description="Maximum sigma for video latent noise.", + ), + InputParam.template("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("image_latents"), + OutputParam("video_latents", type_hint=torch.Tensor, description="Noisy video latents"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + image_latents = block_state.image_latents + video_latents = block_state.video_latents + num_latent_frames_per_chunk = block_state.num_latent_frames_per_chunk + + # Add noise to first frame (single sigma) + image_noise_sigma = ( + torch.rand(1, device=device, generator=block_state.generator) + * (block_state.image_noise_sigma_max - block_state.image_noise_sigma_min) + + block_state.image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=block_state.generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + + # Add per-frame noise to video chunks + noisy_latents_chunks = [] + num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk + for i in range(num_latent_chunks): + chunk_start = i * num_latent_frames_per_chunk + chunk_end = chunk_start + num_latent_frames_per_chunk + latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :] + + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand(chunk_frames, device=device, generator=block_state.generator) + * (block_state.video_noise_sigma_max - block_state.video_noise_sigma_min) + + block_state.video_noise_sigma_min + ) + frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1) + + noisy_chunk = ( + frame_sigmas * randn_tensor(latent_chunk.shape, generator=block_state.generator, device=device) + + (1 - frame_sigmas) * latent_chunk + ) + noisy_latents_chunks.append(noisy_chunk) + video_latents = torch.cat(noisy_latents_chunks, dim=2) + + block_state.image_latents = image_latents.to(device=device, dtype=torch.float32) + block_state.video_latents = video_latents.to(device=device, dtype=torch.float32) + + self.set_block_state(state, block_state) + return components, state + + +class HeliosPrepareHistoryStep(ModularPipelineBlocks): + """Prepares chunk/history indices and initializes history state for the chunk loop.""" + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Prepares the chunk loop by computing latent dimensions, number of chunks, " + "history indices, and initializing history state (history_latents, image_latents, latent_chunks)." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", HeliosTransformer3DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height", default=384), + InputParam.template("width", default=640), + InputParam( + "num_frames", default=132, type_hint=int, description="Total number of video frames to generate." + ), + InputParam("batch_size", required=True, type_hint=int), + InputParam( + "num_latent_frames_per_chunk", + default=9, + type_hint=int, + description="Number of latent frames per temporal chunk.", + ), + InputParam( + "history_sizes", + default=[16, 2, 1], + type_hint=list, + description="Sizes of long/mid/short history buffers for temporal context.", + ), + InputParam( + "keep_first_frame", + default=True, + type_hint=bool, + description="Whether to keep the first frame as a prefix in history.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("num_latent_chunk", type_hint=int, description="Number of temporal chunks"), + OutputParam("latent_shape", type_hint=tuple, description="Shape of latent tensor per chunk"), + OutputParam("history_sizes", type_hint=list, description="Adjusted history sizes (sorted, descending)"), + OutputParam("indices_hidden_states", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("indices_latents_history_short", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("indices_latents_history_mid", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("indices_latents_history_long", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("history_latents", type_hint=torch.Tensor, description="Initialized zero history latents"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size = block_state.batch_size + device = components._execution_device + + block_state.num_frames = max(block_state.num_frames, 1) + history_sizes = sorted(block_state.history_sizes, reverse=True) + + num_channels_latents = components.num_channels_latents + h_latent = block_state.height // components.vae_scale_factor_spatial + w_latent = block_state.width // components.vae_scale_factor_spatial + + # Compute number of chunks + block_state.window_num_frames = ( + block_state.num_latent_frames_per_chunk - 1 + ) * components.vae_scale_factor_temporal + 1 + block_state.num_latent_chunk = max( + 1, (block_state.num_frames + block_state.window_num_frames - 1) // block_state.window_num_frames + ) + + # Modify history_sizes for non-keep_first_frame (matching pipeline behavior) + if not block_state.keep_first_frame: + history_sizes = history_sizes.copy() + history_sizes[-1] = history_sizes[-1] + 1 + + # Compute indices ONCE (same structure for all chunks) + if block_state.keep_first_frame: + indices = torch.arange(0, sum([1, *history_sizes, block_state.num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, block_state.num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + else: + indices = torch.arange(0, sum([*history_sizes, block_state.num_latent_frames_per_chunk])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, block_state.num_latent_frames_per_chunk], dim=0) + + # Latent shape per chunk + block_state.latent_shape = ( + batch_size, + num_channels_latents, + block_state.num_latent_frames_per_chunk, + h_latent, + w_latent, + ) + + # Set outputs + block_state.history_sizes = history_sizes + block_state.indices_hidden_states = indices_hidden_states.unsqueeze(0) + block_state.indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + block_state.indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + block_state.indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + block_state.history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + h_latent, + w_latent, + device=device, + dtype=torch.float32, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class HeliosI2VSeedHistoryStep(ModularPipelineBlocks): + """Seeds history_latents with fake_image_latents for I2V pipelines. + + This small additive step runs after HeliosPrepareHistoryStep and appends fake_image_latents to the initialized + history_latents tensor. + """ + + model_name = "helios" + + @property + def description(self) -> str: + return "I2V history seeding: appends fake_image_latents to history_latents." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("history_latents", required=True, type_hint=torch.Tensor), + InputParam("fake_image_latents", required=True, type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "history_latents", type_hint=torch.Tensor, description="History latents seeded with fake_image_latents" + ), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.history_latents = torch.cat([block_state.history_latents, block_state.fake_image_latents], dim=2) + + self.set_block_state(state, block_state) + return components, state + + +class HeliosV2VSeedHistoryStep(ModularPipelineBlocks): + """Seeds history_latents with video_latents for V2V pipelines. + + This step runs after HeliosPrepareHistoryStep and replaces the tail of history_latents with video_latents. If the + video has fewer frames than the history, the beginning of history is preserved. + """ + + model_name = "helios" + + @property + def description(self) -> str: + return "V2V history seeding: replaces the tail of history_latents with video_latents." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("history_latents", required=True, type_hint=torch.Tensor), + InputParam("video_latents", required=True, type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "history_latents", type_hint=torch.Tensor, description="History latents seeded with video_latents" + ), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + history_latents = block_state.history_latents + video_latents = block_state.video_latents + + history_frames = history_latents.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2) + else: + history_latents = video_latents + + block_state.history_latents = history_latents + + self.set_block_state(state, block_state) + return components, state + + +class HeliosSetTimestepsStep(ModularPipelineBlocks): + """Computes scheduler parameters (mu, sigmas) for the chunk loop.""" + + model_name = "helios" + + @property + def description(self) -> str: + return "Computes scheduler shift parameter (mu) and default sigmas for the Helios chunk loop." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", HeliosTransformer3DModel), + ComponentSpec("scheduler", HeliosScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latent_shape", required=True, type_hint=tuple), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("mu", type_hint=float, description="Scheduler shift parameter"), + OutputParam("sigmas", type_hint=list, description="Sigma schedule for diffusion"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + patch_size = components.transformer.config.patch_size + latent_shape = block_state.latent_shape + image_seq_len = (latent_shape[-1] * latent_shape[-2] * latent_shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + + if block_state.sigmas is None: + block_state.sigmas = np.linspace(0.999, 0.0, block_state.num_inference_steps + 1)[:-1] + + block_state.mu = calculate_shift( + image_seq_len, + components.scheduler.config.get("base_image_seq_len", 256), + components.scheduler.config.get("max_image_seq_len", 4096), + components.scheduler.config.get("base_shift", 0.5), + components.scheduler.config.get("max_shift", 1.15), + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/helios/decoders.py b/src/diffusers/modular_pipelines/helios/decoders.py new file mode 100644 index 000000000000..f08ddedfd15a --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/decoders.py @@ -0,0 +1,110 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLWan +from ...utils import logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HeliosDecodeStep(ModularPipelineBlocks): + """Decode all chunk latents with VAE, trim frames, and postprocess into final video output.""" + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Decodes all chunk latents with the VAE, concatenates them, " + "trims to the target frame count, and postprocesses into the final video output." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latent_chunks", required=True, type_hint=list, description="List of per-chunk denoised latent tensors" + ), + InputParam("num_frames", required=True, type_hint=int, description="The target number of output frames"), + InputParam.template("output_type", default="np"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "videos", + type_hint=list[list[PIL.Image.Image]] | list[torch.Tensor] | list[np.ndarray], + description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae = components.vae + + latents_mean = ( + torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(vae.device, vae.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( + vae.device, vae.dtype + ) + + history_video = None + for chunk_latents in block_state.latent_chunks: + current_latents = chunk_latents.to(vae.dtype) / latents_std + latents_mean + current_video = vae.decode(current_latents, return_dict=False)[0] + + if history_video is None: + history_video = current_video + else: + history_video = torch.cat([history_video, current_video], dim=2) + + # Trim to proper frame count + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // components.vae_scale_factor_temporal * components.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + + block_state.videos = components.video_processor.postprocess_video( + history_video, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/helios/denoise.py b/src/diffusers/modular_pipelines/helios/denoise.py new file mode 100644 index 000000000000..ff7a3699c51f --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/denoise.py @@ -0,0 +1,1069 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math + +import torch +import torch.nn.functional as F +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance +from ...models import HeliosTransformer3DModel +from ...schedulers import HeliosScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .before_denoise import calculate_shift +from .modular_pipeline import HeliosModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def sample_block_noise( + batch_size, + channel, + num_frames, + height, + width, + gamma, + patch_size=(1, 2, 2), + device=None, + generator=None, +): + """Generate spatially-correlated block noise for pyramid upsampling correction. + + Uses a multivariate normal distribution with covariance based on `gamma` to produce noise with block structure, + matching the upsampling artifacts that need correction. + """ + # NOTE: A generator must be provided to ensure correct and reproducible results. + # Creating a default generator here is a fallback only — without a fixed seed, + # the output will be non-deterministic and may produce incorrect results in CP context. + if generator is None: + generator = torch.Generator(device=device) + elif isinstance(generator, list): + generator = generator[0] + + _, ph, pw = patch_size + block_size = ph * pw + + cov = ( + torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma + ) + cov += torch.eye(block_size, device=device) * 1e-8 + cov = cov.float() # Upcast to fp32 for numerical stability — cholesky is unreliable in fp16/bf16. + + L = torch.linalg.cholesky(cov) + block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + z = torch.randn(block_number, block_size, device=generator.device, generator=generator).to(device) + noise = z @ L.T + + noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) + noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + return noise + + +# ======================================== +# Chunk Loop Leaf Blocks +# ======================================== + + +class HeliosChunkHistorySliceStep(ModularPipelineBlocks): + """Slices history latents into short/mid/long for a T2V chunk. + + At k==0 with no image_latents, creates a zero prefix. Otherwise uses image_latents (either provided or captured + from first chunk by HeliosChunkUpdateStep). + """ + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "T2V history slice: splits history into long/mid/short. At k==0 with no image_latents, " + "creates a zero prefix; otherwise uses image_latents as prefix for short history." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "keep_first_frame", + default=True, + type_hint=bool, + description="Whether to keep the first frame as a prefix in history.", + ), + InputParam( + "history_sizes", + required=True, + type_hint=list, + description="Sizes of long/mid/short history buffers for temporal context.", + ), + InputParam( + "history_latents", + required=True, + type_hint=torch.Tensor, + description="Accumulated history latents from previous chunks.", + ), + InputParam("latent_shape", required=True, type_hint=tuple), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + keep_first_frame = block_state.keep_first_frame + history_sizes = block_state.history_sizes + image_latents = block_state.image_latents + device = components._execution_device + + batch_size, num_channels_latents, _, h_latent, w_latent = block_state.latent_shape + + if keep_first_frame: + latents_history_long, latents_history_mid, latents_history_1x = block_state.history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + if image_latents is None and k == 0: + latents_prefix = torch.zeros( + batch_size, + num_channels_latents, + 1, + h_latent, + w_latent, + device=device, + dtype=torch.float32, + ) + else: + latents_prefix = image_latents + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + latents_history_long, latents_history_mid, latents_history_short = block_state.history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + + block_state.latents_history_short = latents_history_short + block_state.latents_history_mid = latents_history_mid + block_state.latents_history_long = latents_history_long + + return components, block_state + + +class HeliosI2VChunkHistorySliceStep(ModularPipelineBlocks): + """Slices history latents into short/mid/long for an I2V chunk. + + Always uses image_latents as prefix (assumes history pre-seeded with fake_image_latents). + """ + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "I2V history slice: splits pre-seeded history into long/mid/short, " + "always using image_latents as prefix for short history." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "keep_first_frame", + default=True, + type_hint=bool, + description="Whether to keep the first frame as a prefix in history.", + ), + InputParam( + "history_sizes", + required=True, + type_hint=list, + description="Sizes of long/mid/short history buffers for temporal context.", + ), + InputParam( + "history_latents", + required=True, + type_hint=torch.Tensor, + description="Accumulated history latents from previous chunks.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="First-frame latents used as prefix for short history.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + keep_first_frame = block_state.keep_first_frame + history_sizes = block_state.history_sizes + image_latents = block_state.image_latents + + if keep_first_frame: + latents_history_long, latents_history_mid, latents_history_1x = block_state.history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + latents_history_short = torch.cat([image_latents, latents_history_1x], dim=2) + else: + latents_history_long, latents_history_mid, latents_history_short = block_state.history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + + block_state.latents_history_short = latents_history_short + block_state.latents_history_mid = latents_history_mid + block_state.latents_history_long = latents_history_long + + return components, block_state + + +class HeliosChunkNoiseGenStep(ModularPipelineBlocks): + """Generates noise latents for a chunk using randn_tensor.""" + + model_name = "helios" + + @property + def description(self) -> str: + return "Generates random noise latents at full resolution for a single chunk." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latent_shape", required=True, type_hint=tuple), + InputParam.template("generator"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + device = components._execution_device + block_state.latents = randn_tensor( + block_state.latent_shape, generator=block_state.generator, device=device, dtype=torch.float32 + ) + return components, block_state + + +class HeliosPyramidChunkNoiseGenStep(ModularPipelineBlocks): + """Generates noise latents and downsamples to smallest pyramid level.""" + + model_name = "helios-pyramid" + + @property + def description(self) -> str: + return ( + "Generates random noise at full resolution, then downsamples to the smallest " + "pyramid level via bilinear interpolation." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latent_shape", required=True, type_hint=tuple), + InputParam( + "pyramid_num_inference_steps_list", + default=[10, 10, 10], + type_hint=list, + description="Number of denoising steps per pyramid stage.", + ), + InputParam.template("generator"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + device = components._execution_device + batch_size, num_channels_latents, num_latent_frames, h_latent, w_latent = block_state.latent_shape + + latents = randn_tensor( + block_state.latent_shape, generator=block_state.generator, device=device, dtype=torch.float32 + ) + + # Downsample to smallest pyramid level + h, w = h_latent, w_latent + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_latent_frames, num_channels_latents, h, w) + for _ in range(len(block_state.pyramid_num_inference_steps_list) - 1): + h //= 2 + w //= 2 + latents = F.interpolate(latents, size=(h, w), mode="bilinear") * 2 + block_state.latents = latents.reshape(batch_size, num_latent_frames, num_channels_latents, h, w).permute( + 0, 2, 1, 3, 4 + ) + + return components, block_state + + +class HeliosChunkSchedulerResetStep(ModularPipelineBlocks): + """Resets the scheduler with timesteps for a single chunk.""" + + model_name = "helios" + + @property + def description(self) -> str: + return "Resets the scheduler with the correct timesteps and shift parameter (mu) for this chunk." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", HeliosScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("mu", required=True, type_hint=float), + InputParam.template("sigmas", required=True), + InputParam.template("num_inference_steps"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + device = components._execution_device + components.scheduler.set_timesteps( + block_state.num_inference_steps, device=device, sigmas=block_state.sigmas, mu=block_state.mu + ) + block_state.timesteps = components.scheduler.timesteps + + return components, block_state + + +# ======================================== +# Inner Denoising Blocks +# ======================================== + + +class HeliosChunkDenoiseInner(ModularPipelineBlocks): + """Inner timestep loop for denoising a single chunk, using guider for guidance.""" + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Inner denoising loop that iterates over timesteps for a single chunk. " + "Uses the guider to manage conditional/unconditional forward passes with cache_context, " + "applies guidance, and runs scheduler step." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", HeliosTransformer3DModel), + ComponentSpec("scheduler", HeliosScheduler), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents"), + InputParam.template("timesteps"), + InputParam("prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam.template("denoiser_input_fields"), + InputParam.template("num_inference_steps"), + InputParam.template("attention_kwargs"), + InputParam.template("generator"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + latents = block_state.latents + timesteps = block_state.timesteps + num_inference_steps = block_state.num_inference_steps + + transformer_dtype = components.transformer.dtype + num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order + + # Guider inputs: only encoder_hidden_states differs between cond/uncond + guider_inputs = { + "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), + } + + # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + shared_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + shared_kwargs[field_name] = field_value + + # Add loop-internal history latents with dtype casting + shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) + shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) + shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) + shared_kwargs["attention_kwargs"] = block_state.attention_kwargs + + with tqdm(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(torch.int64) + latent_model_input = latents.to(transformer_dtype) + + components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {k: getattr(guider_state_batch, k) for k in guider_inputs.keys()} + + context_name = getattr(guider_state_batch, components.guider._identifier_key) + with components.transformer.cache_context(context_name): + guider_state_batch.noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep, + return_dict=False, + **cond_kwargs, + **shared_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + noise_pred = components.guider(guider_state)[0] + + # Scheduler step + latents = components.scheduler.step( + noise_pred, + t, + latents, + generator=block_state.generator, + return_dict=False, + )[0] + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + block_state.latents = latents + return components, block_state + + +class HeliosPyramidChunkDenoiseInner(ModularPipelineBlocks): + """Nested pyramid stage loop with inner timestep denoising. + + For each pyramid stage (small -> full resolution): + 1. Upsample latents + block noise correction (stages > 0) + 2. Compute mu from current resolution, set scheduler timesteps + 3. Run timestep denoising loop (same logic as HeliosChunkDenoiseInner) + """ + + model_name = "helios-pyramid" + + @property + def description(self) -> str: + return ( + "Pyramid denoising inner block: loops over pyramid stages from smallest to full resolution. " + "Each stage upsamples latents (with block noise correction), recomputes scheduler parameters, " + "and runs the timestep denoising loop." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", HeliosTransformer3DModel), + ComponentSpec("scheduler", HeliosScheduler), + ComponentSpec( + "guider", + ClassifierFreeZeroStarGuidance, + config=FrozenDict({"guidance_scale": 5.0, "zero_init_steps": 2}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents"), + InputParam("prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam.template("denoiser_input_fields"), + InputParam( + "pyramid_num_inference_steps_list", + default=[10, 10, 10], + type_hint=list, + description="Number of denoising steps per pyramid stage.", + ), + InputParam.template("attention_kwargs"), + InputParam.template("generator"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + device = components._execution_device + transformer_dtype = components.transformer.dtype + latents = block_state.latents + pyramid_num_stages = len(block_state.pyramid_num_inference_steps_list) + + # Guider inputs: only encoder_hidden_states differs between cond/uncond + guider_inputs = { + "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), + } + + # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + shared_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + shared_kwargs[field_name] = field_value + + # Add loop-internal history latents with dtype casting + shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) + shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) + shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) + shared_kwargs["attention_kwargs"] = block_state.attention_kwargs + + # Save original zero_init_steps if the guider supports it (e.g. ClassifierFreeZeroStarGuidance). + # Helios only applies zero init in pyramid stage 0 (lowest resolution), so we disable it + # for subsequent stages by temporarily setting zero_init_steps=0. + orig_zero_init_steps = getattr(components.guider, "zero_init_steps", None) + + for i_s in range(pyramid_num_stages): + # --- Stage setup --- + + # Disable zero init for stages > 0 (only stage 0 should have zero init) + if orig_zero_init_steps is not None and i_s > 0: + components.guider.zero_init_steps = 0 + + # a. Compute mu from current resolution (before upsample, matching standard pipeline) + patch_size = components.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + mu = calculate_shift( + image_seq_len, + components.scheduler.config.get("base_image_seq_len", 256), + components.scheduler.config.get("max_image_seq_len", 4096), + components.scheduler.config.get("base_shift", 0.5), + components.scheduler.config.get("max_shift", 1.15), + ) + + # b. Set scheduler timesteps for this stage + num_inference_steps = block_state.pyramid_num_inference_steps_list[i_s] + components.scheduler.set_timesteps( + num_inference_steps, + i_s, + device=device, + mu=mu, + ) + timesteps = components.scheduler.timesteps + + # c. Upsample + block noise correction for stages > 0 + if i_s > 0: + batch_size, num_channels_latents, num_frames, current_h, current_w = latents.shape + new_h = current_h * 2 + new_w = current_w * 2 + + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, num_channels_latents, current_h, current_w + ) + latents = F.interpolate(latents, size=(new_h, new_w), mode="nearest") + latents = latents.reshape(batch_size, num_frames, num_channels_latents, new_h, new_w).permute( + 0, 2, 1, 3, 4 + ) + + # Block noise correction + ori_sigma = 1 - components.scheduler.ori_start_sigmas[i_s] + gamma = components.scheduler.config.gamma + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + + batch_size, num_channels_latents, num_frames, h, w = latents.shape + noise = sample_block_noise( + batch_size, + num_channels_latents, + num_frames, + h, + w, + gamma, + patch_size, + device=device, + generator=block_state.generator, + ) + noise = noise.to(dtype=transformer_dtype) + latents = alpha * latents + beta * noise + + # --- Timestep denoising loop --- + num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order + + with tqdm(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(torch.int64) + latent_model_input = latents.to(transformer_dtype) + + components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {kk: getattr(guider_state_batch, kk) for kk in guider_inputs.keys()} + + context_name = getattr(guider_state_batch, components.guider._identifier_key) + with components.transformer.cache_context(context_name): + guider_state_batch.noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep, + return_dict=False, + **cond_kwargs, + **shared_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + noise_pred = components.guider(guider_state)[0] + + # Scheduler step + latents = components.scheduler.step( + noise_pred, + t, + latents, + generator=block_state.generator, + return_dict=False, + )[0] + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + # Restore original zero_init_steps + if orig_zero_init_steps is not None: + components.guider.zero_init_steps = orig_zero_init_steps + + block_state.latents = latents + return components, block_state + + +# ======================================== +# Post-Denoise Update +# ======================================== + + +class HeliosChunkUpdateStep(ModularPipelineBlocks): + """Updates chunk collection and history after denoising a single chunk.""" + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Post-denoising update step: appends the denoised latents to the chunk list, " + "captures image_latents from the first chunk if needed, and extends history_latents." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", type_hint=torch.Tensor), + InputParam("history_latents", type_hint=torch.Tensor), + InputParam("keep_first_frame", default=True, type_hint=bool), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + # e. Collect denoised latents for this chunk + block_state.latent_chunks.append(block_state.latents) + + # f. Update history + if block_state.keep_first_frame and k == 0 and block_state.image_latents is None: + block_state.image_latents = block_state.latents[:, :, 0:1, :, :] + + block_state.history_latents = torch.cat([block_state.history_latents, block_state.latents], dim=2) + + return components, block_state + + +# ======================================== +# Chunk Loop Wrapper +# ======================================== + + +class HeliosChunkLoopWrapper(LoopSequentialPipelineBlocks): + """Outer chunk loop that iterates over temporal chunks. + + History indices, scheduler params, and history state are prepared by HeliosPrepareHistoryStep and + HeliosSetTimestepsStep before this block runs. Sub-blocks handle per-chunk preparation, denoising, and history + updates. + """ + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Pipeline block that iterates over temporal chunks for progressive video generation. " + "At each chunk iteration, it runs sub-blocks for preparation, denoising, and history updates." + ) + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam("num_latent_chunk", required=True, type_hint=int), + ] + + @property + def loop_intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.latent_chunks = [] + + if not hasattr(block_state, "image_latents"): + block_state.image_latents = None + + for k in range(block_state.num_latent_chunk): + components, block_state = self.loop_step(components, block_state, k=k) + + self.set_block_state(state, block_state) + + return components, state + + +# ======================================== +# Composed Chunk Denoise Steps +# ======================================== + + +class HeliosChunkDenoiseStep(HeliosChunkLoopWrapper): + """T2V chunk-based denoising: history slice -> noise gen -> scheduler reset -> denoise -> update.""" + + block_classes = [ + HeliosChunkHistorySliceStep, + HeliosChunkNoiseGenStep, + HeliosChunkSchedulerResetStep, + HeliosChunkDenoiseInner, + HeliosChunkUpdateStep, + ] + block_names = ["history_slice", "noise_gen", "scheduler_reset", "denoise_inner", "update_chunk"] + + @property + def description(self) -> str: + return ( + "T2V chunk denoise step that iterates over temporal chunks.\n" + "At each chunk: history_slice -> noise_gen -> scheduler_reset -> denoise_inner -> update_chunk." + ) + + +class HeliosI2VChunkDenoiseStep(HeliosChunkLoopWrapper): + """I2V chunk-based denoising: I2V history slice -> noise gen -> scheduler reset -> denoise -> update.""" + + block_classes = [ + HeliosI2VChunkHistorySliceStep, + HeliosChunkNoiseGenStep, + HeliosChunkSchedulerResetStep, + HeliosChunkDenoiseInner, + HeliosChunkUpdateStep, + ] + block_names = ["history_slice", "noise_gen", "scheduler_reset", "denoise_inner", "update_chunk"] + + @property + def description(self) -> str: + return ( + "I2V chunk denoise step that iterates over temporal chunks.\n" + "At each chunk: history_slice (I2V) -> noise_gen -> scheduler_reset -> denoise_inner -> update_chunk." + ) + + +class HeliosPyramidDistilledChunkDenoiseInner(ModularPipelineBlocks): + """Nested pyramid stage loop with DMD denoising for distilled checkpoints. + + Same progressive multi-resolution strategy as HeliosPyramidChunkDenoiseInner, but: + - Guidance is disabled (guidance_scale=1.0, no unconditional pass) + - Supports is_amplify_first_chunk (doubles first chunk's timesteps via scheduler) + - Tracks start_point_list and passes DMD-specific args to scheduler.step() + """ + + model_name = "helios-pyramid" + + @property + def description(self) -> str: + return ( + "Distilled pyramid denoising inner block for DMD checkpoints. Loops over pyramid stages " + "from smallest to full resolution with guidance disabled and DMD scheduler support." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", HeliosTransformer3DModel), + ComponentSpec("scheduler", HeliosScheduler), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 1.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents"), + InputParam("prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam.template("denoiser_input_fields"), + InputParam( + "pyramid_num_inference_steps_list", + default=[2, 2, 2], + type_hint=list, + description="Number of denoising steps per pyramid stage.", + ), + InputParam( + "is_amplify_first_chunk", + default=True, + type_hint=bool, + description="Whether to double the first chunk's timesteps via the scheduler for amplified generation.", + ), + InputParam.template("attention_kwargs"), + InputParam.template("generator"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): + device = components._execution_device + transformer_dtype = components.transformer.dtype + latents = block_state.latents + pyramid_num_stages = len(block_state.pyramid_num_inference_steps_list) + is_first_chunk = k == 0 + + # Track start points for DMD scheduler + start_point_list = [latents] + + # Guider inputs: only encoder_hidden_states differs between cond/uncond + guider_inputs = { + "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), + } + + # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + shared_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + shared_kwargs[field_name] = field_value + + # Add loop-internal history latents with dtype casting + shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) + shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) + shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) + shared_kwargs["attention_kwargs"] = block_state.attention_kwargs + + for i_s in range(pyramid_num_stages): + # --- Stage setup --- + patch_size = components.transformer.config.patch_size + + # a. Compute mu from current resolution (before upsample, matching standard pipeline) + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + mu = calculate_shift( + image_seq_len, + components.scheduler.config.get("base_image_seq_len", 256), + components.scheduler.config.get("max_image_seq_len", 4096), + components.scheduler.config.get("base_shift", 0.5), + components.scheduler.config.get("max_shift", 1.15), + ) + + # b. Set scheduler timesteps for this stage (with DMD amplification) + num_inference_steps = block_state.pyramid_num_inference_steps_list[i_s] + components.scheduler.set_timesteps( + num_inference_steps, + i_s, + device=device, + mu=mu, + is_amplify_first_chunk=block_state.is_amplify_first_chunk and is_first_chunk, + ) + timesteps = components.scheduler.timesteps + + # c. Upsample + block noise correction for stages > 0 + if i_s > 0: + batch_size, num_channels_latents, num_frames, current_h, current_w = latents.shape + new_h = current_h * 2 + new_w = current_w * 2 + + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, num_channels_latents, current_h, current_w + ) + latents = F.interpolate(latents, size=(new_h, new_w), mode="nearest") + latents = latents.reshape(batch_size, num_frames, num_channels_latents, new_h, new_w).permute( + 0, 2, 1, 3, 4 + ) + + # Block noise correction + ori_sigma = 1 - components.scheduler.ori_start_sigmas[i_s] + gamma = components.scheduler.config.gamma + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + + batch_size, num_channels_latents, num_frames, h, w = latents.shape + noise = sample_block_noise( + batch_size, + num_channels_latents, + num_frames, + h, + w, + gamma, + patch_size, + device=device, + generator=block_state.generator, + ) + noise = noise.to(dtype=transformer_dtype) + latents = alpha * latents + beta * noise + + start_point_list.append(latents) + + # --- Timestep denoising loop --- + num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order + + with tqdm(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(torch.int64) + latent_model_input = latents.to(transformer_dtype) + + components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {k: getattr(guider_state_batch, k) for k in guider_inputs.keys()} + + context_name = getattr(guider_state_batch, components.guider._identifier_key) + with components.transformer.cache_context(context_name): + guider_state_batch.noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep, + return_dict=False, + **cond_kwargs, + **shared_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + noise_pred = components.guider(guider_state)[0] + + # Scheduler step with DMD args + latents = components.scheduler.step( + noise_pred, + t, + latents, + generator=block_state.generator, + return_dict=False, + cur_sampling_step=i, + dmd_noisy_tensor=start_point_list[i_s], + dmd_sigmas=components.scheduler.sigmas, + dmd_timesteps=components.scheduler.timesteps, + all_timesteps=timesteps, + )[0] + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + block_state.latents = latents + return components, block_state + + +class HeliosPyramidChunkDenoiseStep(HeliosChunkLoopWrapper): + """T2V pyramid chunk denoising: history slice -> pyramid noise gen -> pyramid denoise inner -> update.""" + + block_classes = [ + HeliosChunkHistorySliceStep, + HeliosPyramidChunkNoiseGenStep, + HeliosPyramidChunkDenoiseInner, + HeliosChunkUpdateStep, + ] + block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] + + @property + def description(self) -> str: + return ( + "T2V pyramid chunk denoise step that iterates over temporal chunks.\n" + "At each chunk: history_slice -> noise_gen (pyramid) -> denoise_inner (pyramid stages) -> update_chunk.\n" + "Denoising starts at the smallest resolution and progressively upsamples." + ) + + +class HeliosPyramidI2VChunkDenoiseStep(HeliosChunkLoopWrapper): + """I2V pyramid chunk denoising: I2V history slice -> pyramid noise gen -> pyramid denoise inner -> update.""" + + block_classes = [ + HeliosI2VChunkHistorySliceStep, + HeliosPyramidChunkNoiseGenStep, + HeliosPyramidChunkDenoiseInner, + HeliosChunkUpdateStep, + ] + block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] + + @property + def description(self) -> str: + return ( + "I2V pyramid chunk denoise step that iterates over temporal chunks.\n" + "At each chunk: history_slice (I2V) -> noise_gen (pyramid) -> denoise_inner (pyramid stages) -> update_chunk.\n" + "Denoising starts at the smallest resolution and progressively upsamples." + ) + + +class HeliosPyramidDistilledChunkDenoiseStep(HeliosChunkLoopWrapper): + """T2V distilled pyramid chunk denoising with DMD scheduler and no CFG.""" + + block_classes = [ + HeliosChunkHistorySliceStep, + HeliosPyramidChunkNoiseGenStep, + HeliosPyramidDistilledChunkDenoiseInner, + HeliosChunkUpdateStep, + ] + block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] + + @property + def description(self) -> str: + return ( + "T2V distilled pyramid chunk denoise step with DMD scheduler.\n" + "At each chunk: history_slice -> noise_gen (pyramid) -> denoise_inner (distilled/DMD) -> update_chunk." + ) + + +class HeliosPyramidDistilledI2VChunkDenoiseStep(HeliosChunkLoopWrapper): + """I2V distilled pyramid chunk denoising with DMD scheduler and no CFG.""" + + block_classes = [ + HeliosI2VChunkHistorySliceStep, + HeliosPyramidChunkNoiseGenStep, + HeliosPyramidDistilledChunkDenoiseInner, + HeliosChunkUpdateStep, + ] + block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] + + @property + def description(self) -> str: + return ( + "I2V distilled pyramid chunk denoise step with DMD scheduler.\n" + "At each chunk: history_slice (I2V) -> noise_gen (pyramid) -> denoise_inner (distilled/DMD) -> update_chunk." + ) diff --git a/src/diffusers/modular_pipelines/helios/encoders.py b/src/diffusers/modular_pipelines/helios/encoders.py new file mode 100644 index 000000000000..4671fbd12c96 --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/encoders.py @@ -0,0 +1,392 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html + +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLWan +from ...utils import is_ftfy_available, logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import HeliosModularPipeline + + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def get_t5_prompt_embeds( + text_encoder: UMT5EncoderModel, + tokenizer: AutoTokenizer, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype | None = None, +): + """Encode text prompts into T5 embeddings for Helios. + + Args: + text_encoder: The T5 text encoder model. + tokenizer: The tokenizer for the text encoder. + prompt: The prompt or prompts to encode. + max_sequence_length: Maximum sequence length for tokenization. + device: Device to place tensors on. + dtype: Optional dtype override. Defaults to `text_encoder.dtype`. + + Returns: + A tuple of `(prompt_embeds, attention_mask)` where `prompt_embeds` is the encoded text embeddings and + `attention_mask` is a boolean mask. + """ + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + return prompt_embeds, text_inputs.attention_mask.bool() + + +class HeliosTextEncoderStep(ModularPipelineBlocks): + model_name = "helios" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", UMT5EncoderModel), + ComponentSpec("tokenizer", AutoTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("negative_prompt_embeds"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and not isinstance(negative_prompt, (str, list)): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if prompt is not None and negative_prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + neg_list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if type(prompt_list) is not type(neg_list): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + if len(prompt_list) != len(neg_list): + raise ValueError( + f"`negative_prompt` has batch size {len(neg_list)}, but `prompt` has batch size" + f" {len(prompt_list)}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt = block_state.prompt + negative_prompt = block_state.negative_prompt + max_sequence_length = block_state.max_sequence_length + device = components._execution_device + + self.check_inputs(prompt, negative_prompt) + + # Encode prompt + block_state.prompt_embeds, _ = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode negative prompt + block_state.negative_prompt_embeds = None + if components.requires_unconditional_embeds: + negative_prompt = negative_prompt or "" + if isinstance(prompt, list) and isinstance(negative_prompt, str): + negative_prompt = len(prompt) * [negative_prompt] + + block_state.negative_prompt_embeds, _ = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + self.set_block_state(state, block_state) + return components, state + + +class HeliosImageVaeEncoderStep(ModularPipelineBlocks): + """Encodes an input image into VAE latent space for image-to-video generation.""" + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Image Encoder step that encodes an input image into VAE latent space, " + "producing image_latents (first frame prefix) and fake_image_latents (history seed) " + "for image-to-video generation." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image"), + InputParam.template("height", default=384), + InputParam.template("width", default=640), + InputParam( + "num_latent_frames_per_chunk", + default=9, + type_hint=int, + description="Number of latent frames per temporal chunk.", + ), + InputParam.template("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("image_latents"), + OutputParam( + "fake_image_latents", type_hint=torch.Tensor, description="Fake image latents for history seeding" + ), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae = components.vae + device = components._execution_device + + latents_mean = ( + torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(vae.device, vae.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( + vae.device, vae.dtype + ) + + # Preprocess image to 4D tensor (B, C, H, W) + image = components.video_processor.preprocess( + block_state.image, height=block_state.height, width=block_state.width + ) + image_5d = image.unsqueeze(2).to(device=device, dtype=vae.dtype) # (B, C, 1, H, W) + + # Encode image to get image_latents + image_latents = vae.encode(image_5d).latent_dist.sample(generator=block_state.generator) + image_latents = (image_latents - latents_mean) * latents_std + + # Encode fake video to get fake_image_latents + min_frames = (block_state.num_latent_frames_per_chunk - 1) * components.vae_scale_factor_temporal + 1 + fake_video = image_5d.repeat(1, 1, min_frames, 1, 1) # (B, C, min_frames, H, W) + fake_latents_full = vae.encode(fake_video).latent_dist.sample(generator=block_state.generator) + fake_latents_full = (fake_latents_full - latents_mean) * latents_std + fake_image_latents = fake_latents_full[:, :, -1:, :, :] + + block_state.image_latents = image_latents.to(device=device, dtype=torch.float32) + block_state.fake_image_latents = fake_image_latents.to(device=device, dtype=torch.float32) + + self.set_block_state(state, block_state) + return components, state + + +class HeliosVideoVaeEncoderStep(ModularPipelineBlocks): + """Encodes an input video into VAE latent space for video-to-video generation. + + Produces `image_latents` (first frame) and `video_latents` (remaining frames encoded in chunks). + """ + + model_name = "helios" + + @property + def description(self) -> str: + return ( + "Video Encoder step that encodes an input video into VAE latent space, " + "producing image_latents (first frame) and video_latents (chunked video frames) " + "for video-to-video generation." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("video", required=True, description="Input video for video-to-video generation"), + InputParam.template("height", default=384), + InputParam.template("width", default=640), + InputParam( + "num_latent_frames_per_chunk", + default=9, + type_hint=int, + description="Number of latent frames per temporal chunk.", + ), + InputParam.template("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("image_latents"), + OutputParam("video_latents", type_hint=torch.Tensor, description="Encoded video latents (chunked)"), + ] + + @torch.no_grad() + def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae = components.vae + device = components._execution_device + num_latent_frames_per_chunk = block_state.num_latent_frames_per_chunk + + latents_mean = ( + torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(vae.device, vae.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( + vae.device, vae.dtype + ) + + # Preprocess video + video = components.video_processor.preprocess_video( + block_state.video, height=block_state.height, width=block_state.width + ) + video = video.to(device=device, dtype=vae.dtype) + + # Encode video into latents + num_frames = video.shape[2] + min_frames = (num_latent_frames_per_chunk - 1) * 4 + 1 + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames " + f"(got {num_frames} frames). " + f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({num_latent_frames_per_chunk} - 1) * 4 + 1 = {min_frames}" + ) + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + + # Encode first frame + first_frame = video[:, :, 0:1, :, :] + image_latents = vae.encode(first_frame).latent_dist.sample(generator=block_state.generator) + image_latents = (image_latents - latents_mean) * latents_std + + # Encode remaining frames in chunks + latents_chunks = [] + for i in range(num_chunks): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[:, :, chunk_start:chunk_end, :, :] + chunk_latents = vae.encode(video_chunk).latent_dist.sample(generator=block_state.generator) + chunk_latents = (chunk_latents - latents_mean) * latents_std + latents_chunks.append(chunk_latents) + video_latents = torch.cat(latents_chunks, dim=2) + + block_state.image_latents = image_latents.to(device=device, dtype=torch.float32) + block_state.video_latents = video_latents.to(device=device, dtype=torch.float32) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/helios/modular_blocks_helios.py b/src/diffusers/modular_pipelines/helios/modular_blocks_helios.py new file mode 100644 index 000000000000..e01d62966465 --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/modular_blocks_helios.py @@ -0,0 +1,542 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + HeliosAdditionalInputsStep, + HeliosAddNoiseToImageLatentsStep, + HeliosAddNoiseToVideoLatentsStep, + HeliosI2VSeedHistoryStep, + HeliosPrepareHistoryStep, + HeliosSetTimestepsStep, + HeliosTextInputStep, + HeliosV2VSeedHistoryStep, +) +from .decoders import HeliosDecodeStep +from .denoise import HeliosChunkDenoiseStep, HeliosI2VChunkDenoiseStep +from .encoders import HeliosImageVaeEncoderStep, HeliosTextEncoderStep, HeliosVideoVaeEncoderStep + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. Vae Encoder +# ==================== + + +# auto_docstring +class HeliosAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Encoder step that encodes video or image inputs. This is an auto pipeline block. + - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided. + - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided. + - If neither is provided, step will be skipped. + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + video (`None`, *optional*): + Input video for video-to-video generation + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + image (`Image | list`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + video_latents (`Tensor`): + Encoded video latents (chunked) + fake_image_latents (`Tensor`): + Fake image latents for history seeding + """ + + block_classes = [HeliosVideoVaeEncoderStep, HeliosImageVaeEncoderStep] + block_names = ["video_encoder", "image_encoder"] + block_trigger_inputs = ["video", "image"] + + @property + def description(self): + return ( + "Encoder step that encodes video or image inputs. This is an auto pipeline block.\n" + " - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.\n" + " - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.\n" + " - If neither is provided, step will be skipped." + ) + + +# ==================== +# 2. DENOISE +# ==================== + + +# DENOISE (T2V) +# auto_docstring +class HeliosCoreDenoiseStep(SequentialPipelineBlocks): + """ + Denoise block that takes encoded conditions and runs the chunk-based denoising process. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios" + block_classes = [ + HeliosTextInputStep, + HeliosPrepareHistoryStep, + HeliosSetTimestepsStep, + HeliosChunkDenoiseStep, + ] + block_names = ["input", "prepare_history", "set_timesteps", "chunk_denoise"] + + @property + def description(self): + return "Denoise block that takes encoded conditions and runs the chunk-based denoising process." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# DENOISE (I2V) +# auto_docstring +class HeliosI2VCoreDenoiseStep(SequentialPipelineBlocks): + """ + I2V denoise block that seeds history with image latents and uses I2V-aware chunk preparation. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video/fake-image latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video/fake-image latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios" + block_classes = [ + HeliosTextInputStep, + HeliosAdditionalInputsStep( + image_latent_inputs=[InputParam.template("image_latents")], + additional_batch_inputs=[ + InputParam( + "fake_image_latents", + type_hint=torch.Tensor, + description="Fake image latents used as history seed for I2V generation.", + ), + ], + ), + HeliosAddNoiseToImageLatentsStep, + HeliosPrepareHistoryStep, + HeliosI2VSeedHistoryStep, + HeliosSetTimestepsStep, + HeliosI2VChunkDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "add_noise_image", + "prepare_history", + "seed_history", + "set_timesteps", + "chunk_denoise", + ] + + @property + def description(self): + return "I2V denoise block that seeds history with image latents and uses I2V-aware chunk preparation." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# DENOISE (V2V) +# auto_docstring +class HeliosV2VCoreDenoiseStep(SequentialPipelineBlocks): + """ + V2V denoise block that seeds history with video latents and uses I2V-aware chunk preparation. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios" + block_classes = [ + HeliosTextInputStep, + HeliosAdditionalInputsStep( + image_latent_inputs=[InputParam.template("image_latents")], + additional_batch_inputs=[ + InputParam( + "video_latents", type_hint=torch.Tensor, description="Encoded video latents for V2V generation." + ), + ], + ), + HeliosAddNoiseToVideoLatentsStep, + HeliosPrepareHistoryStep, + HeliosV2VSeedHistoryStep, + HeliosSetTimestepsStep, + HeliosI2VChunkDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "add_noise_video", + "prepare_history", + "seed_history", + "set_timesteps", + "chunk_denoise", + ] + + @property + def description(self): + return "V2V denoise block that seeds history with video latents and uses I2V-aware chunk preparation." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# AUTO DENOISE +# auto_docstring +class HeliosAutoCoreDenoiseStep(ConditionalPipelineBlocks): + """ + Core denoise step that selects the appropriate denoising block. + - `HeliosV2VCoreDenoiseStep` (video2video) for video-to-video tasks. + - `HeliosI2VCoreDenoiseStep` (image2video) for image-to-video tasks. + - `HeliosCoreDenoiseStep` (text2video) for text-to-video tasks. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`): + Custom sigmas for the denoising process. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + block_classes = [HeliosV2VCoreDenoiseStep, HeliosI2VCoreDenoiseStep, HeliosCoreDenoiseStep] + block_names = ["video2video", "image2video", "text2video"] + block_trigger_inputs = ["video_latents", "fake_image_latents"] + default_block_name = "text2video" + + def select_block(self, video_latents=None, fake_image_latents=None): + if video_latents is not None: + return "video2video" + elif fake_image_latents is not None: + return "image2video" + return None + + @property + def description(self): + return ( + "Core denoise step that selects the appropriate denoising block.\n" + " - `HeliosV2VCoreDenoiseStep` (video2video) for video-to-video tasks.\n" + " - `HeliosI2VCoreDenoiseStep` (image2video) for image-to-video tasks.\n" + " - `HeliosCoreDenoiseStep` (text2video) for text-to-video tasks." + ) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", HeliosTextEncoderStep()), + ("vae_encoder", HeliosAutoVaeEncoderStep()), + ("denoise", HeliosAutoCoreDenoiseStep()), + ("decode", HeliosDecodeStep()), + ] +) + +# ==================== +# 3. Auto Blocks +# ==================== + + +# auto_docstring +class HeliosAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-video, image-to-video, and video-to-video tasks using Helios. + + Supported workflows: + - `text2video`: requires `prompt` + - `image2video`: requires `prompt`, `image` + - `video2video`: requires `prompt`, `video` + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae + (`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`HeliosTransformer3DModel`) scheduler + (`HeliosScheduler`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for prompt encoding. + video (`None`, *optional*): + Input video for video-to-video generation + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + image (`Image | list`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`): + Custom sigmas for the denoising process. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + output_type (`str`, *optional*, defaults to np): + Output format: 'pil', 'np', 'pt'. + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "helios" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + _workflow_map = { + "text2video": {"prompt": True}, + "image2video": {"prompt": True, "image": True}, + "video2video": {"prompt": True, "video": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-video, image-to-video, and video-to-video tasks using Helios." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/src/diffusers/modular_pipelines/helios/modular_blocks_helios_pyramid.py b/src/diffusers/modular_pipelines/helios/modular_blocks_helios_pyramid.py new file mode 100644 index 000000000000..14f6bf80c221 --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/modular_blocks_helios_pyramid.py @@ -0,0 +1,520 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + HeliosAdditionalInputsStep, + HeliosAddNoiseToImageLatentsStep, + HeliosAddNoiseToVideoLatentsStep, + HeliosI2VSeedHistoryStep, + HeliosPrepareHistoryStep, + HeliosTextInputStep, + HeliosV2VSeedHistoryStep, +) +from .decoders import HeliosDecodeStep +from .denoise import HeliosPyramidChunkDenoiseStep, HeliosPyramidI2VChunkDenoiseStep +from .encoders import HeliosImageVaeEncoderStep, HeliosTextEncoderStep, HeliosVideoVaeEncoderStep + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. Vae Encoder +# ==================== + + +# auto_docstring +class HeliosPyramidAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Encoder step that encodes video or image inputs. This is an auto pipeline block. + - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided. + - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided. + - If neither is provided, step will be skipped. + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + video (`None`, *optional*): + Input video for video-to-video generation + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + image (`Image | list`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + video_latents (`Tensor`): + Encoded video latents (chunked) + fake_image_latents (`Tensor`): + Fake image latents for history seeding + """ + + block_classes = [HeliosVideoVaeEncoderStep, HeliosImageVaeEncoderStep] + block_names = ["video_encoder", "image_encoder"] + block_trigger_inputs = ["video", "image"] + + @property + def description(self): + return ( + "Encoder step that encodes video or image inputs. This is an auto pipeline block.\n" + " - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.\n" + " - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.\n" + " - If neither is provided, step will be skipped." + ) + + +# ==================== +# 2. DENOISE +# ==================== + + +# DENOISE (T2V) +# auto_docstring +class HeliosPyramidCoreDenoiseStep(SequentialPipelineBlocks): + """ + T2V pyramid denoise block with progressive multi-resolution denoising. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider + (`ClassifierFreeZeroStarGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios-pyramid" + block_classes = [ + HeliosTextInputStep, + HeliosPrepareHistoryStep, + HeliosPyramidChunkDenoiseStep, + ] + block_names = ["input", "prepare_history", "pyramid_chunk_denoise"] + + @property + def description(self): + return "T2V pyramid denoise block with progressive multi-resolution denoising." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# DENOISE (I2V) +# auto_docstring +class HeliosPyramidI2VCoreDenoiseStep(SequentialPipelineBlocks): + """ + I2V pyramid denoise block with progressive multi-resolution denoising. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider + (`ClassifierFreeZeroStarGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video/fake-image latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video/fake-image latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios-pyramid" + block_classes = [ + HeliosTextInputStep, + HeliosAdditionalInputsStep( + image_latent_inputs=[InputParam.template("image_latents")], + additional_batch_inputs=[ + InputParam( + "fake_image_latents", + type_hint=torch.Tensor, + description="Fake image latents used as history seed for I2V generation.", + ), + ], + ), + HeliosAddNoiseToImageLatentsStep, + HeliosPrepareHistoryStep, + HeliosI2VSeedHistoryStep, + HeliosPyramidI2VChunkDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "add_noise_image", + "prepare_history", + "seed_history", + "pyramid_chunk_denoise", + ] + + @property + def description(self): + return "I2V pyramid denoise block with progressive multi-resolution denoising." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# DENOISE (V2V) +# auto_docstring +class HeliosPyramidV2VCoreDenoiseStep(SequentialPipelineBlocks): + """ + V2V pyramid denoise block with progressive multi-resolution denoising. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider + (`ClassifierFreeZeroStarGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios-pyramid" + block_classes = [ + HeliosTextInputStep, + HeliosAdditionalInputsStep( + image_latent_inputs=[InputParam.template("image_latents")], + additional_batch_inputs=[ + InputParam( + "video_latents", type_hint=torch.Tensor, description="Encoded video latents for V2V generation." + ), + ], + ), + HeliosAddNoiseToVideoLatentsStep, + HeliosPrepareHistoryStep, + HeliosV2VSeedHistoryStep, + HeliosPyramidI2VChunkDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "add_noise_video", + "prepare_history", + "seed_history", + "pyramid_chunk_denoise", + ] + + @property + def description(self): + return "V2V pyramid denoise block with progressive multi-resolution denoising." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# AUTO DENOISE +# auto_docstring +class HeliosPyramidAutoCoreDenoiseStep(ConditionalPipelineBlocks): + """ + Pyramid core denoise step that selects the appropriate denoising block. + - `HeliosPyramidV2VCoreDenoiseStep` (video2video) for video-to-video tasks. + - `HeliosPyramidI2VCoreDenoiseStep` (image2video) for image-to-video tasks. + - `HeliosPyramidCoreDenoiseStep` (text2video) for text-to-video tasks. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider + (`ClassifierFreeZeroStarGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + block_classes = [HeliosPyramidV2VCoreDenoiseStep, HeliosPyramidI2VCoreDenoiseStep, HeliosPyramidCoreDenoiseStep] + block_names = ["video2video", "image2video", "text2video"] + block_trigger_inputs = ["video_latents", "fake_image_latents"] + default_block_name = "text2video" + + def select_block(self, video_latents=None, fake_image_latents=None): + if video_latents is not None: + return "video2video" + elif fake_image_latents is not None: + return "image2video" + return None + + @property + def description(self): + return ( + "Pyramid core denoise step that selects the appropriate denoising block.\n" + " - `HeliosPyramidV2VCoreDenoiseStep` (video2video) for video-to-video tasks.\n" + " - `HeliosPyramidI2VCoreDenoiseStep` (image2video) for image-to-video tasks.\n" + " - `HeliosPyramidCoreDenoiseStep` (text2video) for text-to-video tasks." + ) + + +# ==================== +# 3. Auto Blocks +# ==================== + +PYRAMID_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", HeliosTextEncoderStep()), + ("vae_encoder", HeliosPyramidAutoVaeEncoderStep()), + ("denoise", HeliosPyramidAutoCoreDenoiseStep()), + ("decode", HeliosDecodeStep()), + ] +) + + +# auto_docstring +class HeliosPyramidAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for pyramid progressive generation (T2V/I2V/V2V) using Helios. + + Supported workflows: + - `text2video`: requires `prompt` + - `image2video`: requires `prompt`, `image` + - `video2video`: requires `prompt`, `video` + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae + (`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`HeliosTransformer3DModel`) scheduler + (`HeliosScheduler`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for prompt encoding. + video (`None`, *optional*): + Input video for video-to-video generation + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + image (`Image | list`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + output_type (`str`, *optional*, defaults to np): + Output format: 'pil', 'np', 'pt'. + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "helios-pyramid" + + block_classes = PYRAMID_AUTO_BLOCKS.values() + block_names = PYRAMID_AUTO_BLOCKS.keys() + + _workflow_map = { + "text2video": {"prompt": True}, + "image2video": {"prompt": True, "image": True}, + "video2video": {"prompt": True, "video": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for pyramid progressive generation (T2V/I2V/V2V) using Helios." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/src/diffusers/modular_pipelines/helios/modular_blocks_helios_pyramid_distilled.py b/src/diffusers/modular_pipelines/helios/modular_blocks_helios_pyramid_distilled.py new file mode 100644 index 000000000000..e9e37df5d00c --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/modular_blocks_helios_pyramid_distilled.py @@ -0,0 +1,530 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + HeliosAdditionalInputsStep, + HeliosAddNoiseToImageLatentsStep, + HeliosAddNoiseToVideoLatentsStep, + HeliosI2VSeedHistoryStep, + HeliosPrepareHistoryStep, + HeliosTextInputStep, + HeliosV2VSeedHistoryStep, +) +from .decoders import HeliosDecodeStep +from .denoise import HeliosPyramidDistilledChunkDenoiseStep, HeliosPyramidDistilledI2VChunkDenoiseStep +from .encoders import HeliosImageVaeEncoderStep, HeliosTextEncoderStep, HeliosVideoVaeEncoderStep + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. Vae Encoder +# ==================== + + +# auto_docstring +class HeliosPyramidDistilledAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Encoder step for distilled pyramid pipeline. + - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided. + - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided. + - If neither is provided, step will be skipped. + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + video (`None`, *optional*): + Input video for video-to-video generation + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + image (`Image | list`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + video_latents (`Tensor`): + Encoded video latents (chunked) + fake_image_latents (`Tensor`): + Fake image latents for history seeding + """ + + block_classes = [HeliosVideoVaeEncoderStep, HeliosImageVaeEncoderStep] + block_names = ["video_encoder", "image_encoder"] + block_trigger_inputs = ["video", "image"] + + @property + def description(self): + return ( + "Encoder step for distilled pyramid pipeline.\n" + " - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.\n" + " - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.\n" + " - If neither is provided, step will be skipped." + ) + + +# ==================== +# 2. DENOISE +# ==================== + + +# DENOISE (T2V) +# auto_docstring +class HeliosPyramidDistilledCoreDenoiseStep(SequentialPipelineBlocks): + """ + T2V distilled pyramid denoise block with DMD scheduler and no CFG. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + is_amplify_first_chunk (`bool`, *optional*, defaults to True): + Whether to double the first chunk's timesteps via the scheduler for amplified generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios-pyramid" + block_classes = [ + HeliosTextInputStep, + HeliosPrepareHistoryStep, + HeliosPyramidDistilledChunkDenoiseStep, + ] + block_names = ["input", "prepare_history", "pyramid_chunk_denoise"] + + @property + def description(self): + return "T2V distilled pyramid denoise block with DMD scheduler and no CFG." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# DENOISE (I2V) +# auto_docstring +class HeliosPyramidDistilledI2VCoreDenoiseStep(SequentialPipelineBlocks): + """ + I2V distilled pyramid denoise block with DMD scheduler and no CFG. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video/fake-image latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video/fake-image latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + is_amplify_first_chunk (`bool`, *optional*, defaults to True): + Whether to double the first chunk's timesteps via the scheduler for amplified generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios-pyramid" + block_classes = [ + HeliosTextInputStep, + HeliosAdditionalInputsStep( + image_latent_inputs=[InputParam.template("image_latents")], + additional_batch_inputs=[ + InputParam( + "fake_image_latents", + type_hint=torch.Tensor, + description="Fake image latents used as history seed for I2V generation.", + ), + ], + ), + HeliosAddNoiseToImageLatentsStep, + HeliosPrepareHistoryStep, + HeliosI2VSeedHistoryStep, + HeliosPyramidDistilledI2VChunkDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "add_noise_image", + "prepare_history", + "seed_history", + "pyramid_chunk_denoise", + ] + + @property + def description(self): + return "I2V distilled pyramid denoise block with DMD scheduler and no CFG." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# DENOISE (V2V) +# auto_docstring +class HeliosPyramidDistilledV2VCoreDenoiseStep(SequentialPipelineBlocks): + """ + V2V distilled pyramid denoise block with DMD scheduler and no CFG. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`, *optional*, defaults to [16, 2, 1]): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + is_amplify_first_chunk (`bool`, *optional*, defaults to True): + Whether to double the first chunk's timesteps via the scheduler for amplified generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + model_name = "helios-pyramid" + block_classes = [ + HeliosTextInputStep, + HeliosAdditionalInputsStep( + image_latent_inputs=[InputParam.template("image_latents")], + additional_batch_inputs=[ + InputParam( + "video_latents", type_hint=torch.Tensor, description="Encoded video latents for V2V generation." + ), + ], + ), + HeliosAddNoiseToVideoLatentsStep, + HeliosPrepareHistoryStep, + HeliosV2VSeedHistoryStep, + HeliosPyramidDistilledI2VChunkDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "add_noise_video", + "prepare_history", + "seed_history", + "pyramid_chunk_denoise", + ] + + @property + def description(self): + return "V2V distilled pyramid denoise block with DMD scheduler and no CFG." + + @property + def outputs(self): + return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")] + + +# AUTO DENOISE +# auto_docstring +class HeliosPyramidDistilledAutoCoreDenoiseStep(ConditionalPipelineBlocks): + """ + Distilled pyramid core denoise step that selects the appropriate denoising block. + - `HeliosPyramidDistilledV2VCoreDenoiseStep` (video2video) for video-to-video tasks. + - `HeliosPyramidDistilledI2VCoreDenoiseStep` (image2video) for image-to-video tasks. + - `HeliosPyramidDistilledCoreDenoiseStep` (text2video) for text-to-video tasks. + + Components: + transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + is_amplify_first_chunk (`bool`, *optional*, defaults to True): + Whether to double the first chunk's timesteps via the scheduler for amplified generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + + Outputs: + latent_chunks (`list`): + List of per-chunk denoised latent tensors + """ + + block_classes = [ + HeliosPyramidDistilledV2VCoreDenoiseStep, + HeliosPyramidDistilledI2VCoreDenoiseStep, + HeliosPyramidDistilledCoreDenoiseStep, + ] + block_names = ["video2video", "image2video", "text2video"] + block_trigger_inputs = ["video_latents", "fake_image_latents"] + default_block_name = "text2video" + + def select_block(self, video_latents=None, fake_image_latents=None): + if video_latents is not None: + return "video2video" + elif fake_image_latents is not None: + return "image2video" + return None + + @property + def description(self): + return ( + "Distilled pyramid core denoise step that selects the appropriate denoising block.\n" + " - `HeliosPyramidDistilledV2VCoreDenoiseStep` (video2video) for video-to-video tasks.\n" + " - `HeliosPyramidDistilledI2VCoreDenoiseStep` (image2video) for image-to-video tasks.\n" + " - `HeliosPyramidDistilledCoreDenoiseStep` (text2video) for text-to-video tasks." + ) + + +# ==================== +# 3. Auto Blocks +# ==================== + +DISTILLED_PYRAMID_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", HeliosTextEncoderStep()), + ("vae_encoder", HeliosPyramidDistilledAutoVaeEncoderStep()), + ("denoise", HeliosPyramidDistilledAutoCoreDenoiseStep()), + ("decode", HeliosDecodeStep()), + ] +) + + +# auto_docstring +class HeliosPyramidDistilledAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for distilled pyramid progressive generation (T2V/I2V/V2V) using Helios. + + Supported workflows: + - `text2video`: requires `prompt` + - `image2video`: requires `prompt`, `image` + - `video2video`: requires `prompt`, `video` + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae + (`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`HeliosTransformer3DModel`) scheduler + (`HeliosScheduler`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for prompt encoding. + video (`None`, *optional*): + Input video for video-to-video generation + height (`int`, *optional*, defaults to 384): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 640): + The width in pixels of the generated image. + num_latent_frames_per_chunk (`int`, *optional*, defaults to 9): + Number of latent frames per temporal chunk. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + image (`Image | list`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + video_latents (`Tensor`, *optional*): + Encoded video latents for V2V generation. + image_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for image latent noise. + image_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for image latent noise. + video_noise_sigma_min (`float`, *optional*, defaults to 0.111): + Minimum sigma for video latent noise. + video_noise_sigma_max (`float`, *optional*, defaults to 0.135): + Maximum sigma for video latent noise. + num_frames (`int`, *optional*, defaults to 132): + Total number of video frames to generate. + history_sizes (`list`): + Sizes of long/mid/short history buffers for temporal context. + keep_first_frame (`bool`, *optional*, defaults to True): + Whether to keep the first frame as a prefix in history. + pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]): + Number of denoising steps per pyramid stage. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + is_amplify_first_chunk (`bool`, *optional*, defaults to True): + Whether to double the first chunk's timesteps via the scheduler for amplified generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + fake_image_latents (`Tensor`, *optional*): + Fake image latents used as history seed for I2V generation. + output_type (`str`, *optional*, defaults to np): + Output format: 'pil', 'np', 'pt'. + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "helios-pyramid" + + block_classes = DISTILLED_PYRAMID_AUTO_BLOCKS.values() + block_names = DISTILLED_PYRAMID_AUTO_BLOCKS.keys() + + _workflow_map = { + "text2video": {"prompt": True}, + "image2video": {"prompt": True, "image": True}, + "video2video": {"prompt": True, "video": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for distilled pyramid progressive generation (T2V/I2V/V2V) using Helios." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/src/diffusers/modular_pipelines/helios/modular_pipeline.py b/src/diffusers/modular_pipelines/helios/modular_pipeline.py new file mode 100644 index 000000000000..fd3875381c56 --- /dev/null +++ b/src/diffusers/modular_pipelines/helios/modular_pipeline.py @@ -0,0 +1,87 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import HeliosLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HeliosModularPipeline( + ModularPipeline, + HeliosLoraLoaderMixin, +): + """ + A ModularPipeline for Helios text-to-video generation. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "HeliosAutoBlocks" + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = self.vae.config.scale_factor_spatial + return vae_scale_factor + + @property + def vae_scale_factor_temporal(self): + vae_scale_factor = 4 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = self.vae.config.scale_factor_temporal + return vae_scale_factor + + @property + def num_channels_latents(self): + # YiYi TODO: find out default value + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + +class HeliosPyramidModularPipeline(HeliosModularPipeline): + """ + A ModularPipeline for Helios pyramid (progressive resolution) video generation. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "HeliosPyramidAutoBlocks" + + +class HeliosPyramidDistilledModularPipeline(HeliosModularPipeline): + """ + A ModularPipeline for Helios distilled pyramid video generation using DMD scheduler. + + Uses guidance_scale=1.0 (no CFG) and supports is_amplify_first_chunk for the DMD scheduler. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "HeliosPyramidDistilledAutoBlocks" diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index a563d2aa99eb..9cd2f9f5c6ae 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -106,6 +106,16 @@ def _wan_i2v_map_fn(config_dict=None): return "WanImage2VideoModularPipeline" +def _helios_pyramid_map_fn(config_dict=None): + if config_dict is None: + return "HeliosPyramidModularPipeline" + + if config_dict.get("is_distilled", False): + return "HeliosPyramidDistilledModularPipeline" + else: + return "HeliosPyramidModularPipeline" + + MODULAR_PIPELINE_MAPPING = OrderedDict( [ ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), @@ -120,6 +130,8 @@ def _wan_i2v_map_fn(config_dict=None): ("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")), ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), ("z-image", _create_default_map_fn("ZImageModularPipeline")), + ("helios", _create_default_map_fn("HeliosModularPipeline")), + ("helios-pyramid", _helios_pyramid_map_fn), ] ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 157b04ef266a..730a788ed1b8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -152,6 +152,96 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HeliosAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HeliosModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HeliosPyramidAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HeliosPyramidDistilledAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HeliosPyramidDistilledModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HeliosPyramidModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImageAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/helios/__init__.py b/tests/modular_pipelines/helios/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/helios/test_modular_pipeline_helios.py b/tests/modular_pipelines/helios/test_modular_pipeline_helios.py new file mode 100644 index 000000000000..44a01dad6525 --- /dev/null +++ b/tests/modular_pipelines/helios/test_modular_pipeline_helios.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from diffusers.modular_pipelines import ( + HeliosAutoBlocks, + HeliosModularPipeline, + HeliosPyramidAutoBlocks, + HeliosPyramidModularPipeline, +) + +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +HELIOS_WORKFLOWS = { + "text2video": [ + ("text_encoder", "HeliosTextEncoderStep"), + ("denoise.input", "HeliosTextInputStep"), + ("denoise.prepare_history", "HeliosPrepareHistoryStep"), + ("denoise.set_timesteps", "HeliosSetTimestepsStep"), + ("denoise.chunk_denoise", "HeliosChunkDenoiseStep"), + ("decode", "HeliosDecodeStep"), + ], + "image2video": [ + ("text_encoder", "HeliosTextEncoderStep"), + ("vae_encoder", "HeliosImageVaeEncoderStep"), + ("denoise.input", "HeliosTextInputStep"), + ("denoise.additional_inputs", "HeliosAdditionalInputsStep"), + ("denoise.add_noise_image", "HeliosAddNoiseToImageLatentsStep"), + ("denoise.prepare_history", "HeliosPrepareHistoryStep"), + ("denoise.seed_history", "HeliosI2VSeedHistoryStep"), + ("denoise.set_timesteps", "HeliosSetTimestepsStep"), + ("denoise.chunk_denoise", "HeliosI2VChunkDenoiseStep"), + ("decode", "HeliosDecodeStep"), + ], + "video2video": [ + ("text_encoder", "HeliosTextEncoderStep"), + ("vae_encoder", "HeliosVideoVaeEncoderStep"), + ("denoise.input", "HeliosTextInputStep"), + ("denoise.additional_inputs", "HeliosAdditionalInputsStep"), + ("denoise.add_noise_video", "HeliosAddNoiseToVideoLatentsStep"), + ("denoise.prepare_history", "HeliosPrepareHistoryStep"), + ("denoise.seed_history", "HeliosV2VSeedHistoryStep"), + ("denoise.set_timesteps", "HeliosSetTimestepsStep"), + ("denoise.chunk_denoise", "HeliosI2VChunkDenoiseStep"), + ("decode", "HeliosDecodeStep"), + ], +} + + +class TestHeliosModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = HeliosModularPipeline + pipeline_blocks_class = HeliosAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-helios-modular-pipe" + + params = frozenset(["prompt", "height", "width", "num_frames"]) + batch_params = frozenset(["prompt"]) + optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"]) + output_name = "videos" + expected_workflow_blocks = HELIOS_WORKFLOWS + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + @pytest.mark.skip(reason="num_videos_per_prompt") + def test_num_images_per_prompt(self): + pass + + +HELIOS_PYRAMID_WORKFLOWS = { + "text2video": [ + ("text_encoder", "HeliosTextEncoderStep"), + ("denoise.input", "HeliosTextInputStep"), + ("denoise.prepare_history", "HeliosPrepareHistoryStep"), + ("denoise.pyramid_chunk_denoise", "HeliosPyramidChunkDenoiseStep"), + ("decode", "HeliosDecodeStep"), + ], + "image2video": [ + ("text_encoder", "HeliosTextEncoderStep"), + ("vae_encoder", "HeliosImageVaeEncoderStep"), + ("denoise.input", "HeliosTextInputStep"), + ("denoise.additional_inputs", "HeliosAdditionalInputsStep"), + ("denoise.add_noise_image", "HeliosAddNoiseToImageLatentsStep"), + ("denoise.prepare_history", "HeliosPrepareHistoryStep"), + ("denoise.seed_history", "HeliosI2VSeedHistoryStep"), + ("denoise.pyramid_chunk_denoise", "HeliosPyramidI2VChunkDenoiseStep"), + ("decode", "HeliosDecodeStep"), + ], + "video2video": [ + ("text_encoder", "HeliosTextEncoderStep"), + ("vae_encoder", "HeliosVideoVaeEncoderStep"), + ("denoise.input", "HeliosTextInputStep"), + ("denoise.additional_inputs", "HeliosAdditionalInputsStep"), + ("denoise.add_noise_video", "HeliosAddNoiseToVideoLatentsStep"), + ("denoise.prepare_history", "HeliosPrepareHistoryStep"), + ("denoise.seed_history", "HeliosV2VSeedHistoryStep"), + ("denoise.pyramid_chunk_denoise", "HeliosPyramidI2VChunkDenoiseStep"), + ("decode", "HeliosDecodeStep"), + ], +} + + +class TestHeliosPyramidModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = HeliosPyramidModularPipeline + pipeline_blocks_class = HeliosPyramidAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-helios-pyramid-modular-pipe" + + params = frozenset(["prompt", "height", "width", "num_frames"]) + batch_params = frozenset(["prompt"]) + optional_params = frozenset(["pyramid_num_inference_steps_list", "num_videos_per_prompt", "latents"]) + output_name = "videos" + expected_workflow_blocks = HELIOS_PYRAMID_WORKFLOWS + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "pyramid_num_inference_steps_list": [2, 2], + "height": 64, + "width": 64, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference_batch_single_identical(self): + # Pyramid pipeline injects noise at each stage, so batch vs single can differ more + super().test_inference_batch_single_identical(expected_max_diff=5e-1) + + @pytest.mark.skip(reason="Pyramid multi-stage noise makes offload comparison unreliable with tiny models") + def test_components_auto_cpu_offload_inference_consistent(self): + pass + + @pytest.mark.skip(reason="Pyramid multi-stage noise makes save/load comparison unreliable with tiny models") + def test_save_from_pretrained(self): + pass + + @pytest.mark.skip(reason="num_videos_per_prompt") + def test_num_images_per_prompt(self): + pass From 625b1d1314acf9e6b96038b3ac959af1b7f3d49b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 10 Mar 2026 09:07:16 +0530 Subject: [PATCH 037/215] [CI] Potential fix for code scanning alert no. 2150: Workflow does not contain permissions (#13230) Potential fix for code scanning alert no. 2150: Workflow does not contain permissions Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul --- .github/workflows/pr_tests_gpu.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index b79d80f71c09..9c63ad755f3b 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -1,5 +1,8 @@ name: Fast GPU Tests on PR +permissions: + contents: read + on: pull_request: branches: main From abcef878fd7f607a81473e1b79cfcf27f42ef11b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 11 Mar 2026 09:26:46 +0530 Subject: [PATCH 038/215] [Quantization] Deprecate Quanto (#13180) * update * update --- src/diffusers/quantizers/quantization_config.py | 4 +++- src/diffusers/quantizers/quanto/quanto_quantizer.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 138ec7b7e989..9a467e6b21ee 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -36,7 +36,7 @@ from packaging import version -from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging +from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging if is_torch_available(): @@ -844,6 +844,8 @@ def __init__( modules_to_not_convert: list[str] | None = None, **kwargs, ): + deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0." + deprecate("QuantoConfig", "1.0.0", deprecation_message) self.quant_method = QuantizationMethod.QUANTO self.weights_dtype = weights_dtype self.modules_to_not_convert = modules_to_not_convert diff --git a/src/diffusers/quantizers/quanto/quanto_quantizer.py b/src/diffusers/quantizers/quanto/quanto_quantizer.py index a036dabfe6f4..9a04291c883a 100644 --- a/src/diffusers/quantizers/quanto/quanto_quantizer.py +++ b/src/diffusers/quantizers/quanto/quanto_quantizer.py @@ -3,6 +3,7 @@ from diffusers.utils.import_utils import is_optimum_quanto_version from ...utils import ( + deprecate, get_module_from_name, is_accelerate_available, is_accelerate_version, @@ -42,6 +43,9 @@ def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) def validate_environment(self, *args, **kwargs): + deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0." + deprecate("QuantoQuantizer", "1.0.0", deprecation_message) + if not is_optimum_quanto_available(): raise ImportError( "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)" From fd5e8673894d127047295825ad271c23827ab17b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 11 Mar 2026 16:42:11 +0530 Subject: [PATCH 039/215] [Context Parallel] Add support for custom device mesh (#13064) * add custom mesh support * update --------- Co-authored-by: Sayak Paul --- src/diffusers/models/_modeling_parallel.py | 13 ++- src/diffusers/models/modeling_utils.py | 2 +- tests/models/testing_utils/parallelism.py | 105 +++++++++++++++++++-- 3 files changed, 112 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index ed966dc8fe98..8573c01ca4c7 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -60,6 +60,16 @@ class ContextParallelConfig: rotate_method (`str`, *optional*, defaults to `"allgather"`): Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"` is supported. + ulysses_anything (`bool`, *optional*, defaults to `False`): + Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that + are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and + `ring_degree` must be 1. + mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): + A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of + creating a new one. This is useful when combining context parallelism with other parallelism strategies + (e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and + "ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with + `mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP). """ @@ -68,6 +78,7 @@ class ContextParallelConfig: convert_to_fp32: bool = True # TODO: support alltoall rotate_method: Literal["allgather", "alltoall"] = "allgather" + mesh: torch.distributed.device_mesh.DeviceMesh | None = None # Whether to enable ulysses anything attention to support # any sequence lengths and any head numbers. ulysses_anything: bool = False @@ -124,7 +135,7 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." ) - self._flattened_mesh = self._mesh._flatten() + self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten() self._ring_mesh = self._mesh["ring"] self._ulysses_mesh = self._mesh["ulysses"] self._ring_local_rank = self._ring_mesh.get_local_rank() diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0901840679e3..401074050333 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1567,7 +1567,7 @@ def enable_parallelism( mesh = None if config.context_parallel_config is not None: cp_config = config.context_parallel_config - mesh = torch.distributed.device_mesh.init_device_mesh( + mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh( device_type=device_type, mesh_shape=cp_config.mesh_shape, mesh_dim_names=cp_config.mesh_dim_names, diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index e05b36799e66..3858acf71ec5 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di model.eval() # Move inputs to device - inputs_on_device = {} - for key, value in inputs_dict.items(): - if isinstance(value, torch.Tensor): - inputs_on_device[key] = value.to(device) - else: - inputs_on_device[key] = value + inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} # Enable context parallelism cp_config = ContextParallelConfig(**cp_dict) @@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di dist.destroy_process_group() +def _custom_mesh_worker( + rank, + world_size, + master_port, + model_class, + init_dict, + cp_dict, + mesh_shape, + mesh_dim_names, + inputs_dict, + return_dict, +): + """Worker function for context parallel testing with a user-provided custom DeviceMesh.""" + try: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + model = model_class(**init_dict) + model.to(device) + model.eval() + + inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + + # DeviceMesh must be created after init_process_group, inside each worker process. + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names + ) + cp_config = ContextParallelConfig(**cp_dict, mesh=mesh) + model.enable_parallelism(config=cp_config) + + with torch.no_grad(): + output = model(**inputs_on_device, return_dict=False)[0] + + if rank == 0: + return_dict["status"] = "success" + return_dict["output_shape"] = list(output.shape) + + except Exception as e: + if rank == 0: + return_dict["status"] = "error" + return_dict["error"] = str(e) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + @is_context_parallel @require_torch_multi_accelerator class ContextParallelTesterMixin: @@ -126,3 +174,48 @@ def test_context_parallel_inference(self, cp_type): assert return_dict.get("status") == "success", ( f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + + @pytest.mark.parametrize( + "cp_type,mesh_shape,mesh_dim_names", + [ + ("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")), + ("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")), + ], + ids=["ring-3d-fsdp", "ulysses-3d-fsdp"], + ) + def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()} + cp_dict = {cp_type: world_size} + + master_port = _find_free_port() + manager = mp.Manager() + return_dict = manager.dict() + + mp.spawn( + _custom_mesh_worker, + args=( + world_size, + master_port, + self.model_class, + init_dict, + cp_dict, + mesh_shape, + mesh_dim_names, + inputs_dict, + return_dict, + ), + nprocs=world_size, + join=True, + ) + + assert return_dict.get("status") == "success", ( + f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}" + ) From c05afff10d9d51c115b49461bf2f29d7408e99b3 Mon Sep 17 00:00:00 2001 From: Miguel Martin Date: Wed, 11 Mar 2026 09:14:56 -0700 Subject: [PATCH 040/215] Update Documentation for NVIDIA Cosmos (#13251) * fix docs * update main example --- docs/source/en/_toctree.yml | 4 +-- docs/source/en/api/pipelines/cosmos.md | 32 ++++++++++--------- docs/source/en/api/pipelines/overview.md | 1 + .../cosmos/pipeline_cosmos2_5_transfer.py | 5 ++- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e0b7af4898b2..c69bcd340b27 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -532,8 +532,6 @@ title: ControlNet-XS with Stable Diffusion XL - local: api/pipelines/controlnet_union title: ControlNetUnion - - local: api/pipelines/cosmos - title: Cosmos - local: api/pipelines/ddim title: DDIM - local: api/pipelines/ddpm @@ -677,6 +675,8 @@ title: CogVideoX - local: api/pipelines/consisid title: ConsisID + - local: api/pipelines/cosmos + title: Cosmos - local: api/pipelines/framepack title: Framepack - local: api/pipelines/helios diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md index 2302ed2c4a6c..d4851997b9ce 100644 --- a/docs/source/en/api/pipelines/cosmos.md +++ b/docs/source/en/api/pipelines/cosmos.md @@ -21,29 +21,31 @@ > [!TIP] > Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. -## Loading original format checkpoints - -Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method. +## Basic usage ```python import torch -from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel - -model_id = "nvidia/Cosmos-Predict2-2B-Text2Image" -transformer = CosmosTransformer3DModel.from_single_file( - "https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt", - torch_dtype=torch.bfloat16, -).to("cuda") -pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) +from diffusers import Cosmos2_5_PredictBasePipeline +from diffusers.utils import export_to_video + +model_id = "nvidia/Cosmos-Predict2.5-2B" +pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( + model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16 +) pipe.to("cuda") -prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." +prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor." negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." output = pipe( - prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1) -).images[0] -output.save("output.png") + image=None, + video=None, + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=93, + generator=torch.Generator().manual_seed(1), +).frames[0] +export_to_video(output, "text2world.mp4", fps=16) ``` ## Cosmos2_5_TransferPipeline diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index 22fcf560eaca..cf5950686f22 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -44,6 +44,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image | | [ControlNet-XS](controlnetxs) | text2image | | [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image | +| [Cosmos](cosmos) | text2video, video2video | | [Dance Diffusion](dance_diffusion) | unconditional audio generation | | [DDIM](ddim) | unconditional image generation | | [DDPM](ddpm) | unconditional image generation | diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py index bbe38c44355e..b04b921d596a 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py @@ -82,13 +82,16 @@ def retrieve_latents( ```python >>> import cv2 >>> import numpy as np + >>> from PIL import Image >>> import torch >>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel >>> from diffusers.utils import export_to_video, load_video >>> model_id = "nvidia/Cosmos-Transfer2.5-2B" >>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur) - >>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge") + >>> controlnet = AutoModel.from_pretrained( + ... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16 + ... ) >>> pipe = Cosmos2_5_TransferPipeline.from_pretrained( ... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16 ... ) From 068ea9a8574b190c41bb7c83603c5b5e80378ae6 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Thu, 12 Mar 2026 02:39:24 +0900 Subject: [PATCH 041/215] Add `PRXPipeline` in `AUTO_TEXT2IMAGE_PIPELINES_MAPPING` (#13257) --- src/diffusers/pipelines/auto_pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 72151dc40a53..7f8ebd06cef1 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -95,6 +95,7 @@ StableDiffusionXLPAGPipeline, ) from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline +from .prx import PRXPipeline from .qwenimage import ( QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, @@ -185,6 +186,7 @@ ("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline), ("z-image-omni", ZImageOmniPipeline), ("ovis", OvisImagePipeline), + ("prx", PRXPipeline), ] ) From b694d8dcd1e6931b834f798d11605b51e1ed42d3 Mon Sep 17 00:00:00 2001 From: huemin <100716027+huemin-art@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:53:56 -0700 Subject: [PATCH 042/215] klein 9b kv (#13262) * klein 9b kv * Apply style fixes * fix typo inline modulation split * make fix-copies --------- Co-authored-by: github-actions[bot] --- src/diffusers/__init__.py | 2 + .../models/transformers/transformer_flux2.py | 498 +++++++++- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/flux2/__init__.py | 2 + .../flux2/pipeline_flux2_klein_kv.py | 887 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + 6 files changed, 1386 insertions(+), 22 deletions(-) create mode 100644 src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 546fbe57be9e..0be7b8166a37 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -510,6 +510,7 @@ "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", + "Flux2KleinKVPipeline", "Flux2KleinPipeline", "Flux2Pipeline", "FluxControlImg2ImgPipeline", @@ -1266,6 +1267,7 @@ EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, + Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline, FluxControlImg2ImgPipeline, diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index f77498c74fc1..b2b6ac168703 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -40,6 +40,200 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class Flux2KVLayerCache: + """Per-layer KV cache for reference image tokens in the Flux2 Klein KV model. + + Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step. Tensor + format: (batch_size, num_ref_tokens, num_heads, head_dim). + """ + + def __init__(self): + self.k_ref: torch.Tensor | None = None + self.v_ref: torch.Tensor | None = None + + def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor): + """Store reference token K/V.""" + self.k_ref = k_ref + self.v_ref = v_ref + + def get(self) -> tuple[torch.Tensor, torch.Tensor]: + """Retrieve cached reference token K/V.""" + if self.k_ref is None: + raise RuntimeError("KV cache has not been populated yet.") + return self.k_ref, self.v_ref + + def clear(self): + self.k_ref = None + self.v_ref = None + + +class Flux2KVCache: + """Container for all layers' reference-token KV caches. + + Holds separate cache lists for double-stream and single-stream transformer blocks. + """ + + def __init__(self, num_double_layers: int, num_single_layers: int): + self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)] + self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)] + self.num_ref_tokens: int = 0 + + def get_double(self, layer_idx: int) -> Flux2KVLayerCache: + return self.double_block_caches[layer_idx] + + def get_single(self, layer_idx: int) -> Flux2KVLayerCache: + return self.single_block_caches[layer_idx] + + def clear(self): + for cache in self.double_block_caches: + cache.clear() + for cache in self.single_block_caches: + cache.clear() + self.num_ref_tokens = 0 + + +def _flux2_kv_causal_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_txt_tokens: int, + num_ref_tokens: int, + kv_cache: Flux2KVLayerCache | None = None, + backend=None, +) -> torch.Tensor: + """Causal attention for KV caching where reference tokens only self-attend. + + All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim). + + Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens, ref tokens + only attend to themselves. With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected + between txt and img. + """ + # No ref tokens and no cache — standard full attention + if num_ref_tokens == 0 and kv_cache is None: + return dispatch_attention_fn(query, key, value, backend=backend) + + if kv_cache is not None: + # Cached mode: inject ref K/V between txt and img + k_ref, v_ref = kv_cache.get() + + k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1) + v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1) + + return dispatch_attention_fn(query, k_all, v_all, backend=backend) + + # Extract mode: ref tokens self-attend, txt+img attend to all + ref_start = num_txt_tokens + ref_end = num_txt_tokens + num_ref_tokens + + q_txt = query[:, :ref_start] + q_ref = query[:, ref_start:ref_end] + q_img = query[:, ref_end:] + + k_txt = key[:, :ref_start] + k_ref = key[:, ref_start:ref_end] + k_img = key[:, ref_end:] + + v_txt = value[:, :ref_start] + v_ref = value[:, ref_start:ref_end] + v_img = value[:, ref_end:] + + # txt+img attend to all tokens + q_txt_img = torch.cat([q_txt, q_img], dim=1) + k_all = torch.cat([k_txt, k_ref, k_img], dim=1) + v_all = torch.cat([v_txt, v_ref, v_img], dim=1) + attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend) + attn_txt = attn_txt_img[:, :ref_start] + attn_img = attn_txt_img[:, ref_start:] + + # ref tokens self-attend only + attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend) + + return torch.cat([attn_txt, attn_ref, attn_img], dim=1) + + +def _blend_mod_params( + img_params: tuple[torch.Tensor, ...], + ref_params: tuple[torch.Tensor, ...], + num_ref: int, + seq_len: int, +) -> tuple[torch.Tensor, ...]: + """Blend modulation parameters so that the first `num_ref` positions use `ref_params`.""" + blended = [] + for im, rm in zip(img_params, ref_params): + if im.ndim == 2: + im = im.unsqueeze(1) + rm = rm.unsqueeze(1) + B = im.shape[0] + blended.append( + torch.cat( + [rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]], + dim=1, + ) + ) + return tuple(blended) + + +def _blend_double_block_mods( + img_mod: torch.Tensor, + ref_mod: torch.Tensor, + num_ref: int, + seq_len: int, +) -> torch.Tensor: + """Blend double-block image-stream modulations for a [ref, img] sequence layout. + + Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is compatible + with `Flux2Modulation.split(mod, 2)`. + """ + if img_mod.ndim == 2: + img_mod = img_mod.unsqueeze(1) + ref_mod = ref_mod.unsqueeze(1) + img_chunks = torch.chunk(img_mod, 6, dim=-1) + ref_chunks = torch.chunk(ref_mod, 6, dim=-1) + img_mods = (img_chunks[0:3], img_chunks[3:6]) + ref_mods = (ref_chunks[0:3], ref_chunks[3:6]) + + all_params = [] + for img_set, ref_set in zip(img_mods, ref_mods): + blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len) + all_params.extend(blended) + return torch.cat(all_params, dim=-1) + + +def _blend_single_block_mods( + single_mod: torch.Tensor, + ref_mod: torch.Tensor, + num_txt: int, + num_ref: int, + seq_len: int, +) -> torch.Tensor: + """Blend single-block modulations for a [txt, ref, img] sequence layout. + + Takes raw modulation tensors and returns a blended raw tensor compatible with `Flux2Modulation.split(mod, 1)`. + """ + if single_mod.ndim == 2: + single_mod = single_mod.unsqueeze(1) + ref_mod = ref_mod.unsqueeze(1) + img_params = torch.chunk(single_mod, 3, dim=-1) + ref_params = torch.chunk(ref_mod, 3, dim=-1) + + blended = [] + for im, rm in zip(img_params, ref_params): + if im.ndim == 2: + im = im.unsqueeze(1) + rm = rm.unsqueeze(1) + B = im.shape[0] + im_expanded = im.expand(B, seq_len, -1) + rm_expanded = rm.expand(B, num_ref, -1) + blended.append( + torch.cat( + [im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]], + dim=1, + ) + ) + return torch.cat(blended, dim=-1) + + def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -181,9 +375,108 @@ def __call__( return hidden_states +class Flux2KVAttnProcessor: + """ + Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens. + + When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal attention is + used (ref tokens self-attend only, txt+img attend to all). When `kv_cache_mode` is "cached", cached ref K/V are + injected during attention. When no KV args are provided, behaves identically to `Flux2AttnProcessor`. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + kv_cache: Flux2KVLayerCache | None = None, + kv_cache_mode: str | None = None, + num_ref_tokens: int = 0, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0 + + # Extract ref K/V from the combined sequence + if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0: + ref_start = num_txt_tokens + ref_end = num_txt_tokens + num_ref_tokens + kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone()) + + # Dispatch attention + if kv_cache_mode == "extract" and num_ref_tokens > 0: + hidden_states = _flux2_kv_causal_attention( + query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend + ) + elif kv_cache_mode == "cached" and kv_cache is not None: + hidden_states = _flux2_kv_causal_attention( + query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend + ) + else: + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class Flux2Attention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = Flux2AttnProcessor - _available_processors = [Flux2AttnProcessor] + _available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor] def __init__( self, @@ -312,6 +605,90 @@ def __call__( return hidden_states +class Flux2KVParallelSelfAttnProcessor: + """ + Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens. + + When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used. When + `kv_cache_mode` is "cached", cached ref K/V are injected during attention. When no KV args are provided, behaves + identically to `Flux2ParallelSelfAttnProcessor`. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Flux2ParallelSelfAttention", + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + kv_cache: Flux2KVLayerCache | None = None, + kv_cache_mode: str | None = None, + num_txt_tokens: int = 0, + num_ref_tokens: int = 0, + ) -> torch.Tensor: + # Parallel in (QKV + MLP in) projection + hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 + ) + + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + # Extract ref K/V from the combined sequence + if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0: + ref_start = num_txt_tokens + ref_end = num_txt_tokens + num_ref_tokens + kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone()) + + # Dispatch attention + if kv_cache_mode == "extract" and num_ref_tokens > 0: + attn_output = _flux2_kv_causal_attention( + query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend + ) + elif kv_cache_mode == "cached" and kv_cache is not None: + attn_output = _flux2_kv_causal_attention( + query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend + ) + else: + attn_output = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + attn_output = attn_output.flatten(2, 3) + attn_output = attn_output.to(query.dtype) + + # Handle the feedforward (FF) logic + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) + + # Concatenate and parallel output projection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1) + hidden_states = attn.to_out(hidden_states) + + return hidden_states + + class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin): """ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. @@ -322,7 +699,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin): """ _default_processor_cls = Flux2ParallelSelfAttnProcessor - _available_processors = [Flux2ParallelSelfAttnProcessor] + _available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor] # Does not support QKV fusion as the QKV projections are always fused _supports_qkv_fusion = False @@ -780,6 +1157,8 @@ def __init__( self.gradient_checkpointing = False + _skip_keys = ["kv_cache"] + @apply_lora_scale("joint_attention_kwargs") def forward( self, @@ -791,19 +1170,21 @@ def forward( guidance: torch.Tensor = None, joint_attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, + kv_cache: "Flux2KVCache | None" = None, + kv_cache_mode: str | None = None, + num_ref_tokens: int = 0, + ref_fixed_timestep: float = 0.0, ) -> torch.Tensor | Transformer2DModelOutput: """ - The [`FluxTransformer2DModel`] forward method. + The [`Flux2Transformer2DModel`] forward method. Args: hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - timestep ( `torch.LongTensor`): + timestep (`torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -811,13 +1192,23 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. + kv_cache (`Flux2KVCache`, *optional*): + KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created and + returned. When "cached", the provided cache is used to inject ref K/V during attention. + kv_cache_mode (`str`, *optional*): + One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V). When + `None`, standard forward pass without KV caching. + num_ref_tokens (`int`, defaults to `0`): + Number of reference image tokens prepended to `hidden_states` (only used when + `kv_cache_mode="extract"`). + ref_fixed_timestep (`float`, defaults to `0.0`): + Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`). Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + `tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the + populated `Flux2KVCache`. """ - # 0. Handle input arguments - num_txt_tokens = encoder_hidden_states.shape[1] # 1. Calculate timestep embedding and modulation parameters @@ -832,13 +1223,33 @@ def forward( double_stream_mod_txt = self.double_stream_modulation_txt(temb) single_stream_mod = self.single_stream_modulation(temb) + # KV extract mode: create cache and blend modulations for ref tokens + if kv_cache_mode == "extract" and num_ref_tokens > 0: + num_img_tokens = hidden_states.shape[1] # includes ref tokens + + kv_cache = Flux2KVCache( + num_double_layers=len(self.transformer_blocks), + num_single_layers=len(self.single_transformer_blocks), + ) + kv_cache.num_ref_tokens = num_ref_tokens + + # Ref tokens use a fixed timestep for modulation + ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000) + ref_temb = self.time_guidance_embed(ref_timestep, guidance) + + ref_double_mod_img = self.double_stream_modulation_img(ref_temb) + ref_single_mod = self.single_stream_modulation(ref_temb) + + # Blend double block img modulation: [ref_mod, img_mod] + double_stream_mod_img = _blend_double_block_mods( + double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens + ) + # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) hidden_states = self.x_embedder(hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states) # 3. Calculate RoPE embeddings from image and text tokens - # NOTE: the below logic means that we can't support batched inference with images of different resolutions or - # text prompts of differents lengths. Is this a use case we want to support? if img_ids.ndim == 3: img_ids = img_ids[0] if txt_ids.ndim == 3: @@ -851,8 +1262,29 @@ def forward( torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), ) - # 4. Double Stream Transformer Blocks + # 4. Build joint_attention_kwargs with KV cache info + if kv_cache_mode == "extract": + kv_attn_kwargs = { + **(joint_attention_kwargs or {}), + "kv_cache": None, + "kv_cache_mode": "extract", + "num_ref_tokens": num_ref_tokens, + } + elif kv_cache_mode == "cached" and kv_cache is not None: + kv_attn_kwargs = { + **(joint_attention_kwargs or {}), + "kv_cache": None, + "kv_cache_mode": "cached", + "num_ref_tokens": kv_cache.num_ref_tokens, + } + else: + kv_attn_kwargs = joint_attention_kwargs + + # 5. Double Stream Transformer Blocks for index_block, block in enumerate(self.transformer_blocks): + if kv_cache_mode is not None and kv_cache is not None: + kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block) + if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, @@ -861,7 +1293,7 @@ def forward( double_stream_mod_img, double_stream_mod_txt, concat_rotary_emb, - joint_attention_kwargs, + kv_attn_kwargs, ) else: encoder_hidden_states, hidden_states = block( @@ -870,13 +1302,30 @@ def forward( temb_mod_img=double_stream_mod_img, temb_mod_txt=double_stream_mod_txt, image_rotary_emb=concat_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=kv_attn_kwargs, ) + # Concatenate text and image streams for single-block inference hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - # 5. Single Stream Transformer Blocks + # Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod] + if kv_cache_mode == "extract" and num_ref_tokens > 0: + total_single_len = hidden_states.shape[1] + single_stream_mod = _blend_single_block_mods( + single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len + ) + + # Build single-block KV kwargs (single blocks need num_txt_tokens) + if kv_cache_mode is not None: + kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens} + else: + kv_attn_kwargs_single = kv_attn_kwargs + + # 6. Single Stream Transformer Blocks for index_block, block in enumerate(self.single_transformer_blocks): + if kv_cache_mode is not None and kv_cache is not None: + kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block) + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( block, @@ -884,7 +1333,7 @@ def forward( None, single_stream_mod, concat_rotary_emb, - joint_attention_kwargs, + kv_attn_kwargs_single, ) else: hidden_states = block( @@ -892,15 +1341,24 @@ def forward( encoder_hidden_states=None, temb_mod=single_stream_mod, image_rotary_emb=concat_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=kv_attn_kwargs_single, ) - # Remove text tokens from concatenated stream - hidden_states = hidden_states[:, num_txt_tokens:, ...] - # 6. Output layers + # Remove text tokens (and ref tokens in extract mode) from concatenated stream + if kv_cache_mode == "extract" and num_ref_tokens > 0: + hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...] + else: + hidden_states = hidden_states[:, num_txt_tokens:, ...] + + # 7. Output layers hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) + if kv_cache_mode == "extract": + if not return_dict: + return (output,), kv_cache + return Transformer2DModelOutput(sample=output), kv_cache + if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8007035338b0..b9596f4b7952 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -129,7 +129,7 @@ ] _import_structure["bria"] = ["BriaPipeline"] _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] - _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"] + _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -671,7 +671,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) - from .flux2 import Flux2KleinPipeline, Flux2Pipeline + from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline from .helios import HeliosPipeline, HeliosPyramidPipeline from .hidream_image import HiDreamImagePipeline diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py index f6e1d5206630..52a8f464b0ce 100644 --- a/src/diffusers/pipelines/flux2/__init__.py +++ b/src/diffusers/pipelines/flux2/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_flux2"] = ["Flux2Pipeline"] _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"] + _import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -33,6 +34,7 @@ else: from .pipeline_flux2 import Flux2Pipeline from .pipeline_flux2_klein import Flux2KleinPipeline + from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline else: import sys diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py new file mode 100644 index 000000000000..62f83f8a11f8 --- /dev/null +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py @@ -0,0 +1,887 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM + +from ...loaders import Flux2LoraLoaderMixin +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel +from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import Flux2KleinKVPipeline + + >>> pipe = Flux2KleinKVPipeline.from_pretrained( + ... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> ref_image = Image.open("reference.png") + >>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0] + >>> image.save("flux2_kv_output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin): + r""" + The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning. + + On the first denoising step, reference image tokens are included in the forward pass and their attention K/V + projections are cached. On subsequent steps, the cached K/V are reused without recomputing, providing faster + inference when using reference images. + + Reference: + [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3ForCausalLM`]): + [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM) + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + transformer: Flux2Transformer2DModel, + is_distilled: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + # Set KV-cache-aware attention processors + self._set_kv_attn_processors() + + @staticmethod + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (list[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def _set_kv_attn_processors(self): + """Replace default attention processors with KV-cache-aware variants.""" + for block in self.transformer.transformer_blocks: + block.attn.set_processor(Flux2KVAttnProcessor()) + for block in self.transformer.single_transformer_blocks: + block.attn.set_processor(Flux2KVParallelSelfAttnProcessor()) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (9, 18, 27), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: list[PIL.Image.Image] | PIL.Image.Image | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 4, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (9, 18, 27), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*): + Reference image(s) for conditioning. On the first denoising step, reference tokens are included in the + forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are reused without + recomputing. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 4): + The number of denoising steps. + sigmas (`List[float]`, *optional*): + Custom sigmas for the denoising schedule. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + Generator(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format: `"pil"` or `"np"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `Flux2PipelineOutput` or a plain tuple. + attention_kwargs (`dict`, *optional*): + Extra kwargs passed to attention processors. + callback_on_step_end (`Callable`, *optional*): + Callback function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + Tensor inputs for the callback function. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length for the prompt. + text_encoder_out_layers (`tuple[int]`): + Layer indices for text encoder hidden state extraction. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`. + """ + + # 1. Check inputs + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop with KV caching + # Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V) + # Steps 1+: forward_kv_cached (reuse cached ref K/V) + # No ref images: standard forward + self.scheduler.set_begin_index(0) + kv_cache = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if i == 0 and image_latents is not None: + # Step 0: include ref tokens, extract KV cache + latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1) + + output, kv_cache = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + kv_cache_mode="extract", + num_ref_tokens=image_latents.shape[1], + ) + noise_pred = output[0] + + elif kv_cache is not None: + # Steps 1+: use cached ref KV, no ref tokens in input + noise_pred = self.transformer( + hidden_states=latents.to(self.transformer.dtype), + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + kv_cache=kv_cache, + kv_cache_mode="cached", + )[0] + + else: + # No reference images: standard forward + noise_pred = self.transformer( + hidden_states=latents.to(self.transformer.dtype), + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Clean up KV cache + if kv_cache is not None: + kv_cache.clear() + + self._current_timestep = None + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + if output_type == "latent": + image = latents + else: + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 730a788ed1b8..2ec5bc002f41 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1202,6 +1202,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2KleinKVPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2KleinPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From c98649cce6e7afcada86d3993556387c95f93fc2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Mar 2026 06:58:53 +0530 Subject: [PATCH 043/215] [lora] fix z-image non-diffusers lora loading. (#13255) fix z-image non-diffusers lora loading. --- src/diffusers/loaders/lora_conversion_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 0895d5223e13..6e43aef51ce0 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2538,8 +2538,12 @@ def normalize_out_key(k: str) -> str: def get_alpha_scales(down_weight, alpha_key): rank = down_weight.shape[0] - alpha = state_dict.pop(alpha_key).item() - scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + alpha_tensor = state_dict.pop(alpha_key, None) + if alpha_tensor is None: + return 1.0, 1.0 + scale = ( + alpha_tensor.item() / rank + ) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here scale_down = scale scale_up = 1.0 while scale_down * 2 < scale_up: From cd96d60307794779d32d89da0d0b2f1d8c275422 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Mar 2026 10:05:11 +0530 Subject: [PATCH 044/215] [core] Flux2 klein kv followups (#13264) * implement Flux2Transformer2DModelOutput. * add output class to docs. * add Flux2KleinKV to docs. * add pipeline tests for klein kv. --- .../source/en/api/models/flux2_transformer.md | 4 + docs/source/en/api/pipelines/flux2.md | 6 + .../models/transformers/transformer_flux2.py | 28 ++- .../flux2/pipeline_flux2_klein_kv.py | 3 +- .../flux2/test_pipeline_flux2_klein_kv.py | 174 ++++++++++++++++++ 5 files changed, 207 insertions(+), 8 deletions(-) create mode 100644 tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py diff --git a/docs/source/en/api/models/flux2_transformer.md b/docs/source/en/api/models/flux2_transformer.md index c85681d2b011..d0f0545e6a31 100644 --- a/docs/source/en/api/models/flux2_transformer.md +++ b/docs/source/en/api/models/flux2_transformer.md @@ -17,3 +17,7 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest- ## Flux2Transformer2DModel [[autodoc]] Flux2Transformer2DModel + +## Flux2Transformer2DModelOutput + +[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md index 4ace2f3b3aa0..2a2b39b95630 100644 --- a/docs/source/en/api/pipelines/flux2.md +++ b/docs/source/en/api/pipelines/flux2.md @@ -41,5 +41,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a ## Flux2KleinPipeline [[autodoc]] Flux2KleinPipeline + - all + - __call__ + +## Flux2KleinKVPipeline + +[[autodoc]] Flux2KleinKVPipeline - all - __call__ \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index b2b6ac168703..5c90f3a46a98 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +from dataclasses import dataclass from typing import Any import torch @@ -21,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import apply_lora_scale, logging +from ...utils import BaseOutput, apply_lora_scale, logging from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -32,7 +33,6 @@ apply_rotary_emb, get_1d_rotary_pos_embed, ) -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous @@ -40,6 +40,22 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@dataclass +class Flux2Transformer2DModelOutput(BaseOutput): + """ + The output of [`Flux2Transformer2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on the `encoder_hidden_states` input. + kv_cache (`Flux2KVCache`, *optional*): + The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`. + """ + + sample: "torch.Tensor" # noqa: F821 + kv_cache: "Flux2KVCache | None" = None + + class Flux2KVLayerCache: """Per-layer KV cache for reference image tokens in the Flux2 Klein KV model. @@ -1174,7 +1190,7 @@ def forward( kv_cache_mode: str | None = None, num_ref_tokens: int = 0, ref_fixed_timestep: float = 0.0, - ) -> torch.Tensor | Transformer2DModelOutput: + ) -> torch.Tensor | Flux2Transformer2DModelOutput: """ The [`Flux2Transformer2DModel`] forward method. @@ -1356,10 +1372,10 @@ def forward( if kv_cache_mode == "extract": if not return_dict: - return (output,), kv_cache - return Transformer2DModelOutput(sample=output), kv_cache + return (output, kv_cache) + return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache) if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) + return Flux2Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py index 62f83f8a11f8..671953be63c1 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py @@ -793,7 +793,7 @@ def __call__( latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype) latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1) - output, kv_cache = self.transformer( + noise_pred, kv_cache = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=None, @@ -805,7 +805,6 @@ def __call__( kv_cache_mode="extract", num_ref_tokens=image_latents.shape[1], ) - noise_pred = output[0] elif kv_cache is not None: # Steps 1+: use cached ref KV, no ref tokens in input diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py b/tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py new file mode 100644 index 000000000000..046364f9269d --- /dev/null +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py @@ -0,0 +1,174 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + Flux2KleinKVPipeline, + Flux2Transformer2DModel, +) + +from ...testing_utils import torch_device +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist + + +class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Flux2KleinKVPipeline + params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"]) + batch_params = frozenset(["prompt"]) + + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + supports_dduf = False + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = Flux2Transformer2DModel( + patch_size=1, + in_channels=4, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=16, + timestep_guidance_channels=256, + axes_dims_rope=[4, 4, 4, 4], + guidance_embeds=False, + ) + + # Create minimal Qwen3 config + config = Qwen3Config( + intermediate_size=16, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=151936, + max_position_embeddings=512, + ) + torch.manual_seed(0) + text_encoder = Qwen3ForCausalLM(config) + + # Use a simple tokenizer for testing + tokenizer = Qwen2TokenizerFast.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + ) + + torch.manual_seed(0) + vae = AutoencoderKLFlux2( + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "a dog is dancing", + "image": Image.new("RGB", (64, 64)), + "generator": generator, + "num_inference_steps": 2, + "height": 8, + "width": 8, + "max_sequence_length": 64, + "output_type": "np", + "text_encoder_out_layers": (1,), + } + return inputs + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), + ) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + self.assertTrue( + np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), + ("Fusion of QKV projections shouldn't affect the outputs."), + ) + self.assertTrue( + np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), + ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."), + ) + self.assertTrue( + np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), + ("Original outputs should match when fused QKV projections are disabled."), + ) + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + self.assertEqual( + (output_height, output_width), + (expected_height, expected_width), + f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}", + ) + + def test_without_image(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()).to(device) + inputs = self.get_dummy_inputs(device) + del inputs["image"] + image = pipe(**inputs).images + self.assertEqual(image.shape, (1, 8, 8, 3)) + + @unittest.skip("Needs to be revisited") + def test_encode_prompt_works_in_isolation(self): + pass From e722f7a39bdd2bc1b9b4c12c6618c6c6fd8eb9cd Mon Sep 17 00:00:00 2001 From: teith <123115827+teith@users.noreply.github.com> Date: Fri, 13 Mar 2026 19:56:38 +0100 Subject: [PATCH 045/215] fix: correct invalid type annotation for `image` in `Flux2Pipeline.__call__` (#13205) fix: correct invalid type annotation for image in Flux2Pipeline.__call__ --- src/diffusers/pipelines/flux2/pipeline_flux2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index 6cd0563fcc19..4b60c6042d4f 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -744,7 +744,7 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - image: list[PIL.Image.Image, PIL.Image.Image] | None = None, + image: PIL.Image.Image | list[PIL.Image.Image] | None = None, prompt: str | list[str] = None, height: int | None = None, width: int | None = None, From 5ff24bafc333e0e3efdd86c3a712e17457fc0ab8 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sat, 14 Mar 2026 08:35:12 -1000 Subject: [PATCH 046/215] Add AGENTS.md (#13259) * add a draft * add * up * Apply suggestions from code review Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .ai/AGENTS.md | 77 +++++++++++++++++++++++ .gitignore | 6 +- Makefile | 13 +++- docs/source/en/conceptual/contribution.md | 14 ++++- 4 files changed, 107 insertions(+), 3 deletions(-) create mode 100644 .ai/AGENTS.md diff --git a/.ai/AGENTS.md b/.ai/AGENTS.md new file mode 100644 index 000000000000..9e93ae79df92 --- /dev/null +++ b/.ai/AGENTS.md @@ -0,0 +1,77 @@ +# Diffusers — Agent Guide + +## Coding style + +Strive to write code as simple and explicit as possible. + +- Minimize small helper/utility functions — inline the logic instead. A reader should be able to follow the full flow without jumping between functions. +- No defensive code or unused code paths — do not add fallback paths, safety checks, or configuration options "just in case". When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating. +- Do not guess user intent and silently correct behavior. Make the expected inputs clear in the docstring, and raise a concise error for unsupported cases rather than adding complex fallback logic. + +--- + +### Dependencies +- No new mandatory dependency without discussion (e.g. `einops`) +- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py` + +## Code formatting +- `make style` and `make fix-copies` should be run as the final step before opening a PR + +### Copied Code +- Many classes are kept in sync with a source via a `# Copied from ...` header comment +- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source +- Remove the header to intentionally break the link + +### Models +- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls. +- Try to not introduce graph breaks as much as possible for better compatibility with `torch.compile`. For example, DO NOT arbitrarily insert operations from NumPy in the forward implementations. +- Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`. + +```python +# transformer_mymodel.py + +class MyModelAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__(self, attn, hidden_states, attention_mask=None, ...): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + # reshape, apply rope, etc. + hidden_states = dispatch_attention_fn( + query, key, value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + return attn.to_out[0](hidden_states) + + +class MyModelAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = MyModelAttnProcessor + _available_processors = [MyModelAttnProcessor] + + def __init__(self, query_dim, heads=8, dim_head=64, ...): + super().__init__() + self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False) + self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False) + self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False) + self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)]) + self.set_processor(MyModelAttnProcessor()) + + def forward(self, hidden_states, attention_mask=None, **kwargs): + return self.processor(self, hidden_states, attention_mask, **kwargs) +``` + +Consult the implementations in `src/diffusers/models/transformers/` if you need further references. + +### Pipeline +- All pipelines must inherit from `DiffusionPipeline`. Consult implementations in `src/diffusers/pipelines` in case you need references. +- DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline` which will be a part of the core codebase (`src`). + + +### Tests +- Slow tests gated with `@slow` and `RUN_SLOW=1` +- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference. diff --git a/.gitignore b/.gitignore index a55026febd5a..d281b8d1511c 100644 --- a/.gitignore +++ b/.gitignore @@ -178,4 +178,8 @@ tags .ruff_cache # wandb -wandb \ No newline at end of file +wandb + +# AI agent generated symlinks +/AGENTS.md +/CLAUDE.md \ No newline at end of file diff --git a/Makefile b/Makefile index b90ff82ab268..491baba074c9 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples +.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples codex claude clean-ai # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src @@ -98,3 +98,14 @@ post-release: post-patch: python utils/release.py --post_release --patch + +# AI agent symlinks + +codex: + ln -snf .ai/AGENTS.md AGENTS.md + +claude: + ln -snf .ai/AGENTS.md CLAUDE.md + +clean-ai: + rm -f AGENTS.md CLAUDE.md diff --git a/docs/source/en/conceptual/contribution.md b/docs/source/en/conceptual/contribution.md index e39a6434f095..22bb265b1d79 100644 --- a/docs/source/en/conceptual/contribution.md +++ b/docs/source/en/conceptual/contribution.md @@ -565,4 +565,16 @@ $ git push --set-upstream origin your-branch-for-syncing ### Style guide -For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html). \ No newline at end of file +For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html). + + +## Coding with AI agents + +The repository keeps AI-agent configuration in `.ai/` and exposes local agent files via symlinks. + +- **Source of truth** — edit `.ai/AGENTS.md` (and any future `.ai/skills/`) +- **Don't edit** generated root-level `AGENTS.md` or `CLAUDE.md` — they are symlinks +- Setup commands: + - `make codex` — symlink for OpenAI Codex + - `make claude` — symlink for Claude Code + - `make clean-ai` — remove generated symlinks \ No newline at end of file From 958a156ff88ad84690600ef45dfbef45cbc6fa02 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Mon, 16 Mar 2026 13:24:57 -0700 Subject: [PATCH 047/215] [docs] updates (#13248) * fixes * few more links * update zh * fix --- docs/source/en/_toctree.yml | 4 +- .../en/api/pipelines/hunyuan_video15.md | 2 +- .../source/en/api/pipelines/hunyuanimage21.md | 2 +- .../en/modular_diffusers/modular_pipeline.md | 2 +- docs/source/en/modular_diffusers/overview.md | 2 +- docs/source/en/optimization/memory.md | 140 +----------------- docs/source/en/optimization/xformers.md | 2 +- .../guiders.md | 0 docs/source/zh/_toctree.yml | 4 +- .../guiders.md | 0 10 files changed, 10 insertions(+), 148 deletions(-) rename docs/source/en/{modular_diffusers => using-diffusers}/guiders.md (100%) rename docs/source/zh/{modular_diffusers => using-diffusers}/guiders.md (100%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c69bcd340b27..6b1a7288d60f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -22,6 +22,8 @@ title: Reproducibility - local: using-diffusers/schedulers title: Schedulers + - local: using-diffusers/guiders + title: Guiders - local: using-diffusers/automodel title: AutoModel - local: using-diffusers/other-formats @@ -110,8 +112,6 @@ title: ModularPipeline - local: modular_diffusers/components_manager title: ComponentsManager - - local: modular_diffusers/guiders - title: Guiders - local: modular_diffusers/custom_blocks title: Building Custom Blocks - local: modular_diffusers/mellon diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index d77e72bb0f71..dfaeab6528f9 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -99,7 +99,7 @@ To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)` pipe.guider = pipe.guider.new(guidance_scale=5.0) ``` -Read more on Guider [here](../../modular_diffusers/guiders). +Read more on Guider [here](../../using-diffusers/guiders). diff --git a/docs/source/en/api/pipelines/hunyuanimage21.md b/docs/source/en/api/pipelines/hunyuanimage21.md index f7ba40e23796..9e8ea2627e33 100644 --- a/docs/source/en/api/pipelines/hunyuanimage21.md +++ b/docs/source/en/api/pipelines/hunyuanimage21.md @@ -30,7 +30,7 @@ HunyuanImage-2.1 comes in the following variants: ## HunyuanImage-2.1 -HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead. +HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../../using-diffusers/guiders)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead. ```python import torch diff --git a/docs/source/en/modular_diffusers/modular_pipeline.md b/docs/source/en/modular_diffusers/modular_pipeline.md index e28e13ed5655..27bc61634805 100644 --- a/docs/source/en/modular_diffusers/modular_pipeline.md +++ b/docs/source/en/modular_diffusers/modular_pipeline.md @@ -338,7 +338,7 @@ guider = ClassifierFreeGuidance(guidance_scale=5.0) pipeline.update_components(guider=guider) ``` -See the [Guiders](./guiders) guide for more details on available guiders and how to configure them. +See the [Guiders](../using-diffusers/guiders) guide for more details on available guiders and how to configure them. ## Splitting a pipeline into stages diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md index 83975200d664..159a1e2ff9e6 100644 --- a/docs/source/en/modular_diffusers/overview.md +++ b/docs/source/en/modular_diffusers/overview.md @@ -39,7 +39,7 @@ The Modular Diffusers docs are organized as shown below. - [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`]. - [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines. -- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline. +- [Guiders](../using-diffusers/guiders) shows you how to use different guidance methods in the pipeline. ## Mellon Integration diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 611e07ec7655..5212b70c9cea 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -482,144 +482,6 @@ print( ) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works ``` -## torch.jit.trace - -[torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) records the operations a model performs on a sample input and creates a new, optimized representation of the model based on the recorded execution path. During tracing, the model is optimized to reduce overhead from Python and dynamic control flows and operations are fused together for more efficiency. The returned executable or [ScriptFunction](https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html) can be compiled. - -```py -import time -import torch -from diffusers import StableDiffusionPipeline -import functools - -# torch disable grad -torch.set_grad_enabled(False) - -# set variables -n_experiments = 2 -unet_runs_per_experiment = 50 - -# load sample inputs -def generate_inputs(): - sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16) - timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999 - encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16) - return sample, timestep, encoder_hidden_states - - -pipeline = StableDiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16, - use_safetensors=True, -).to("cuda") -unet = pipeline.unet -unet.eval() -unet.to(memory_format=torch.channels_last) # use channels_last memory format -unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default - -# warmup -for _ in range(3): - with torch.inference_mode(): - inputs = generate_inputs() - orig_output = unet(*inputs) - -# trace -print("tracing..") -unet_traced = torch.jit.trace(unet, inputs) -unet_traced.eval() -print("done tracing") - -# warmup and optimize graph -for _ in range(5): - with torch.inference_mode(): - inputs = generate_inputs() - orig_output = unet_traced(*inputs) - -# benchmarking -with torch.inference_mode(): - for _ in range(n_experiments): - torch.cuda.synchronize() - start_time = time.time() - for _ in range(unet_runs_per_experiment): - orig_output = unet_traced(*inputs) - torch.cuda.synchronize() - print(f"unet traced inference took {time.time() - start_time:.2f} seconds") - for _ in range(n_experiments): - torch.cuda.synchronize() - start_time = time.time() - for _ in range(unet_runs_per_experiment): - orig_output = unet(*inputs) - torch.cuda.synchronize() - print(f"unet inference took {time.time() - start_time:.2f} seconds") - -# save the model -unet_traced.save("unet_traced.pt") -``` - -Replace the pipeline's UNet with the traced version. - -```py -import torch -from diffusers import StableDiffusionPipeline -from dataclasses import dataclass - -@dataclass -class UNet2DConditionOutput: - sample: torch.Tensor - -pipeline = StableDiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16, - use_safetensors=True, -).to("cuda") - -# use jitted unet -unet_traced = torch.jit.load("unet_traced.pt") - -# del pipeline.unet -class TracedUNet(torch.nn.Module): - def __init__(self): - super().__init__() - self.in_channels = pipe.unet.config.in_channels - self.device = pipe.unet.device - - def forward(self, latent_model_input, t, encoder_hidden_states): - sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] - return UNet2DConditionOutput(sample=sample) - -pipeline.unet = TracedUNet() - -with torch.inference_mode(): - image = pipe([prompt] * 1, num_inference_steps=50).images[0] -``` - ## Memory-efficient attention -> [!TIP] -> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention)! - -The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types. - -By default, if PyTorch >= 2.0 is installed, [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) is used. You don't need to make any additional changes to your code. - -SDPA supports [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [xFormers](https://github.com/facebookresearch/xformers) as well as a native C++ PyTorch implementation. It automatically selects the most optimal implementation based on your input. - -You can explicitly use xFormers with the [`~ModelMixin.enable_xformers_memory_efficient_attention`] method. - -```py -# pip install xformers -import torch -from diffusers import StableDiffusionXLPipeline - -pipeline = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - torch_dtype=torch.float16, -).to("cuda") -pipeline.enable_xformers_memory_efficient_attention() -``` - -Call [`~ModelMixin.disable_xformers_memory_efficient_attention`] to disable it. - -```py -pipeline.disable_xformers_memory_efficient_attention() -``` \ No newline at end of file +Diffusers supports multiple memory-efficient attention backends (FlashAttention, xFormers, SageAttention, and more) through [`~ModelMixin.set_attention_backend`]. Refer to the [Attention backends](./attention_backends) guide to learn how to switch between them. diff --git a/docs/source/en/optimization/xformers.md b/docs/source/en/optimization/xformers.md index 523e81559547..a5ef4c6fbdb9 100644 --- a/docs/source/en/optimization/xformers.md +++ b/docs/source/en/optimization/xformers.md @@ -23,7 +23,7 @@ pip install xformers > [!TIP] > The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers). -After xFormers is installed, you can use `enable_xformers_memory_efficient_attention()` for faster inference and reduced memory consumption as shown in this [section](memory#memory-efficient-attention). +After xFormers is installed, you can use it with [`~ModelMixin.set_attention_backend`] as shown in the [Attention backends](./attention_backends) guide. > [!WARNING] > According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments. diff --git a/docs/source/en/modular_diffusers/guiders.md b/docs/source/en/using-diffusers/guiders.md similarity index 100% rename from docs/source/en/modular_diffusers/guiders.md rename to docs/source/en/using-diffusers/guiders.md diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index ab9eaf6ec7fb..af51506746b2 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -14,6 +14,8 @@ sections: - local: using-diffusers/schedulers title: Load schedulers and models + - local: using-diffusers/guiders + title: Guiders - title: Inference isExpanded: false @@ -80,8 +82,6 @@ title: ModularPipeline - local: modular_diffusers/components_manager title: ComponentsManager - - local: modular_diffusers/guiders - title: Guiders - title: Training isExpanded: false diff --git a/docs/source/zh/modular_diffusers/guiders.md b/docs/source/zh/using-diffusers/guiders.md similarity index 100% rename from docs/source/zh/modular_diffusers/guiders.md rename to docs/source/zh/using-diffusers/guiders.md From 8f7f3f3324850ead72d6e63067e232c18053df81 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 17 Mar 2026 11:22:15 +0800 Subject: [PATCH 048/215] fix parallelism case failure in xpu (#13270) * fix parallelism case failure in xpu Signed-off-by: Wang, Yi * updated Signed-off-by: Wang, Yi --------- Signed-off-by: Wang, Yi Co-authored-by: Sayak Paul --- tests/models/testing_utils/parallelism.py | 33 ++++++++++++++++++----- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 3858acf71ec5..db9817c86995 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -26,9 +26,17 @@ from ...testing_utils import ( is_context_parallel, require_torch_multi_accelerator, + torch_device, ) +# Device configuration mapping +DEVICE_CONFIG = { + "cuda": {"backend": "nccl", "module": torch.cuda}, + "xpu": {"backend": "xccl", "module": torch.xpu}, +} + + def _find_free_port(): """Find a free port on localhost.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -47,12 +55,17 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) + # Get device configuration + device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"]) + backend = device_config["backend"] + device_module = device_config["module"] + # Initialize process group - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) # Set device for this process - torch.cuda.set_device(rank) - device = torch.device(f"cuda:{rank}") + device_module.set_device(rank) + device = torch.device(f"{torch_device}:{rank}") # Create model model = model_class(**init_dict) @@ -103,10 +116,16 @@ def _custom_mesh_worker( os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + # Get device configuration + device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"]) + backend = device_config["backend"] + device_module = device_config["module"] - torch.cuda.set_device(rank) - device = torch.device(f"cuda:{rank}") + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + + # Set device for this process + device_module.set_device(rank) + device = torch.device(f"{torch_device}:{rank}") model = model_class(**init_dict) model.to(device) @@ -116,7 +135,7 @@ def _custom_mesh_worker( # DeviceMesh must be created after init_process_group, inside each worker process. mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names + torch_device, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names ) cp_config = ContextParallelConfig(**cp_dict, mesh=mesh) model.enable_parallelism(config=cp_config) From b804c66364729ea22e8e238e2eae12f329d1131a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Mar 2026 09:47:53 +0530 Subject: [PATCH 049/215] [Modular] Fix dtype assignment when type hint is AutoModel (#13271) * update * update --- .../modular_pipelines/modular_pipeline_utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index fa82f17a9108..656ab253ccc2 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -309,16 +309,16 @@ def load(self, **kwargs) -> Any: f"`type_hint` is required when loading a single file model but is missing for component: {self.name}" ) + from diffusers import AutoModel + # `torch_dtype` is not an accepted parameter for tokenizers and processors. # As a result, it gets stored in `init_kwargs`, which are written to the config # during save. This causes JSON serialization to fail when saving the component. - if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module): + if self.type_hint is not None and not issubclass(self.type_hint, (torch.nn.Module, AutoModel)): kwargs.pop("torch_dtype", None) if self.type_hint is None: try: - from diffusers import AutoModel - component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs) except Exception as e: raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") @@ -332,12 +332,6 @@ def load(self, **kwargs) -> Any: else getattr(self.type_hint, "from_pretrained") ) - # `torch_dtype` is not an accepted parameter for tokenizers and processors. - # As a result, it gets stored in `init_kwargs`, which are written to the config - # during save. This causes JSON serialization to fail when saving the component. - if not issubclass(self.type_hint, torch.nn.Module): - kwargs.pop("torch_dtype", None) - try: component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs) except Exception as e: From 2841a0540d84b3ef8244a39cd727d95350120b3d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 17 Mar 2026 10:11:47 +0530 Subject: [PATCH 050/215] [tests] fix llava kwargs in the hunyuan tests (#13275) fix llava kwargs in the hunyuan tests --- tests/pipelines/hunyuan_video/test_hunyuan_image2video.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py index 2a28e5e42f7d..1732ac06d1f1 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py @@ -139,7 +139,9 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): num_hidden_layers=2, image_size=224, ) - llava_text_encoder_config = LlavaConfig(vision_config, text_config, pad_token_id=100, image_token_index=101) + llava_text_encoder_config = LlavaConfig( + vision_config=vision_config, text_config=text_config, pad_token_id=100, image_token_index=101 + ) clip_text_encoder_config = CLIPTextConfig( bos_token_id=0, From ac95d1eb6b4f8272f551b6b6e134ffb5c10163ba Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Mar 2026 16:44:04 +0530 Subject: [PATCH 051/215] [CI] Qwen Image Model Test Refactor (#13069) * update * update * update --------- Co-authored-by: Sayak Paul --- .../test_models_transformer_qwenimage.py | 352 +++++++++++------- 1 file changed, 210 insertions(+), 142 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index e6b19377b14f..713a1bec70a5 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,49 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import warnings import torch from diffusers import QwenImageTransformer2DModel from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + ContextParallelTesterMixin, + LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = QwenImageTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.7, 0.6, 0.6] - - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True - +class QwenImageTransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - return self.prepare_dummy_input() + def model_class(self): + return QwenImageTransformer2DModel @property - def input_shape(self): + def output_shape(self) -> tuple[int, int]: return (16, 16) @property - def output_shape(self): + def input_shape(self) -> tuple[int, int]: return (16, 16) - def prepare_dummy_input(self, height=4, width=4): + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 4, + "joint_attention_dim": 16, + "guidance_embeds": False, + "axes_dims_rope": (8, 4, 4), + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: batch_size = 1 num_latent_channels = embedding_dim = 16 - sequence_length = 7 + height = width = 4 + sequence_length = 8 vae_scale_factor = 4 - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) orig_height = height * 2 * vae_scale_factor @@ -70,89 +104,57 @@ def prepare_dummy_input(self, height=4, width=4): "img_shapes": img_shapes, } - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "patch_size": 2, - "in_channels": 16, - "out_channels": 4, - "num_layers": 2, - "attention_head_dim": 16, - "num_attention_heads": 3, - "joint_attention_dim": 16, - "guidance_embeds": False, - "axes_dims_rope": (8, 4, 4), - } - - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_gradient_checkpointing_is_applied(self): - expected_set = {"QwenImageTransformer2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) +class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin): def test_infers_text_seq_len_from_mask(self): - """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() - encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid + encoder_hidden_states_mask[:, 2:] = 0 rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask ) - # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) - self.assertIsInstance(rope_text_seq_len, int) - - # Verify per_sample_len is computed correctly (max valid position + 1 = 2) - self.assertIsInstance(per_sample_len, torch.Tensor) - self.assertEqual(int(per_sample_len.max().item()), 2) - - # Verify mask is normalized to bool dtype - self.assertTrue(normalized_mask.dtype == torch.bool) - self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values - - # Verify rope_text_seq_len is at least the sequence length - self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) + assert isinstance(rope_text_seq_len, int) + assert isinstance(per_sample_len, torch.Tensor) + assert int(per_sample_len.max().item()) == 2 + assert normalized_mask.dtype == torch.bool + assert normalized_mask.sum().item() == 2 + assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1] - # Test 2: Verify model runs successfully with inferred values inputs["encoder_hidden_states_mask"] = normalized_mask with torch.no_grad(): output = model(**inputs) - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] - # Test 3: Different mask pattern (padding at beginning) encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() - encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding - encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + encoder_hidden_states_mask2[:, :3] = 0 + encoder_hidden_states_mask2[:, 3:] = 1 rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask2 ) - # Max valid position is 6 (last token), so per_sample_len should be 7 - self.assertEqual(int(per_sample_len2.max().item()), 7) - self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + assert int(per_sample_len2.max().item()) == 8 + assert normalized_mask2.sum().item() == 5 - # Test 4: No mask provided (None case) rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], None ) - self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) - self.assertIsInstance(rope_text_seq_len_none, int) - self.assertIsNone(per_sample_len_none) - self.assertIsNone(normalized_mask_none) + assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1] + assert isinstance(rope_text_seq_len_none, int) + assert per_sample_len_none is None + assert normalized_mask_none is None def test_non_contiguous_attention_mask(self): - """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() - # Pattern: [True, False, True, False, True, False, False] encoder_hidden_states_mask[:, 1] = 0 encoder_hidden_states_mask[:, 3] = 0 encoder_hidden_states_mask[:, 5:] = 0 @@ -160,95 +162,85 @@ def test_non_contiguous_attention_mask(self): inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask ) - self.assertEqual(int(per_sample_len.max().item()), 5) - self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) - self.assertIsInstance(inferred_rope_len, int) - self.assertTrue(normalized_mask.dtype == torch.bool) + assert int(per_sample_len.max().item()) == 5 + assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1] + assert isinstance(inferred_rope_len, int) + assert normalized_mask.dtype == torch.bool inputs["encoder_hidden_states_mask"] = normalized_mask with torch.no_grad(): output = model(**inputs) - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] def test_txt_seq_lens_deprecation(self): - """Test that passing txt_seq_lens raises a deprecation warning.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - # Prepare inputs with txt_seq_lens (deprecated parameter) txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] - # Remove encoder_hidden_states_mask to use the deprecated path inputs_with_deprecated = inputs.copy() inputs_with_deprecated.pop("encoder_hidden_states_mask") inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens - # Test that deprecation warning is raised - with self.assertWarns(FutureWarning) as warning_context: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") with torch.no_grad(): output = model(**inputs_with_deprecated) - # Verify the warning message mentions the deprecation - warning_message = str(warning_context.warning) - self.assertIn("txt_seq_lens", warning_message) - self.assertIn("deprecated", warning_message) - self.assertIn("encoder_hidden_states_mask", warning_message) + future_warnings = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(future_warnings) > 0, "Expected FutureWarning to be raised" - # Verify the model still works correctly despite the deprecation - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + warning_message = str(future_warnings[0].message) + assert "txt_seq_lens" in warning_message + assert "deprecated" in warning_message + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] def test_layered_model_with_mask(self): - """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" - # Create layered model config init_dict = { "patch_size": 2, "in_channels": 16, "out_channels": 4, "num_layers": 2, "attention_head_dim": 16, - "num_attention_heads": 3, + "num_attention_heads": 4, "joint_attention_dim": 16, - "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) - "use_layer3d_rope": True, # Enable layered RoPE - "use_additional_t_cond": True, # Enable additional time conditioning + "axes_dims_rope": (8, 4, 4), + "use_layer3d_rope": True, + "use_additional_t_cond": True, } model = self.model_class(**init_dict).to(torch_device) - # Verify the model uses QwenEmbedLayer3DRope from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope - self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + assert isinstance(model.pos_embed, QwenEmbedLayer3DRope) - # Test single generation with layered structure batch_size = 1 - text_seq_len = 7 + text_seq_len = 8 img_h, img_w = 4, 4 layers = 4 - # For layered model: (layers + 1) because we have N layers + 1 combined image hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) - # Create mask with some padding encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) - encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens + encoder_hidden_states_mask[0, 5:] = 0 timestep = torch.tensor([1.0]).to(torch_device) - # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) - # Layer structure: 4 layers + 1 condition image img_shapes = [ [ - (1, img_h, img_w), # layer 0 - (1, img_h, img_w), # layer 1 - (1, img_h, img_w), # layer 2 - (1, img_h, img_w), # layer 3 - (1, img_h, img_w), # condition image (last one gets special treatment) + (1, img_h, img_w), + (1, img_h, img_w), + (1, img_h, img_w), + (1, img_h, img_w), + (1, img_h, img_w), ] ] @@ -262,37 +254,113 @@ def test_layered_model_with_mask(self): additional_t_cond=addition_t_cond, ) - self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + assert output.sample.shape[1] == hidden_states.shape[1] + + +class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for QwenImage Transformer.""" + + +class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin): + """Training tests for QwenImage Transformer.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"QwenImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for QwenImage Transformer.""" -class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = QwenImageTransformer2DModel +class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): + """Context Parallel inference tests for QwenImage Transformer.""" - def prepare_init_args_and_inputs_for_common(self): - return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common() - def prepare_dummy_input(self, height, width): - return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) +class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for QwenImage Transformer.""" - def test_torch_compile_recompilation_and_graph_break(self): - super().test_torch_compile_recompilation_and_graph_break() + +class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for QwenImage Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + batch_size = 1 + num_latent_channels = embedding_dim = 16 + sequence_length = 8 + vae_scale_factor = 4 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for QwenImage Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + batch_size = 1 + num_latent_channels = embedding_dim = 16 + sequence_length = 8 + vae_scale_factor = 4 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } def test_torch_compile_with_and_without_mask(self): - """Test that torch.compile works with both None mask and padding mask.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) model.eval() model.compile(mode="default", fullgraph=True) - # Test 1: Run with None mask (no padding, all tokens are valid) inputs_no_mask = inputs.copy() inputs_no_mask["encoder_hidden_states_mask"] = None - # First run to allow compilation with torch.no_grad(): output_no_mask = model(**inputs_no_mask) - # Second run to verify no recompilation with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(error_on_recompile=True), @@ -300,19 +368,15 @@ def test_torch_compile_with_and_without_mask(self): ): output_no_mask_2 = model(**inputs_no_mask) - self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1] - # Test 2: Run with all-ones mask (should behave like None) inputs_all_ones = inputs.copy() - # Keep the all-ones mask - self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + assert inputs_all_ones["encoder_hidden_states_mask"].all().item() - # First run to allow compilation with torch.no_grad(): output_all_ones = model(**inputs_all_ones) - # Second run to verify no recompilation with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(error_on_recompile=True), @@ -320,21 +384,18 @@ def test_torch_compile_with_and_without_mask(self): ): output_all_ones_2 = model(**inputs_all_ones) - self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1] - # Test 3: Run with actual padding mask (has zeros) inputs_with_padding = inputs.copy() mask_with_padding = inputs["encoder_hidden_states_mask"].clone() - mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding + mask_with_padding[:, 4:] = 0 inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding - # First run to allow compilation with torch.no_grad(): output_with_padding = model(**inputs_with_padding) - # Second run to verify no recompilation with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(error_on_recompile=True), @@ -342,8 +403,15 @@ def test_torch_compile_with_and_without_mask(self): ): output_with_padding_2 = model(**inputs_with_padding) - self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1] + + assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3) + + +class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for QwenImage Transformer.""" + - # Verify that outputs are different (mask should affect results) - self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)) +class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for QwenImage Transformer.""" From 4ec59b650299e4835e3a877f883a40b6b3114c9a Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 18 Mar 2026 12:09:52 +0800 Subject: [PATCH 052/215] add ltx2 vae in sana-video; (#13229) * add ltx2 vae in sana-video; * add ltx vae in conversion script; * Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py Co-authored-by: YiYi Xu * condition `vae_scale_factor_xxx` related settings on VAE types; * make the mean/std depends on vae class; --------- Co-authored-by: YiYi Xu --- scripts/convert_sana_video_to_diffusers.py | 49 +++++++++++--- .../sana_video/pipeline_sana_video.py | 44 +++++++++---- .../sana_video/pipeline_sana_video_i2v.py | 64 +++++++++++++------ 3 files changed, 115 insertions(+), 42 deletions(-) diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index a939a06cbd46..c6be52d455b8 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -12,6 +12,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import ( + AutoencoderKLLTX2Video, AutoencoderKLWan, DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, @@ -24,7 +25,10 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext -ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"] +ckpt_ids = [ + "Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth", + "Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth", +] # https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py @@ -92,12 +96,22 @@ def main(args): if args.video_size == 480: sample_size = 30 # Wan-VAE: 8xp2 downsample factor patch_size = (1, 2, 2) + in_channels = 16 + out_channels = 16 elif args.video_size == 720: - sample_size = 22 # Wan-VAE: 32xp1 downsample factor + sample_size = 22 # DC-AE-V: 32xp1 downsample factor patch_size = (1, 1, 1) + in_channels = 32 + out_channels = 32 else: raise ValueError(f"Video size {args.video_size} is not supported.") + if args.vae_type == "ltx2": + sample_size = 22 + patch_size = (1, 1, 1) + in_channels = 128 + out_channels = 128 + for depth in range(layer_num): # Transformer blocks. converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( @@ -182,8 +196,8 @@ def main(args): # Transformer with CTX(): transformer_kwargs = { - "in_channels": 16, - "out_channels": 16, + "in_channels": in_channels, + "out_channels": out_channels, "num_attention_heads": 20, "attention_head_dim": 112, "num_layers": 20, @@ -235,9 +249,12 @@ def main(args): else: print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - vae = AutoencoderKLWan.from_pretrained( - "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 - ) + if args.vae_type == "ltx2": + vae_path = args.vae_path or "Lightricks/LTX-2" + vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) + else: + vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) # Text Encoder text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" @@ -314,7 +331,23 @@ def main(args): choices=["flow-dpm_solver", "flow-euler", "uni-pc"], help="Scheduler type to use.", ) - parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.") + parser.add_argument( + "--vae_type", + default="wan", + type=str, + choices=["wan", "ltx2"], + help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).", + ) + parser.add_argument( + "--vae_path", + default=None, + type=str, + required=False, + help="Optional VAE path or repo id. If not set, a default is used per VAE type.", + ) + parser.add_argument( + "--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v." + ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.") parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py index 8b44dfc1143c..7ae85639e358 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py @@ -24,7 +24,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import SanaLoraLoaderMixin -from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel from ...schedulers import DPMSolverMultistepScheduler from ...utils import ( BACKENDS_MAPPING, @@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): The tokenizer used to tokenize the prompt. text_encoder ([`Gemma2PreTrainedModel`]): Text encoder model to encode the input prompts. - vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer ([`SanaVideoTransformer3DModel`]): Conditional Transformer to denoise the input latents. @@ -213,7 +213,7 @@ def __init__( self, tokenizer: GemmaTokenizer | GemmaTokenizerFast, text_encoder: Gemma2PreTrainedModel, - vae: AutoencoderDC | AutoencoderKLWan, + vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan, transformer: SanaVideoTransformer3DModel, scheduler: DPMSolverMultistepScheduler, ): @@ -223,8 +223,19 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + if getattr(self, "vae", None): + if isinstance(self.vae, AutoencoderKLLTX2Video): + self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio + elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)): + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 self.vae_scale_factor = self.vae_scale_factor_spatial @@ -985,14 +996,21 @@ def __call__( if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError ) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + if isinstance(self.vae, AutoencoderKLLTX2Video): + latents_mean = self.vae.latents_mean + latents_std = self.vae.latents_std + z_dim = self.vae.config.latent_channels + elif isinstance(self.vae, AutoencoderKLWan): + latents_mean = torch.tensor(self.vae.config.latents_mean) + latents_std = torch.tensor(self.vae.config.latents_std) + z_dim = self.vae.config.z_dim + else: + latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype) + latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype) + z_dim = latents.shape[1] + + latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean try: video = self.vae.decode(latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py index b90d7c6f5a60..81df1d0759da 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py @@ -26,7 +26,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import SanaLoraLoaderMixin -from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, @@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): The tokenizer used to tokenize the prompt. text_encoder ([`Gemma2PreTrainedModel`]): Text encoder model to encode the input prompts. - vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer ([`SanaVideoTransformer3DModel`]): Conditional Transformer to denoise the input latents. @@ -203,7 +203,7 @@ def __init__( self, tokenizer: GemmaTokenizer | GemmaTokenizerFast, text_encoder: Gemma2PreTrainedModel, - vae: AutoencoderDC | AutoencoderKLWan, + vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan, transformer: SanaVideoTransformer3DModel, scheduler: FlowMatchEulerDiscreteScheduler, ): @@ -213,8 +213,19 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + if getattr(self, "vae", None): + if isinstance(self.vae, AutoencoderKLLTX2Video): + self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio + elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)): + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 self.vae_scale_factor = self.vae_scale_factor_spatial @@ -687,14 +698,18 @@ def prepare_latents( image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, -1, 1, 1, 1) - .to(image_latents.device, image_latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( - image_latents.device, image_latents.dtype - ) + if isinstance(self.vae, AutoencoderKLLTX2Video): + _latents_mean = self.vae.latents_mean + _latents_std = self.vae.latents_std + elif isinstance(self.vae, AutoencoderKLWan): + _latents_mean = torch.tensor(self.vae.config.latents_mean) + _latents_std = torch.tensor(self.vae.config.latents_std) + else: + _latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype) + _latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype) + + latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype) image_latents = (image_latents - latents_mean) * latents_std latents[:, :, 0:1] = image_latents.to(dtype) @@ -1034,14 +1049,21 @@ def __call__( if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError ) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + if isinstance(self.vae, AutoencoderKLLTX2Video): + latents_mean = self.vae.latents_mean + latents_std = self.vae.latents_std + z_dim = self.vae.config.latent_channels + elif isinstance(self.vae, AutoencoderKLWan): + latents_mean = torch.tensor(self.vae.config.latents_mean) + latents_std = torch.tensor(self.vae.config.latents_std) + z_dim = self.vae.config.z_dim + else: + latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype) + latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype) + z_dim = latents.shape[1] + + latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean try: video = self.vae.decode(latents, return_dict=False)[0] From 6fb3fa758cae8dca8ddb8115473463c01663d13f Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Wed, 18 Mar 2026 14:58:35 +0800 Subject: [PATCH 053/215] skip invalid test case for helios pipeline (#13218) * skip invalid test case for helio pipeline Signed-off-by: Liu, Kaixuan * update skip reason Signed-off-by: Liu, Kaixuan --------- Signed-off-by: Liu, Kaixuan --- tests/pipelines/helios/test_helios.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/helios/test_helios.py b/tests/pipelines/helios/test_helios.py index b8ee99085036..93f80b31c5fd 100644 --- a/tests/pipelines/helios/test_helios.py +++ b/tests/pipelines/helios/test_helios.py @@ -139,9 +139,9 @@ def test_inference(self): generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) - # Override to set a more lenient max diff threshold. + @unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case") def test_save_load_float16(self): - super().test_save_load_float16(expected_max_diff=0.03) + pass @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): From 7f457987b369d89d8b02eba9b9e37a7ec664d3ad Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 19 Mar 2026 02:11:58 +0800 Subject: [PATCH 054/215] [Helios] Remove lru_cache for better AoTI compatibility and cleaner code (#13282) fix: drop lru_cache for better AoTI compatibility --- src/diffusers/models/transformers/transformer_helios.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 6d81f8a13af7..922b0724c87e 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from functools import lru_cache from typing import Any import torch @@ -343,7 +342,6 @@ def get_frequency_batched(self, freqs_base, pos): return freqs.cos(), freqs.sin() @torch.no_grad() - @lru_cache(maxsize=32) def _get_spatial_meshgrid(self, height, width, device_str): device = torch.device(device_str) grid_y_coords = torch.arange(height, device=device, dtype=torch.float32) From 00efd936f05c51e55cea564f78cf935b10bf68e2 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Thu, 19 Mar 2026 11:55:08 +0800 Subject: [PATCH 055/215] =?UTF-8?q?fix:=20'PaintByExampleImageEncoder'=20o?= =?UTF-8?q?bject=20has=20no=20attribute=20'all=5Ftied=5Fw=E2=80=A6=20(#132?= =?UTF-8?q?52)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: 'PaintByExampleImageEncoder' object has no attribute 'all_tied_weights_keys' Signed-off-by: Liu, Kaixuan * also fix LDMBertModel Signed-off-by: Liu, Kaixuan --------- Signed-off-by: Liu, Kaixuan Co-authored-by: YiYi Xu --- .../pipelines/latent_diffusion/pipeline_latent_diffusion.py | 1 + src/diffusers/pipelines/paint_by_example/image_encoder.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index ec43988f9389..458e6dbfe7d2 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -720,6 +720,7 @@ def __init__(self, config: LDMBertConfig): super().__init__(config) self.model = LDMBertEncoder(config) self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + self.post_init() def forward( self, diff --git a/src/diffusers/pipelines/paint_by_example/image_encoder.py b/src/diffusers/pipelines/paint_by_example/image_encoder.py index 74c575ed8653..da1273bcdd52 100644 --- a/src/diffusers/pipelines/paint_by_example/image_encoder.py +++ b/src/diffusers/pipelines/paint_by_example/image_encoder.py @@ -35,6 +35,8 @@ def __init__(self, config, proj_size=None): # uncondition for scaling self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) + self.post_init() + def forward(self, pixel_values, return_uncond_vector=False): clip_output = self.model(pixel_values=pixel_values) latent_states = clip_output.pooler_output From 9c5e41e07f3765f479312f3fbebcf4f86adb4398 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:58:29 -0700 Subject: [PATCH 056/215] Add Support for LTX-2.3 Models (#13217) * Initial implementation of perturbed attn processor for LTX 2.3 * Update DiT block for LTX 2.3 + add self_attention_mask * Add flag to control using perturbed attn processor for now * Add support for new video upsampling blocks used by LTX-2.3 * Support LTX-2.3 Big-VGAN V2-style vocoder * Initial implementation of LTX-2.3 vocoder with bandwidth extender * Initial support for LTX-2.3 per-modality feature extractor * Refactor so that text connectors own all text encoder hidden_states normalization logic * Fix some bugs for inference * Fix LTX-2.X DiT block forward pass * Support prompt timestep embeds and prompt cross attn modulation * Add LTX-2.3 configs to conversion script * Support converting LTX-2.3 DiT checkpoints * Support converting LTX-2.3 Video VAE checkpoints * Support converting LTX-2.3 Vocoder with bandwidth extender * Support converting LTX-2.3 text connectors * Don't convert any upsamplers for now * Support self attention mask for LTX2Pipeline * Fix some inference bugs * Support self attn mask and sigmas for LTX-2.3 I2V, Cond pipelines * Support STG and modality isolation guidance for LTX-2.3 * make style and make quality * Make audio guidance values default to video values by default * Update to LTX-2.3 style guidance rescaling * Support cross timesteps for LTX-2.3 cross attention modulation * Fix RMS norm bug for LTX-2.3 text connectors * Perform guidance rescale in sample (x0) space following original code * Support LTX-2.3 Latent Spatial Upsampler model * Support LTX-2.3 distilled LoRA * Support LTX-2.3 Distilled checkpoint * Support LTX-2.3 prompt enhancement * Make LTX-2.X processor non-required so that tests pass * Fix test_components_function tests for LTX2 T2V and I2V * Fix LTX-2.3 Video VAE configuration bug causing pixel jitter * Apply suggestions from code review Co-authored-by: Sayak Paul * Refactor LTX-2.X Video VAE upsampler block init logic * Refactor LTX-2.X guidance rescaling to use rescale_noise_cfg * Use generator initial seed to control prompt enhancement if available * Remove self attention mask logic as it is not used in any current pipelines * Commit fixes suggested by claude code (guidance in sample (x0) space, denormalize after timestep conditioning) * Use constant shift following original code --------- Co-authored-by: Sayak Paul --- scripts/convert_ltx2_to_diffusers.py | 378 +++++++++++-- .../loaders/lora_conversion_utils.py | 3 + .../autoencoders/autoencoder_kl_ltx2.py | 82 ++- .../models/transformers/transformer_ltx2.py | 529 ++++++++++++++---- src/diffusers/pipelines/ltx2/__init__.py | 4 +- src/diffusers/pipelines/ltx2/connectors.py | 224 ++++++-- .../pipelines/ltx2/latent_upsampler.py | 5 +- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 475 ++++++++++++---- .../pipelines/ltx2/pipeline_ltx2_condition.py | 408 ++++++++++---- .../ltx2/pipeline_ltx2_image2video.py | 487 ++++++++++++---- src/diffusers/pipelines/ltx2/vocoder.py | 463 ++++++++++++++- tests/pipelines/ltx2/test_ltx2.py | 1 + tests/pipelines/ltx2/test_ltx2_image2video.py | 1 + 13 files changed, 2494 insertions(+), 566 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 9dee954af6d0..f1556557889f 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -7,7 +7,7 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from transformers import AutoTokenizer, Gemma3ForConditionalGeneration +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, Gemma3Processor from diffusers import ( AutoencoderKLLTX2Audio, @@ -17,7 +17,7 @@ LTX2Pipeline, LTX2VideoTransformer3DModel, ) -from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder +from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder, LTX2VocoderWithBWE from diffusers.utils.import_utils import is_accelerate_available @@ -44,6 +44,12 @@ "k_norm": "norm_k", } +LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT = { + **LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT, + "audio_prompt_adaln_single": "audio_prompt_adaln", + "prompt_adaln_single": "prompt_adaln", +} + LTX_2_0_VIDEO_VAE_RENAME_DICT = { # Encoder "down_blocks.0": "down_blocks.0", @@ -72,6 +78,13 @@ "per_channel_statistics.std-of-means": "latents_std", } +LTX_2_3_VIDEO_VAE_RENAME_DICT = { + **LTX_2_0_VIDEO_VAE_RENAME_DICT, + # Decoder extra blocks + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", +} + LTX_2_0_AUDIO_VAE_RENAME_DICT = { "per_channel_statistics.mean-of-means": "latents_mean", "per_channel_statistics.std-of-means": "latents_std", @@ -84,10 +97,34 @@ "conv_post": "conv_out", } -LTX_2_0_TEXT_ENCODER_RENAME_DICT = { +LTX_2_3_VOCODER_RENAME_DICT = { + # Handle upsamplers ("ups" --> "upsamplers") due to name clash + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", + "act_post": "act_out", + "downsample.lowpass": "downsample", +} + +LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", "video_embeddings_connector": "video_connector", "audio_embeddings_connector": "audio_connector", "transformer_1d_blocks": "transformer_blocks", + "text_embedding_projection.aggregate_embed": "text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + # LTX-2.3 uses per-modality embedding projections + "text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in", + "text_embedding_projection.video_aggregate_embed": "video_text_proj_in", # Attention QK Norms "q_norm": "norm_q", "k_norm": "norm_k", @@ -129,23 +166,24 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: dict[str return +def convert_ltx2_3_vocoder_upsamplers(key: str, state_dict: dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if ".ups." in key: + new_key = key.replace(".ups.", ".upsamplers.") + param = state_dict.pop(key) + state_dict[new_key] = param + return + + LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { "video_embeddings_connector": remove_keys_inplace, "audio_embeddings_connector": remove_keys_inplace, "adaln_single": convert_ltx2_transformer_adaln_single, } -LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { - "connectors.": "", - "video_embeddings_connector": "video_connector", - "audio_embeddings_connector": "audio_connector", - "transformer_1d_blocks": "transformer_blocks", - "text_embedding_projection.aggregate_embed": "text_proj_in", - # Attention QK Norms - "q_norm": "norm_q", - "k_norm": "norm_k", -} - LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_inplace, "per_channel_statistics.mean-of-stds": remove_keys_inplace, @@ -155,13 +193,19 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: dict[str LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} +LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP = { + ".ups.": convert_ltx2_3_vocoder_upsamplers, +} + +LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP = {} + def split_transformer_and_connector_state_dict(state_dict: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: connector_prefixes = ( "video_embeddings_connector", "audio_embeddings_connector", "transformer_1d_blocks", - "text_embedding_projection.aggregate_embed", + "text_embedding_projection", "connectors.", "video_connector", "audio_connector", @@ -225,7 +269,7 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str, special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP elif version == "2.0": config = { - "model_id": "diffusers-internal-dev/new-ltx-model", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "in_channels": 128, "out_channels": 128, @@ -238,6 +282,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str, "pos_embed_max_pos": 20, "base_height": 2048, "base_width": 2048, + "gated_attn": False, + "cross_attn_mod": False, "audio_in_channels": 128, "audio_out_channels": 128, "audio_patch_size": 1, @@ -249,6 +295,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str, "audio_pos_embed_max_pos": 20, "audio_sampling_rate": 16000, "audio_hop_length": 160, + "audio_gated_attn": False, + "audio_cross_attn_mod": False, "num_layers": 48, "activation_fn": "gelu-approximate", "qk_norm": "rms_norm_across_heads", @@ -263,10 +311,62 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str, "timestep_scale_multiplier": 1000, "cross_attn_timestep_scale_multiplier": 1000, "rope_type": "split", + "use_prompt_embeddings": True, + "perturbed_attn": False, }, } rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "gated_attn": True, + "cross_attn_mod": True, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "audio_gated_attn": True, + "audio_cross_attn_mod": True, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1000, + "rope_type": "split", + "use_prompt_embeddings": False, + "perturbed_attn": True, + }, + } + rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap @@ -293,7 +393,7 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str, } elif version == "2.0": config = { - "model_id": "diffusers-internal-dev/new-ltx-model", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "caption_channels": 3840, "text_proj_in_factor": 49, @@ -301,20 +401,52 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str, "video_connector_attention_head_dim": 128, "video_connector_num_layers": 2, "video_connector_num_learnable_registers": 128, + "video_gated_attn": False, "audio_connector_num_attention_heads": 30, "audio_connector_attention_head_dim": 128, "audio_connector_num_layers": 2, "audio_connector_num_learnable_registers": 128, + "audio_gated_attn": False, "connector_rope_base_seq_len": 4096, "rope_theta": 10000.0, "rope_double_precision": True, "causal_temporal_positioning": False, "rope_type": "split", + "per_modality_projections": False, + "proj_bias": False, }, } - - rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT - special_keys_remap = {} + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 32, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 8, + "video_connector_num_learnable_registers": 128, + "video_gated_attn": True, + "audio_connector_num_attention_heads": 32, + "audio_connector_attention_head_dim": 64, + "audio_connector_num_layers": 8, + "audio_connector_num_learnable_registers": 128, + "audio_gated_attn": True, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + "rope_type": "split", + "per_modality_projections": True, + "video_hidden_dim": 4096, + "audio_hidden_dim": 2048, + "proj_bias": True, + }, + } + rename_dict = LTX_2_3_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap @@ -416,7 +548,7 @@ def get_ltx2_video_vae_config( special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP elif version == "2.0": config = { - "model_id": "diffusers-internal-dev/dummy-ltx2", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "in_channels": 3, "out_channels": 3, @@ -435,6 +567,7 @@ def get_ltx2_video_vae_config( "decoder_spatio_temporal_scaling": (True, True, True), "decoder_inject_noise": (False, False, False, False), "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_type": ("spatiotemporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), "timestep_conditioning": timestep_conditioning, @@ -451,6 +584,44 @@ def get_ltx2_video_vae_config( } rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 1024), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 512, 1024), + "layers_per_block": (4, 6, 4, 2, 2), + "decoder_layers_per_block": (4, 6, 4, 2, 2), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True, True), + "decoder_inject_noise": (False, False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_type": ("spatiotemporal", "spatiotemporal", "temporal", "spatial"), + "upsample_residual": (False, False, False, False), + "upsample_factor": (2, 2, 1, 2), + "timestep_conditioning": timestep_conditioning, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "zeros", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_3_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap @@ -485,7 +656,7 @@ def convert_ltx2_video_vae( def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: if version == "2.0": config = { - "model_id": "diffusers-internal-dev/new-ltx-model", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "base_channels": 128, "output_channels": 2, @@ -508,6 +679,31 @@ def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, A } rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + "double_z": True, + }, # Same config as LTX-2.0 + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap @@ -540,7 +736,7 @@ def convert_ltx2_audio_vae(original_state_dict: dict[str, Any], version: str) -> def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: if version == "2.0": config = { - "model_id": "diffusers-internal-dev/new-ltx-model", + "model_id": "Lightricks/LTX-2", "diffusers_config": { "in_channels": 128, "hidden_channels": 1024, @@ -549,21 +745,71 @@ def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any "upsample_factors": [6, 5, 2, 2, 2], "resnet_kernel_sizes": [3, 7, 11], "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "act_fn": "leaky_relu", "leaky_relu_negative_slope": 0.1, + "antialias": False, + "final_act_fn": "tanh", + "final_bias": True, "output_sampling_rate": 24000, }, } rename_dict = LTX_2_0_VOCODER_RENAME_DICT special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1536, + "out_channels": 2, + "upsample_kernel_sizes": [11, 4, 4, 4, 4, 4], + "upsample_factors": [5, 2, 2, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "act_fn": "snakebeta", + "leaky_relu_negative_slope": 0.1, + "antialias": True, + "antialias_ratio": 2, + "antialias_kernel_size": 12, + "final_act_fn": None, + "final_bias": False, + "bwe_in_channels": 128, + "bwe_hidden_channels": 512, + "bwe_out_channels": 2, + "bwe_upsample_kernel_sizes": [12, 11, 4, 4, 4], + "bwe_upsample_factors": [6, 5, 2, 2, 2], + "bwe_resnet_kernel_sizes": [3, 7, 11], + "bwe_resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "bwe_act_fn": "snakebeta", + "bwe_leaky_relu_negative_slope": 0.1, + "bwe_antialias": True, + "bwe_antialias_ratio": 2, + "bwe_antialias_kernel_size": 12, + "bwe_final_act_fn": None, + "bwe_final_bias": False, + "filter_length": 512, + "hop_length": 80, + "window_length": 512, + "num_mel_channels": 64, + "input_sampling_rate": 16000, + "output_sampling_rate": 48000, + }, + } + rename_dict = LTX_2_3_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP return config, rename_dict, special_keys_remap def convert_ltx2_vocoder(original_state_dict: dict[str, Any], version: str) -> dict[str, Any]: config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) diffusers_config = config["diffusers_config"] + if version == "2.3": + vocoder_cls = LTX2VocoderWithBWE + else: + vocoder_cls = LTX2Vocoder with init_empty_weights(): - vocoder = LTX2Vocoder.from_config(diffusers_config) + vocoder = vocoder_cls.from_config(diffusers_config) # Handle official code --> diffusers key remapping via the remap dict for key in list(original_state_dict.keys()): @@ -594,6 +840,18 @@ def get_ltx2_spatial_latent_upsampler_config(version: str): "spatial_upsample": True, "temporal_upsample": False, "rational_spatial_scale": 2.0, + "use_rational_resampler": True, + } + elif version == "2.3": + config = { + "in_channels": 128, + "mid_channels": 1024, + "num_blocks_per_stage": 4, + "dims": 3, + "spatial_upsample": True, + "temporal_upsample": False, + "rational_spatial_scale": 2.0, + "use_rational_resampler": False, } else: raise ValueError(f"Unsupported version: {version}") @@ -651,13 +909,17 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: dict[str, Any], prefi model_state_dict = {} for param_name, param in combined_ckpt.items(): if param_name.startswith(prefix): - model_state_dict[param_name.replace(prefix, "")] = param + model_state_dict[param_name.removeprefix(prefix)] = param if prefix == "model.diffusion_model.": # Some checkpoints store the text connector projection outside the diffusion model prefix. - connector_key = "text_embedding_projection.aggregate_embed.weight" - if connector_key in combined_ckpt and connector_key not in model_state_dict: - model_state_dict[connector_key] = combined_ckpt[connector_key] + connector_prefixes = ["text_embedding_projection"] + for param_name, param in combined_ckpt.items(): + for prefix in connector_prefixes: + if param_name.startswith(prefix): + # Check to make sure we're not overwriting an existing key + if param_name not in model_state_dict: + model_state_dict[param_name] = combined_ckpt[param_name] return model_state_dict @@ -686,7 +948,7 @@ def none_or_str(value: str): "--version", type=str, default="2.0", - choices=["test", "2.0"], + choices=["test", "2.0", "2.3"], help="Version of the LTX 2.0 model", ) @@ -748,6 +1010,11 @@ def none_or_str(value: str): action="store_true", help="Whether to save a latent upsampling pipeline", ) + parser.add_argument( + "--add_processor", + action="store_true", + help="Whether to add a Gemma3Processor to the pipeline for prompt enhancement.", + ) parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) @@ -756,6 +1023,12 @@ def none_or_str(value: str): parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument( + "--upsample_output_path", + type=str, + default=None, + help="Path where converted upsampling pipeline should be saved", + ) return parser.parse_args() @@ -787,7 +1060,7 @@ def main(args): args.audio_vae, args.dit, args.vocoder, - args.text_encoder, + args.connectors, args.full_pipeline, args.upsample_pipeline, ] @@ -852,7 +1125,12 @@ def main(args): if not args.full_pipeline: tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) - if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline: + if args.add_processor: + processor = Gemma3Processor.from_pretrained(args.text_encoder_model_id) + if not args.full_pipeline: + processor.save_pretrained(os.path.join(args.output_path, "processor")) + + if args.latent_upsampler or args.upsample_pipeline: original_latent_upsampler_ckpt = load_hub_or_local_checkpoint( repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename ) @@ -866,14 +1144,26 @@ def main(args): latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler")) if args.full_pipeline: - scheduler = FlowMatchEulerDiscreteScheduler( - use_dynamic_shifting=True, - base_shift=0.95, - max_shift=2.05, - base_image_seq_len=1024, - max_image_seq_len=4096, - shift_terminal=0.1, - ) + is_distilled_ckpt = "distilled" in args.combined_filename + if is_distilled_ckpt: + # Disable dynamic shifting and terminal shift so that distilled sigmas are used as-is + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=False, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=None, + ) + else: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) pipe = LTX2Pipeline( scheduler=scheduler, @@ -891,10 +1181,12 @@ def main(args): if args.upsample_pipeline: pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) - # Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline - pipe.save_pretrained( - os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB" - ) + # As two diffusers pipelines cannot be in the same directory, save the upsampling pipeline to its own directory + if args.upsample_output_path: + upsample_output_path = args.upsample_output_path + else: + upsample_output_path = args.output_path + pipe.save_pretrained(upsample_output_path, safe_serialization=True, max_shard_size="5GB") if __name__ == "__main__": diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 6e43aef51ce0..298aa61d37ed 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2156,6 +2156,9 @@ def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_pref "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", "q_norm": "norm_q", "k_norm": "norm_k", + # LTX-2.3 + "audio_prompt_adaln_single": "audio_prompt_adaln", + "prompt_adaln_single": "prompt_adaln", } else: rename_dict = {"aggregate_embed": "text_proj_in"} diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 7c04bd715c25..f4f7d46628c8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -237,7 +237,7 @@ def forward( # Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d -class LTXVideoDownsampler3d(nn.Module): +class LTX2VideoDownsampler3d(nn.Module): def __init__( self, in_channels: int, @@ -285,10 +285,11 @@ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Ten # Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d -class LTXVideoUpsampler3d(nn.Module): +class LTX2VideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, + out_channels: int | None = None, stride: int | tuple[int, int, int] = 1, residual: bool = False, upscale_factor: int = 1, @@ -300,7 +301,8 @@ def __init__( self.residual = residual self.upscale_factor = upscale_factor - out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + out_channels = out_channels or in_channels + out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, @@ -408,7 +410,7 @@ def __init__( ) elif downsample_type == "spatial": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), @@ -417,7 +419,7 @@ def __init__( ) elif downsample_type == "temporal": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), @@ -426,7 +428,7 @@ def __init__( ) elif downsample_type == "spatiotemporal": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), @@ -580,6 +582,7 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, + upsample_type: str = "spatiotemporal", inject_noise: bool = False, timestep_conditioning: bool = False, upsample_residual: bool = False, @@ -609,16 +612,23 @@ def __init__( self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList( - [ - LTXVideoUpsampler3d( - out_channels * upscale_factor, - stride=(2, 2, 2), - residual=upsample_residual, - upscale_factor=upscale_factor, - spatial_padding_mode=spatial_padding_mode, - ) - ] + self.upsamplers = nn.ModuleList() + + if upsample_type == "spatial": + upsample_stride = (1, 2, 2) + elif upsample_type == "temporal": + upsample_stride = (2, 1, 1) + elif upsample_type == "spatiotemporal": + upsample_stride = (2, 2, 2) + + self.upsamplers.append( + LTX2VideoUpsampler3d( + in_channels=out_channels * upscale_factor, + stride=upsample_stride, + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) ) resnets = [] @@ -716,7 +726,7 @@ def __init__( "LTX2VideoDownBlock3D", "LTX2VideoDownBlock3D", ), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True), layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), patch_size: int = 4, @@ -726,6 +736,9 @@ def __init__( spatial_padding_mode: str = "zeros", ): super().__init__() + num_encoder_blocks = len(layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1) self.patch_size = patch_size self.patch_size_t = patch_size_t @@ -860,19 +873,27 @@ def __init__( in_channels: int = 128, out_channels: int = 3, block_out_channels: tuple[int, ...] = (256, 512, 1024), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True), layers_per_block: tuple[int, ...] = (5, 5, 5, 5), + upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"), patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = False, - inject_noise: tuple[bool, ...] = (False, False, False), + inject_noise: bool | tuple[bool, ...] = (False, False, False), timestep_conditioning: bool = False, - upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_residual: bool | tuple[bool, ...] = (True, True, True), upsample_factor: tuple[bool, ...] = (2, 2, 2), spatial_padding_mode: str = "reflect", ) -> None: super().__init__() + num_decoder_blocks = len(layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1) + if isinstance(inject_noise, bool): + inject_noise = (inject_noise,) * num_decoder_blocks + if isinstance(upsample_residual, bool): + upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1) self.patch_size = patch_size self.patch_size_t = patch_size_t @@ -917,6 +938,7 @@ def __init__( num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], + upsample_type=upsample_type[i], inject_noise=inject_noise[i + 1], timestep_conditioning=timestep_conditioning, upsample_residual=upsample_residual[i], @@ -1058,11 +1080,12 @@ def __init__( decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024), layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), - decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), - decoder_inject_noise: tuple[bool, ...] = (False, False, False, False), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True), + decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False), downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), - upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: bool | tuple[bool, ...] = (True, True, True), upsample_factor: tuple[int, ...] = (2, 2, 2), timestep_conditioning: bool = False, patch_size: int = 4, @@ -1077,6 +1100,16 @@ def __init__( temporal_compression_ratio: int = None, ) -> None: super().__init__() + num_encoder_blocks = len(layers_per_block) + num_decoder_blocks = len(decoder_layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1) + if isinstance(decoder_spatio_temporal_scaling, bool): + decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1) + if isinstance(decoder_inject_noise, bool): + decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks + if isinstance(upsample_residual, bool): + upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1) self.encoder = LTX2VideoEncoder3d( in_channels=in_channels, @@ -1098,6 +1131,7 @@ def __init__( block_out_channels=decoder_block_out_channels, spatio_temporal_scaling=decoder_spatio_temporal_scaling, layers_per_block=decoder_layers_per_block, + upsample_type=upsample_type, patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index db949ca34a1f..a4915ccfb96a 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -178,6 +178,10 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states + if attn.to_gate_logits is not None: + # Calculate gate logits on original hidden_states + gate_logits = attn.to_gate_logits(hidden_states) + query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -212,6 +216,112 @@ def __call__( hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) + if attn.to_gate_logits is not None: + hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D] + # The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1 + gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H] + hidden_states = hidden_states * gates.unsqueeze(-1) + hidden_states = hidden_states.flatten(2, 3) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2PerturbedAttnProcessor: + r""" + Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + perturbation_mask: torch.Tensor | None = None, + all_perturbed: bool | None = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.to_gate_logits is not None: + # Calculate gate logits on original hidden_states + gate_logits = attn.to_gate_logits(hidden_states) + + value = attn.to_v(encoder_hidden_states) + if all_perturbed is None: + all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False + + if all_perturbed: + # Skip attention, use the value projection value + hidden_states = value + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if perturbation_mask is not None: + value = value.flatten(2, 3) + hidden_states = torch.lerp(value, hidden_states, perturbation_mask) + + if attn.to_gate_logits is not None: + hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D] + # The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1 + gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H] + hidden_states = hidden_states * gates.unsqueeze(-1) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states @@ -224,7 +334,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin): """ _default_processor_cls = LTX2AudioVideoAttnProcessor - _available_processors = [LTX2AudioVideoAttnProcessor] + _available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor] def __init__( self, @@ -240,6 +350,7 @@ def __init__( norm_eps: float = 1e-6, norm_elementwise_affine: bool = True, rope_type: str = "interleaved", + apply_gated_attention: bool = False, processor=None, ): super().__init__() @@ -266,6 +377,12 @@ def __init__( self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(torch.nn.Dropout(dropout)) + if apply_gated_attention: + # Per head gate values + self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True) + else: + self.to_gate_logits = None + if processor is None: processor = self._default_processor_cls() self.set_processor(processor) @@ -321,6 +438,10 @@ def __init__( audio_num_attention_heads: int, audio_attention_head_dim, audio_cross_attention_dim: int, + video_gated_attn: bool = False, + video_cross_attn_adaln: bool = False, + audio_gated_attn: bool = False, + audio_cross_attn_adaln: bool = False, qk_norm: str = "rms_norm_across_heads", activation_fn: str = "gelu-approximate", attention_bias: bool = True, @@ -328,9 +449,16 @@ def __init__( eps: float = 1e-6, elementwise_affine: bool = False, rope_type: str = "interleaved", + perturbed_attn: bool = False, ): super().__init__() + self.perturbed_attn = perturbed_attn + if perturbed_attn: + attn_processor_cls = LTX2PerturbedAttnProcessor + else: + attn_processor_cls = LTX2AudioVideoAttnProcessor + # 1. Self-Attention (video and audio) self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.attn1 = LTX2Attention( @@ -343,6 +471,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -356,6 +486,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 2. Prompt Cross-Attention @@ -370,6 +502,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -383,6 +517,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -398,6 +534,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video @@ -412,6 +550,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 4. Feedforward layers @@ -422,14 +562,36 @@ def __init__( self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) # 5. Per-Layer Modulation Parameters - # Self-Attention / Feedforward AdaLayerNorm-Zero mod params - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + # Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params + # 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q + self.video_cross_attn_adaln = video_cross_attn_adaln + self.audio_cross_attn_adaln = audio_cross_attn_adaln + video_mod_param_num = 9 if self.video_cross_attn_adaln else 6 + audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6 + self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5) + + # Prompt cross-attn (attn2) additional modulation params + self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln + if self.cross_attn_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim)) # Per-layer a2v, v2a Cross-Attention mod params self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + @staticmethod + def get_mod_params( + scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int + ) -> tuple[torch.Tensor, ...]: + num_ada_params = scale_shift_table.shape[0] + ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.shape[1], num_ada_params, -1 + ) + ada_params = ada_values.unbind(dim=2) + return ada_params + def forward( self, hidden_states: torch.Tensor, @@ -442,143 +604,181 @@ def forward( temb_ca_audio_scale_shift: torch.Tensor, temb_ca_gate: torch.Tensor, temb_ca_audio_gate: torch.Tensor, + temb_prompt: torch.Tensor | None = None, + temb_prompt_audio: torch.Tensor | None = None, video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, + self_attention_mask: torch.Tensor | None = None, + audio_self_attention_mask: torch.Tensor | None = None, a2v_cross_attention_mask: torch.Tensor | None = None, v2a_cross_attention_mask: torch.Tensor | None = None, + use_a2v_cross_attention: bool = True, + use_v2a_cross_attention: bool = True, + perturbation_mask: torch.Tensor | None = None, + all_perturbed: bool | None = None, ) -> torch.Tensor: batch_size = hidden_states.size(0) # 1. Video and Audio Self-Attention - norm_hidden_states = self.norm1(hidden_states) + # 1.1. Video Self-Attention + video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6] + if self.video_cross_attn_adaln: + shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9] - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( - batch_size, temb.size(1), num_ada_params, -1 - ) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - attn_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=None, - query_rotary_emb=video_rotary_emb, - ) + video_self_attn_args = { + "hidden_states": norm_hidden_states, + "encoder_hidden_states": None, + "query_rotary_emb": video_rotary_emb, + "attention_mask": self_attention_mask, + } + if self.perturbed_attn: + video_self_attn_args["perturbation_mask"] = perturbation_mask + video_self_attn_args["all_perturbed"] = all_perturbed + + attn_hidden_states = self.attn1(**video_self_attn_args) hidden_states = hidden_states + attn_hidden_states * gate_msa - norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) - - num_audio_ada_params = self.audio_scale_shift_table.shape[0] - audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( - batch_size, temb_audio.size(1), num_audio_ada_params, -1 - ) + # 1.2. Audio Self-Attention + audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size) audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( - audio_ada_values.unbind(dim=2) + audio_ada_params[:6] ) + if self.audio_cross_attn_adaln: + audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9] + + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa - attn_audio_hidden_states = self.audio_attn1( - hidden_states=norm_audio_hidden_states, - encoder_hidden_states=None, - query_rotary_emb=audio_rotary_emb, - ) + audio_self_attn_args = { + "hidden_states": norm_audio_hidden_states, + "encoder_hidden_states": None, + "query_rotary_emb": audio_rotary_emb, + "attention_mask": audio_self_attention_mask, + } + if self.perturbed_attn: + audio_self_attn_args["perturbation_mask"] = perturbation_mask + audio_self_attn_args["all_perturbed"] = all_perturbed + + attn_audio_hidden_states = self.audio_attn1(**audio_self_attn_args) audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa - # 2. Video and Audio Cross-Attention with the text embeddings + # 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text) + if self.cross_attn_adaln: + video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size) + shift_text_kv, scale_text_kv = video_prompt_ada_params + + audio_prompt_ada_params = self.get_mod_params( + self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size + ) + audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params + + # 2.1. Video-Text Cross-Attention (Q: Video; K,V: Text) norm_hidden_states = self.norm2(hidden_states) + if self.video_cross_attn_adaln: + norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q + if self.cross_attn_adaln: + encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv + attn_hidden_states = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, query_rotary_emb=None, attention_mask=encoder_attention_mask, ) + if self.video_cross_attn_adaln: + attn_hidden_states = attn_hidden_states * gate_text_q hidden_states = hidden_states + attn_hidden_states + # 2.2. Audio-Text Cross-Attention norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + if self.audio_cross_attn_adaln: + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q + if self.cross_attn_adaln: + audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv + attn_audio_hidden_states = self.audio_attn2( norm_audio_hidden_states, encoder_hidden_states=audio_encoder_hidden_states, query_rotary_emb=None, attention_mask=audio_encoder_attention_mask, ) + if self.audio_cross_attn_adaln: + attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q audio_hidden_states = audio_hidden_states + attn_audio_hidden_states # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention - norm_hidden_states = self.audio_to_video_norm(hidden_states) - norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) - - # Combine global and per-layer cross attention modulation parameters - # Video - video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] - video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] - - video_ca_scale_shift_table = ( - video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) - + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) - ).unbind(dim=2) - video_ca_gate = ( - video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) - + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) - ).unbind(dim=2) - - video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table - a2v_gate = video_ca_gate[0].squeeze(2) - - # Audio - audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] - audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] - - audio_ca_scale_shift_table = ( - audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) - + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) - ).unbind(dim=2) - audio_ca_gate = ( - audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) - + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) - ).unbind(dim=2) - - audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table - v2a_gate = audio_ca_gate[0].squeeze(2) - - # Audio-to-Video Cross Attention: Q: Video; K,V: Audio - mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze( - 2 - ) - mod_norm_audio_hidden_states = norm_audio_hidden_states * ( - 1 + audio_a2v_ca_scale.squeeze(2) - ) + audio_a2v_ca_shift.squeeze(2) - - a2v_attn_hidden_states = self.audio_to_video_attn( - mod_norm_hidden_states, - encoder_hidden_states=mod_norm_audio_hidden_states, - query_rotary_emb=ca_video_rotary_emb, - key_rotary_emb=ca_audio_rotary_emb, - attention_mask=a2v_cross_attention_mask, - ) + if use_a2v_cross_attention or use_v2a_cross_attention: + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) - hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + # 3.1. Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] - # Video-to-Audio Cross Attention: Q: Audio; K,V: Video - mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze( - 2 - ) - mod_norm_audio_hidden_states = norm_audio_hidden_states * ( - 1 + audio_v2a_ca_scale.squeeze(2) - ) + audio_v2a_ca_shift.squeeze(2) - - v2a_attn_hidden_states = self.video_to_audio_attn( - mod_norm_audio_hidden_states, - encoder_hidden_states=mod_norm_hidden_states, - query_rotary_emb=ca_audio_rotary_emb, - key_rotary_emb=ca_video_rotary_emb, - attention_mask=v2a_cross_attention_mask, - ) + video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size) + video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params + a2v_gate = video_ca_gate_param[0].squeeze(2) - audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_ada_params = self.get_mod_params( + audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size + ) + audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params + v2a_gate = audio_ca_gate_param[0].squeeze(2) + + # 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio + if use_a2v_cross_attention: + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_a2v_ca_scale.squeeze(2) + ) + video_a2v_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video + if use_v2a_cross_attention: + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_v2a_ca_scale.squeeze(2) + ) + video_v2a_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states # 4. Feedforward norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp @@ -918,6 +1118,8 @@ def __init__( pos_embed_max_pos: int = 20, base_height: int = 2048, base_width: int = 2048, + gated_attn: bool = False, + cross_attn_mod: bool = False, audio_in_channels: int = 128, # Audio Arguments audio_out_channels: int | None = 128, audio_patch_size: int = 1, @@ -929,6 +1131,8 @@ def __init__( audio_pos_embed_max_pos: int = 20, audio_sampling_rate: int = 16000, audio_hop_length: int = 160, + audio_gated_attn: bool = False, + audio_cross_attn_mod: bool = False, num_layers: int = 48, # Shared arguments activation_fn: str = "gelu-approximate", qk_norm: str = "rms_norm_across_heads", @@ -943,6 +1147,8 @@ def __init__( timestep_scale_multiplier: int = 1000, cross_attn_timestep_scale_multiplier: int = 1000, rope_type: str = "interleaved", + use_prompt_embeddings=True, + perturbed_attn: bool = False, ) -> None: super().__init__() @@ -956,17 +1162,25 @@ def __init__( self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) # 2. Prompt embeddings - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=audio_inner_dim - ) + if use_prompt_embeddings: + # LTX-2.0; LTX-2.3 uses per-modality feature projections in the connector instead + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) # 3. Timestep Modulation Params and Embedding + self.prompt_modulation = cross_attn_mod or audio_cross_attn_mod # used by LTX-2.3 + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters - self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + video_time_emb_mod_params = 9 if cross_attn_mod else 6 + audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6 + self.time_embed = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=video_time_emb_mod_params, use_additional_conditions=False + ) self.audio_time_embed = LTX2AdaLayerNormSingle( - audio_inner_dim, num_mod_params=6, use_additional_conditions=False + audio_inner_dim, num_mod_params=audio_time_emb_mod_params, use_additional_conditions=False ) # 3.2. Global Cross Attention Modulation Parameters @@ -995,6 +1209,13 @@ def __init__( self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + # 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3) + if self.prompt_modulation: + self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False) + self.audio_prompt_adaln = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=2, use_additional_conditions=False + ) + # 4. Rotary Positional Embeddings (RoPE) # Self-Attention self.rope = LTX2AudioVideoRotaryPosEmbed( @@ -1071,6 +1292,10 @@ def __init__( audio_num_attention_heads=audio_num_attention_heads, audio_attention_head_dim=audio_attention_head_dim, audio_cross_attention_dim=audio_cross_attention_dim, + video_gated_attn=gated_attn, + video_cross_attn_adaln=cross_attn_mod, + audio_gated_attn=audio_gated_attn, + audio_cross_attn_adaln=audio_cross_attn_mod, qk_norm=qk_norm, activation_fn=activation_fn, attention_bias=attention_bias, @@ -1078,6 +1303,7 @@ def __init__( eps=norm_eps, elementwise_affine=norm_elementwise_affine, rope_type=rope_type, + perturbed_attn=perturbed_attn, ) for _ in range(num_layers) ] @@ -1101,6 +1327,8 @@ def forward( audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, audio_timestep: torch.LongTensor | None = None, + sigma: torch.Tensor | None = None, + audio_sigma: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, num_frames: int | None = None, @@ -1110,6 +1338,10 @@ def forward( audio_num_frames: int | None = None, video_coords: torch.Tensor | None = None, audio_coords: torch.Tensor | None = None, + isolate_modalities: bool = False, + spatio_temporal_guidance_blocks: list[int] | None = None, + perturbation_mask: torch.Tensor | None = None, + use_cross_timestep: bool = False, attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, ) -> torch.Tensor: @@ -1131,6 +1363,13 @@ def forward( audio_timestep (`torch.Tensor`, *optional*): Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation params. This is only used by certain pipelines such as the I2V pipeline. + sigma (`torch.Tensor`, *optional*): + Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in + models such as LTX-2.3. + audio_sigma (`torch.Tensor`, *optional*): + Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in + models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to + the provided `sigma` value. encoder_attention_mask (`torch.Tensor`, *optional*): Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. audio_encoder_attention_mask (`torch.Tensor`, *optional*): @@ -1152,6 +1391,21 @@ def forward( audio_coords (`torch.Tensor`, *optional*): The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + isolate_modalities (`bool`, *optional*, defaults to `False`): + Whether to isolate each modality by turning off cross-modality (audio-to-video and video-to-audio) + cross attention (for all blocks). Use for modality guidance in LTX-2.3. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The transformer block indices at which to apply spatio-temporal guidance (STG), which shortcuts the + self-attention operations by simply using the values rather than the full scaled dot-product attention + (SDPA) operation. If `None` or empty, STG will not be applied to any block. + perturbation_mask (`torch.Tensor`, *optional*): + Perturbation mask for STG of shape `(batch_size,)` or `(batch_size, 1, 1)`. Should be 0 at batch + elements where STG should be applied and 1 elsewhere. If STG is being used but `peturbation_mask` is + not supplied, will default to applying STG (perturbing) all batch elements. + use_cross_timestep (`bool` *optional*, defaults to `False`): + Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when + calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior; + `False` is the legacy LTX-2.0 behavior. attention_kwargs (`dict[str, Any]`, *optional*): Optional dict of keyword args to be passed to the attention processor. return_dict (`bool`, *optional*, defaults to `True`): @@ -1165,6 +1419,7 @@ def forward( """ # Determine timestep for audio. audio_timestep = audio_timestep if audio_timestep is not None else timestep + audio_sigma = audio_sigma if audio_sigma is not None else sigma # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: @@ -1223,14 +1478,28 @@ def forward( temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + if self.prompt_modulation: + # LTX-2.3 + temb_prompt, _ = self.prompt_adaln( + sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + temb_prompt_audio, _ = self.audio_prompt_adaln( + audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype + ) + temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1)) + temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1)) + else: + temb_prompt = temb_prompt_audio = None + # 3.2. Prepare global modality cross attention modulation parameters + video_ca_timestep = audio_sigma.flatten() if use_cross_timestep else timestep.flatten() video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( - timestep.flatten(), + video_ca_timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( - timestep.flatten() * timestep_cross_attn_gate_scale_factor, + video_ca_timestep * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) @@ -1239,13 +1508,14 @@ def forward( ) video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + audio_ca_timestep = sigma.flatten() if use_cross_timestep else audio_timestep.flatten() audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( - audio_timestep.flatten(), + audio_ca_timestep, batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( - audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + audio_ca_timestep * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) @@ -1254,15 +1524,30 @@ def forward( ) audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) - # 4. Prepare prompt embeddings - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + # 4. Prepare prompt embeddings (LTX-2.0) + if self.config.use_prompt_embeddings: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) - audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view( + batch_size, -1, audio_hidden_states.size(-1) + ) # 5. Run transformer blocks - for block in self.transformer_blocks: + spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks or [] + if len(spatio_temporal_guidance_blocks) > 0 and perturbation_mask is None: + # If STG is being used and perturbation_mask is not supplied, default to perturbing all batch elements. + perturbation_mask = torch.zeros((batch_size,)) + if perturbation_mask is not None and perturbation_mask.ndim == 1: + perturbation_mask = perturbation_mask[:, None, None] # unsqueeze to 3D to broadcast with hidden_states + all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False + stg_blocks = set(spatio_temporal_guidance_blocks) + + for block_idx, block in enumerate(self.transformer_blocks): + block_perturbation_mask = perturbation_mask if block_idx in stg_blocks else None + block_all_perturbed = all_perturbed if block_idx in stg_blocks else False + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, audio_hidden_states = self._gradient_checkpointing_func( block, @@ -1276,12 +1561,22 @@ def forward( audio_cross_attn_scale_shift, video_cross_attn_a2v_gate, audio_cross_attn_v2a_gate, + temb_prompt, + temb_prompt_audio, video_rotary_emb, audio_rotary_emb, video_cross_attn_rotary_emb, audio_cross_attn_rotary_emb, encoder_attention_mask, audio_encoder_attention_mask, + None, # self_attention_mask + None, # audio_self_attention_mask + None, # a2v_cross_attention_mask + None, # v2a_cross_attention_mask + not isolate_modalities, # use_a2v_cross_attention + not isolate_modalities, # use_v2a_cross_attention + block_perturbation_mask, + block_all_perturbed, ) else: hidden_states, audio_hidden_states = block( @@ -1295,12 +1590,22 @@ def forward( temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, temb_ca_gate=video_cross_attn_a2v_gate, temb_ca_audio_gate=audio_cross_attn_v2a_gate, + temb_prompt=temb_prompt, + temb_prompt_audio=temb_prompt_audio, video_rotary_emb=video_rotary_emb, audio_rotary_emb=audio_rotary_emb, ca_video_rotary_emb=video_cross_attn_rotary_emb, ca_audio_rotary_emb=audio_cross_attn_rotary_emb, encoder_attention_mask=encoder_attention_mask, audio_encoder_attention_mask=audio_encoder_attention_mask, + self_attention_mask=None, + audio_self_attention_mask=None, + a2v_cross_attention_mask=None, + v2a_cross_attention_mask=None, + use_a2v_cross_attention=not isolate_modalities, + use_v2a_cross_attention=not isolate_modalities, + perturbation_mask=block_perturbation_mask, + all_perturbed=block_all_perturbed, ) # 6. Output layers (including unpatchification) diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index d6a408d5c546..7177faaf3486 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -28,7 +28,7 @@ _import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] - _import_structure["vocoder"] = ["LTX2Vocoder"] + _import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -44,7 +44,7 @@ from .pipeline_ltx2_condition import LTX2ConditionPipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline - from .vocoder import LTX2Vocoder + from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE else: import sys diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index 4b2a81a9dc2c..a49de4083342 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -9,6 +11,79 @@ from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor +def per_layer_masked_mean_norm( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +): + """ + Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states. + Respects the padding of the hidden states. + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + +def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True) + norm_text_encoder_hidden_states = text_encoder_hidden_states * torch.rsqrt(variance + eps) + return norm_text_encoder_hidden_states + + class LTX2RotaryPosEmbed1d(nn.Module): """ 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. @@ -106,6 +181,7 @@ def __init__( activation_fn: str = "gelu-approximate", eps: float = 1e-6, rope_type: str = "interleaved", + apply_gated_attention: bool = False, ): super().__init__() @@ -115,8 +191,9 @@ def __init__( heads=num_attention_heads, kv_heads=num_attention_heads, dim_head=attention_head_dim, - processor=LTX2AudioVideoAttnProcessor(), rope_type=rope_type, + apply_gated_attention=apply_gated_attention, + processor=LTX2AudioVideoAttnProcessor(), ) self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) @@ -160,6 +237,7 @@ def __init__( eps: float = 1e-6, causal_temporal_positioning: bool = False, rope_type: str = "interleaved", + gated_attention: bool = False, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -188,6 +266,7 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, rope_type=rope_type, + apply_gated_attention=gated_attention, ) for _ in range(num_layers) ] @@ -260,24 +339,36 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): @register_to_config def __init__( self, - caption_channels: int, - text_proj_in_factor: int, - video_connector_num_attention_heads: int, - video_connector_attention_head_dim: int, - video_connector_num_layers: int, - video_connector_num_learnable_registers: int | None, - audio_connector_num_attention_heads: int, - audio_connector_attention_head_dim: int, - audio_connector_num_layers: int, - audio_connector_num_learnable_registers: int | None, - connector_rope_base_seq_len: int, - rope_theta: float, - rope_double_precision: bool, - causal_temporal_positioning: bool, + caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size + text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B + video_connector_num_attention_heads: int = 30, + video_connector_attention_head_dim: int = 128, + video_connector_num_layers: int = 2, + video_connector_num_learnable_registers: int | None = 128, + video_gated_attn: bool = False, + audio_connector_num_attention_heads: int = 30, + audio_connector_attention_head_dim: int = 128, + audio_connector_num_layers: int = 2, + audio_connector_num_learnable_registers: int | None = 128, + audio_gated_attn: bool = False, + connector_rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_temporal_positioning: bool = False, rope_type: str = "interleaved", + per_modality_projections: bool = False, + video_hidden_dim: int = 4096, + audio_hidden_dim: int = 2048, + proj_bias: bool = False, ): super().__init__() - self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + text_encoder_dim = caption_channels * text_proj_in_factor + if per_modality_projections: + self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias) + self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias) + else: + self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias) + self.video_connector = LTX2ConnectorTransformer1d( num_attention_heads=video_connector_num_attention_heads, attention_head_dim=video_connector_attention_head_dim, @@ -288,6 +379,7 @@ def __init__( rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, rope_type=rope_type, + gated_attention=video_gated_attn, ) self.audio_connector = LTX2ConnectorTransformer1d( num_attention_heads=audio_connector_num_attention_heads, @@ -299,26 +391,86 @@ def __init__( rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, rope_type=rope_type, + gated_attention=audio_gated_attn, ) def forward( - self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False - ): - # Convert to additive attention mask, if necessary - if not additive_mask: - text_dtype = text_encoder_hidden_states.dtype - attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) - attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max - - text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) - - video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) - - attn_mask = (new_attn_mask < 1e-6).to(torch.int64) - attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) - video_text_embedding = video_text_embedding * attn_mask - new_attn_mask = attn_mask.squeeze(-1) - - audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) - - return video_text_embedding, audio_text_embedding, new_attn_mask + self, + text_encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + padding_side: str = "left", + scale_factor: int = 8, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text + embeddings for the LTX-2.X DiT models. + + Args: + text_encoder_hidden_states (`torch.Tensor`)): + Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len, + caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): + Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked + positions. + padding_side (`str`, *optional*, defaults to `"left"`): + The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to + `"left"` as this is what the default Gemma3-12B text encoder uses. Only used if + `per_modality_projections` is `False` (LTX-2.0 models). + scale_factor (`int`, *optional*, defaults to `8`): + Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False` + (LTX-2.0 models). + """ + if text_encoder_hidden_states.ndim == 3: + # Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor] + text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1)) + + if self.config.per_modality_projections: + # LTX-2.3 + norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states) + + norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3) + bool_mask = attention_mask.bool().unsqueeze(-1) + norm_text_encoder_hidden_states = torch.where( + bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states) + ) + + # Rescale norms with respect to video and audio dims for feature extractors + video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels) + video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor + audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels) + audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor + + # Per-Modality Feature extractors + video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb) + audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb) + else: + # LTX-2.0 + sequence_lengths = attention_mask.sum(dim=-1) + norm_text_encoder_hidden_states = per_layer_masked_mean_norm( + text_hidden_states=text_encoder_hidden_states, + sequence_lengths=sequence_lengths, + device=text_encoder_hidden_states.device, + padding_side=padding_side, + scale_factor=scale_factor, + ) + + text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states) + video_text_emb_proj = text_emb_proj + audio_text_emb_proj = text_emb_proj + + # Convert to additive attention mask for connectors + text_dtype = video_text_emb_proj.dtype + attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype) + attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + add_attn_mask = attention_mask * torch.finfo(text_dtype).max + + video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask) + + # Convert video attn mask to binary (multiplicative) mask and mask video text embedding + binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64) + binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * binary_attn_mask + + audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask) + + return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1) diff --git a/src/diffusers/pipelines/ltx2/latent_upsampler.py b/src/diffusers/pipelines/ltx2/latent_upsampler.py index f6c589a70ab6..329ced36d45b 100644 --- a/src/diffusers/pipelines/ltx2/latent_upsampler.py +++ b/src/diffusers/pipelines/ltx2/latent_upsampler.py @@ -195,7 +195,8 @@ def __init__( dims: int = 3, spatial_upsample: bool = True, temporal_upsample: bool = False, - rational_spatial_scale: float | None = 2.0, + rational_spatial_scale: float = 2.0, + use_rational_resampler: bool = True, ): super().__init__() @@ -220,7 +221,7 @@ def __init__( PixelShuffleND(3), ) elif spatial_upsample: - if rational_spatial_scale is not None: + if use_rational_resampler: self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale) else: self.upsampler = torch.nn.Sequential( diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 037840360137..73ebac0f173c 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -18,7 +18,7 @@ import numpy as np import torch -from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast +from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin @@ -31,7 +31,7 @@ from ..pipeline_utils import DiffusionPipeline from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder +from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): @@ -209,7 +209,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): """ model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" - _optional_components = [] + _optional_components = ["processor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -221,7 +221,8 @@ def __init__( tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, - vocoder: LTX2Vocoder, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, + processor: Gemma3Processor | None = None, ): super().__init__() @@ -234,6 +235,7 @@ def __init__( transformer=transformer, vocoder=vocoder, scheduler=scheduler, + processor=processor, ) self.vae_spatial_compression_ratio = ( @@ -268,73 +270,6 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) - @staticmethod - def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - def _get_gemma_prompt_embeds( self, prompt: str | list[str], @@ -387,16 +322,7 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - - prompt_embeds = self._pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -494,6 +420,50 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + @torch.no_grad() + def enhance_prompt( + self, + prompt: str, + system_prompt: str, + max_new_tokens: int = 512, + seed: int = 10, + generator: torch.Generator | None = None, + generation_kwargs: dict[str, Any] | None = None, + device: str | torch.device | None = None, + ): + """ + Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a + `transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt. + """ + device = device or self._execution_device + if generation_kwargs is None: + # Set to default generation kwargs + generation_kwargs = {"do_sample": True, "temperature": 0.7} + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user prompt: {prompt}"}, + ] + template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.processor(text=template, images=None, return_tensors="pt").to(device) + self.text_encoder.to(device) + + # `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness, + # so manually apply a seed for reproducible generation. + if generator is not None: + # Overwrite seed to generator's initial seed + seed = generator.initial_seed() + torch.manual_seed(seed) + generated_sequences = self.text_encoder.generate( + **model_inputs, + max_new_tokens=max_new_tokens, + **generation_kwargs, + ) # tensor of shape [batch_size, seq_len] + + generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)] + enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + return enhanced_prompt + def check_inputs( self, prompt, @@ -504,6 +474,9 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -547,6 +520,12 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + @staticmethod def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. @@ -734,7 +713,6 @@ def prepare_audio_latents( latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) - # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) @@ -749,6 +727,24 @@ def prepare_audio_latents( latents = self._pack_audio_latents(latents) return latents + def convert_velocity_to_x0( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx] + return sample_x0 + + def convert_x0_to_velocity( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx] + return sample_v + @property def guidance_scale(self): return self._guidance_scale @@ -757,9 +753,41 @@ def guidance_scale(self): def guidance_rescale(self): return self._guidance_rescale + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) @property def num_timesteps(self): @@ -791,7 +819,14 @@ def __call__( sigmas: list[float] | None = None, timesteps: list[int] = None, guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, noise_scale: float = 0.0, num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -803,6 +838,11 @@ def __call__( negative_prompt_attention_mask: torch.Tensor | None = None, decode_timestep: float | list[float] = 0.0, decode_noise_scale: float | list[float] | None = None, + use_cross_timestep: bool = False, + system_prompt: str | None = None, + prompt_max_new_tokens: int = 512, + prompt_enhancement_kwargs: dict[str, Any] | None = None, + prompt_enhancement_seed: int = 10, output_type: str = "pil", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, @@ -841,13 +881,47 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is + a separate value `audio_guidance_scale` for the audio modality). + stg_scale (`float`, *optional*, defaults to `0.0`): + Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for + Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate + where we move the sample away from a weak sample from a perturbed version of the denoising model. + Enabling STG will result in an additional denoising model forward pass; the default value of `0.0` + means that STG is disabled. + modality_scale (`float`, *optional*, defaults to `1.0`): + Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a + weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio) + cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an + additional denoising model forward pass; the default value of `1.0` means that modality guidance is + disabled. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when - using zero terminal SNR. + using zero terminal SNR. Used for the video modality. + audio_guidance_scale (`float`, *optional* defaults to `None`): + Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for + video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest + that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for + LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value + `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and + audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the + video value `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule + is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and + audio. If `None`, defaults to the video value `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value + `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used + (`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0 + and `[28]` is recommended for LTX-2.3. noise_scale (`float`, *optional*, defaults to `0.0`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to the `latents` and `audio_latents` before continue denoising. @@ -878,6 +952,24 @@ def __call__( The timestep at which generated video is decoded. decode_noise_scale (`float`, defaults to `None`): The interpolation factor between random noise and denoised latents at the decode timestep. + use_cross_timestep (`bool` *optional*, defaults to `False`): + Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when + calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior; + `False` is the legacy LTX-2.0 behavior. + system_prompt (`str`, *optional*, defaults to `None`): + Optional system prompt to use for prompt enhancement. The system prompt will be used by the current + text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from + the original `prompt` to condition generation. If not supplied, prompt enhancement will not be + performed. + prompt_max_new_tokens (`int`, *optional*, defaults to `512`): + The maximum number of new tokens to generate when performing prompt enhancement. + prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`): + Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of + `do_sample=True` and `temperature=0.7` will be used. See + https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate + for more details. + prompt_enhancement_seed (`int`, *optional*, default to `10`): + Random seed for any random operations during prompt enhancement. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -910,6 +1002,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -920,10 +1017,21 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, ) + # Per-modality guidance scales (video, audio) self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -939,6 +1047,17 @@ def __call__( device = self._execution_device # 3. Prepare text embeddings + if system_prompt is not None and prompt is not None: + prompt = self.enhance_prompt( + prompt=prompt, + system_prompt=system_prompt, + max_new_tokens=prompt_max_new_tokens, + seed=prompt_enhancement_seed, + generator=generator, + generation_kwargs=prompt_enhancement_kwargs, + device=device, + ) + ( prompt_embeds, prompt_attention_mask, @@ -960,9 +1079,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side ) # 4. Prepare latent variables @@ -984,7 +1105,7 @@ def __call__( raise ValueError( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." ) - video_sequence_length = latent_num_frames * latent_height * latent_width + # video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( @@ -1041,7 +1162,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( - video_sequence_length, + self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_image_seq_len", 1024), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.95), @@ -1069,11 +1190,6 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ) # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate @@ -1111,6 +1227,7 @@ def __call__( encoder_hidden_states=connector_prompt_embeds, audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=timestep, + sigma=timestep, # Used by LTX-2.3 encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, @@ -1120,7 +1237,10 @@ def __call__( audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - # rope_interpolation_scale=rope_interpolation_scale, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, attention_kwargs=attention_kwargs, return_dict=False, ) @@ -1128,24 +1248,155 @@ def __call__( noise_pred_audio = noise_pred_audio.float() if self.do_classifier_free_guidance: - noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( - noise_pred_video_text - noise_pred_video_uncond + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_video_uncond_text = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_text, i, self.scheduler ) + # Use delta formulation as it works more nicely with multiple guidance terms + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) - noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( - noise_pred_audio_text - noise_pred_audio_uncond + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + noise_pred_audio_uncond_text = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler + ) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text ) - if self.guidance_rescale > 0: - # Based on 3.4. in https://huggingface.co/papers/2305.08891 - noise_pred_video = rescale_noise_cfg( - noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + # Get positive values from merged CFG inputs in case we need to do other DiT forward passes + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + # Only split values that remain constant throughout the loop once + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + # Split values that vary each denoising loop iteration + timestep = timestep.chunk(2, dim=0)[0] + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + + if self.do_spatio_temporal_guidance: + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, ) - noise_pred_audio = rescale_noise_cfg( - noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + noise_pred_video_uncond_stg = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_stg, i, self.scheduler + ) + noise_pred_audio_uncond_stg = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler + ) + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + if self.do_modality_isolation_guidance: + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, ) + noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() + noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() + noise_pred_video_uncond_modality = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_modality, i, self.scheduler + ) + noise_pred_audio_uncond_modality = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler + ) + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_modality + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_modality + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Now apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta + + # Apply LTX-2.X guidance rescaling + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale + ) + else: + noise_pred_video = noise_pred_video_g + + if self.audio_guidance_rescale > 0: + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale + ) + else: + noise_pred_audio = noise_pred_audio_g + + # Convert back to velocity for scheduler + noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] @@ -1177,9 +1428,6 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) audio_latents = self._denormalize_audio_latents( audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std @@ -1187,6 +1435,9 @@ def __call__( audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) if output_type == "latent": + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) video = latents audio = audio_latents else: @@ -1209,6 +1460,10 @@ def __call__( ] latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.vae.dtype) video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 4c451330f439..a80d011015cf 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -33,7 +33,7 @@ from ..pipeline_utils import DiffusionPipeline from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder +from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): @@ -254,7 +254,7 @@ def __init__( tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, - vocoder: LTX2Vocoder, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, ): super().__init__() @@ -300,74 +300,6 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) - @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds - def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, @@ -421,16 +353,7 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - - prompt_embeds = self._pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -541,6 +464,9 @@ def check_inputs( negative_prompt_attention_mask=None, latents=None, audio_latents=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -597,6 +523,12 @@ def check_inputs( f" using the `_unpack_audio_latents` method)." ) + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -984,6 +916,24 @@ def prepare_audio_latents( latents = self._pack_audio_latents(latents) return latents + def convert_velocity_to_x0( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx] + return sample_x0 + + def convert_x0_to_velocity( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx] + return sample_v + @property def guidance_scale(self): return self._guidance_scale @@ -992,9 +942,41 @@ def guidance_scale(self): def guidance_rescale(self): return self._guidance_rescale + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) @property def num_timesteps(self): @@ -1027,7 +1009,14 @@ def __call__( sigmas: list[float] | None = None, timesteps: list[float] | None = None, guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, noise_scale: float | None = None, num_videos_per_prompt: int | None = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -1039,6 +1028,7 @@ def __call__( negative_prompt_attention_mask: torch.Tensor | None = None, decode_timestep: float | list[float] = 0.0, decode_noise_scale: float | list[float] | None = None, + use_cross_timestep: bool = False, output_type: str = "pil", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, @@ -1079,13 +1069,47 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is + a separate value `audio_guidance_scale` for the audio modality). + stg_scale (`float`, *optional*, defaults to `0.0`): + Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for + Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate + where we move the sample away from a weak sample from a perturbed version of the denoising model. + Enabling STG will result in an additional denoising model forward pass; the default value of `0.0` + means that STG is disabled. + modality_scale (`float`, *optional*, defaults to `1.0`): + Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a + weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio) + cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an + additional denoising model forward pass; the default value of `1.0` means that modality guidance is + disabled. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when - using zero terminal SNR. + using zero terminal SNR. Used for the video modality. + audio_guidance_scale (`float`, *optional* defaults to `None`): + Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for + video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest + that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for + LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value + `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and + audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the + video value `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule + is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and + audio. If `None`, defaults to the video value `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value + `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used + (`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0 + and `[28]` is recommended for LTX-2.3. noise_scale (`float`, *optional*, defaults to `None`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the @@ -1117,6 +1141,10 @@ def __call__( The timestep at which generated video is decoded. decode_noise_scale (`float`, defaults to `None`): The interpolation factor between random noise and denoised latents at the decode timestep. + use_cross_timestep (`bool` *optional*, defaults to `False`): + Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when + calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior; + `False` is the legacy LTX-2.0 behavior. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1149,6 +1177,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -1161,10 +1194,21 @@ def __call__( negative_prompt_attention_mask=negative_prompt_attention_mask, latents=latents, audio_latents=audio_latents, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, ) + # Per-modality guidance scales (video, audio) self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -1208,9 +1252,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side ) # 4. Prepare latent variables @@ -1222,7 +1268,7 @@ def __call__( "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." ) _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] - video_sequence_length = latent_num_frames * latent_height * latent_width + # video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels latents, conditioning_mask, clean_latents = self.prepare_latents( @@ -1272,7 +1318,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( - video_sequence_length, + self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_image_seq_len", 1024), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.95), @@ -1301,11 +1347,6 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ) # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate @@ -1344,6 +1385,7 @@ def __call__( audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=video_timestep, audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, @@ -1353,7 +1395,10 @@ def __call__( audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - # rope_interpolation_scale=rope_interpolation_scale, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, attention_kwargs=attention_kwargs, return_dict=False, ) @@ -1361,41 +1406,172 @@ def __call__( noise_pred_audio = noise_pred_audio.float() if self.do_classifier_free_guidance: - noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( - noise_pred_video_text - noise_pred_video_uncond + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_video_uncond_text = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_text, i, self.scheduler ) + # Use delta formulation as it works more nicely with multiple guidance terms + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) - noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( - noise_pred_audio_text - noise_pred_audio_uncond + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + noise_pred_audio_uncond_text = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler + ) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text ) - if self.guidance_rescale > 0: - # Based on 3.4. in https://huggingface.co/papers/2305.08891 - noise_pred_video = rescale_noise_cfg( - noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + # Get positive values from merged CFG inputs in case we need to do other DiT forward passes + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + # Only split values that remain constant throughout the loop once + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + # Split values that vary each denoising loop iteration + timestep = timestep.chunk(2, dim=0)[0] + video_timestep = video_timestep.chunk(2, dim=0)[0] + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + + if self.do_spatio_temporal_guidance: + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, ) - noise_pred_audio = rescale_noise_cfg( - noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + noise_pred_video_uncond_stg = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_stg, i, self.scheduler + ) + noise_pred_audio_uncond_stg = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler + ) + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + if self.do_modality_isolation_guidance: + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, ) + noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() + noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() + noise_pred_video_uncond_modality = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_modality, i, self.scheduler + ) + noise_pred_audio_uncond_modality = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler + ) + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_modality + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_modality + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Now apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta + + # Apply LTX-2.X guidance rescaling + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale + ) + else: + noise_pred_video = noise_pred_video_g + + if self.audio_guidance_rescale > 0: + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale + ) + else: + noise_pred_audio = noise_pred_audio_g # NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG bsz = noise_pred_video.size(0) - sigma = self.scheduler.sigmas[i] - # Convert the noise_pred_video velocity model prediction into a sample (x0) prediction - denoised_sample = latents - noise_pred_video * sigma # Apply the (packed) conditioning mask to the denoised (x0) sample and clean conditioning. The # conditioning mask contains conditioning strengths from 0 (always use denoised sample) to 1 (always # use conditions), with intermediate values specifying how strongly to follow the conditions. + # NOTE: this operation should be applied in sample (x0) space and not velocity space (which is the + # space the denoising model outputs are in) denoised_sample_cond = ( - denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] + noise_pred_video * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] ).to(noise_pred_video.dtype) + # Convert the denoised (x0) sample back to a velocity for the scheduler - denoised_latents_cond = ((latents - denoised_sample_cond) / sigma).to(noise_pred_video.dtype) + noise_pred_video = self.convert_x0_to_velocity(latents, denoised_sample_cond, i, self.scheduler) + noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler) # Compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(denoised_latents_cond, t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in # the step method (such as _step_index) @@ -1425,9 +1601,6 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) audio_latents = self._denormalize_audio_latents( audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std @@ -1435,6 +1608,9 @@ def __call__( audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) if output_type == "latent": + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) video = latents audio = audio_latents else: @@ -1457,6 +1633,10 @@ def __call__( ] latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.vae.dtype) video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 83ba2cd7c685..997bfd9fc9dc 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -18,7 +18,7 @@ import numpy as np import torch -from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast +from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput @@ -32,7 +32,7 @@ from ..pipeline_utils import DiffusionPipeline from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder +from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): @@ -212,7 +212,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL """ model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" - _optional_components = [] + _optional_components = ["processor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -224,7 +224,8 @@ def __init__( tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, - vocoder: LTX2Vocoder, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, + processor: Gemma3Processor | None = None, ): super().__init__() @@ -237,6 +238,7 @@ def __init__( transformer=transformer, vocoder=vocoder, scheduler=scheduler, + processor=processor, ) self.vae_spatial_compression_ratio = ( @@ -271,74 +273,6 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) - @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds - def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, @@ -392,16 +326,7 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - - prompt_embeds = self._pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -500,6 +425,57 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + @torch.no_grad() + def enhance_prompt( + self, + image: PipelineImageInput, + prompt: str, + system_prompt: str, + max_new_tokens: int = 512, + seed: int = 10, + generator: torch.Generator | None = None, + generation_kwargs: dict[str, Any] | None = None, + device: str | torch.device | None = None, + ): + """ + Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a + `transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt. + """ + device = device or self._execution_device + if generation_kwargs is None: + # Set to default generation kwargs + generation_kwargs = {"do_sample": True, "temperature": 0.7} + + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": f"User Raw Input Prompt: {prompt}."}, + ], + }, + ] + template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.processor(text=template, images=image, return_tensors="pt").to(device) + self.text_encoder.to(device) + + # `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness, + # so manually apply a seed for reproducible generation. + if generator is not None: + # Overwrite seed to generator's initial seed + seed = generator.initial_seed() + torch.manual_seed(seed) + generated_sequences = self.text_encoder.generate( + **model_inputs, + max_new_tokens=max_new_tokens, + **generation_kwargs, + ) # tensor of shape [batch_size, seq_len] + + generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)] + enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + return enhanced_prompt + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs def check_inputs( self, @@ -511,6 +487,9 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -554,6 +533,12 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -788,7 +773,6 @@ def prepare_audio_latents( latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) - # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) @@ -803,6 +787,24 @@ def prepare_audio_latents( latents = self._pack_audio_latents(latents) return latents + def convert_velocity_to_x0( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx] + return sample_x0 + + def convert_x0_to_velocity( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx] + return sample_v + @property def guidance_scale(self): return self._guidance_scale @@ -811,9 +813,41 @@ def guidance_scale(self): def guidance_rescale(self): return self._guidance_rescale + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) @property def num_timesteps(self): @@ -846,7 +880,14 @@ def __call__( sigmas: list[float] | None = None, timesteps: list[int] | None = None, guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, noise_scale: float = 0.0, num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -858,6 +899,11 @@ def __call__( negative_prompt_attention_mask: torch.Tensor | None = None, decode_timestep: float | list[float] = 0.0, decode_noise_scale: float | list[float] | None = None, + use_cross_timestep: bool = False, + system_prompt: str | None = None, + prompt_max_new_tokens: int = 512, + prompt_enhancement_kwargs: dict[str, Any] | None = None, + prompt_enhancement_seed: int = 10, output_type: str = "pil", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, @@ -898,13 +944,47 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is + a separate value `audio_guidance_scale` for the audio modality). + stg_scale (`float`, *optional*, defaults to `0.0`): + Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for + Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate + where we move the sample away from a weak sample from a perturbed version of the denoising model. + Enabling STG will result in an additional denoising model forward pass; the default value of `0.0` + means that STG is disabled. + modality_scale (`float`, *optional*, defaults to `1.0`): + Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a + weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio) + cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an + additional denoising model forward pass; the default value of `1.0` means that modality guidance is + disabled. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when - using zero terminal SNR. + using zero terminal SNR. Used for the video modality. + audio_guidance_scale (`float`, *optional* defaults to `None`): + Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for + video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest + that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for + LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value + `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and + audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the + video value `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule + is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and + audio. If `None`, defaults to the video value `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value + `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used + (`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0 + and `[28]` is recommended for LTX-2.3. noise_scale (`float`, *optional*, defaults to `0.0`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to the `latents` and `audio_latents` before continue denoising. @@ -935,6 +1015,24 @@ def __call__( The timestep at which generated video is decoded. decode_noise_scale (`float`, defaults to `None`): The interpolation factor between random noise and denoised latents at the decode timestep. + use_cross_timestep (`bool` *optional*, defaults to `False`): + Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when + calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior; + `False` is the legacy LTX-2.0 behavior. + system_prompt (`str`, *optional*, defaults to `None`): + Optional system prompt to use for prompt enhancement. The system prompt will be used by the current + text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from + the original `prompt` to condition generation. If not supplied, prompt enhancement will not be + performed. + prompt_max_new_tokens (`int`, *optional*, defaults to `512`): + The maximum number of new tokens to generate when performing prompt enhancement. + prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`): + Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of + `do_sample=True` and `temperature=0.7` will be used. See + https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate + for more details. + prompt_enhancement_seed (`int`, *optional*, default to `10`): + Random seed for any random operations during prompt enhancement. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -967,6 +1065,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -977,10 +1080,21 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, ) + # Per-modality guidance scales (video, audio) self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -996,6 +1110,18 @@ def __call__( device = self._execution_device # 3. Prepare text embeddings + if system_prompt is not None and prompt is not None: + prompt = self.enhance_prompt( + image=image, + prompt=prompt, + system_prompt=system_prompt, + max_new_tokens=prompt_max_new_tokens, + seed=prompt_enhancement_seed, + generator=generator, + generation_kwargs=prompt_enhancement_kwargs, + device=device, + ) + ( prompt_embeds, prompt_attention_mask, @@ -1017,9 +1143,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side ) # 4. Prepare latent variables @@ -1041,7 +1169,7 @@ def __call__( raise ValueError( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." ) - video_sequence_length = latent_num_frames * latent_height * latent_width + # video_sequence_length = latent_num_frames * latent_height * latent_width if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) @@ -1105,7 +1233,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( - video_sequence_length, + self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_image_seq_len", 1024), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.95), @@ -1134,11 +1262,6 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ) # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate @@ -1177,6 +1300,7 @@ def __call__( audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=video_timestep, audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, @@ -1186,7 +1310,10 @@ def __call__( audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - # rope_interpolation_scale=rope_interpolation_scale, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, attention_kwargs=attention_kwargs, return_dict=False, ) @@ -1194,24 +1321,154 @@ def __call__( noise_pred_audio = noise_pred_audio.float() if self.do_classifier_free_guidance: - noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( - noise_pred_video_text - noise_pred_video_uncond + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_video_uncond_text = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_text, i, self.scheduler ) + # Use delta formulation as it works more nicely with multiple guidance terms + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) - noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( - noise_pred_audio_text - noise_pred_audio_uncond + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + noise_pred_audio_uncond_text = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler + ) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text ) - if self.guidance_rescale > 0: - # Based on 3.4. in https://huggingface.co/papers/2305.08891 - noise_pred_video = rescale_noise_cfg( - noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + # Get positive values from merged CFG inputs in case we need to do other DiT forward passes + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + # Only split values that remain constant throughout the loop once + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + # Split values that vary each denoising loop iteration + timestep = timestep.chunk(2, dim=0)[0] + video_timestep = video_timestep.chunk(2, dim=0)[0] + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + + if self.do_spatio_temporal_guidance: + with self.transformer.cache_context("uncond_stg"): + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, ) - noise_pred_audio = rescale_noise_cfg( - noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + noise_pred_video_uncond_stg = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_stg, i, self.scheduler + ) + noise_pred_audio_uncond_stg = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler + ) + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + if self.do_modality_isolation_guidance: + with self.transformer.cache_context("uncond_modality"): + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, ) + noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() + noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() + noise_pred_video_uncond_modality = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_modality, i, self.scheduler + ) + noise_pred_audio_uncond_modality = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler + ) + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_modality + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_modality + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Now apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta + + # Apply LTX-2.X guidance rescaling + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale + ) + else: + noise_pred_video = noise_pred_video_g + + if self.audio_guidance_rescale > 0: + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale + ) + else: + noise_pred_audio = noise_pred_audio_g # compute the previous noisy sample x_t -> x_t-1 noise_pred_video = self._unpack_latents( @@ -1231,6 +1488,10 @@ def __call__( self.transformer_temporal_patch_size, ) + # Convert back to velocity for scheduler + noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler) + noise_pred_video = noise_pred_video[:, :, 1:] noise_latents = latents[:, :, 1:] pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] @@ -1268,9 +1529,6 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) audio_latents = self._denormalize_audio_latents( audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std @@ -1278,6 +1536,9 @@ def __call__( audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) if output_type == "latent": + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) video = latents audio = audio_latents else: @@ -1300,6 +1561,10 @@ def __call__( ] latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.vae.dtype) video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py index 551c3ac5980f..f0004f2ec02d 100644 --- a/src/diffusers/pipelines/ltx2/vocoder.py +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -8,6 +8,209 @@ from ...models.modeling_utils import ModelMixin +def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: + """ + Creates a Kaiser sinc kernel for low-pass filtering. + + Args: + cutoff (`float`): + Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist + frequency). + half_width (`float`): + Used to determine the Kaiser window's beta parameter. + kernel_size: + Size of the Kaiser window (and ultimately the Kaiser sinc kernel). + + Returns: + `torch.Tensor` of shape `(kernel_size,)`: + The Kaiser sinc kernel. + """ + delta_f = 4 * half_width + half_size = kernel_size // 2 + amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if amplitude > 50.0: + beta = 0.1102 * (amplitude - 8.7) + elif amplitude >= 21.0: + beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) + else: + beta = 0.0 + + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + even = kernel_size % 2 == 0 + time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size + + if cutoff == 0.0: + filter = torch.zeros_like(time) + else: + time = 2 * cutoff * time + sinc = torch.where( + time == 0, + torch.ones_like(time), + torch.sin(math.pi * time) / math.pi / time, + ) + filter = 2 * cutoff * window * sinc + filter = filter / filter.sum() + return filter + + +class DownSample1d(nn.Module): + """1D low-pass filter for antialias downsampling.""" + + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + use_padding: bool = True, + padding_mode: str = "replicate", + persistent: bool = True, + ): + super().__init__() + self.ratio = ratio + self.kernel_size = kernel_size or int(6 * ratio // 2) * 2 + self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1 + self.pad_right = self.kernel_size // 2 + self.use_padding = use_padding + self.padding_mode = padding_mode + + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size) + self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x expected shape: [batch_size, num_channels, hidden_dim] + num_channels = x.shape[1] + if self.use_padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels) + return x_filtered + + +class UpSample1d(nn.Module): + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + window_type: str = "kaiser", + padding_mode: str = "replicate", + persistent: bool = True, + ): + super().__init__() + self.ratio = ratio + self.padding_mode = padding_mode + + if window_type == "hann": + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + + time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff + time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 + sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser sinc filter is BigVGAN default + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2 + self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2 + + sinc_filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size, + ) + + self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x expected shape: [batch_size, num_channels, hidden_dim] + num_channels = x.shape[1] + x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode) + low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1) + x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels) + return x[..., self.pad_left : -self.pad_right] + + +class AntiAliasAct1d(nn.Module): + """ + Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples + to avoid aliasing. + """ + + def __init__( + self, + act_fn: str | nn.Module, + ratio: int = 2, + kernel_size: int = 12, + **kwargs, + ): + super().__init__() + self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size) + if isinstance(act_fn, str): + if act_fn == "snakebeta": + act_fn = SnakeBeta(**kwargs) + elif act_fn == "snake": + act_fn = SnakeBeta(**kwargs) + else: + act_fn = nn.LeakyReLU(**kwargs) + self.act = act_fn + self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +class SnakeBeta(nn.Module): + """ + Implements the Snake and SnakeBeta activations, which help with learning periodic patterns. + """ + + def __init__( + self, + channels: int, + alpha: float = 1.0, + eps: float = 1e-9, + trainable_params: bool = True, + logscale: bool = True, + use_beta: bool = True, + ): + super().__init__() + self.eps = eps + self.logscale = logscale + self.use_beta = use_beta + + self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) + self.alpha.requires_grad = trainable_params + if use_beta: + self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) + self.beta.requires_grad = trainable_params + + def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + broadcast_shape = [1] * hidden_states.ndim + broadcast_shape[channel_dim] = -1 + alpha = self.alpha.view(broadcast_shape) + if self.use_beta: + beta = self.beta.view(broadcast_shape) + + if self.logscale: + alpha = torch.exp(alpha) + if self.use_beta: + beta = torch.exp(beta) + + amplitude = beta if self.use_beta else alpha + hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2) + return hidden_states + + class ResBlock(nn.Module): def __init__( self, @@ -15,12 +218,15 @@ def __init__( kernel_size: int = 3, stride: int = 1, dilations: tuple[int, ...] = (1, 3, 5), + act_fn: str = "leaky_relu", leaky_relu_negative_slope: float = 0.1, + antialias: bool = False, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, padding_mode: str = "same", ): super().__init__() self.dilations = dilations - self.negative_slope = leaky_relu_negative_slope self.convs1 = nn.ModuleList( [ @@ -28,6 +234,18 @@ def __init__( for dilation in dilations ] ) + self.acts1 = nn.ModuleList() + for _ in range(len(self.convs1)): + if act_fn == "snakebeta": + act = SnakeBeta(channels, use_beta=True) + elif act_fn == "snake": + act = SnakeBeta(channels, use_beta=False) + else: + act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) + + if antialias: + act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + self.acts1.append(act) self.convs2 = nn.ModuleList( [ @@ -35,12 +253,24 @@ def __init__( for _ in range(len(dilations)) ] ) + self.acts2 = nn.ModuleList() + for _ in range(len(self.convs2)): + if act_fn == "snakebeta": + act = SnakeBeta(channels, use_beta=True) + elif act_fn == "snake": + act = SnakeBeta(channels, use_beta=False) + else: + act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) + + if antialias: + act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + self.acts2.append(act) def forward(self, x: torch.Tensor) -> torch.Tensor: - for conv1, conv2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, negative_slope=self.negative_slope) + for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2): + xt = act1(x) xt = conv1(xt) - xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = act2(xt) xt = conv2(xt) x = x + xt return x @@ -61,7 +291,13 @@ def __init__( upsample_factors: list[int] = [6, 5, 2, 2, 2], resnet_kernel_sizes: list[int] = [3, 7, 11], resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + act_fn: str = "leaky_relu", leaky_relu_negative_slope: float = 0.1, + antialias: bool = False, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + final_act_fn: str | None = "tanh", # tanh, clamp, None + final_bias: bool = True, output_sampling_rate: int = 24000, ): super().__init__() @@ -69,7 +305,9 @@ def __init__( self.resnets_per_upsample = len(resnet_kernel_sizes) self.out_channels = out_channels self.total_upsample_factor = math.prod(upsample_factors) + self.act_fn = act_fn self.negative_slope = leaky_relu_negative_slope + self.final_act_fn = final_act_fn if self.num_upsample_layers != len(upsample_factors): raise ValueError( @@ -83,6 +321,13 @@ def __init__( f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." ) + supported_act_fns = ["snakebeta", "snake", "leaky_relu"] + if self.act_fn not in supported_act_fns: + raise ValueError( + f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are " + f"{supported_act_fns}." + ) + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) self.upsamplers = nn.ModuleList() @@ -103,15 +348,27 @@ def __init__( for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): self.resnets.append( ResBlock( - output_channels, - kernel_size, + channels=output_channels, + kernel_size=kernel_size, dilations=dilations, + act_fn=act_fn, leaky_relu_negative_slope=leaky_relu_negative_slope, + antialias=antialias, + antialias_ratio=antialias_ratio, + antialias_kernel_size=antialias_kernel_size, ) ) input_channels = output_channels - self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + if act_fn == "snakebeta" or act_fn == "snake": + # Always use antialiasing + act_out = SnakeBeta(channels=output_channels, use_beta=True) + self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + elif act_fn == "leaky_relu": + # NOTE: does NOT use self.negative_slope, following the original code + self.act_out = nn.LeakyReLU() + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias) def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: r""" @@ -139,7 +396,9 @@ def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch hidden_states = self.conv_in(hidden_states) for i in range(self.num_upsample_layers): - hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + if self.act_fn == "leaky_relu": + # Other activations are inside each upsampling block + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) hidden_states = self.upsamplers[i](hidden_states) # Run all resnets in parallel on hidden_states @@ -149,10 +408,190 @@ def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch hidden_states = torch.mean(resnet_outputs, dim=0) - # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of - # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended - hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.act_out(hidden_states) hidden_states = self.conv_out(hidden_states) - hidden_states = torch.tanh(hidden_states) + if self.final_act_fn == "tanh": + hidden_states = torch.tanh(hidden_states) + elif self.final_act_fn == "clamp": + hidden_states = torch.clamp(hidden_states, -1, 1) return hidden_states + + +class CausalSTFT(nn.Module): + """ + Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases + multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact + buffers should be loaded from the checkpoint in bfloat16. + """ + + def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512): + super().__init__() + self.hop_length = hop_length + self.window_length = window_length + n_freqs = filter_length // 2 + 1 + + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if waveform.ndim == 2: + waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples] + + left_pad = max(0, self.window_length - self.hop_length) # causal: left-only + waveform = F.pad(waveform, (left_pad, 0)) + + spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """ + Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be + loaded from the checkpoint in bfloat16. + """ + + def __init__( + self, + filter_length: int = 512, + hop_length: int = 80, + window_length: int = 512, + num_mel_channels: int = 64, + ): + super().__init__() + self.stft_fn = CausalSTFT(filter_length, hop_length, window_length) + + num_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True) + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + magnitude, phase = self.stft_fn(waveform) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class LTX2VocoderWithBWE(ModelMixin, ConfigMixin): + """ + LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the + BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same + architecture as the original vocoder. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1536, + out_channels: int = 2, + upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4], + upsample_factors: list[int] = [5, 2, 2, 2, 2, 2], + resnet_kernel_sizes: list[int] = [3, 7, 11], + resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + act_fn: str = "snakebeta", + leaky_relu_negative_slope: float = 0.1, + antialias: bool = True, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + final_act_fn: str | None = None, + final_bias: bool = False, + bwe_in_channels: int = 128, + bwe_hidden_channels: int = 512, + bwe_out_channels: int = 2, + bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4], + bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2], + bwe_resnet_kernel_sizes: list[int] = [3, 7, 11], + bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + bwe_act_fn: str = "snakebeta", + bwe_leaky_relu_negative_slope: float = 0.1, + bwe_antialias: bool = True, + bwe_antialias_ratio: int = 2, + bwe_antialias_kernel_size: int = 12, + bwe_final_act_fn: str | None = None, + bwe_final_bias: bool = False, + filter_length: int = 512, + hop_length: int = 80, + window_length: int = 512, + num_mel_channels: int = 64, + input_sampling_rate: int = 16000, + output_sampling_rate: int = 48000, + ): + super().__init__() + + self.vocoder = LTX2Vocoder( + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=out_channels, + upsample_kernel_sizes=upsample_kernel_sizes, + upsample_factors=upsample_factors, + resnet_kernel_sizes=resnet_kernel_sizes, + resnet_dilations=resnet_dilations, + act_fn=act_fn, + leaky_relu_negative_slope=leaky_relu_negative_slope, + antialias=antialias, + antialias_ratio=antialias_ratio, + antialias_kernel_size=antialias_kernel_size, + final_act_fn=final_act_fn, + final_bias=final_bias, + output_sampling_rate=input_sampling_rate, + ) + self.bwe_generator = LTX2Vocoder( + in_channels=bwe_in_channels, + hidden_channels=bwe_hidden_channels, + out_channels=bwe_out_channels, + upsample_kernel_sizes=bwe_upsample_kernel_sizes, + upsample_factors=bwe_upsample_factors, + resnet_kernel_sizes=bwe_resnet_kernel_sizes, + resnet_dilations=bwe_resnet_dilations, + act_fn=bwe_act_fn, + leaky_relu_negative_slope=bwe_leaky_relu_negative_slope, + antialias=bwe_antialias, + antialias_ratio=bwe_antialias_ratio, + antialias_kernel_size=bwe_antialias_kernel_size, + final_act_fn=bwe_final_act_fn, + final_bias=bwe_final_bias, + output_sampling_rate=output_sampling_rate, + ) + + self.mel_stft = MelSTFT( + filter_length=filter_length, + hop_length=hop_length, + window_length=window_length, + num_mel_channels=num_mel_channels, + ) + + self.resampler = UpSample1d( + ratio=output_sampling_rate // input_sampling_rate, + window_type="hann", + persistent=False, + ) + + def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: + # 1. Run stage 1 vocoder to get low sampling rate waveform + x = self.vocoder(mel_spec) + batch_size, num_channels, num_samples = x.shape + + # Pad to exact multiple of hop_length for exact mel frame count + remainder = num_samples % self.config.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + # 2. Compute mel spectrogram on vocoder output + mel, _, _, _ = self.mel_stft(x.flatten(0, 1)) + mel = mel.unflatten(0, (-1, num_channels)) + + # 3. Run bandwidth extender (BWE) on new mel spectrogram + mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins] + residual = self.bwe_generator(mel_for_bwe) + + # 4. Residual connection with resampler + skip = self.resampler(x) + waveform = torch.clamp(residual + skip, -1, 1) + output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate + waveform = waveform[..., :output_samples] + return waveform diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py index 7d1a3bfc9987..0941ae550989 100644 --- a/tests/pipelines/ltx2/test_ltx2.py +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -171,6 +171,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "connectors": connectors, "vocoder": vocoder, + "processor": None, } return components diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py index 92c000c7bf7c..a0e4cb803084 100644 --- a/tests/pipelines/ltx2/test_ltx2_image2video.py +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -171,6 +171,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "connectors": connectors, "vocoder": vocoder, + "processor": None, } return components From be2424ba45cd8c51befd0bd0399aee2d31ce1b40 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 19 Mar 2026 18:07:41 -1000 Subject: [PATCH 057/215] [agents]support skills (#13269) * support skills * update * Apply suggestions from code review Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update baSeed on new best practice * Update .ai/skills/parity-testing/pitfalls.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * update --------- Co-authored-by: yiyi@huggingface.co Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: yiyi@huggingface.co Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .ai/AGENTS.md | 54 +----- .ai/skills/model-integration/SKILL.md | 167 +++++++++++++++++ .../model-integration/modular-conversion.md | 152 ++++++++++++++++ .ai/skills/parity-testing/SKILL.md | 170 ++++++++++++++++++ .../parity-testing/checkpoint-mechanism.md | 103 +++++++++++ .ai/skills/parity-testing/pitfalls.md | 116 ++++++++++++ .gitignore | 4 +- Makefile | 7 + docs/source/en/conceptual/contribution.md | 10 +- 9 files changed, 728 insertions(+), 55 deletions(-) create mode 100644 .ai/skills/model-integration/SKILL.md create mode 100644 .ai/skills/model-integration/modular-conversion.md create mode 100644 .ai/skills/parity-testing/SKILL.md create mode 100644 .ai/skills/parity-testing/checkpoint-mechanism.md create mode 100644 .ai/skills/parity-testing/pitfalls.md diff --git a/.ai/AGENTS.md b/.ai/AGENTS.md index 9e93ae79df92..a42ca5ede871 100644 --- a/.ai/AGENTS.md +++ b/.ai/AGENTS.md @@ -24,54 +24,10 @@ Strive to write code as simple and explicit as possible. ### Models - All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls. -- Try to not introduce graph breaks as much as possible for better compatibility with `torch.compile`. For example, DO NOT arbitrarily insert operations from NumPy in the forward implementations. -- Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`. +- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`. +- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details. -```python -# transformer_mymodel.py +## Skills -class MyModelAttnProcessor: - _attention_backend = None - _parallel_config = None - - def __call__(self, attn, hidden_states, attention_mask=None, ...): - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - # reshape, apply rope, etc. - hidden_states = dispatch_attention_fn( - query, key, value, - attn_mask=attention_mask, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - hidden_states = hidden_states.flatten(2, 3) - return attn.to_out[0](hidden_states) - - -class MyModelAttention(nn.Module, AttentionModuleMixin): - _default_processor_cls = MyModelAttnProcessor - _available_processors = [MyModelAttnProcessor] - - def __init__(self, query_dim, heads=8, dim_head=64, ...): - super().__init__() - self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False) - self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False) - self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False) - self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)]) - self.set_processor(MyModelAttnProcessor()) - - def forward(self, hidden_states, attention_mask=None, **kwargs): - return self.processor(self, hidden_states, attention_mask, **kwargs) -``` - -Consult the implementations in `src/diffusers/models/transformers/` if you need further references. - -### Pipeline -- All pipelines must inherit from `DiffusionPipeline`. Consult implementations in `src/diffusers/pipelines` in case you need references. -- DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline` which will be a part of the core codebase (`src`). - - -### Tests -- Slow tests gated with `@slow` and `RUN_SLOW=1` -- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference. +Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. +Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity). diff --git a/.ai/skills/model-integration/SKILL.md b/.ai/skills/model-integration/SKILL.md new file mode 100644 index 000000000000..16880e4d4850 --- /dev/null +++ b/.ai/skills/model-integration/SKILL.md @@ -0,0 +1,167 @@ +--- +name: integrating-models +description: > + Use when adding a new model or pipeline to diffusers, setting up file + structure for a new model, converting a pipeline to modular format, or + converting weights for a new version of an already-supported model. +--- + +## Goal + +Integrate a new model into diffusers end-to-end. The overall flow: + +1. **Gather info** — ask the user for the reference repo, setup guide, a runnable inference script, and other objectives such as standard vs modular. +2. **Confirm the plan** — once you have everything, tell the user exactly what you'll do: e.g. "I'll integrate model X with pipeline Y into diffusers based on your script. I'll run parity tests (model-level and pipeline-level) using the `parity-testing` skill to verify numerical correctness against the reference." +3. **Implement** — write the diffusers code (model, pipeline, scheduler if needed), convert weights, register in `__init__.py`. +4. **Parity test** — use the `parity-testing` skill to verify component and e2e parity against the reference implementation. +5. **Deliver a unit test** — provide a self-contained test script that runs the diffusers implementation, checks numerical output (np allclose), and saves an image/video for visual verification. This is what the user runs to confirm everything works. + +Work one workflow at a time — get it to full parity before moving on. + +## Setup — gather before starting + +Before writing any code, gather info in this order: + +1. **Reference repo** — ask for the github link. If they've already set it up locally, ask for the path. Otherwise, ask what setup steps are needed (install deps, download checkpoints, set env vars, etc.) and run through them before proceeding. +2. **Inference script** — ask for a runnable end-to-end script for a basic workflow first (e.g. T2V). Then ask what other workflows they want to support (I2V, V2V, etc.) and agree on the full implementation order together. +3. **Standard vs modular** — standard pipelines, modular, or both? + +Use `AskUserQuestion` with structured choices for step 3 when the options are known. + +## Standard Pipeline Integration + +### File structure for a new model + +``` +src/diffusers/ + models/transformers/transformer_.py # The core model + schedulers/scheduling_.py # If model needs a custom scheduler + pipelines// + __init__.py + pipeline_.py # Main pipeline + pipeline__.py # Variant pipelines (e.g. pyramid, distilled) + pipeline_output.py # Output dataclass + loaders/lora_pipeline.py # LoRA mixin (add to existing file) + +tests/ + models/transformers/test_models_transformer_.py + pipelines//test_.py + lora/test_lora_layers_.py + +docs/source/en/api/ + pipelines/.md + models/_transformer3d.md # or appropriate name +``` + +### Integration checklist + +- [ ] Implement transformer model with `from_pretrained` support +- [ ] Implement or reuse scheduler +- [ ] Implement pipeline(s) with `__call__` method +- [ ] Add LoRA support if applicable +- [ ] Register all classes in `__init__.py` files (lazy imports) +- [ ] Write unit tests (model, pipeline, LoRA) +- [ ] Write docs +- [ ] Run `make style` and `make quality` +- [ ] Test parity with reference implementation (see `parity-testing` skill) + +### Attention pattern + +Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`. + +```python +# transformer_mymodel.py + +class MyModelAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__(self, attn, hidden_states, attention_mask=None, ...): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + # reshape, apply rope, etc. + hidden_states = dispatch_attention_fn( + query, key, value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + return attn.to_out[0](hidden_states) + + +class MyModelAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = MyModelAttnProcessor + _available_processors = [MyModelAttnProcessor] + + def __init__(self, query_dim, heads=8, dim_head=64, ...): + super().__init__() + self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False) + self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False) + self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False) + self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)]) + self.set_processor(MyModelAttnProcessor()) + + def forward(self, hidden_states, attention_mask=None, **kwargs): + return self.processor(self, hidden_states, attention_mask, **kwargs) +``` + +Consult the implementations in `src/diffusers/models/transformers/` if you need further references. + +### Implementation rules + +1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed. +2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references. +3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`). + +### Test setup + +- Slow tests gated with `@slow` and `RUN_SLOW=1` +- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference. + +### Common diffusers conventions + +- Pipelines inherit from `DiffusionPipeline` +- Models use `ModelMixin` with `register_to_config` for config serialization +- Schedulers use `SchedulerMixin` with `ConfigMixin` +- Use `@torch.no_grad()` on pipeline `__call__` +- Support `output_type="latent"` for skipping VAE decode +- Support `generator` parameter for reproducibility +- Use `self.progress_bar(timesteps)` for progress tracking + +## Gotchas + +1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`. + +2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`. + +3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise. + +4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors. + +5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference. + +6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value. + +7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures. + +8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision. + +--- + +## Modular Pipeline Conversion + +See [modular-conversion.md](modular-conversion.md) for the full guide on converting standard pipelines to modular format, including block types, build order, guider abstraction, and conversion checklist. + +--- + +## Weight Conversion Tips + + diff --git a/.ai/skills/model-integration/modular-conversion.md b/.ai/skills/model-integration/modular-conversion.md new file mode 100644 index 000000000000..a143d1f84ba3 --- /dev/null +++ b/.ai/skills/model-integration/modular-conversion.md @@ -0,0 +1,152 @@ +# Modular Pipeline Conversion Reference + +## When to use + +Modular pipelines break a monolithic `__call__` into composable blocks. Convert when: +- The model supports multiple workflows (T2V, I2V, V2V, etc.) +- Users need to swap guidance strategies (CFG, CFG-Zero*, PAG) +- You want to share blocks across pipeline variants + +## File structure + +``` +src/diffusers/modular_pipelines// + __init__.py # Lazy imports + modular_pipeline.py # Pipeline class (tiny, mostly config) + encoders.py # Text encoder + image/video VAE encoder blocks + before_denoise.py # Pre-denoise setup blocks + denoise.py # The denoising loop blocks + decoders.py # VAE decode block + modular_blocks_.py # Block assembly (AutoBlocks) +``` + +## Block types decision tree + +``` +Is this a single operation? + YES -> ModularPipelineBlocks (leaf block) + +Does it run multiple blocks in sequence? + YES -> SequentialPipelineBlocks + Does it iterate (e.g. chunk loop)? + YES -> LoopSequentialPipelineBlocks + +Does it choose ONE block based on which input is present? + Is the selection 1:1 with trigger inputs? + YES -> AutoPipelineBlocks (simple trigger mapping) + NO -> ConditionalPipelineBlocks (custom select_block method) +``` + +## Build order (easiest first) + +1. `decoders.py` -- Takes latents, runs VAE decode, returns images/videos +2. `encoders.py` -- Takes prompt, returns prompt_embeds. Add image/video VAE encoder if needed +3. `before_denoise.py` -- Timesteps, latent prep, noise setup. Each logical operation = one block +4. `denoise.py` -- The hardest. Convert guidance to guider abstraction + +## Key pattern: Guider abstraction + +Original pipeline has guidance baked in: +```python +for i, t in enumerate(timesteps): + noise_pred = self.transformer(latents, prompt_embeds, ...) + if self.do_classifier_free_guidance: + noise_uncond = self.transformer(latents, negative_prompt_embeds, ...) + noise_pred = noise_uncond + scale * (noise_pred - noise_uncond) + latents = self.scheduler.step(noise_pred, t, latents).prev_sample +``` + +Modular pipeline separates concerns: +```python +guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), +} + +for i, t in enumerate(timesteps): + components.guider.set_state(step=i, num_inference_steps=num_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {k: getattr(batch, k) for k in guider_inputs} + context_name = getattr(batch, components.guider._identifier_key) + with components.transformer.cache_context(context_name): + batch.noise_pred = components.transformer( + hidden_states=latents, timestep=timestep, + return_dict=False, **cond_kwargs, **shared_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + noise_pred = components.guider(guider_state)[0] + latents = components.scheduler.step(noise_pred, t, latents, generator=generator)[0] +``` + +## Key pattern: Chunk loops for video models + +Use `LoopSequentialPipelineBlocks` for outer loop: +```python +class ChunkDenoiseStep(LoopSequentialPipelineBlocks): + block_classes = [PrepareChunkStep, NoiseGenStep, DenoiseInnerStep, UpdateStep] +``` + +Note: blocks inside `LoopSequentialPipelineBlocks` receive `(components, block_state, k)` where `k` is the loop iteration index. + +## Key pattern: Workflow selection + +```python +class AutoDenoise(ConditionalPipelineBlocks): + block_classes = [V2VDenoiseStep, I2VDenoiseStep, T2VDenoiseStep] + block_trigger_inputs = ["video_latents", "image_latents"] + default_block_name = "text2video" +``` + +## Standard InputParam/OutputParam templates + +```python +# Inputs +InputParam.template("prompt") # str, required +InputParam.template("negative_prompt") # str, optional +InputParam.template("image") # PIL.Image, optional +InputParam.template("generator") # torch.Generator, optional +InputParam.template("num_inference_steps") # int, default=50 +InputParam.template("latents") # torch.Tensor, optional + +# Outputs +OutputParam.template("prompt_embeds") +OutputParam.template("negative_prompt_embeds") +OutputParam.template("image_latents") +OutputParam.template("latents") +OutputParam.template("videos") +OutputParam.template("images") +``` + +## ComponentSpec patterns + +```python +# Heavy models - loaded from pretrained +ComponentSpec("transformer", YourTransformerModel) +ComponentSpec("vae", AutoencoderKL) + +# Lightweight objects - created inline from config +ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config" +) +``` + +## Conversion checklist + +- [ ] Read original pipeline's `__call__` end-to-end, map stages +- [ ] Write test scripts (reference + target) with identical seeds +- [ ] Create file structure under `modular_pipelines//` +- [ ] Write decoder block (simplest) +- [ ] Write encoder blocks (text, image, video) +- [ ] Write before_denoise blocks (timesteps, latent prep, noise) +- [ ] Write denoise block with guider abstraction (hardest) +- [ ] Create pipeline class with `default_blocks_name` +- [ ] Assemble blocks in `modular_blocks_.py` +- [ ] Wire up `__init__.py` with lazy imports +- [ ] Run `make style` and `make quality` +- [ ] Test all workflows for parity with reference diff --git a/.ai/skills/parity-testing/SKILL.md b/.ai/skills/parity-testing/SKILL.md new file mode 100644 index 000000000000..9638e947723e --- /dev/null +++ b/.ai/skills/parity-testing/SKILL.md @@ -0,0 +1,170 @@ +--- +name: testing-parity +description: > + Use when debugging or verifying numerical parity between pipeline + implementations (e.g., research repo vs diffusers, standard vs modular). + Also relevant when outputs look wrong — washed out, pixelated, or have + visual artifacts — as these are usually parity bugs. +--- + +## Setup — gather before starting + +Before writing any test code, gather: + +1. **Which two implementations** are being compared (e.g. research repo → diffusers, standard → modular, or research → modular). Use `AskUserQuestion` with structured choices if not already clear. +2. **Two equivalent runnable scripts** — one for each implementation, both expected to produce identical output given the same inputs. These scripts define what "parity" means concretely. + +When invoked from the `model-integration` skill, you already have context: the reference script comes from step 2 of setup, and the diffusers script is the one you just wrote. You just need to make sure both scripts are runnable and use the same inputs/seed/params. + +## Test strategy + +**Component parity (CPU/float32) -- always run, as you build.** +Test each component before assembling the pipeline. This is the foundation -- if individual pieces are wrong, the pipeline can't be right. Each component in isolation, strict max_diff < 1e-3. + +Test freshly converted checkpoints and saved checkpoints. +- **Fresh**: convert from checkpoint weights, compare against reference (catches conversion bugs) +- **Saved**: load from saved model on disk, compare against reference (catches stale saves) + +Keep component test scripts around -- you will need to re-run them during pipeline debugging with different inputs or config values. + +Template -- one self-contained script per component, reference and diffusers side-by-side: +```python +@torch.inference_mode() +def test_my_component(mode="fresh", model_path=None): + # 1. Deterministic input + gen = torch.Generator().manual_seed(42) + x = torch.randn(1, 3, 64, 64, generator=gen, dtype=torch.float32) + + # 2. Reference: load from checkpoint, run, free + ref_model = ReferenceModel.from_config(config) + ref_model.load_state_dict(load_weights("prefix"), strict=True) + ref_model = ref_model.float().eval() + ref_out = ref_model(x).clone() + del ref_model + + # 3. Diffusers: fresh (convert weights) or saved (from_pretrained) + if mode == "fresh": + diff_model = convert_my_component(load_weights("prefix")) + else: + diff_model = DiffusersModel.from_pretrained(model_path, torch_dtype=torch.float32) + diff_model = diff_model.float().eval() + diff_out = diff_model(x) + del diff_model + + # 4. Compare in same script -- no saving to disk + max_diff = (ref_out - diff_out).abs().max().item() + assert max_diff < 1e-3, f"FAIL: max_diff={max_diff:.2e}" +``` +Key points: (a) both reference and diffusers component in one script -- never split into separate scripts that save/load intermediates, (b) deterministic input via seeded generator, (c) load one model at a time to fit in CPU RAM, (d) `.clone()` the reference output before deleting the model. + +**E2E visual (GPU/bfloat16) -- once the pipeline is assembled.** +Both pipelines generate independently with identical seeds/params. Save outputs and compare visually. If outputs look identical, you're done -- no need for deeper testing. + +**Pipeline stage tests -- only if E2E fails and you need to isolate the bug.** +If the user already suspects where divergence is, start there. Otherwise, work through stages in order. + +First, **match noise generation**: the way initial noise/latents are constructed (seed handling, generator, randn call order) often differs between the two scripts. If the noise doesn't match, nothing downstream will match. Check how noise is initialized in the diffusers script — if it doesn't match the reference, temporarily change it to match. Note what you changed so it can be reverted after parity is confirmed. + +For small models, run on CPU/float32 for strict comparison. For large models (e.g. 22B params), CPU/float32 is impractical -- use GPU/bfloat16 with `enable_model_cpu_offload()` and relax tolerances (max_diff < 1e-1 for bfloat16 is typical for passing tests; cosine similarity > 0.9999 is a good secondary check). + +Test encode and decode stages first -- they're simpler and bugs there are easier to fix. Only debug the denoising loop if encode and decode both pass. + +The challenge: pipelines are monolithic `__call__` methods -- you can't just call "the encode part". See [checkpoint-mechanism.md](checkpoint-mechanism.md) for the checkpoint class that lets you stop, save, or inject tensors at named locations inside the pipeline. + +**Stage test order — encode, decode, then denoise:** + +- **`encode`** (test first): Stop both pipelines at `"preloop"`. Compare **every single variable** that will be consumed by the denoising loop -- not just latents and sigmas, but also prompt embeddings, attention masks, positional coordinates, connector outputs, and any conditioning inputs. +- **`decode`** (test second, before denoise): Run the reference pipeline fully -- checkpoint the post-loop latents AND let it finish to get the decoded output. Then feed those same post-loop latents through the diffusers pipeline's decode path. Compare both numerically AND visually. +- **`denoise`** (test last): Run both pipelines with realistic `num_steps` (e.g. 30) so the scheduler computes correct sigmas/timesteps, but stop after 2 loop iterations using `after_step_1`. Don't set `num_steps=2` -- that produces unrealistic sigma schedules. + +```python +# Encode stage -- stop before the loop, compare ALL inputs: +ref_ckpts = {"preloop": Checkpoint(save=True, stop=True)} +run_reference_pipeline(ref_ckpts) +ref_data = ref_ckpts["preloop"].data + +diff_ckpts = {"preloop": Checkpoint(save=True, stop=True)} +run_diffusers_pipeline(diff_ckpts) +diff_data = diff_ckpts["preloop"].data + +# Compare EVERY variable consumed by the denoise loop: +compare_tensors("latents", ref_data["latents"], diff_data["latents"]) +compare_tensors("sigmas", ref_data["sigmas"], diff_data["sigmas"]) +compare_tensors("prompt_embeds", ref_data["prompt_embeds"], diff_data["prompt_embeds"]) +# ... every single tensor the transformer forward() will receive +``` + +**E2E-injected visual test**: Once you've identified a suspected root cause using stage tests, confirm it with an e2e-injected run -- inject the known-good tensor from reference and generate a full video. If the output looks identical to reference, you've confirmed the root cause. + +## Debugging technique: Injection for root-cause isolation + +When stage tests show divergence, **inject a known-good tensor from one pipeline into the other** to test whether the remaining code is correct. + +The principle: if you suspect input X is the root cause of divergence in stage S: +1. Run the reference pipeline and capture X +2. Run the diffusers pipeline but **replace** its X with the reference's X (via checkpoint load) +3. Compare outputs of stage S + +If outputs now match: X was the root cause. If they still diverge: the bug is in the stage logic itself, not in X. + +| What you're testing | What you inject | Where you inject | +|---|---|---| +| Is the decode stage correct? | Post-loop latents from reference | Before decode | +| Is the denoise loop correct? | Pre-loop latents from reference | Before the loop | +| Is step N correct? | Post-step-(N-1) latents from reference | Before step N | + +**Per-step accumulation tracing**: When injection confirms the loop is correct but you want to understand *how* a small initial difference compounds, capture `after_step_{i}` for every step and plot the max_diff curve. A healthy curve stays bounded; an exponential blowup in later steps points to an amplification mechanism (see Pitfall #13 in [pitfalls.md](pitfalls.md)). + +## Debugging technique: Visual comparison via frame extraction + +For video pipelines, numerical metrics alone can be misleading. Extract and view individual frames: + +```python +import numpy as np +from PIL import Image + +def extract_frames(video_np, frame_indices): + """video_np: (frames, H, W, 3) float array in [0, 1]""" + for idx in frame_indices: + frame = (video_np[idx] * 255).clip(0, 255).astype(np.uint8) + img = Image.fromarray(frame) + img.save(f"frame_{idx}.png") + +# Compare specific frames from both pipelines +extract_frames(ref_video, [0, 60, 120]) +extract_frames(diff_video, [0, 60, 120]) +``` + +## Testing rules + +1. **Never use reference code in the diffusers test path.** Each side must use only its own code. +2. **Never monkey-patch model internals in tests.** Do not replace `model.forward` or patch internal methods. +3. **Debugging instrumentation must be non-destructive.** Checkpoint captures for debugging are fine, but must not alter control flow or outputs. +4. **Prefer CPU/float32 for numerical comparison when practical.** Float32 avoids bfloat16 precision noise that obscures real bugs. But for large models (22B+), GPU/bfloat16 with `enable_model_cpu_offload()` is necessary -- use relaxed tolerances and cosine similarity as a secondary metric. +5. **Test both fresh conversion AND saved model.** Fresh catches conversion logic bugs; saved catches stale/corrupted weights from previous runs. +6. **Diff configs before debugging.** Before investigating any divergence, dump and compare all config values. A 30-second config diff prevents hours of debugging based on wrong assumptions. +7. **Never modify cached/downloaded model configs directly.** Don't edit files in `~/.cache/huggingface/`. Instead, save to a local directory or open a PR on the upstream repo. +8. **Compare ALL loop inputs in the encode test.** The preloop checkpoint must capture every single tensor the transformer forward() will receive. + +## Comparison utilities + +```python +def compare_tensors(name: str, a: torch.Tensor, b: torch.Tensor, tol: float = 1e-3) -> bool: + if a.shape != b.shape: + print(f" FAIL {name}: shape mismatch {a.shape} vs {b.shape}") + return False + diff = (a.float() - b.float()).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + cos = torch.nn.functional.cosine_similarity( + a.float().flatten().unsqueeze(0), b.float().flatten().unsqueeze(0) + ).item() + passed = max_diff < tol + print(f" {'PASS' if passed else 'FAIL'} {name}: max={max_diff:.2e}, mean={mean_diff:.2e}, cos={cos:.5f}") + return passed +``` +Cosine similarity is especially useful for GPU/bfloat16 tests where max_diff can be noisy -- `cos > 0.9999` is a strong signal even when max_diff exceeds tolerance. + +## Gotchas + +See [pitfalls.md](pitfalls.md) for the full list of gotchas to watch for during parity testing. diff --git a/.ai/skills/parity-testing/checkpoint-mechanism.md b/.ai/skills/parity-testing/checkpoint-mechanism.md new file mode 100644 index 000000000000..43743ebb07a5 --- /dev/null +++ b/.ai/skills/parity-testing/checkpoint-mechanism.md @@ -0,0 +1,103 @@ +# Checkpoint Mechanism for Stage Testing + +## Overview + +Pipelines are monolithic `__call__` methods -- you can't just call "the encode part". The checkpoint mechanism lets you stop, save, or inject tensors at named locations inside the pipeline. + +## The Checkpoint class + +Add a `_checkpoints` argument to both the diffusers pipeline and the reference implementation. + +```python +@dataclass +class Checkpoint: + save: bool = False # capture variables into ckpt.data + stop: bool = False # halt pipeline after this point + load: bool = False # inject ckpt.data into local variables + data: dict = field(default_factory=dict) +``` + +## Pipeline instrumentation + +The pipeline accepts an optional `dict[str, Checkpoint]`. Place checkpoint calls at boundaries between pipeline stages -- after each encoder, before the denoising loop (capture all loop inputs), after each loop iteration, after the loop (capture final latents before decode). + +```python +def __call__(self, prompt, ..., _checkpoints=None): + # --- text encoding --- + prompt_embeds = self.text_encoder(prompt) + _maybe_checkpoint(_checkpoints, "text_encoding", { + "prompt_embeds": prompt_embeds, + }) + + # --- prepare latents, sigmas, positions --- + latents = self.prepare_latents(...) + sigmas = self.scheduler.sigmas + # ... + + _maybe_checkpoint(_checkpoints, "preloop", { + "latents": latents, + "sigmas": sigmas, + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "video_coords": video_coords, + # capture EVERYTHING the loop needs -- every tensor the transformer + # forward() receives. Missing even one variable here means you can't + # tell if it's the source of divergence during denoise debugging. + }) + + # --- denoising loop --- + for i, t in enumerate(timesteps): + noise_pred = self.transformer(latents, t, prompt_embeds, ...) + latents = self.scheduler.step(noise_pred, t, latents)[0] + + _maybe_checkpoint(_checkpoints, f"after_step_{i}", { + "latents": latents, + }) + + _maybe_checkpoint(_checkpoints, "post_loop", { + "latents": latents, + }) + + # --- decode --- + video = self.vae.decode(latents) + return video +``` + +## The helper function + +Each `_maybe_checkpoint` call does three things based on the Checkpoint's flags: `save` captures the local variables into `ckpt.data`, `load` injects pre-populated `ckpt.data` back into local variables, `stop` halts execution (raises an exception caught at the top level). + +```python +def _maybe_checkpoint(checkpoints, name, data): + if not checkpoints: + return + ckpt = checkpoints.get(name) + if ckpt is None: + return + if ckpt.save: + ckpt.data.update(data) + if ckpt.stop: + raise PipelineStop # caught at __call__ level, returns None +``` + +## Injection support + +Add `load` support at each checkpoint where you might want to inject: + +```python +_maybe_checkpoint(_checkpoints, "preloop", {"latents": latents, ...}) + +# Load support: replace local variables with injected data +if _checkpoints: + ckpt = _checkpoints.get("preloop") + if ckpt is not None and ckpt.load: + latents = ckpt.data["latents"].to(device=device, dtype=latents.dtype) +``` + +## Key insight + +The checkpoint dict is passed into the pipeline and mutated in-place. After the pipeline returns (or stops early), you read back `ckpt.data` to get the captured tensors. Both pipelines save under their own key names, so the test maps between them (e.g. reference `"video_state.latent"` -> diffusers `"latents"`). + +## Memory management for large models + +For large models, free the source pipeline's GPU memory before loading the target pipeline. Clone injected tensors to CPU, delete everything else, then run the target with `enable_model_cpu_offload()`. diff --git a/.ai/skills/parity-testing/pitfalls.md b/.ai/skills/parity-testing/pitfalls.md new file mode 100644 index 000000000000..b0f59876f94a --- /dev/null +++ b/.ai/skills/parity-testing/pitfalls.md @@ -0,0 +1,116 @@ +# Complete Pitfalls Reference + +## 1. Global CPU RNG +`MultivariateNormal.sample()` uses the global CPU RNG, not `torch.Generator`. Must call `torch.manual_seed(seed)` before each pipeline run. A `generator=` kwarg won't help. + +## 2. Timestep dtype +Many transformers expect `int64` timesteps. `get_timestep_embedding` casts to float, so `745.3` and `745` produce different embeddings. Match the reference's casting. + +## 3. Guidance parameter mapping +Parameter names may differ: reference `zero_steps=1` (meaning `i <= 1`, 2 steps) vs target `zero_init_steps=2` (meaning `step < 2`, same thing). Check exact semantics. + +## 4. `patch_size` in noise generation +If noise generation depends on `patch_size` (e.g. `sample_block_noise`), it must be passed through. Missing it changes noise spatial structure. + +## 5. Variable shadowing in nested loops +Nested loops (stages -> chunks -> timesteps) can shadow variable names. If outer loop uses `latents` and inner loop also assigns to `latents`, scoping must match the reference. + +## 6. Float precision differences -- don't dismiss them +Target may compute in float32 where reference used bfloat16. Small per-element diffs (1e-3 to 1e-2) *look* harmless but can compound catastrophically over iterative processes like denoising loops (see Pitfalls #11 and #13). Before dismissing a precision difference: (a) check whether it feeds into an iterative process, (b) if so, trace the accumulation curve over all iterations to see if it stays bounded or grows exponentially. Only truly non-iterative precision diffs (e.g. in a single-pass encoder) are safe to accept. + +## 7. Scheduler state reset between stages +Some schedulers accumulate state (e.g. `model_outputs` in UniPC) that must be cleared between stages. + +## 8. Component access +Standard: `self.transformer`. Modular: `components.transformer`. Missing this causes AttributeError. + +## 9. Guider state across stages +In multi-stage denoising, the guider's internal state (e.g. `zero_init_steps`) may need save/restore between stages. + +## 10. Model storage location +NEVER store converted models in `/tmp/` -- temporary directories get wiped on restart. Always save converted checkpoints under a persistent path in the project repo (e.g. `models/ltx23-diffusers/`). + +## 11. Noise dtype mismatch (causes washed-out output) + +Reference code often generates noise in float32 then casts to model dtype (bfloat16) before storing: + +```python +noise = torch.randn(..., dtype=torch.float32, generator=gen) +noise = noise.to(dtype=model_dtype) # bfloat16 -- values get quantized +``` + +Diffusers pipelines may keep latents in float32 throughout the loop. The per-element difference is only ~1.5e-02, but this compounds over 30 denoising steps via 1/sigma amplification (Pitfall #13) and produces completely washed-out output. + +**Fix**: Match the reference -- generate noise in the model's working dtype: +```python +latent_dtype = self.transformer.dtype # e.g. bfloat16 +latents = self.prepare_latents(..., dtype=latent_dtype, ...) +``` + +**Detection**: Encode stage test shows initial latent max_diff of exactly ~1.5e-02. This specific magnitude is the signature of float32->bfloat16 quantization error. + +## 12. RoPE position dtype + +RoPE cosine/sine values are sensitive to position coordinate dtype. If reference uses bfloat16 positions but diffusers uses float32, the RoPE output diverges significantly (max_diff up to 2.0). Different modalities may use different position dtypes (e.g. video bfloat16, audio float32) -- check the reference carefully. + +## 13. 1/sigma error amplification in Euler denoising + +In Euler/flow-matching, the velocity formula divides by sigma: `v = (latents - pred_x0) / sigma`. As sigma shrinks from ~1.0 (step 0) to ~0.001 (step 29), errors are amplified up to 1000x. A 1.5e-02 init difference grows linearly through mid-steps, then exponentially in final steps, reaching max_diff ~6.0. This is why dtype mismatches (Pitfalls #11, #12) that seem tiny at init produce visually broken output. Use per-step accumulation tracing to diagnose. + +## 14. Config value assumptions -- always diff, never assume + +When debugging parity, don't assume config values match code defaults. The published model checkpoint may override defaults with different values. A wrong assumption about a single config field can send you down hours of debugging in the wrong direction. + +**The pattern that goes wrong:** +1. You see `param_x` has default `1` in the code +2. The reference code also uses `param_x` with a default of `1` +3. You assume both sides use `1` and apply a "fix" based on that +4. But the actual checkpoint config has `param_x: 1000`, and so does the published diffusers config +5. Your "fix" now *creates* divergence instead of fixing it + +**Prevention -- config diff first:** +```python +# Reference: read from checkpoint metadata (no model loading needed) +from safetensors import safe_open +import json +ref_config = json.loads(safe_open(checkpoint_path, framework="pt").metadata()["config"]) + +# Diffusers: read from model config +from diffusers import MyModel +diff_model = MyModel.from_pretrained(model_path, subfolder="transformer") +diff_config = dict(diff_model.config) + +# Compare all values +for key in sorted(set(list(ref_config.get("transformer", {}).keys()) + list(diff_config.keys()))): + ref_val = ref_config.get("transformer", {}).get(key, "MISSING") + diff_val = diff_config.get(key, "MISSING") + if ref_val != diff_val: + print(f" DIFF {key}: ref={ref_val}, diff={diff_val}") +``` + +Run this **before** writing any hooks, analysis code, or fixes. It takes 30 seconds and catches wrong assumptions immediately. + +**When debugging divergence -- trace values, don't reason about them:** +If two implementations diverge, hook the actual intermediate values at the point of divergence rather than reading code to figure out what the values "should" be. Code analysis builds on assumptions; value tracing reveals facts. + +## 15. Decoder config mismatch (causes pixelated artifacts) + +The upstream model config may have wrong values for decoder-specific parameters (e.g. `upsample_residual`, `upsample_type`). These control whether the decoder uses skip connections in upsampling -- getting them wrong produces severe pixelation or blocky artifacts. + +**Detection**: Feed identical post-loop latents through both decoders. If max pixel diff is large (PSNR < 40 dB) on CPU/float32, it's a real bug, not precision noise. Trace through decoder blocks (conv_in -> mid_block -> up_blocks) to find where divergence starts. + +**Fix**: Correct the config value. Don't edit cached files in `~/.cache/huggingface/` -- either save to a local model directory or open a PR on the upstream repo (see Testing Rule #7). + +## 16. Incomplete injection tests -- inject ALL variables or the test is invalid + +When doing injection tests (feeding reference tensors into the diffusers pipeline), you must inject **every** divergent input, including sigmas/timesteps. A common mistake: the preloop checkpoint saves sigmas but the injection code only loads latents and embeddings. The test then runs with different sigma schedules, making it impossible to isolate the real cause. + +**Prevention**: After writing injection code, verify by listing every variable the injected stage consumes and checking each one is either (a) injected from reference, or (b) confirmed identical between pipelines. + +## 17. bf16 connector/encoder divergence -- don't chase it + +When running on GPU/bfloat16, multi-layer encoders (e.g. 8-layer connector transformers) accumulate bf16 rounding noise that looks alarming (max_diff 0.3-2.7). Before investigating, re-run the component test on CPU/float32. If it passes (max_diff < 1e-4), the divergence is pure precision noise, not a code bug. Don't spend hours tracing through layers -- confirm on CPU/float32 and move on. + +## 18. Stale test fixtures + +When using saved tensors for cross-pipeline comparison, always ensure both sets of tensors were captured from the same run configuration (same seed, same config, same code version). Mixing fixtures from different runs (e.g. reference tensors from yesterday, diffusers tensors from today after a code change) creates phantom divergence that wastes debugging time. Regenerate both sides in a single test script execution. diff --git a/.gitignore b/.gitignore index d281b8d1511c..5d19ea2db3c9 100644 --- a/.gitignore +++ b/.gitignore @@ -182,4 +182,6 @@ wandb # AI agent generated symlinks /AGENTS.md -/CLAUDE.md \ No newline at end of file +/CLAUDE.md +/.agents/skills +/.claude/skills \ No newline at end of file diff --git a/Makefile b/Makefile index 491baba074c9..138b0bfa5101 100644 --- a/Makefile +++ b/Makefile @@ -103,9 +103,16 @@ post-patch: codex: ln -snf .ai/AGENTS.md AGENTS.md + mkdir -p .agents + rm -rf .agents/skills + ln -snf ../.ai/skills .agents/skills claude: ln -snf .ai/AGENTS.md CLAUDE.md + mkdir -p .claude + rm -rf .claude/skills + ln -snf ../.ai/skills .claude/skills clean-ai: rm -f AGENTS.md CLAUDE.md + rm -rf .agents/skills .claude/skills diff --git a/docs/source/en/conceptual/contribution.md b/docs/source/en/conceptual/contribution.md index 22bb265b1d79..47cde3251168 100644 --- a/docs/source/en/conceptual/contribution.md +++ b/docs/source/en/conceptual/contribution.md @@ -572,9 +572,9 @@ For documentation strings, 🧨 Diffusers follows the [Google style](https://goo The repository keeps AI-agent configuration in `.ai/` and exposes local agent files via symlinks. -- **Source of truth** — edit `.ai/AGENTS.md` (and any future `.ai/skills/`) -- **Don't edit** generated root-level `AGENTS.md` or `CLAUDE.md` — they are symlinks +- **Source of truth** — edit files under `.ai/` (`AGENTS.md` for coding guidelines, `skills/` for on-demand task knowledge) +- **Don't edit** generated root-level `AGENTS.md`, `CLAUDE.md`, or `.agents/skills`/`.claude/skills` — they are symlinks - Setup commands: - - `make codex` — symlink for OpenAI Codex - - `make claude` — symlink for Claude Code - - `make clean-ai` — remove generated symlinks \ No newline at end of file + - `make codex` — symlink guidelines + skills for OpenAI Codex + - `make claude` — symlink guidelines + skills for Claude Code + - `make clean-ai` — remove all generated symlinks \ No newline at end of file From dc5084293defda5ca11f8c4aa0645a06e763ee59 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 20 Mar 2026 10:36:03 +0530 Subject: [PATCH 058/215] [Modular] Test for catching dtype and device issues with AutoModel type hints (#13287) * update * update * update --- .../test_modular_pipelines_custom_blocks.py | 111 +++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 7c6e97a36eb7..59d6a3e75f55 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -31,7 +31,41 @@ WanModularPipeline, ) -from ..testing_utils import nightly, require_torch, slow +from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device + + +def _create_tiny_model_dir(model_dir): + TINY_MODEL_CODE = ( + "import torch\n" + "from diffusers import ModelMixin, ConfigMixin\n" + "from diffusers.configuration_utils import register_to_config\n" + "\n" + "class TinyModel(ModelMixin, ConfigMixin):\n" + " @register_to_config\n" + " def __init__(self, hidden_size=4):\n" + " super().__init__()\n" + " self.linear = torch.nn.Linear(hidden_size, hidden_size)\n" + "\n" + " def forward(self, x):\n" + " return self.linear(x)\n" + ) + + with open(os.path.join(model_dir, "modeling.py"), "w") as f: + f.write(TINY_MODEL_CODE) + + config = { + "_class_name": "TinyModel", + "_diffusers_version": "0.0.0", + "auto_map": {"AutoModel": "modeling.TinyModel"}, + "hidden_size": 4, + } + with open(os.path.join(model_dir, "config.json"), "w") as f: + json.dump(config, f) + + torch.save( + {"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)}, + os.path.join(model_dir, "diffusion_pytorch_model.bin"), + ) class DummyCustomBlockSimple(ModularPipelineBlocks): @@ -341,6 +375,81 @@ def __call__(self, components, state: PipelineState) -> PipelineState: loaded_pipe.update_components(custom_model=custom_model) assert getattr(loaded_pipe, "custom_model", None) is not None + def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path): + """Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel.""" + from diffusers import AutoModel + + model_dir = str(tmp_path / "model") + os.makedirs(model_dir) + _create_tiny_model_dir(model_dir) + + class DtypeTestBlock(ModularPipelineBlocks): + @property + def expected_components(self): + return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)] + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("prompt", type_hint=str, required=True)] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("output", type_hint=str)] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.output = "test" + self.set_block_state(state, block_state) + return components, state + + block = DtypeTestBlock() + pipe = block.init_pipeline() + pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True) + + assert pipe.model.dtype == torch.float16 + + @require_torch_accelerator + def test_automodel_type_hint_preserves_device(self, tmp_path): + """Test that ComponentSpec with AutoModel type_hint correctly passes device_map.""" + from diffusers import AutoModel + + model_dir = str(tmp_path / "model") + os.makedirs(model_dir) + _create_tiny_model_dir(model_dir) + + class DeviceTestBlock(ModularPipelineBlocks): + @property + def expected_components(self): + return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)] + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("prompt", type_hint=str, required=True)] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("output", type_hint=str)] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.output = "test" + self.set_block_state(state, block_state) + return components, state + + block = DeviceTestBlock() + pipe = block.init_pipeline() + pipe.load_components(device_map=torch_device, trust_remote_code=True) + + assert pipe.model.device.type == torch_device + def test_custom_block_loads_from_hub(self): repo_id = "hf-internal-testing/tiny-modular-diffusers-block" block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) From d765c4f6d0f79702aac95e88e3b4fce76ff92401 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 20 Mar 2026 11:40:06 +0530 Subject: [PATCH 059/215] [CI] Update transformer version in release tests (#13296) update --- .github/workflows/release_tests_fast.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml index f667d715090d..7d097d165928 100644 --- a/.github/workflows/release_tests_fast.yml +++ b/.github/workflows/release_tests_fast.yml @@ -4,6 +4,7 @@ name: (Release) Fast GPU Tests on main on: + workflow_dispatch: push: branches: - "v*.*.*-release" @@ -33,6 +34,7 @@ jobs: - name: Install dependencies run: | uv pip install -e ".[quality]" + uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - name: Environment run: | python utils/print_env.py @@ -74,6 +76,7 @@ jobs: run: | uv pip install -e ".[quality]" uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - name: Environment run: | python utils/print_env.py @@ -125,6 +128,7 @@ jobs: uv pip install -e ".[quality]" uv pip install peft@git+https://github.com/huggingface/peft.git uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - name: Environment run: | @@ -175,6 +179,7 @@ jobs: uv pip install -e ".[quality]" uv pip install peft@git+https://github.com/huggingface/peft.git uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - name: Environment run: | @@ -232,6 +237,7 @@ jobs: - name: Install dependencies run: | uv pip install -e ".[quality,training]" + uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - name: Environment run: | python utils/print_env.py @@ -274,6 +280,7 @@ jobs: - name: Install dependencies run: | uv pip install -e ".[quality,training]" + uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - name: Environment run: | python utils/print_env.py @@ -316,6 +323,7 @@ jobs: - name: Install dependencies run: | uv pip install -e ".[quality,training]" + uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - name: Environment run: | From 166c1759609640ec0527252241ef5affda9932d7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 20 Mar 2026 16:02:16 +0530 Subject: [PATCH 060/215] [ci] hoping to fix is_flaky with wanvace. (#13294) * hoping to fix is_flaky with wanvace. * revert changes in src/diffusers/utils/testing_utils.py and propagate them to tests/testing_utils.py. * up --- tests/lora/test_lora_layers_wanvace.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index c8acaea9bef0..dc435094d780 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -28,7 +28,6 @@ from ..testing_utils import ( floats_tensor, - is_flaky, require_peft_backend, require_peft_version_greater, skip_mps, @@ -46,7 +45,6 @@ @require_peft_backend @skip_mps -@is_flaky(max_attempts=10, description="very flaky class") class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = WanVACEPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler @@ -73,8 +71,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "base_dim": 3, "z_dim": 4, "dim_mult": [1, 1, 1, 1], - "latents_mean": torch.randn(4).numpy().tolist(), - "latents_std": torch.randn(4).numpy().tolist(), + "latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245], + "latents_std": [2.8184, 1.4541, 2.3275, 2.6558], "num_res_blocks": 1, "temperal_downsample": [False, True, True], } From a11507040f8f603ed9b38479910fdfc3d50d5dea Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 20 Mar 2026 17:28:09 +0530 Subject: [PATCH 061/215] [core] fa4 support. (#13280) * start fa4 support. * up * specify minimum version --- .../en/optimization/attention_backends.md | 1 + src/diffusers/models/attention_dispatch.py | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index f3ff4781c6ec..6dab9a2b1f50 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and | `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention | | `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels | | `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm | +| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 | | `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 | | `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 | | `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels | diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 5b1f831ed060..c407f59037e6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum): FLASH_HUB = "flash_hub" FLASH_VARLEN = "flash_varlen" FLASH_VARLEN_HUB = "flash_varlen_hub" + FLASH_4_HUB = "flash_4_hub" _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" @@ -358,6 +359,11 @@ class _HubKernelConfig: function_attr="sageattn", version=1, ), + AttentionBackendName.FLASH_4_HUB: _HubKernelConfig( + repo_id="kernels-staging/flash-attn4", + function_attr="flash_attn_func", + version=0, + ), } @@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None AttentionBackendName._FLASH_3_HUB, AttentionBackendName._FLASH_3_VARLEN_HUB, AttentionBackendName.SAGE_HUB, + AttentionBackendName.FLASH_4_HUB, ]: if not is_kernels_available(): raise RuntimeError( @@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`." ) + if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"): + raise RuntimeError( + f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`." + ) + elif backend == AttentionBackendName.AITER: if not _CAN_USE_AITER_ATTN: raise RuntimeError( @@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_4_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=False, +) +def _flash_attention_4_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + scale: float | None = None, + is_causal: bool = False, + return_lse: bool = False, + _parallel_config: "ParallelConfig" | None = None, +) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 4.") + + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn + out = func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + ) + if isinstance(out, tuple): + return (out[0], out[1]) if return_lse else out[0] + return out + + @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_VARLEN_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], From e4535fd1536078c2c7d297ce53e5803c230f7311 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 20 Mar 2026 20:53:21 +0530 Subject: [PATCH 062/215] [tests] fix audioldm2 tests. (#13293) fix audioldm2 tests. --- .../pipelines/audioldm2/pipeline_audioldm2.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index b023974a33dd..b79ee280ca34 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -324,17 +324,18 @@ def generate_language_model( `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): The sequence of generated hidden-states. """ - cache_position_kwargs = {} - if is_transformers_version("<", "4.52.1"): - cache_position_kwargs["input_ids"] = inputs_embeds - else: - cache_position_kwargs["seq_length"] = inputs_embeds.shape[0] - cache_position_kwargs["device"] = ( - self.language_model.device if getattr(self, "language_model", None) is not None else self.device - ) - cache_position_kwargs["model_kwargs"] = model_kwargs max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens - model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs) + if hasattr(self.language_model, "_get_initial_cache_position"): + cache_position_kwargs = {} + if is_transformers_version("<", "4.52.1"): + cache_position_kwargs["input_ids"] = inputs_embeds + else: + cache_position_kwargs["seq_length"] = inputs_embeds.shape[0] + cache_position_kwargs["device"] = ( + self.language_model.device if getattr(self, "language_model", None) is not None else self.device + ) + cache_position_kwargs["model_kwargs"] = model_kwargs + model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs) for _ in range(max_new_tokens): # prepare model inputs From 0127a0e5300c54f027801085608e889dedb76774 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 21 Mar 2026 09:41:48 +0530 Subject: [PATCH 063/215] [tests] test load_components in modular (#13245) * test load_components. * fix * fix * u[ * up --- .../test_modular_pipelines_common.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index d897ed793376..8a65999b2006 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -5,6 +5,7 @@ import pytest import torch +from huggingface_hub import hf_hub_download import diffusers from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks @@ -32,6 +33,33 @@ ) +def _get_specified_components(path_or_repo_id, cache_dir=None): + if os.path.isdir(path_or_repo_id): + config_path = os.path.join(path_or_repo_id, "modular_model_index.json") + else: + try: + config_path = hf_hub_download( + repo_id=path_or_repo_id, + filename="modular_model_index.json", + local_dir=cache_dir, + ) + except Exception: + return None + + with open(config_path) as f: + config = json.load(f) + + components = set() + for k, v in config.items(): + if isinstance(v, (str, int, float, bool)): + continue + for entry in v: + if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")): + components.add(k) + break + return components + + class ModularPipelineTesterMixin: """ It provides a set of common tests for each modular pipeline, @@ -360,6 +388,39 @@ def test_save_from_pretrained(self, tmp_path): assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def test_load_expected_components_from_pretrained(self, tmp_path): + pipe = self.get_pipeline() + expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path) + if not expected: + pytest.skip("Skipping test as we couldn't fetch the expected components.") + + actual = { + name + for name in pipe.components + if getattr(pipe, name, None) is not None + and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null") + } + assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}" + + def test_load_expected_components_from_save_pretrained(self, tmp_path): + pipe = self.get_pipeline() + save_dir = str(tmp_path / "saved-pipeline") + pipe.save_pretrained(save_dir) + + expected = _get_specified_components(save_dir) + loaded_pipe = ModularPipeline.from_pretrained(save_dir) + loaded_pipe.load_components(torch_dtype=torch.float32) + + actual = { + name + for name in loaded_pipe.components + if getattr(loaded_pipe, name, None) is not None + and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null") + } + assert expected == actual, ( + f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}" + ) + def test_modular_index_consistency(self, tmp_path): pipe = self.get_pipeline() components_spec = pipe._component_specs From 7d555606f92e64055ae70f1ae08fe00fe7f652d9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 23 Mar 2026 16:56:08 +0530 Subject: [PATCH 064/215] [CI] Flux2 Model Test Refactor (#13071) * update * update * update --------- Co-authored-by: Sayak Paul --- .../test_models_transformer_flux.py | 9 +- .../test_models_transformer_flux2.py | 613 +++++++++++++++--- 2 files changed, 542 insertions(+), 80 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 2d39dadfcad1..24be833d0ed2 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -41,7 +41,6 @@ ModelOptCompileTesterMixin, ModelOptTesterMixin, ModelTesterMixin, - PyramidAttentionBroadcastTesterMixin, QuantoCompileTesterMixin, QuantoTesterMixin, SingleFileTesterMixin, @@ -219,6 +218,10 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin): class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin): """Training tests for Flux Transformer.""" + def test_gradient_checkpointing_is_applied(self): + expected_set = {"FluxTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin): """Attention processor tests for Flux Transformer.""" @@ -412,10 +415,6 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn """BitsAndBytes + compile tests for Flux Transformer.""" -class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin): - """PyramidAttentionBroadcast cache tests for Flux Transformer.""" - - class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin): """FirstBlockCache tests for Flux Transformer.""" diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index 316d5fa770bb..a109f603411d 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -13,48 +13,95 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import torch -from diffusers import Flux2Transformer2DModel, attention_backend +from diffusers import Flux2Transformer2DModel +from diffusers.models.transformers.transformer_flux2 import ( + Flux2KVAttnProcessor, + Flux2KVCache, + Flux2KVLayerCache, + Flux2KVParallelSelfAttnProcessor, +) +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + ContextParallelTesterMixin, + GGUFCompileTesterMixin, + GGUFTesterMixin, + LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchAoCompileTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = Flux2Transformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.7, 0.6, 0.6] - - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True - +class Flux2TransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - return self.prepare_dummy_input() + def model_class(self): + return Flux2Transformer2DModel @property - def input_shape(self): + def output_shape(self) -> tuple[int, int]: return (16, 4) @property - def output_shape(self): + def input_shape(self) -> tuple[int, int]: return (16, 4) - def prepare_dummy_input(self, height=4, width=4): + @property + def model_split_percents(self) -> list: + # We override the items here because the transformer under consideration is small. + return [0.7, 0.6, 0.6] + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def uses_custom_attn_processor(self) -> bool: + # Skip setting testing with default: AttnProcessor + return True + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "timestep_guidance_channels": 256, # Hardcoded in original code + "axes_dims_rope": [4, 4, 4, 4], + } + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: batch_size = 1 num_latent_channels = 4 sequence_length = 48 embedding_dim = 32 - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) t_coords = torch.arange(1) h_coords = torch.arange(height) @@ -82,8 +129,286 @@ def prepare_dummy_input(self, height=4, width=4): "guidance": guidance, } - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + +class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin): + pass + + +class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for Flux2 Transformer.""" + + +class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin): + """Training tests for Flux2 Transformer.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Flux2Transformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for Flux2 Transformer.""" + + +class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin): + """Context Parallel inference tests for Flux2 Transformer.""" + + +class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for Flux2 Transformer.""" + + +class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for Flux2 Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for LoRA hotswap tests.""" + batch_size = 1 + num_latent_channels = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + + t_coords = torch.arange(1) + h_coords = torch.arange(height) + w_coords = torch.arange(width) + l_coords = torch.arange(1) + image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) + image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + text_t_coords = torch.arange(1) + text_h_coords = torch.arange(1) + text_w_coords = torch.arange(1) + text_l_coords = torch.arange(sequence_length) + text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords) + text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "timestep": timestep, + "guidance": guidance, + } + + +class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin): + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for compilation tests.""" + batch_size = 1 + num_latent_channels = 4 + sequence_length = 48 + embedding_dim = 32 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + + t_coords = torch.arange(1) + h_coords = torch.arange(height) + w_coords = torch.arange(width) + l_coords = torch.arange(1) + image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) + image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + text_t_coords = torch.arange(1) + text_h_coords = torch.arange(1) + text_w_coords = torch.arange(1) + text_l_coords = torch.arange(sequence_length) + text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords) + text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "timestep": timestep, + "guidance": guidance, + } + + +class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for Flux2 Transformer.""" + + +class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for Flux2 Transformer.""" + + +class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin): + """GGUF quantization tests for Flux2 Transformer.""" + + @property + def gguf_filename(self): + return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf" + + @property + def torch_dtype(self): + return torch.bfloat16 + + def get_dummy_inputs(self): + """Override to provide inputs matching the real FLUX2 model dimensions. + + Flux2 defaults: in_channels=128, joint_attention_dim=15360 + """ + batch_size = 1 + height = 64 + width = 64 + sequence_length = 512 + + hidden_states = randn_tensor( + (batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ) + + # Flux2 uses 4D image/text IDs (t, h, w, l) + t_coords = torch.arange(1) + h_coords = torch.arange(height) + w_coords = torch.arange(width) + l_coords = torch.arange(1) + image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) + image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + text_t_coords = torch.arange(1) + text_h_coords = torch.arange(1) + text_w_coords = torch.arange(1) + text_l_coords = torch.arange(sequence_length) + text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords) + text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype) + guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "timestep": timestep, + "guidance": guidance, + } + + +class TestFlux2TransformerTorchAoCompile(Flux2TransformerTesterConfig, TorchAoCompileTesterMixin): + """TorchAO + compile tests for Flux2 Transformer.""" + + +class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin): + """GGUF + compile tests for Flux2 Transformer.""" + + @property + def gguf_filename(self): + return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf" + + @property + def torch_dtype(self): + return torch.bfloat16 + + def get_dummy_inputs(self): + """Override to provide inputs matching the real FLUX2 model dimensions. + + Flux2 defaults: in_channels=128, joint_attention_dim=15360 + """ + batch_size = 1 + height = 64 + width = 64 + sequence_length = 512 + + hidden_states = randn_tensor( + (batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ) + + # Flux2 uses 4D image/text IDs (t, h, w, l) + t_coords = torch.arange(1) + h_coords = torch.arange(height) + w_coords = torch.arange(width) + l_coords = torch.arange(1) + image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) + image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + text_t_coords = torch.arange(1) + text_h_coords = torch.arange(1) + text_w_coords = torch.arange(1) + text_l_coords = torch.arange(sequence_length) + text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords) + text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + + timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype) + guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "timestep": timestep, + "guidance": guidance, + } + + +class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig): + num_ref_tokens = 4 + + @property + def model_class(self): + return Flux2Transformer2DModel + + @property + def output_shape(self) -> tuple[int, int]: + return (16, 4) + + @property + def input_shape(self) -> tuple[int, int]: + return (16, 4) + + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def uses_custom_attn_processor(self) -> bool: + return True + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { "patch_size": 1, "in_channels": 4, "num_layers": 1, @@ -91,72 +416,210 @@ def prepare_init_args_and_inputs_for_common(self): "attention_head_dim": 16, "num_attention_heads": 2, "joint_attention_dim": 32, - "timestep_guidance_channels": 256, # Hardcoded in original code + "timestep_guidance_channels": 256, "axes_dims_rope": [4, 4, 4, 4], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - # TODO (Daniel, Sayak): We can remove this test. - def test_flux2_consistency(self, seed=0): - torch.manual_seed(seed) - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(seed) - model = self.model_class(**init_dict) - # state_dict = model.state_dict() - # for key, param in state_dict.items(): - # print(f"{key} | {param.shape}") - # torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt") - model.to(torch_device) - model.eval() - - with attention_backend("native"): - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + batch_size = 1 + num_latent_channels = 4 + sequence_length = 48 + embedding_dim = 32 + num_ref_tokens = self.num_ref_tokens + + ref_hidden_states = randn_tensor( + (batch_size, num_ref_tokens, num_latent_channels), generator=self.generator, device=torch_device + ) + img_hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + hidden_states = torch.cat([ref_hidden_states, img_hidden_states], dim=1) + + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + + ref_t_coords = torch.arange(1) + ref_h_coords = torch.arange(num_ref_tokens) + ref_w_coords = torch.arange(1) + ref_l_coords = torch.arange(1) + ref_ids = torch.cartesian_prod(ref_t_coords, ref_h_coords, ref_w_coords, ref_l_coords) + ref_ids = ref_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) - self.assertIsNotNone(output) + t_coords = torch.arange(1) + h_coords = torch.arange(height) + w_coords = torch.arange(width) + l_coords = torch.arange(1) + image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) + image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) + image_ids = torch.cat([ref_ids, image_ids], dim=1) - # input & output have to have the same shape - input_tensor = inputs_dict[self.main_input_name] - expected_shape = input_tensor.shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + text_t_coords = torch.arange(1) + text_h_coords = torch.arange(1) + text_w_coords = torch.arange(1) + text_l_coords = torch.arange(sequence_length) + text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords) + text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device) - # Check against expected slice - # fmt: off - expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416]) - # fmt: on + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size) - flat_output = output.cpu().flatten() - generated_slice = torch.cat([flat_output[:8], flat_output[-8:]]) - self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4)) + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "img_ids": image_ids, + "txt_ids": text_ids, + "timestep": timestep, + "guidance": guidance, + } - def test_gradient_checkpointing_is_applied(self): - expected_set = {"Flux2Transformer2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) +class TestFlux2TransformerKVCache(Flux2TransformerKVCacheTesterConfig): + """KV cache tests for Flux2 Transformer.""" + + def test_kv_layer_cache_store_and_get(self): + cache = Flux2KVLayerCache() + k = torch.randn(1, 4, 2, 16) + v = torch.randn(1, 4, 2, 16) + cache.store(k, v) + k_out, v_out = cache.get() + assert torch.equal(k, k_out) + assert torch.equal(v, v_out) + + def test_kv_layer_cache_get_before_store_raises(self): + cache = Flux2KVLayerCache() + try: + cache.get() + assert False, "Expected RuntimeError" + except RuntimeError: + pass + + def test_kv_layer_cache_clear(self): + cache = Flux2KVLayerCache() + cache.store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16)) + cache.clear() + assert cache.k_ref is None + assert cache.v_ref is None + + def test_kv_cache_structure(self): + num_double = 3 + num_single = 2 + cache = Flux2KVCache(num_double, num_single) + assert len(cache.double_block_caches) == num_double + assert len(cache.single_block_caches) == num_single + assert cache.num_ref_tokens == 0 + + for i in range(num_double): + assert isinstance(cache.get_double(i), Flux2KVLayerCache) + for i in range(num_single): + assert isinstance(cache.get_single(i), Flux2KVLayerCache) + + def test_kv_cache_clear(self): + cache = Flux2KVCache(2, 1) + cache.num_ref_tokens = 4 + cache.get_double(0).store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16)) + cache.clear() + assert cache.num_ref_tokens == 0 + assert cache.get_double(0).k_ref is None + + def _set_kv_attn_processors(self, model): + for block in model.transformer_blocks: + block.attn.set_processor(Flux2KVAttnProcessor()) + for block in model.single_transformer_blocks: + block.attn.set_processor(Flux2KVParallelSelfAttnProcessor()) + + @torch.no_grad() + def test_extract_mode_returns_cache(self): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + self._set_kv_attn_processors(model) + + output = model( + **self.get_dummy_inputs(), + kv_cache_mode="extract", + num_ref_tokens=self.num_ref_tokens, + ref_fixed_timestep=0.0, + ) + + assert output.kv_cache is not None + assert isinstance(output.kv_cache, Flux2KVCache) + assert output.kv_cache.num_ref_tokens == self.num_ref_tokens + + for layer_cache in output.kv_cache.double_block_caches: + assert layer_cache.k_ref is not None + assert layer_cache.v_ref is not None + + for layer_cache in output.kv_cache.single_block_caches: + assert layer_cache.k_ref is not None + assert layer_cache.v_ref is not None + + @torch.no_grad() + def test_extract_mode_output_shape(self): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() -class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = Flux2Transformer2DModel - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + height, width = 4, 4 + output = model( + **self.get_dummy_inputs(height=height, width=width), + kv_cache_mode="extract", + num_ref_tokens=self.num_ref_tokens, + ref_fixed_timestep=0.0, + ) - def prepare_init_args_and_inputs_for_common(self): - return Flux2TransformerTests().prepare_init_args_and_inputs_for_common() + assert output.sample.shape == (1, height * width, 4) - def prepare_dummy_input(self, height, width): - return Flux2TransformerTests().prepare_dummy_input(height=height, width=width) + @torch.no_grad() + def test_cached_mode_uses_cache(self): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + height, width = 4, 4 + extract_output = model( + **self.get_dummy_inputs(height=height, width=width), + kv_cache_mode="extract", + num_ref_tokens=self.num_ref_tokens, + ref_fixed_timestep=0.0, + ) + + base_config = Flux2TransformerTesterConfig() + cached_inputs = base_config.get_dummy_inputs(height=height, width=width) + cached_output = model( + **cached_inputs, + kv_cache=extract_output.kv_cache, + kv_cache_mode="cached", + ) + + assert cached_output.sample.shape == (1, height * width, 4) + assert cached_output.kv_cache is None + + @torch.no_grad() + def test_extract_return_dict_false(self): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() -class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = Flux2Transformer2DModel - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + output = model( + **self.get_dummy_inputs(), + kv_cache_mode="extract", + num_ref_tokens=self.num_ref_tokens, + ref_fixed_timestep=0.0, + return_dict=False, + ) + + assert isinstance(output, tuple) + assert len(output) == 2 + assert isinstance(output[1], Flux2KVCache) + + @torch.no_grad() + def test_no_kv_cache_mode_returns_no_cache(self): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() - def prepare_init_args_and_inputs_for_common(self): - return Flux2TransformerTests().prepare_init_args_and_inputs_for_common() + base_config = Flux2TransformerTesterConfig() + output = model(**base_config.get_dummy_inputs()) - def prepare_dummy_input(self, height, width): - return Flux2TransformerTests().prepare_dummy_input(height=height, width=width) + assert output.kv_cache is None From 9da4e24544a0a8d6b5a7ae9815c394a8512030fb Mon Sep 17 00:00:00 2001 From: Charles Date: Mon, 23 Mar 2026 13:40:07 +0100 Subject: [PATCH 065/215] [export] Add export-safe LRU cache helper (#13290) * [core] Add export-safe LRU cache helper * torch version check! --------- Co-authored-by: Sayak Paul --- src/diffusers/hooks/context_parallel.py | 5 ++-- src/diffusers/models/attention_dispatch.py | 4 ++-- .../transformers/transformer_qwenimage.py | 9 ++++---- src/diffusers/utils/torch_utils.py | 23 +++++++++++++++++++ 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6130be2b8290..f6ab623a1865 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -import functools import inspect from dataclasses import dataclass from typing import Type @@ -32,7 +31,7 @@ gather_size_by_comm, ) from ..utils import get_logger -from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module +from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module from .hooks import HookRegistry, ModelHook @@ -327,7 +326,7 @@ def unshard_anything( return tensor -@functools.lru_cache(maxsize=64) +@lru_cache_unless_export(maxsize=64) def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]: gather_shapes = [] for i in range(world_size): diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index c407f59037e6..42dc63273740 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -49,7 +49,7 @@ is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -from ..utils.torch_utils import maybe_allow_in_graph +from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph from ._modeling_parallel import gather_size_by_comm @@ -587,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None ) -@functools.lru_cache(maxsize=128) +@lru_cache_unless_export(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, seq_len_q: int, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index a54cb3b8e092..c5419b9f107e 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import math from math import prod from typing import Any @@ -25,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import apply_lora_scale, deprecate, logging -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -307,7 +306,7 @@ def forward( return vid_freqs, txt_freqs - @functools.lru_cache(maxsize=128) + @lru_cache_unless_export(maxsize=128) def _compute_video_freqs( self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None ) -> torch.Tensor: @@ -428,7 +427,7 @@ def forward( return vid_freqs, txt_freqs - @functools.lru_cache(maxsize=None) + @lru_cache_unless_export(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs @@ -450,7 +449,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() - @functools.lru_cache(maxsize=None) + @lru_cache_unless_export(maxsize=None) def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7f4cb3e12766..8a48316bf3dd 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -19,11 +19,16 @@ import functools import os +from typing import Callable, ParamSpec, TypeVar from . import logging from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version +T = TypeVar("T") +P = ParamSpec("P") + + if is_torch_available(): import torch from torch.fft import fftn, fftshift, ifftn, ifftshift @@ -333,5 +338,23 @@ def disable_full_determinism(): torch.use_deterministic_algorithms(False) +@functools.wraps(functools.lru_cache) +def lru_cache_unless_export(maxsize=128, typed=False): + def outer_wrapper(fn: Callable[P, T]): + cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn) + if is_torch_version("<", "2.7.0"): + return cached + + @functools.wraps(fn) + def inner_wrapper(*args: P.args, **kwargs: P.kwargs): + if torch.compiler.is_exporting(): + return fn(*args, **kwargs) + return cached(*args, **kwargs) + + return inner_wrapper + + return outer_wrapper + + if is_torch_available(): torch_device = get_device() From 77b87263434213185af65ce487b42a4765ca74ff Mon Sep 17 00:00:00 2001 From: ddavidchick Date: Tue, 24 Mar 2026 01:56:49 +0300 Subject: [PATCH 066/215] Add KVAE 1.0 (#13033) * add kvae2d * add kvae3d video * add docs for kvae2d and kvae3d video * style fixes * fix kvae3d docs * fix normalzation * fix kvae video for code style * fix kvae video * kvae minor fixes * add gradient ckpting for kvaes * get rid of inplace ops kvae video * add tests for KVAEs * kvae2d normalization style change * kvaes fix style * update dummy_pt_objects test for kvaes --------- Co-authored-by: YiYi Xu --- docs/source/en/_toctree.yml | 4 + .../en/api/models/autoencoder_kl_kvae.md | 32 + .../api/models/autoencoder_kl_kvae_video.md | 33 + src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/autoencoders/__init__.py | 2 + .../autoencoders/autoencoder_kl_kvae.py | 802 +++++++++++++++ .../autoencoders/autoencoder_kl_kvae_video.py | 954 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 30 + .../test_models_autoencoder_kl_kvae.py | 73 ++ .../test_models_autoencoder_kl_kvae_video.py | 118 +++ 11 files changed, 2056 insertions(+) create mode 100644 docs/source/en/api/models/autoencoder_kl_kvae.md create mode 100644 docs/source/en/api/models/autoencoder_kl_kvae_video.md create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_kvae.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py create mode 100644 tests/models/autoencoders/test_models_autoencoder_kl_kvae.py create mode 100644 tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6b1a7288d60f..c2c62151132f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -446,6 +446,10 @@ title: AutoencoderKLHunyuanVideo - local: api/models/autoencoder_kl_hunyuan_video15 title: AutoencoderKLHunyuanVideo15 + - local: api/models/autoencoder_kl_kvae + title: AutoencoderKLKVAE + - local: api/models/autoencoder_kl_kvae_video + title: AutoencoderKLKVAEVideo - local: api/models/autoencoderkl_audio_ltx_2 title: AutoencoderKLLTX2Audio - local: api/models/autoencoderkl_ltx_2 diff --git a/docs/source/en/api/models/autoencoder_kl_kvae.md b/docs/source/en/api/models/autoencoder_kl_kvae.md new file mode 100644 index 000000000000..39cbb4c85c5f --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_kvae.md @@ -0,0 +1,32 @@ + + +# AutoencoderKLKVAE + +The 2D variational autoencoder (VAE) model with KL loss. + +The model can be loaded with the following code snippet. + +```python +import torch +from diffusers import AutoencoderKLKVAE + +vae = AutoencoderKLKVAE.from_pretrained("kandinskylab/KVAE-2D-1.0", subfolder="diffusers", torch_dtype=torch.bfloat16) +``` + +## AutoencoderKLKVAE + +[[autodoc]] AutoencoderKLKVAE + - decode + - all diff --git a/docs/source/en/api/models/autoencoder_kl_kvae_video.md b/docs/source/en/api/models/autoencoder_kl_kvae_video.md new file mode 100644 index 000000000000..0120dc2adc51 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_kvae_video.md @@ -0,0 +1,33 @@ + + +# AutoencoderKLKVAEVideo + +The 3D variational autoencoder (VAE) model with KL loss. + +The model can be loaded with the following code snippet. + +```python +import torch +from diffusers import AutoencoderKLKVAEVideo + +vae = AutoencoderKLKVAEVideo.from_pretrained("kandinskylab/KVAE-3D-1.0", subfolder="diffusers", torch_dtype=torch.float16) +``` + +## AutoencoderKLKVAEVideo + +[[autodoc]] AutoencoderKLKVAEVideo + - decode + - all + diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0be7b8166a37..eb5068b499cc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -193,6 +193,8 @@ "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", + "AutoencoderKLKVAE", + "AutoencoderKLKVAEVideo", "AutoencoderKLLTX2Audio", "AutoencoderKLLTX2Video", "AutoencoderKLLTXVideo", @@ -975,6 +977,8 @@ AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, + AutoencoderKLKVAE, + AutoencoderKLKVAEVideo, AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, AutoencoderKLLTXVideo, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e4bc95fdf884..7ded56049833 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -40,6 +40,8 @@ _import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] + _import_structure["autoencoders.autoencoder_kl_kvae"] = ["AutoencoderKLKVAE"] + _import_structure["autoencoders.autoencoder_kl_kvae_video"] = ["AutoencoderKLKVAEVideo"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"] _import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"] @@ -161,6 +163,8 @@ AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, + AutoencoderKLKVAE, + AutoencoderKLKVAEVideo, AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, AutoencoderKLLTXVideo, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index b6a673f7f7a7..609146ec340d 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -9,6 +9,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 +from .autoencoder_kl_kvae import AutoencoderKLKVAE +from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_kvae.py b/src/diffusers/models/autoencoders/autoencoder_kl_kvae.py new file mode 100644 index 000000000000..1bd2363af448 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_kvae.py @@ -0,0 +1,802 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +class KVAEResnetBlock2D(nn.Module): + r""" + A Resnet block with optional guidance. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + conv_shortcut (`bool`, *optional*, default to `False`): + If `True` and `in_channels` not equal to `out_channels`, add a 3x3 nn.conv2d layer for skip-connection. + temb_channels (`int`, *optional*, default to `512`): The number of channels in timestep embedding. + zq_ch (`int`, *optional*, default to `None`): Guidance channels for normalization. + add_conv (`bool`, *optional*, default to `False`): + If `True` add conv2d layer for normalization. + normalization (`nn.Module`, *optional*, default to `None`): The normalization layer. + act_fn (`str`, *optional*, default to `"swish"`): The activation function to use. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + temb_channels: int = 512, + zq_ch: Optional[int] = None, + add_conv: bool = False, + act_fn: str = "swish", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.nonlinearity = get_activation(act_fn) + + if zq_ch is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = KVAEDecoderSpatialNorm2D(in_channels, zq_channels=zq_ch, add_conv=add_conv) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=(1, 1), padding_mode="replicate" + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + if zq_ch is None: + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm2 = KVAEDecoderSpatialNorm2D(out_channels, zq_channels=zq_ch, add_conv=add_conv) + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=(1, 1), + padding_mode="replicate", + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=(1, 1), + padding_mode="replicate", + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, x: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None) -> torch.Tensor: + h = x + + if zq is None: + h = self.norm1(h) + else: + h = self.norm1(h, zq) + + h = self.nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is None: + h = self.norm2(h) + else: + h = self.norm2(h, zq) + + h = self.nonlinearity(h) + + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class KVAEPXSDownsample(nn.Module): + def __init__(self, in_channels: int, factor: int = 2): + r""" + A Downsampling module. + + Args: + in_channels (`int`): The number of channels in the input. + factor (`int`, *optional*, default to `2`): The downsampling factor. + """ + super().__init__() + self.factor = factor + self.unshuffle = nn.PixelUnshuffle(self.factor) + self.spatial_conv = nn.Conv2d( + in_channels, in_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode="reflect" + ) + self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (bchw) + pxs_interm = self.unshuffle(x) + b, c, h, w = pxs_interm.shape + pxs_interm_view = pxs_interm.view(b, c // self.factor**2, self.factor**2, h, w) + pxs_out = torch.mean(pxs_interm_view, dim=2) + + conv_out = self.spatial_conv(x) + + # adding it all together + out = conv_out + pxs_out + return self.linear(out) + + +class KVAEPXSUpsample(nn.Module): + def __init__(self, in_channels: int, factor: int = 2): + r""" + An Upsampling module. + + Args: + in_channels (`int`): The number of channels in the input. + factor (`int`, *optional*, default to `2`): The upsampling factor. + """ + super().__init__() + self.factor = factor + self.shuffle = nn.PixelShuffle(self.factor) + self.spatial_conv = nn.Conv2d( + in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect" + ) + + self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(self.factor**2, dim=1) + pxs_interm = self.shuffle(repeated) + + image_like_ups = F.interpolate(x, scale_factor=2, mode="nearest") + conv_out = self.spatial_conv(image_like_ups) + + # adding it all together + out = conv_out + pxs_interm + return self.linear(out) + + +class KVAEDecoderSpatialNorm2D(nn.Module): + r""" + A 2D normalization module for decoder. + + Args: + in_channels (`int`): The number of channels in the input. + zq_channels (`int`): The number of channels in the guidance. + add_conv (`bool`, *optional*, default to `false`): + If `True` add conv2d 3x3 layer for guidance in the beginning. + """ + + def __init__( + self, + in_channels: int, + zq_channels: int, + add_conv: bool = False, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + + self.add_conv = add_conv + if add_conv: + self.conv = nn.Conv2d( + in_channels=zq_channels, + out_channels=zq_channels, + kernel_size=3, + padding=(1, 1), + padding_mode="replicate", + ) + + self.conv_y = nn.Conv2d( + in_channels=zq_channels, + out_channels=in_channels, + kernel_size=1, + ) + self.conv_b = nn.Conv2d( + in_channels=zq_channels, + out_channels=in_channels, + kernel_size=1, + ) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_first = f + f_first_size = f_first.shape[2:] + zq = F.interpolate(zq, size=f_first_size, mode="nearest") + + if self.add_conv: + zq = self.conv(zq) + + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +class KVAEEncoder2D(nn.Module): + r""" + A 2D encoder module. + + Args: + ch (`int`): The base number of channels in multiresolution blocks. + ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`): + The channel multipliers in multiresolution blocks. + num_res_blocks (`int`): The number of Resnet blocks. + in_channels (`int`): The number of channels in the input. + z_channels (`int`): The number of output channels. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + act_fn (`str`, *optional*, default to `"swish"`): The activation function to use. + """ + + def __init__( + self, + *, + ch: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + in_channels: int, + z_channels: int, + double_z: bool = True, + act_fn: str = "swish", + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + if isinstance(num_res_blocks, int): + self.num_res_blocks = [num_res_blocks] * self.num_resolutions + else: + self.num_res_blocks = num_res_blocks + self.nonlinearity = get_activation(act_fn) + + self.in_channels = in_channels + + self.conv_in = nn.Conv2d( + in_channels=in_channels, + out_channels=self.ch, + kernel_size=3, + padding=(1, 1), + ) + + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks[i_level]): + block.append( + KVAEResnetBlock2D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + ) + ) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level < self.num_resolutions - 1: + down.downsample = KVAEPXSDownsample(in_channels=block_in) # mb: bad out channels + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = KVAEResnetBlock2D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + ) + + self.mid.block_2 = KVAEResnetBlock2D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + ) + + # end + self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True) + + self.conv_out = nn.Conv2d( + in_channels=block_in, + out_channels=2 * z_channels if double_z else z_channels, + kernel_size=3, + padding=(1, 1), + ) + + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # timestep embedding + temb = None + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks[i_level]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func(self.down[i_level].block[i_block], h, temb) + else: + h = self.down[i_level].block[i_block](h, temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + # middle + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func(self.mid.block_1, h, temb) + h = self._gradient_checkpointing_func(self.mid.block_2, h, temb) + else: + h = self.mid.block_1(h, temb) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = self.nonlinearity(h) + h = self.conv_out(h) + + return h + + +class KVAEDecoder2D(nn.Module): + r""" + A 2D decoder module. + + Args: + ch (`int`): The base number of channels in multiresolution blocks. + out_ch (`int`): The number of output channels. + ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`): + The channel multipliers in multiresolution blocks. + num_res_blocks (`int`): The number of Resnet blocks. + in_channels (`int`): The number of channels in the input. + z_channels (`int`): The number of input channels. + give_pre_end (`bool`, *optional*, default to `false`): + If `True` exit the forward pass early and return the penultimate feature map. + zq_ch (`bool`, *optional*, default to `None`): The number of channels in the guidance. + add_conv (`bool`, *optional*, default to `false`): If `True` add conv2d layer for Resnet normalization layer. + act_fn (`str`, *optional*, default to `"swish"`): The activation function to use. + """ + + def __init__( + self, + *, + ch: int, + out_ch: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + in_channels: int, + z_channels: int, + give_pre_end: bool = False, + zq_ch: Optional[int] = None, + add_conv: bool = False, + act_fn: str = "swish", + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.nonlinearity = get_activation(act_fn) + + if zq_ch is None: + zq_ch = z_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + + self.conv_in = nn.Conv2d( + in_channels=z_channels, out_channels=block_in, kernel_size=3, padding=(1, 1), padding_mode="replicate" + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = KVAEResnetBlock2D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + zq_ch=zq_ch, + add_conv=add_conv, + ) + + self.mid.block_2 = KVAEResnetBlock2D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + zq_ch=zq_ch, + add_conv=add_conv, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + KVAEResnetBlock2D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + zq_ch=zq_ch, + add_conv=add_conv, + ) + ) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = KVAEPXSUpsample(in_channels=block_in) + self.up.insert(0, up) + + self.norm_out = KVAEDecoderSpatialNorm2D(block_in, zq_ch, add_conv=add_conv) # , gather=gather_norm) + + self.conv_out = nn.Conv2d( + in_channels=block_in, out_channels=out_ch, kernel_size=3, padding=(1, 1), padding_mode="replicate" + ) + + self.gradient_checkpointing = False + + def forward(self, z: torch.Tensor) -> torch.Tensor: + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + zq = z + h = self.conv_in(z) + + # middle + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, zq) + h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, zq) + else: + h = self.mid.block_1(h, temb, zq) + h = self.mid.block_2(h, temb, zq) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func(self.up[i_level].block[i_block], h, temb, zq) + else: + h = self.up[i_level].block[i_block](h, temb, zq) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq) + h = self.nonlinearity(h) + h = self.conv_out(h) + + return h + + +class AutoencoderKLKVAE(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + channels (int, *optional*, defaults to 128): The base number of channels in multiresolution blocks. + num_enc_blocks (int, *optional*, defaults to 2): + The number of Resnet blocks in encoder multiresolution layers. + num_dec_blocks (int, *optional*, defaults to 2): + The number of Resnet blocks in decoder multiresolution layers. + z_channels (int, *optional*, defaults to 16): Number of channels in the latent space. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels of encoder. + ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`): + The channel multipliers in multiresolution blocks. + sample_size (`int`, *optional*, defaults to `1024`): Sample input size. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + channels: int = 128, + num_enc_blocks: int = 2, + num_dec_blocks: int = 2, + z_channels: int = 16, + double_z: bool = True, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + sample_size: int = 1024, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = KVAEEncoder2D( + in_channels=in_channels, + ch=channels, + ch_mult=ch_mult, + num_res_blocks=num_enc_blocks, + z_channels=z_channels, + double_z=double_z, + ) + + # pass init params to Decoder + self.decoder = KVAEDecoder2D( + out_ch=in_channels, + ch=channels, + ch_mult=ch_mult, + num_res_blocks=num_dec_blocks, + in_channels=None, + z_channels=z_channels, + ) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1))) + self.tile_overlap_factor = 0.25 + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py new file mode 100644 index 000000000000..7038f45fc30e --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py @@ -0,0 +1,954 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def nonlinearity(x: torch.Tensor) -> torch.Tensor: + return F.silu(x) + + +# ============================================================================= +# Base layers +# ============================================================================= + + +class KVAESafeConv3d(nn.Conv3d): + r""" + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM. + """ + + def forward(self, input: torch.Tensor, write_to: torch.Tensor = None) -> torch.Tensor: + memory_count = input.numel() * input.element_size() / (10**9) + + if memory_count > 3: + kernel_size = self.kernel_size[0] + part_num = math.ceil(memory_count / 2) + input_chunks = torch.chunk(input, part_num, dim=2) + + if write_to is None: + output = [] + for i, chunk in enumerate(input_chunks): + if i == 0 or kernel_size == 1: + z = torch.clone(chunk) + else: + z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2) + output.append(super().forward(z)) + return torch.cat(output, dim=2) + else: + time_offset = 0 + for i, chunk in enumerate(input_chunks): + if i == 0 or kernel_size == 1: + z = torch.clone(chunk) + else: + z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2) + z_time = z.size(2) - (kernel_size - 1) + write_to[:, :, time_offset : time_offset + z_time] = super().forward(z) + time_offset += z_time + return write_to + else: + if write_to is None: + return super().forward(input) + else: + write_to[...] = super().forward(input) + return write_to + + +class KVAECausalConv3d(nn.Module): + r""" + A 3D causal convolution layer. + """ + + def __init__( + self, + chan_in: int, + chan_out: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Tuple[int, int, int] = (1, 1, 1), + dilation: Tuple[int, int, int] = (1, 1, 1), + **kwargs, + ): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.height_pad = height_kernel_size // 2 + self.width_pad = width_kernel_size // 2 + self.time_pad = time_kernel_size - 1 + self.time_kernel_size = time_kernel_size + self.stride = stride + + self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + padding_3d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, self.time_pad, 0) + input_padded = F.pad(input, padding_3d, mode="replicate") + return self.conv(input_padded) + + +class KVAECachedCausalConv3d(nn.Module): + r""" + A 3D causal convolution layer with caching for temporal processing. + """ + + def __init__( + self, + chan_in: int, + chan_out: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Tuple[int, int, int] = (1, 1, 1), + dilation: Tuple[int, int, int] = (1, 1, 1), + **kwargs, + ): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.height_pad = height_kernel_size // 2 + self.width_pad = width_kernel_size // 2 + self.time_pad = time_kernel_size - 1 + self.time_kernel_size = time_kernel_size + self.stride = stride + + self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, input: torch.Tensor, cache: Dict) -> torch.Tensor: + t_stride = self.stride[0] + padding_3d = (self.height_pad, self.height_pad, self.width_pad, self.width_pad, 0, 0) + input_parallel = F.pad(input, padding_3d, mode="replicate") + + if cache["padding"] is None: + first_frame = input_parallel[:, :, :1] + time_pad_shape = list(first_frame.shape) + time_pad_shape[2] = self.time_pad + padding = first_frame.expand(time_pad_shape) + else: + padding = cache["padding"] + + out_size = list(input.shape) + out_size[1] = self.conv.out_channels + if t_stride == 2: + out_size[2] = (input.size(2) + 1) // 2 + output = torch.empty(tuple(out_size), dtype=input.dtype, device=input.device) + + offset_out = math.ceil(padding.size(2) / t_stride) + offset_in = offset_out * t_stride - padding.size(2) + + if offset_out > 0: + padding_poisoned = torch.cat( + [padding, input_parallel[:, :, : offset_in + self.time_kernel_size - t_stride]], dim=2 + ) + output[:, :, :offset_out] = self.conv(padding_poisoned) + + if offset_out < output.size(2): + output[:, :, offset_out:] = self.conv(input_parallel[:, :, offset_in:]) + + pad_offset = ( + offset_in + + t_stride * math.trunc((input_parallel.size(2) - offset_in - self.time_kernel_size) / t_stride) + + t_stride + ) + cache["padding"] = torch.clone(input_parallel[:, :, pad_offset:]) + + return output + + +class KVAECachedGroupNorm(nn.Module): + r""" + GroupNorm with caching support for temporal processing. + """ + + def __init__(self, in_channels: int): + super().__init__() + self.norm_layer = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + def forward(self, x: torch.Tensor, cache: Dict = None) -> torch.Tensor: + out = self.norm_layer(x) + if cache is not None and cache.get("mean") is None and cache.get("var") is None: + cache["mean"] = 1 + cache["var"] = 1 + return out + + +# ============================================================================= +# Cached layers +# ============================================================================= + + +class KVAECachedSpatialNorm3D(nn.Module): + r""" + Spatially conditioned normalization for decoder with caching. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + add_conv: bool = False, + ): + super().__init__() + self.norm_layer = KVAECachedGroupNorm(f_channels) + self.add_conv = add_conv + + if add_conv: + self.conv = KVAECachedCausalConv3d(chan_in=zq_channels, chan_out=zq_channels, kernel_size=3) + + self.conv_y = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1) + self.conv_b = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1) + + def forward(self, f: torch.Tensor, zq: torch.Tensor, cache: Dict) -> torch.Tensor: + if cache["norm"].get("mean") is None and cache["norm"].get("var") is None: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] + + zq_first = F.interpolate(zq_first, size=f_first_size, mode="nearest") + + if zq.size(2) > 1: + zq_rest_splits = torch.split(zq_rest, 32, dim=1) + interpolated_splits = [ + F.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits + ] + zq_rest = torch.cat(interpolated_splits, dim=1) + zq = torch.cat([zq_first, zq_rest], dim=2) + else: + zq = zq_first + else: + f_size = f.shape[-3:] + zq_splits = torch.split(zq, 32, dim=1) + interpolated_splits = [F.interpolate(split, size=f_size, mode="nearest") for split in zq_splits] + zq = torch.cat(interpolated_splits, dim=1) + + if self.add_conv: + zq = self.conv(zq, cache["add_conv"]) + + norm_f = self.norm_layer(f, cache["norm"]) + norm_f = norm_f * self.conv_y(zq) + norm_f = norm_f + self.conv_b(zq) + + return norm_f + + +class KVAECachedResnetBlock3D(nn.Module): + r""" + A 3D ResNet block with caching. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 0, + zq_ch: Optional[int] = None, + add_conv: bool = False, + gather_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if zq_ch is None: + self.norm1 = KVAECachedGroupNorm(in_channels) + else: + self.norm1 = KVAECachedSpatialNorm3D(in_channels, zq_ch, add_conv=add_conv) + + self.conv1 = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3) + + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + + if zq_ch is None: + self.norm2 = KVAECachedGroupNorm(out_channels) + else: + self.norm2 = KVAECachedSpatialNorm3D(out_channels, zq_ch, add_conv=add_conv) + + self.conv2 = KVAECachedCausalConv3d(chan_in=out_channels, chan_out=out_channels, kernel_size=3) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3) + else: + self.nin_shortcut = KVAESafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor, temb: torch.Tensor, layer_cache: Dict, zq: torch.Tensor = None) -> torch.Tensor: + h = x + + if zq is None: + # Encoder path - norm takes cache + h = self.norm1(h, cache=layer_cache["norm1"]) + else: + # Decoder path - spatial norm takes zq and cache + h = self.norm1(h, zq, cache=layer_cache["norm1"]) + + h = F.silu(h) + h = self.conv1(h, cache=layer_cache["conv1"]) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] + + if zq is None: + h = self.norm2(h, cache=layer_cache["norm2"]) + else: + h = self.norm2(h, zq, cache=layer_cache["norm2"]) + + h = F.silu(h) + h = self.conv2(h, cache=layer_cache["conv2"]) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x, cache=layer_cache["conv_shortcut"]) + else: + x = self.nin_shortcut(x) + + return x + h + + +class KVAECachedPXSDownsample(nn.Module): + r""" + A 3D downsampling layer using PixelUnshuffle with caching. + """ + + def __init__(self, in_channels: int, compress_time: bool, factor: int = 2): + super().__init__() + self.temporal_compress = compress_time + self.factor = factor + self.unshuffle = nn.PixelUnshuffle(self.factor) + self.s_pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2)) + + self.spatial_conv = KVAESafeConv3d( + in_channels, + in_channels, + kernel_size=(1, 3, 3), + stride=(1, 2, 2), + padding=(0, 1, 1), + padding_mode="reflect", + ) + + if self.temporal_compress: + self.temporal_conv = KVAECachedCausalConv3d( + in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), dilation=(1, 1, 1) + ) + + self.linear = nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def spatial_downsample(self, input: torch.Tensor) -> torch.Tensor: + b, c, t, h, w = input.shape + pxs_input = input.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + # pxs_input = rearrange(input, 'b c t h w -> (b t) c h w') + pxs_interm = self.unshuffle(pxs_input) + b_it, c_it, h_it, w_it = pxs_interm.shape + pxs_interm_view = pxs_interm.view(b_it, c_it // self.factor**2, self.factor**2, h_it, w_it) + pxs_out = torch.mean(pxs_interm_view, dim=2) + pxs_out = pxs_out.view(b, t, -1, h_it, w_it).permute(0, 2, 1, 3, 4) + # pxs_out = rearrange(pxs_out, '(b t) c h w -> b c t h w', t=input.size(2)) + conv_out = self.spatial_conv(input) + return conv_out + pxs_out + + def temporal_downsample(self, input: torch.Tensor, cache: list) -> torch.Tensor: + b, c, t, h, w = input.shape + + permuted = input.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + + if cache[0]["padding"] is None: + first, rest = permuted[..., :1], permuted[..., 1:] + if rest.size(-1) > 0: + rest_interp = F.avg_pool1d(rest, kernel_size=2, stride=2) + full_interp = torch.cat([first, rest_interp], dim=-1) + else: + full_interp = first + else: + rest = permuted + if rest.size(-1) > 0: + full_interp = F.avg_pool1d(rest, kernel_size=2, stride=2) + + t_new = full_interp.size(-1) + full_interp = full_interp.view(b, h, w, c, t_new).permute(0, 3, 4, 1, 2) + conv_out = self.temporal_conv(input, cache[0]) + return conv_out + full_interp + + def forward(self, x: torch.Tensor, cache: list) -> torch.Tensor: + out = self.spatial_downsample(x) + + if self.temporal_compress: + out = self.temporal_downsample(out, cache=cache) + + return self.linear(out) + + +class KVAECachedPXSUpsample(nn.Module): + r""" + A 3D upsampling layer using PixelShuffle with caching. + """ + + def __init__(self, in_channels: int, compress_time: bool, factor: int = 2): + super().__init__() + self.temporal_compress = compress_time + self.factor = factor + self.shuffle = nn.PixelShuffle(self.factor) + + self.spatial_conv = KVAESafeConv3d( + in_channels, + in_channels, + kernel_size=(1, 3, 3), + stride=(1, 1, 1), + padding=(0, 1, 1), + padding_mode="reflect", + ) + + if self.temporal_compress: + self.temporal_conv = KVAECachedCausalConv3d( + in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), dilation=(1, 1, 1) + ) + + self.linear = KVAESafeConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def spatial_upsample(self, input: torch.Tensor) -> torch.Tensor: + b, c, t, h, w = input.shape + input_view = input.permute(0, 2, 1, 3, 4).reshape(b, t * c, h, w) + input_interp = F.interpolate(input_view, scale_factor=2, mode="nearest") + input_interp = input_interp.view(b, t, c, 2 * h, 2 * w).permute(0, 2, 1, 3, 4) + + out = self.spatial_conv(input_interp) + return input_interp + out + + def temporal_upsample(self, input: torch.Tensor, cache: Dict) -> torch.Tensor: + time_factor = 1.0 + 1.0 * (input.size(2) > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + + repeated = input.repeat_interleave(int(time_factor), dim=2) + + if cache["padding"] is None: + tail = repeated[..., int(time_factor - 1) :, :, :] + else: + tail = repeated + + conv_out = self.temporal_conv(tail, cache) + return conv_out + tail + + def forward(self, x: torch.Tensor, cache: Dict) -> torch.Tensor: + if self.temporal_compress: + x = self.temporal_upsample(x, cache) + + s_out = self.spatial_upsample(x) + to = torch.empty_like(s_out) + lin_out = self.linear(s_out, write_to=to) + return lin_out + + +# ============================================================================= +# Cached Encoder/Decoder +# ============================================================================= + + +class KVAECachedEncoder3D(nn.Module): + r""" + Cached 3D Encoder for KVAE. + """ + + def __init__( + self, + ch: int = 128, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int = 2, + dropout: float = 0.0, + in_channels: int = 3, + z_channels: int = 16, + double_z: bool = True, + temporal_compress_times: int = 4, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + self.conv_in = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=self.ch, kernel_size=3) + + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + block_in = ch + + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for i_block in range(self.num_res_blocks): + block.append( + KVAECachedResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + temb_channels=self.temb_ch, + ) + ) + block_in = block_out + + down = nn.Module() + down.block = block + down.attn = attn + + if i_level != self.num_resolutions - 1: + if i_level < self.temporal_compress_level: + down.downsample = KVAECachedPXSDownsample(block_in, compress_time=True) + else: + down.downsample = KVAECachedPXSDownsample(block_in, compress_time=False) + self.down.append(down) + + self.mid = nn.Module() + self.mid.block_1 = KVAECachedResnetBlock3D( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.block_2 = KVAECachedResnetBlock3D( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + self.norm_out = KVAECachedGroupNorm(block_in) + self.conv_out = KVAECachedCausalConv3d( + chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3 + ) + + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor, cache_dict: Dict) -> torch.Tensor: + temb = None + + h = self.conv_in(x, cache=cache_dict["conv_in"]) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func( + self.down[i_level].block[i_block], h, temb, cache_dict[i_level][i_block] + ) + else: + h = self.down[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h, cache=cache_dict[i_level]["down"]) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"]) + h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"]) + else: + h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"]) + h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"]) + + h = self.norm_out(h, cache=cache_dict["norm_out"]) + h = nonlinearity(h) + h = self.conv_out(h, cache=cache_dict["conv_out"]) + + return h + + +class KVAECachedDecoder3D(nn.Module): + r""" + Cached 3D Decoder for KVAE. + """ + + def __init__( + self, + ch: int = 128, + out_ch: int = 3, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int = 2, + dropout: float = 0.0, + z_channels: int = 16, + zq_ch: Optional[int] = None, + add_conv: bool = False, + temporal_compress_times: int = 4, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if zq_ch is None: + zq_ch = z_channels + + block_in = ch * ch_mult[self.num_resolutions - 1] + + self.conv_in = KVAECachedCausalConv3d(chan_in=z_channels, chan_out=block_in, kernel_size=3) + + self.mid = nn.Module() + self.mid.block_1 = KVAECachedResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + ) + self.mid.block_2 = KVAECachedResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + ) + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + + for i_block in range(self.num_res_blocks + 1): + block.append( + KVAECachedResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + ) + ) + block_in = block_out + + up = nn.Module() + up.block = block + up.attn = attn + + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = KVAECachedPXSUpsample(block_in, compress_time=False) + else: + up.upsample = KVAECachedPXSUpsample(block_in, compress_time=True) + self.up.insert(0, up) + + self.norm_out = KVAECachedSpatialNorm3D(block_in, zq_ch, add_conv=add_conv) + self.conv_out = KVAECachedCausalConv3d(chan_in=block_in, chan_out=out_ch, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, z: torch.Tensor, cache_dict: Dict) -> torch.Tensor: + temb = None + zq = z + + h = self.conv_in(z, cache_dict["conv_in"]) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"], zq) + h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"], zq) + else: + h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"], zq=zq) + h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"], zq=zq) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func( + self.up[i_level].block[i_block], h, temb, cache_dict[i_level][i_block], zq + ) + else: + h = self.up[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block], zq=zq) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h, cache_dict[i_level]["up"]) + + h = self.norm_out(h, zq, cache_dict["norm_out"]) + h = nonlinearity(h) + h = self.conv_out(h, cache_dict["conv_out"]) + + return h + + +# ============================================================================= +# Main AutoencoderKL class +# ============================================================================= + + +class AutoencoderKLKVAEVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in + [KVAE](https://github.com/kandinskylab/kvae-1). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). + + Parameters: + ch (`int`, *optional*, defaults to 128): Base channel count. + ch_mult (`Tuple[int]`, *optional*, defaults to `(1, 2, 4, 8)`): Channel multipliers per level. + num_res_blocks (`int`, *optional*, defaults to 2): Number of residual blocks per level. + in_channels (`int`, *optional*, defaults to 3): Number of input channels. + out_ch (`int`, *optional*, defaults to 3): Number of output channels. + z_channels (`int`, *optional*, defaults to 16): Number of latent channels. + temporal_compress_times (`int`, *optional*, defaults to 4): Temporal compression factor. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["KVAECachedResnetBlock3D"] + + @register_to_config + def __init__( + self, + ch: int = 128, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int = 2, + in_channels: int = 3, + out_ch: int = 3, + z_channels: int = 16, + temporal_compress_times: int = 4, + ): + super().__init__() + + self.encoder = KVAECachedEncoder3D( + ch=ch, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + in_channels=in_channels, + z_channels=z_channels, + double_z=True, + temporal_compress_times=temporal_compress_times, + ) + + self.decoder = KVAECachedDecoder3D( + ch=ch, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + out_ch=out_ch, + z_channels=z_channels, + temporal_compress_times=temporal_compress_times, + ) + + self.use_slicing = False + self.use_tiling = False + + def _make_encoder_cache(self) -> Dict: + """Create empty cache for cached encoder.""" + + def make_dict(name, p=None): + if name == "conv": + return {"padding": None} + + layer, module = name.split("_") + if layer == "norm": + if module == "enc": + return {"mean": None, "var": None} + else: + return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")} + elif layer == "resblock": + return { + "norm1": make_dict(f"norm_{module}"), + "norm2": make_dict(f"norm_{module}"), + "conv1": make_dict("conv"), + "conv2": make_dict("conv"), + "conv_shortcut": make_dict("conv"), + } + elif layer.isdigit(): + out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")} + for i in range(p): + out_dict[i] = make_dict(f"resblock_{module}") + return out_dict + + cache = { + "conv_in": make_dict("conv"), + "mid_1": make_dict("resblock_enc"), + "mid_2": make_dict("resblock_enc"), + "norm_out": make_dict("norm_enc"), + "conv_out": make_dict("conv"), + } + # Encoder uses num_res_blocks per level + for i in range(len(self.config.ch_mult)): + cache[i] = make_dict(f"{i}_enc", p=self.config.num_res_blocks) + return cache + + def _make_decoder_cache(self) -> Dict: + """Create empty cache for decoder.""" + + def make_dict(name, p=None): + if name == "conv": + return {"padding": None} + + layer, module = name.split("_") + if layer == "norm": + if module == "enc": + return {"mean": None, "var": None} + else: + return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")} + elif layer == "resblock": + return { + "norm1": make_dict(f"norm_{module}"), + "norm2": make_dict(f"norm_{module}"), + "conv1": make_dict("conv"), + "conv2": make_dict("conv"), + "conv_shortcut": make_dict("conv"), + } + elif layer.isdigit(): + out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")} + for i in range(p): + out_dict[i] = make_dict(f"resblock_{module}") + return out_dict + + cache = { + "conv_in": make_dict("conv"), + "mid_1": make_dict("resblock_dec"), + "mid_2": make_dict("resblock_dec"), + "norm_out": make_dict("norm_dec"), + "conv_out": make_dict("conv"), + } + for i in range(len(self.config.ch_mult)): + cache[i] = make_dict(f"{i}_dec", p=self.config.num_res_blocks + 1) + return cache + + def enable_slicing(self) -> None: + r"""Enable sliced VAE decoding.""" + self.use_slicing = True + + def disable_slicing(self) -> None: + r"""Disable sliced VAE decoding.""" + self.use_slicing = False + + def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor: + # Cached encoder processes by segments + cache = self._make_encoder_cache() + + split_list = [seg_len + 1] + n_frames = x.size(2) - (seg_len + 1) + while n_frames > 0: + split_list.append(seg_len) + n_frames -= seg_len + split_list[-1] += n_frames + + latent = [] + for chunk in torch.split(x, split_list, dim=2): + l = self.encoder(chunk, cache) + sample, _ = torch.chunk(l, 2, dim=1) + latent.append(sample) + + return torch.cat(latent, dim=2) + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of videos into latents. + + Args: + x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W). + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + # For cached encoder, we already did the split in _encode + h_double = torch.cat([h, torch.zeros_like(h)], dim=1) + posterior = DiagonalGaussianDistribution(h_double) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, seg_len: int = 16) -> torch.Tensor: + cache = self._make_decoder_cache() + temporal_compress = self.config.temporal_compress_times + + split_list = [seg_len + 1] + n_frames = temporal_compress * (z.size(2) - 1) - seg_len + while n_frames > 0: + split_list.append(seg_len) + n_frames -= seg_len + split_list[-1] += n_frames + split_list = [math.ceil(size / temporal_compress) for size in split_list] + + recs = [] + for chunk in torch.split(z, split_list, dim=2): + out = self.decoder(chunk, cache) + recs.append(out) + + return torch.cat(recs, dim=2) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of videos. + + Args: + z (`torch.Tensor`): Input batch of latent vectors with shape (B, C, T, H, W). + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: Decoded video. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3425cc8d2b61..c41410d153c9 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -521,6 +521,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLKVAE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderKLKVAEVideo(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTX2Audio(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_kvae.py b/tests/models/autoencoders/test_models_autoencoder_kl_kvae.py new file mode 100644 index 000000000000..adae981f9c76 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_kl_kvae.py @@ -0,0 +1,73 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import AutoencoderKLKVAE + +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLKVAE + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_kvae_config(self): + return { + "in_channels": 3, + "channels": 32, + "num_enc_blocks": 1, + "num_dec_blocks": 1, + "z_channels": 4, + "double_z": True, + "ch_mult": (1, 2), + "sample_size": 32, + } + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_kvae_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "KVAEEncoder2D", + "KVAEDecoder2D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py b/tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py new file mode 100644 index 000000000000..7e9eebb87cf4 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import AutoencoderKLKVAEVideo + +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLKVAEVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_kvae_video_config(self): + return { + "ch": 32, + "ch_mult": (1, 2), + "num_res_blocks": 1, + "in_channels": 3, + "out_ch": 3, + "z_channels": 4, + "temporal_compress_times": 2, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2 + num_channels = 3 + sizes = (16, 16) + + video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + return {"sample": video} + + @property + def input_shape(self): + return (3, 3, 16, 16) + + @property + def output_shape(self): + return (3, 3, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_kvae_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "KVAECachedEncoder3D", + "KVAECachedDecoder3D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip( + "Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass." + ) + def test_model_parallelism(self): + pass + + @unittest.skip( + "Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass." + ) + def test_sharded_checkpoints_device_map(self): + pass + + def _run_nondeterministic(self, fn): + # reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation; + # temporarily relax the requirement for training tests that do backward passes. + import torch + + torch.use_deterministic_algorithms(False) + try: + fn() + finally: + torch.use_deterministic_algorithms(True) + + def test_training(self): + self._run_nondeterministic(super().test_training) + + def test_ema_training(self): + self._run_nondeterministic(super().test_ema_training) + + @unittest.skip( + "Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict " + "that is mutated during the first forward. On recomputation the cache is already populated, " + "causing a different execution path and numerically different gradients. " + "GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation." + ) + def test_effective_gradient_checkpointing(self): + pass + + def test_layerwise_casting_training(self): + self._run_nondeterministic(super().test_layerwise_casting_training) From 03d23c038bf76448c9522e3073f334c591bbd253 Mon Sep 17 00:00:00 2001 From: Cheung Ka Wai Date: Tue, 24 Mar 2026 11:26:40 +0800 Subject: [PATCH 067/215] change QwenImageTransformer UT to batch inputs (#13312) * UT expands to batch inputs * update according to suggestion * update according to suggestion 2 * fix CI * update according to suggestion 3 * clean line --- tests/models/testing_utils/parallelism.py | 20 +++++++++++++++++-- .../test_models_transformer_flux.py | 3 +-- .../test_models_transformer_flux2.py | 3 +-- .../test_models_transformer_qwenimage.py | 18 +++++++++-------- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index db9817c86995..2b6aab59a662 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -22,6 +22,7 @@ import torch.multiprocessing as mp from diffusers.models._modeling_parallel import ContextParallelConfig +from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry from ...testing_utils import ( is_context_parallel, @@ -160,16 +161,21 @@ def _custom_mesh_worker( @require_torch_multi_accelerator class ContextParallelTesterMixin: @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) - def test_context_parallel_inference(self, cp_type): + def test_context_parallel_inference(self, cp_type, batch_size: int = 1): if not torch.distributed.is_available(): pytest.skip("torch.distributed is not available.") if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + if cp_type == "ring_degree": + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + if active_backend == AttentionBackendName.NATIVE: + pytest.skip("Ring attention is not supported with the native attention backend.") + world_size = 2 init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() + inputs_dict = self.get_dummy_inputs(batch_size=batch_size) # Move all tensors to CPU for multiprocessing inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} @@ -194,6 +200,11 @@ def test_context_parallel_inference(self, cp_type): f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + @pytest.mark.xfail(reason="Context parallel may not support batch_size > 1") + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_batch_inputs(self, cp_type): + self.test_context_parallel_inference(cp_type, batch_size=2) + @pytest.mark.parametrize( "cp_type,mesh_shape,mesh_dim_names", [ @@ -209,6 +220,11 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names) if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + if cp_type == "ring_degree": + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + if active_backend == AttentionBackendName.NATIVE: + pytest.skip("Ring attention is not supported with the native attention backend.") + world_size = 2 init_dict = self.get_init_dict() inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()} diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 24be833d0ed2..a15b7be50b97 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -150,8 +150,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": [4, 4, 8], } - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: height = width = 4 num_latent_channels = 4 num_image_channels = 3 diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py index a109f603411d..77b5f1b86e59 100644 --- a/tests/models/transformers/test_models_transformer_flux2.py +++ b/tests/models/transformers/test_models_transformer_flux2.py @@ -90,8 +90,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": [4, 4, 4, 4], } - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - batch_size = 1 + def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]: num_latent_channels = 4 sequence_length = 48 embedding_dim = 32 diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 713a1bec70a5..5b45577f2dff 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -14,6 +14,7 @@ import warnings +import pytest import torch from diffusers import QwenImageTransformer2DModel @@ -77,8 +78,7 @@ def get_init_dict(self) -> dict[str, int | list[int]]: "axes_dims_rope": (8, 4, 4), } - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: num_latent_channels = embedding_dim = 16 height = width = 4 sequence_length = 8 @@ -106,9 +106,10 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin): - def test_infers_text_seq_len_from_mask(self): + @pytest.mark.parametrize("batch_size", [1, 2]) + def test_infers_text_seq_len_from_mask(self, batch_size): init_dict = self.get_init_dict() - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(batch_size=batch_size) model = self.model_class(**init_dict).to(torch_device) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() @@ -122,7 +123,7 @@ def test_infers_text_seq_len_from_mask(self): assert isinstance(per_sample_len, torch.Tensor) assert int(per_sample_len.max().item()) == 2 assert normalized_mask.dtype == torch.bool - assert normalized_mask.sum().item() == 2 + assert normalized_mask.sum().item() == 2 * batch_size assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1] inputs["encoder_hidden_states_mask"] = normalized_mask @@ -139,7 +140,7 @@ def test_infers_text_seq_len_from_mask(self): ) assert int(per_sample_len2.max().item()) == 8 - assert normalized_mask2.sum().item() == 5 + assert normalized_mask2.sum().item() == 5 * batch_size rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], None @@ -149,9 +150,10 @@ def test_infers_text_seq_len_from_mask(self): assert per_sample_len_none is None assert normalized_mask_none is None - def test_non_contiguous_attention_mask(self): + @pytest.mark.parametrize("batch_size", [1, 2]) + def test_non_contiguous_attention_mask(self, batch_size): init_dict = self.get_init_dict() - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(batch_size=batch_size) model = self.model_class(**init_dict).to(torch_device) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() From 99a9835d0b294e29a17936c1d1d660ba064910c9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 24 Mar 2026 10:54:35 +0530 Subject: [PATCH 068/215] [chore] properly deprecate src.diffusers.utils.testing_utils. (#13314) properly deprecate src.diffusers.utils.testing_utils. --- src/diffusers/utils/testing_utils.py | 9 ++++++--- tests/others/test_utils.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index cb7bf942c648..619a37034949 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -29,6 +29,7 @@ from packaging import version from .constants import DIFFUSERS_REQUEST_TIMEOUT +from .deprecation_utils import deprecate from .import_utils import ( BACKENDS_MAPPING, is_accelerate_available, @@ -67,9 +68,11 @@ global_rng = random.Random() logger = get_logger(__name__) -logger.warning( - "diffusers.utils.testing_utils' is deprecated and will be removed in a future version. " - "Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. " +deprecate( + "diffusers.utils.testing_utils", + "1.0.0", + "diffusers.utils.testing_utils is deprecated and will be removed in a future version. " + "Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ", ) _required_peft_version = is_peft_available() and version.parse( version.parse(importlib.metadata.version("peft")).base_version diff --git a/tests/others/test_utils.py b/tests/others/test_utils.py index 747b8d584058..bb0656386394 100755 --- a/tests/others/test_utils.py +++ b/tests/others/test_utils.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os import unittest +import warnings import pytest @@ -182,6 +184,25 @@ def test_deprecate_stacklevel(self): assert str(warning.warning) == "This message is better!!!" assert "diffusers/tests/others/test_utils.py" in warning.filename + def test_deprecate_testing_utils_module(self): + import diffusers.utils.testing_utils + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + importlib.reload(diffusers.utils.testing_utils) + + deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)] + assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils" + + messages = [str(w.message) for w in deprecation_warnings] + assert any("diffusers.utils.testing_utils" in msg for msg in messages), ( + f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}" + ) + assert any( + "diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg + for msg in messages + ), f"Expected deprecation message substring not found, got: {messages}" + # Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py class ExpectationsTester(unittest.TestCase): From c481755aa38e2c8c078f185a895073ce34203ffc Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Tue, 24 Mar 2026 17:00:05 +0800 Subject: [PATCH 069/215] Stabilize low-precision custom autoencoder RMS normalization (#13316) * Stabilize low-precision custom autoencoder RMS normalization * Add fp8/4 * Apply style fixes --------- Co-authored-by: github-actions[bot] Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../autoencoders/autoencoder_kl_hunyuanimage_refiner.py | 9 ++++++++- .../models/autoencoders/autoencoder_kl_hunyuanvideo15.py | 9 ++++++++- .../models/autoencoders/autoencoder_kl_qwenimage.py | 9 ++++++++- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 9 ++++++++- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py index 81957e2feed4..9f53371aadf5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py @@ -87,7 +87,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any( + t in str(x.dtype) for t in ("float4_", "float8_") + ) + normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to( + x.dtype + ) + + return normalized * self.scale * self.gamma + self.bias class HunyuanImageRefinerAttnBlock(nn.Module): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py index 2c38b174a100..e43483b92240 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -87,7 +87,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any( + t in str(x.dtype) for t in ("float4_", "float8_") + ) + normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to( + x.dtype + ) + + return normalized * self.scale * self.gamma + self.bias class HunyuanVideo15AttnBlock(nn.Module): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index f2ca0f42a272..f52071bf470b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -105,7 +105,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any( + t in str(x.dtype) for t in ("float4_", "float8_") + ) + normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to( + x.dtype + ) + + return normalized * self.scale * self.gamma + self.bias class QwenImageUpsample(nn.Upsample): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index ea5d2efe642f..7ba0de0f4a18 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -196,7 +196,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any( + t in str(x.dtype) for t in ("float4_", "float8_") + ) + normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to( + x.dtype + ) + + return normalized * self.scale * self.gamma + self.bias class WanUpsample(nn.Upsample): From 1c6644d5f4359ada846998c3354d4c92ec45a07b Mon Sep 17 00:00:00 2001 From: Cheung Ka Wai Date: Tue, 24 Mar 2026 17:12:50 +0800 Subject: [PATCH 070/215] Fix the attention mask in ulysses SP for QwenImage (#13278) * fix mask in SP * change the modification to qwen specific * drop xfail since qwen-image mask is fixed --------- Co-authored-by: Sayak Paul --- src/diffusers/models/transformers/transformer_qwenimage.py | 1 + tests/models/testing_utils/parallelism.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c5419b9f107e..d88aef4dcf2a 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -933,6 +933,7 @@ def forward( batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + joint_attention_mask = joint_attention_mask[:, None, None, :] block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks): diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 2b6aab59a662..bea832904041 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -200,7 +200,6 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) - @pytest.mark.xfail(reason="Context parallel may not support batch_size > 1") @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) def test_context_parallel_batch_inputs(self, cp_type): self.test_context_parallel_inference(cp_type, batch_size=2) From b0bb01938f25d4e276ca175ec6c49169cd78a48b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 24 Mar 2026 15:48:03 +0530 Subject: [PATCH 071/215] [tests] fix lora logging tests for models. (#13318) * fix lora logging tests for models. * make style --- tests/models/testing_utils/lora.py | 52 +++++++++++++------ .../test_models_transformer_qwenimage.py | 8 +++ 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index dfdc4835ee88..dfa326b014d0 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -481,6 +481,8 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): # ensure that enable_lora_hotswap is called before loading the first adapter import logging + from diffusers.utils import logging as diffusers_logging + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) @@ -488,21 +490,31 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): msg = ( "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." ) - with caplog.at_level(logging.WARNING): - model.enable_lora_hotswap(target_rank=32, check_compiled="warn") - assert any(msg in record.message for record in caplog.records) + diffusers_logging.enable_propagation() + try: + with caplog.at_level(logging.WARNING): + model.enable_lora_hotswap(target_rank=32, check_compiled="warn") + assert any(msg in record.message for record in caplog.records) + finally: + diffusers_logging.disable_propagation() def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog): # check possibility to ignore the error/warning import logging + from diffusers.utils import logging as diffusers_logging + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) model.add_adapter(lora_config) - with caplog.at_level(logging.WARNING): - model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") - assert len(caplog.records) == 0 + diffusers_logging.enable_propagation() + try: + with caplog.at_level(logging.WARNING): + model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") + assert len(caplog.records) == 0 + finally: + diffusers_logging.disable_propagation() def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): # check that wrong argument value raises an error @@ -518,20 +530,26 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplo # check the error and log import logging + from diffusers.utils import logging as diffusers_logging + # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers target_modules0 = ["to_q"] target_modules1 = ["to_q", "to_k"] - with pytest.raises(RuntimeError): # peft raises RuntimeError - with caplog.at_level(logging.ERROR): - self._check_model_hotswap( - tmp_path, - do_compile=True, - rank0=8, - rank1=8, - target_modules0=target_modules0, - target_modules1=target_modules1, - ) - assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records) + diffusers_logging.enable_propagation() + try: + with pytest.raises(RuntimeError): # peft raises RuntimeError + with caplog.at_level(logging.ERROR): + self._check_model_hotswap( + tmp_path, + do_compile=True, + rank0=8, + rank1=8, + target_modules0=target_modules0, + target_modules1=target_modules1, + ) + assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records) + finally: + diffusers_logging.disable_propagation() @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) @require_torch_version_greater("2.7.1") diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 5b45577f2dff..7933aa98f3f2 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -286,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): """LoRA hot-swapping tests for QwenImage Transformer.""" + @pytest.mark.xfail(True, reason="Recompilation issues.", strict=True) + def test_hotswapping_compiled_model_linear(self): + super().test_hotswapping_compiled_model_linear() + + @pytest.mark.xfail(True, reason="Recompilation issues.", strict=True) + def test_hotswapping_compiled_model_both_linear_and_other(self): + super().test_hotswapping_compiled_model_both_linear_and_other() + @property def different_shapes_for_compilation(self): return [(4, 4), (4, 8), (8, 8)] From e13a045b6a0efe07a4dc9469477dc2eca9b2486d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 24 Mar 2026 16:00:24 +0530 Subject: [PATCH 072/215] Fix unguarded `torchvision` import in Cosmos (#13321) update --- .../cosmos/pipeline_cosmos2_5_predict.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py index cdea71a5ab93..581711205814 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -16,22 +16,29 @@ import numpy as np import torch -import torchvision -import torchvision.transforms -import torchvision.transforms.functional from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...models import AutoencoderKLWan, CosmosTransformer3DModel from ...schedulers import UniPCMultistepScheduler -from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils import ( + is_cosmos_guardrail_available, + is_torch_xla_available, + is_torchvision_available, + logging, + replace_example_docstring, +) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import CosmosPipelineOutput +if is_torchvision_available(): + import torchvision.transforms.functional + + if is_cosmos_guardrail_available(): from cosmos_guardrail import CosmosSafetyChecker else: From 8ad13a5f7c6dcb73727aabdca991fb41a90cc6d9 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 24 Mar 2026 16:42:32 +0530 Subject: [PATCH 073/215] [CI] Update fetching pipelines for latest HF Hub Version (#13322) update --- utils/fetch_torch_cuda_pipeline_test_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/fetch_torch_cuda_pipeline_test_matrix.py b/utils/fetch_torch_cuda_pipeline_test_matrix.py index 196f35628ac1..86f3c4bf5f37 100644 --- a/utils/fetch_torch_cuda_pipeline_test_matrix.py +++ b/utils/fetch_torch_cuda_pipeline_test_matrix.py @@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000): def fetch_pipeline_objects(): - models = api.list_models(library="diffusers") + models = api.list_models(filter="diffusers") downloads = defaultdict(int) for model in models: From ed2014713127701f29566ca9f5c83a3094e39b7b Mon Sep 17 00:00:00 2001 From: Alexey Kirillov <43682987+Alexkkir@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:19:50 +0300 Subject: [PATCH 074/215] Use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING (#13320) refactor: use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING Co-authored-by: Alexkkir Co-authored-by: Sayak Paul --- src/diffusers/loaders/peft.py | 35 ++++++++--------------------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a96542c2a50c..daa078bc25d5 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -15,6 +15,7 @@ import inspect import json import os +from collections import defaultdict from functools import partial from pathlib import Path from typing import Literal @@ -44,33 +45,13 @@ logger = logging.get_logger(__name__) -_SET_ADAPTER_SCALE_FN_MAPPING = { - "UNet2DConditionModel": _maybe_expand_lora_scales, - "UNetMotionModel": _maybe_expand_lora_scales, - "SD3Transformer2DModel": lambda model_cls, weights: weights, - "FluxTransformer2DModel": lambda model_cls, weights: weights, - "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, - "ConsisIDTransformer3DModel": lambda model_cls, weights: weights, - "HeliosTransformer3DModel": lambda model_cls, weights: weights, - "MochiTransformer3DModel": lambda model_cls, weights: weights, - "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, - "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, - "SanaTransformer2DModel": lambda model_cls, weights: weights, - "AuraFlowTransformer2DModel": lambda model_cls, weights: weights, - "Lumina2Transformer2DModel": lambda model_cls, weights: weights, - "WanTransformer3DModel": lambda model_cls, weights: weights, - "CogView4Transformer2DModel": lambda model_cls, weights: weights, - "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, - "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, - "WanVACETransformer3DModel": lambda model_cls, weights: weights, - "ChromaTransformer2DModel": lambda model_cls, weights: weights, - "ChronoEditTransformer3DModel": lambda model_cls, weights: weights, - "QwenImageTransformer2DModel": lambda model_cls, weights: weights, - "Flux2Transformer2DModel": lambda model_cls, weights: weights, - "ZImageTransformer2DModel": lambda model_cls, weights: weights, - "LTX2VideoTransformer3DModel": lambda model_cls, weights: weights, - "LTX2TextConnectors": lambda model_cls, weights: weights, -} +_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( + lambda: (lambda model_cls, weights: weights), + { + "UNet2DConditionModel": _maybe_expand_lora_scales, + "UNetMotionModel": _maybe_expand_lora_scales, + }, +) class PeftAdapterMixin: From 56c65c56e4dfc167898d7528da436efb89b89f76 Mon Sep 17 00:00:00 2001 From: Beinsezii <39478211+Beinsezii@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:06:50 -0700 Subject: [PATCH 075/215] ZImageTransformer2D: Only build attention mask if seqlens are not equal (#12955) --- .../models/transformers/transformer_z_image.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 3bbf78bc5e01..8aa30ee082ff 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -788,9 +788,12 @@ def _prepare_sequence( freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] # Attention mask - attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(item_seqlens): - attn_mask[i, :seq_len] = 1 + if all(seq == max_seqlen for seq in item_seqlens): + attn_mask = None + else: + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 # Noise mask noise_mask_tensor = None @@ -871,9 +874,12 @@ def _build_unified_sequence( unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) # Attention mask - attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_seqlens): - attn_mask[i, :seq_len] = 1 + if all(seq == max_seqlen for seq in unified_seqlens): + attn_mask = None + else: + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 # Noise mask noise_mask_tensor = None From 888a6568f1a1acf6feae6e29bb10d4bb401d1914 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 25 Mar 2026 07:51:35 +0530 Subject: [PATCH 076/215] fix klein lora loading. (#13313) --- .../loaders/lora_conversion_utils.py | 185 ++++++++++++++++++ src/diffusers/loaders/lora_pipeline.py | 8 + 2 files changed, 193 insertions(+) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 298aa61d37ed..41948d205c89 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): return converted_state_dict +def _convert_kohya_flux2_lora_to_diffusers(state_dict): + def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False) + alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() + scale = alpha / rank + + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + default_alpha = torch.tensor( + sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False + ) + alpha = sds_sd.pop(sds_key + ".alpha", default_alpha) + scale = alpha / sd_lora_rank + + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check if upweight is sparse + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 + ) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + ait_sd.update(dict.fromkeys(ait_down_keys, down_weight)) + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + else: + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + # Detect number of blocks from keys + num_double_layers = 0 + num_single_layers = 0 + for key in state_dict.keys(): + if key.startswith("lora_unet_double_blocks_"): + block_idx = int(key.split("_")[4]) + num_double_layers = max(num_double_layers, block_idx + 1) + elif key.startswith("lora_unet_single_blocks_"): + block_idx = int(key.split("_")[4]) + num_single_layers = max(num_single_layers, block_idx + 1) + + ait_sd = {} + + for i in range(num_double_layers): + # Attention projections + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0", + ) + _convert_to_ai_toolkit_cat( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit_cat( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + # MLP layers (Flux2 uses ff.linear_in/linear_out) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.linear_in", + ) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.linear_out", + ) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.linear_in", + ) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.linear_out", + ) + + for i in range(num_single_layers): + # Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed) + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj", + ) + # Single blocks: linear2 -> attn.to_out + _convert_to_ai_toolkit( + state_dict, + ait_sd, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.attn.to_out", + ) + + # Handle optional extra keys + extra_mappings = { + "lora_unet_img_in": "transformer.x_embedder", + "lora_unet_txt_in": "transformer.context_embedder", + "lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1", + "lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2", + "lora_unet_final_layer_linear": "transformer.proj_out", + } + for sds_key, ait_key in extra_mappings.items(): + _convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key) + + remaining_keys = list(state_dict.keys()) + if remaining_keys: + logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}") + + return ait_sd + + def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict): """ Convert non-diffusers ZImage LoRA state dict to diffusers format. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5d10f596f2e6..6ec23389ac08 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -43,6 +43,7 @@ _convert_bfl_flux_control_lora_to_diffusers, _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, + _convert_kohya_flux2_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, @@ -5673,6 +5674,13 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + is_peft_format = any(k.startswith("base_model.model.") for k in state_dict) if is_peft_format: state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()} From 7da43a6f4baf15c200f2f9355a2683c3bfb15756 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 25 Mar 2026 11:47:02 +0530 Subject: [PATCH 077/215] fix to device and to dtype tests. (#13323) --- tests/pipelines/test_pipelines_common.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index af3573ce84cb..4d9d1717ba86 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1534,14 +1534,18 @@ def test_to_device(self): pipe.set_progress_bar_config(disable=None) pipe.to("cpu") - model_devices = [component.device.type for component in components.values() if hasattr(component, "device")] + model_devices = [ + component.device.type for component in components.values() if getattr(component, "device", None) + ] self.assertTrue(all(device == "cpu" for device in model_devices)) output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] self.assertTrue(np.isnan(output_cpu).sum() == 0) pipe.to(torch_device) - model_devices = [component.device.type for component in components.values() if hasattr(component, "device")] + model_devices = [ + component.device.type for component in components.values() if getattr(component, "device", None) + ] self.assertTrue(all(device == torch_device for device in model_devices)) output_device = pipe(**self.get_dummy_inputs(torch_device))[0] @@ -1552,11 +1556,11 @@ def test_to_dtype(self): pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) - model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) pipe.to(dtype=torch.float16) - model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): From ff22088bbdd8e6bc95da76157c07b4a8f73e274b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 25 Mar 2026 11:47:50 +0100 Subject: [PATCH 078/215] [Discrete Diffusion] Add LLaDA2 pipeline (#13226) * feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation: - BlockRefinementPipeline: block-wise iterative refinement with confidence-based token commitment, supporting editing threshold for LLaDA2.1 models - LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults - DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p, temperature) and prompt/prefix helpers - compute_confidence_aware_loss: CAP-style training loss - Examples: sampling scripts for LLaDA2 and block refinement, training scripts with Qwen causal LM - Docs and tests included * feat: add BlockRefinementScheduler for commit-by-confidence scheduling Extract the confidence-based token commit logic from BlockRefinementPipeline into a dedicated BlockRefinementScheduler, following diffusers conventions. The scheduler owns: - Transfer schedule computation (get_num_transfer_tokens) - Timestep management (set_timesteps) - Step logic: confidence-based mask-filling and optional token editing The pipeline now delegates scheduling to self.scheduler.step() and accepts a scheduler parameter in __init__. * test: add unit tests for BlockRefinementScheduler 12 tests covering set_timesteps, get_num_transfer_tokens, step logic (confidence-based commits, threshold behavior, editing, prompt masking, batched inputs, tuple output). * docs: add toctree entries and standalone scheduler doc page - Add BlockRefinement and LLaDA2 to docs sidebar navigation - Add BlockRefinementScheduler to schedulers sidebar navigation - Move scheduler autodoc to its own page under api/schedulers/ * feat: add --revision flag and fix dtype deprecation in sample_llada2.py - Add --revision argument for loading model revisions from the Hub - Replace deprecated torch_dtype with dtype for transformers 5.x compat * fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat LLaDA2 models expect a boolean-style (1/0) attention mask, not an additive (0/-inf) mask. The model internally converts to additive, so passing 0/-inf caused double-masking and gibberish output. * refactor: consolidate training scripts into single train_block_refinement.py - Remove toy train_block_refinement_cap.py (self-contained demo with tiny model) - Rename train_block_refinement_qwen_cap.py to train_block_refinement.py (already works with any causal LM via AutoModelForCausalLM) - Fix torch_dtype deprecation and update README with correct script names * fix formatting * docs: improve LLaDA2 and BlockRefinement documentation - Add usage examples with real model IDs and working code - Add recommended parameters table for LLaDA2.1 quality/speed modes - Note that editing is LLaDA2.1-only (not for LLaDA2.0 models) - Remove misleading config defaults section from BlockRefinement docs * feat: set LLaDA2Pipeline defaults to recommended model parameters - threshold: 0.95 -> 0.7 (quality mode) - max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0) - eos_early_stop: False -> True (stop at EOS token) block_length=32, steps=32, temperature=0.0 were already correct. editing_threshold remains None (users enable for LLaDA2.1 models). * feat: default editing_threshold=0.5 for LLaDA2.1 quality mode LLaDA2.1 is the current generation. Users with LLaDA2.0 models can disable editing by passing editing_threshold=None. * fix: align sampling utilities with official LLaDA2 implementation - top_p filtering: add shift-right to preserve at least one token above threshold (matches official code line 1210) - temperature ordering: apply scaling before top-k/top-p filtering so filtering operates on scaled logits (matches official code lines 1232-1235) - greedy branch: return argmax directly when temperature=0 without filtering (matches official code lines 1226-1230) * refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids LLaDA2Pipeline._prepare_prompt_ids was a near-copy of DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate and call the mixin method directly. Also simplify _extract_input_ids since we always pass return_dict=True. * formatting * fix: replace deprecated torch_dtype with dtype in examples and docstrings - Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID - Fix sample_block_refinement.py to use dtype= * remove BlockRefinementPipeline * cleanup * fix readme * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu * removed DiscreteDiffusionPipelineMixin * add support for 2d masks for flash attn * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix issues from review * added tests * formatting * add check_eos_finished to scheduler * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix renaming issues and types * remove duplicate check * Update docs/source/en/api/pipelines/llada2.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --------- Co-authored-by: YiYi Xu Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- docs/source/en/_toctree.yml | 4 + docs/source/en/api/pipelines/llada2.md | 83 +++ docs/source/en/api/pipelines/overview.md | 1 + .../en/api/schedulers/block_refinement.md | 25 + examples/discrete_diffusion/README.md | 50 ++ examples/discrete_diffusion/sample_llada2.py | 263 ++++++++++ examples/discrete_diffusion/train_llada2.py | 321 ++++++++++++ src/diffusers/__init__.py | 8 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/llada2/__init__.py | 47 ++ .../pipelines/llada2/pipeline_llada2.py | 491 ++++++++++++++++++ src/diffusers/schedulers/__init__.py | 2 + .../schedulers/scheduling_block_refinement.py | 459 ++++++++++++++++ src/diffusers/training_utils.py | 87 ++++ src/diffusers/utils/dummy_pt_objects.py | 30 ++ .../dummy_torch_and_transformers_objects.py | 30 ++ tests/others/test_training.py | 46 +- tests/pipelines/llada2/__init__.py | 0 tests/pipelines/llada2/test_llada2.py | 245 +++++++++ .../test_scheduler_block_refinement.py | 470 +++++++++++++++++ 20 files changed, 2663 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/pipelines/llada2.md create mode 100644 docs/source/en/api/schedulers/block_refinement.md create mode 100644 examples/discrete_diffusion/README.md create mode 100644 examples/discrete_diffusion/sample_llada2.py create mode 100644 examples/discrete_diffusion/train_llada2.py create mode 100644 src/diffusers/pipelines/llada2/__init__.py create mode 100644 src/diffusers/pipelines/llada2/pipeline_llada2.py create mode 100644 src/diffusers/schedulers/scheduling_block_refinement.py create mode 100644 tests/pipelines/llada2/__init__.py create mode 100644 tests/pipelines/llada2/test_llada2.py create mode 100644 tests/schedulers/test_scheduler_block_refinement.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c2c62151132f..394d539350d6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -580,6 +580,8 @@ title: Latent Diffusion - local: api/pipelines/ledits_pp title: LEDITS++ + - local: api/pipelines/llada2 + title: LLaDA2 - local: api/pipelines/longcat_image title: LongCat-Image - local: api/pipelines/lumina2 @@ -718,6 +720,8 @@ - sections: - local: api/schedulers/overview title: Overview + - local: api/schedulers/block_refinement + title: BlockRefinementScheduler - local: api/schedulers/cm_stochastic_iterative title: CMStochasticIterativeScheduler - local: api/schedulers/ddim_cogvideox diff --git a/docs/source/en/api/pipelines/llada2.md b/docs/source/en/api/pipelines/llada2.md new file mode 100644 index 000000000000..cf0fa0b0d7b6 --- /dev/null +++ b/docs/source/en/api/pipelines/llada2.md @@ -0,0 +1,83 @@ + + +# LLaDA2 + +[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models +that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, +LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement +steps. + +## Usage + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import BlockRefinementScheduler, LLaDA2Pipeline + +model_id = "inclusionAI/LLaDA2.1-mini" +model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +scheduler = BlockRefinementScheduler() + +pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) +output = pipe( + prompt="Write a short poem about the ocean.", + gen_length=256, + block_length=32, + num_inference_steps=32, + threshold=0.7, + editing_threshold=0.5, + max_post_steps=16, + temperature=0.0, +) +print(output.texts[0]) +``` + +## Callbacks + +Callbacks run after each refinement step and can inspect or modify the current tokens. + +```py +def on_step_end(pipe, step, timestep, callback_kwargs): + cur_x = callback_kwargs["cur_x"] + # Inspect or modify `cur_x` here. + return {"cur_x": cur_x} + +out = pipe( + prompt="Write a short poem.", + callback_on_step_end=on_step_end, + callback_on_step_end_tensor_inputs=["cur_x"], +) +``` + +## Recommended parameters + +LLaDA2.1 models support two modes: + +| Mode | `threshold` | `editing_threshold` | `max_post_steps` | +|------|-------------|---------------------|------------------| +| Quality | 0.7 | 0.5 | 16 | +| Speed | 0.5 | 0.0 | 16 | + +For LLaDA2.0 models, disable editing by passing `editing_threshold=None`. + +For all models: `block_length=32`, `temperature=0.0`, `steps=32`. + +## LLaDA2Pipeline +[[autodoc]] LLaDA2Pipeline + - all + - __call__ + +## LLaDA2PipelineOutput +[[autodoc]] pipelines.LLaDA2PipelineOutput diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index cf5950686f22..3cfdfee8cc2b 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -63,6 +63,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [Latent Diffusion](latent_diffusion) | text2image, super-resolution | | [Latte](latte) | text2image | | [LEDITS++](ledits_pp) | image editing | +| [LLaDA2](llada2) | text2text | | [Lumina-T2X](lumina) | text2image | | [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition | | [MultiDiffusion](panorama) | text2image | diff --git a/docs/source/en/api/schedulers/block_refinement.md b/docs/source/en/api/schedulers/block_refinement.md new file mode 100644 index 000000000000..408da0d80552 --- /dev/null +++ b/docs/source/en/api/schedulers/block_refinement.md @@ -0,0 +1,25 @@ + + +# BlockRefinementScheduler + +The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it +commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different +token with high confidence. + +This scheduler is used by [`LLaDA2Pipeline`]. + +## BlockRefinementScheduler +[[autodoc]] BlockRefinementScheduler + +## BlockRefinementSchedulerOutput +[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput diff --git a/examples/discrete_diffusion/README.md b/examples/discrete_diffusion/README.md new file mode 100644 index 000000000000..a3a8253b1927 --- /dev/null +++ b/examples/discrete_diffusion/README.md @@ -0,0 +1,50 @@ +# Discrete Token Diffusion (Experimental) + +This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions. + +## LLaDA2 + +[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps. + +### Train + +The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral): + +```bash +accelerate launch examples/discrete_diffusion/train_llada2.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --text_column text \ + --output_dir llada2-output \ + --max_train_steps 1000 \ + --prompt_length 32 \ + --block_length 32 \ + --lambda_conf 2.0 \ + --conf_temperature 0.5 +``` + +If you don't want to download a dataset, you can use random-token data: + +```bash +accelerate launch examples/discrete_diffusion/train_llada2.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --output_dir llada2-output \ + --use_dummy_data \ + --num_dummy_samples 2048 +``` + +### Sample + +```bash +python examples/discrete_diffusion/sample_llada2.py \ + --model_id inclusionAI/LLaDA2.1-mini \ + --prompt "Write a short poem about the ocean." \ + --gen_length 256 \ + --num_inference_steps 32 \ + --threshold 0.7 \ + --editing_threshold 0.5 \ + --max_post_steps 16 \ + --use_chat_template \ + --add_generation_prompt +``` diff --git a/examples/discrete_diffusion/sample_llada2.py b/examples/discrete_diffusion/sample_llada2.py new file mode 100644 index 000000000000..067f50fca153 --- /dev/null +++ b/examples/discrete_diffusion/sample_llada2.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for LLaDA2-style discrete diffusion text generation. + +This script demonstrates how to use the LLaDA2Pipeline for text generation +using block-wise iterative refinement. + +Example usage: + python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?" + python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7 +""" + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import BlockRefinementScheduler, LLaDA2Pipeline +from diffusers.hooks import apply_group_offloading + + +def main(): + parser = argparse.ArgumentParser( + description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion." + ) + parser.add_argument( + "--model_id", + type=str, + default="inclusionAI/LLaDA2.0-mini", + help="HuggingFace model ID or path to local model.", + ) + parser.add_argument( + "--prompt", + type=str, + default="Why does Camus think that Sisyphus is happy?", + help="Text prompt to generate from.", + ) + parser.add_argument( + "--gen_length", + type=int, + default=2048, + help="Number of tokens to generate.", + ) + parser.add_argument( + "--block_length", + type=int, + default=32, + help="Size of each generation block.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=32, + help="Number of refinement steps per block.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature (0.0 for greedy).", + ) + parser.add_argument( + "--top_p", + type=float, + default=None, + help="Nucleus sampling probability threshold.", + ) + parser.add_argument( + "--top_k", + type=int, + default=None, + help="Top-k sampling parameter.", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.95, + help="Confidence threshold for committing tokens.", + ) + parser.add_argument( + "--editing_threshold", + type=float, + default=None, + help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).", + ) + parser.add_argument( + "--max_post_steps", + type=int, + default=0, + help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.", + ) + parser.add_argument( + "--sampling_method", + type=str, + default="multinomial", + choices=["auto", "greedy", "multinomial"], + help="Sampling method for block refinement.", + ) + parser.add_argument( + "--eos_early_stop", + action="store_true", + help="Stop generation early when EOS token is generated.", + ) + parser.add_argument( + "--use_chat_template", + action="store_true", + help="Use the tokenizer chat template for the prompt.", + ) + parser.add_argument( + "--add_generation_prompt", + action="store_true", + help="Add the generation prompt when using the chat template.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run inference on.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float32", "float16", "bfloat16"], + help="Model dtype.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility.", + ) + parser.add_argument( + "--offload", + type=str, + default=None, + choices=["group", "sequential"], + help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Model revision (branch, tag, or commit hash) to load from the Hub.", + ) + + args = parser.parse_args() + + # Parse dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + torch_dtype = dtype_map[args.dtype] + + print(f"Loading model: {args.model_id}") + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision) + + # Load model with appropriate memory settings based on offload strategy + if args.offload == "group": + # For group offloading, load to CPU first then apply hooks + print("Using group offloading for memory efficiency...") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + dtype=torch_dtype, + low_cpu_mem_usage=True, + revision=args.revision, + ) + # Apply group offloading with CUDA streams for better performance + onload_device = torch.device(args.device) + offload_device = torch.device("cpu") + apply_group_offloading( + model, + onload_device=onload_device, + offload_device=offload_device, + offload_type="leaf_level", + use_stream=True, + ) + elif args.offload == "sequential": + # For sequential offloading, load to CPU first + print("Using sequential CPU offloading (slower but lower memory)...") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + dtype=torch_dtype, + low_cpu_mem_usage=True, + revision=args.revision, + ) + # Sequential offloading will be applied via pipeline + else: + # Default: use device_map="auto" for automatic memory management + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + dtype=torch_dtype, + device_map="auto", + low_cpu_mem_usage=True, + revision=args.revision, + ) + model.eval() + + # Create pipeline + scheduler = BlockRefinementScheduler() + pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + + # Apply sequential CPU offload if requested + if args.offload == "sequential": + pipe.enable_sequential_cpu_offload() + + # Set up generator for reproducibility + generator = None + if args.seed is not None: + generator = torch.Generator(device=args.device).manual_seed(args.seed) + + print(f"\nPrompt: {args.prompt}") + print( + f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}" + ) + print("-" * 50) + + # Generate + output = pipe( + prompt=args.prompt, + use_chat_template=args.use_chat_template, + add_generation_prompt=args.add_generation_prompt, + gen_length=args.gen_length, + block_length=args.block_length, + num_inference_steps=args.num_inference_steps, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + threshold=args.threshold, + editing_threshold=args.editing_threshold, + max_post_steps=args.max_post_steps, + sampling_method=args.sampling_method, + eos_early_stop=args.eos_early_stop, + generator=generator, + ) + + print("\nGenerated text:") + print(output.texts[0]) + + print(f"\nGenerated {output.sequences.shape[1]} tokens") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_llada2.py b/examples/discrete_diffusion/train_llada2.py new file mode 100644 index 000000000000..7e1967abdd88 --- /dev/null +++ b/examples/discrete_diffusion/train_llada2.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler + +from diffusers import BlockRefinementScheduler +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + cache_dir: Optional[str] + use_dummy_data: bool + num_dummy_samples: int + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + prompt_length: int + block_length: int + + lambda_conf: float + conf_temperature: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.") + + parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.") + parser.add_argument("--num_dummy_samples", type=int, default=2048) + + parser.add_argument("--output_dir", type=str, default="block-refinement-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=1) + parser.add_argument("--gradient_accumulation_steps", type=int, default=8) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + parser.add_argument("--prompt_length", type=int, default=32) + parser.add_argument("--block_length", type=int, default=32) + + parser.add_argument("--lambda_conf", type=float, default=2.0) + parser.add_argument("--conf_temperature", type=float, default=0.5) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer(texts, truncation=True, padding=False, max_length=max_length) + + +class RandomTokenDataset(torch.utils.data.Dataset): + def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int): + self.num_samples = int(num_samples) + self.seq_len = int(seq_len) + self.vocab_size = int(vocab_size) + self.pad_token_id = int(pad_token_id) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + del idx + input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +def main(): + cfg = parse_args() + if cfg.prompt_length >= cfg.max_length: + raise ValueError("`prompt_length` must be < `max_length`.") + if cfg.block_length <= 0: + raise ValueError("`block_length` must be > 0.") + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype) + model.resize_token_embeddings(len(tokenizer)) + if load_dtype == torch.float32: + model.to(dtype=torch.float32) + + mask_token_id = int(tokenizer.mask_token_id) + + if cfg.use_dummy_data: + dataset = RandomTokenDataset( + num_samples=cfg.num_dummy_samples, + seq_len=cfg.max_length, + vocab_size=len(tokenizer), + pad_token_id=int(tokenizer.pad_token_id), + ) + train_dataloader = DataLoader( + dataset, + shuffle=True, + batch_size=cfg.per_device_train_batch_size, + drop_last=True, + ) + else: + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length) + + global_step = 0 + model.train() + + for _epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step) + noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise( + input_ids, + attention_mask, + prompt_length=cfg.prompt_length, + block_length=cfg.block_length, + mask_token_id=mask_token_id, + generator=gen, + ) + + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids) + ) + + logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits + logits_rev = model( + input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids + ).logits + + logits = logits.clone() + logits[..., mask_token_id] = torch.finfo(logits.dtype).min + logits_rev = logits_rev.clone() + logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min + + valid = attention_mask.to(dtype=torch.bool) + masked = masked & valid + masked_rev = masked_rev & valid + + labels = input_ids.clone() + labels[~masked] = -100 + labels_rev = input_ids.clone() + labels_rev[~masked_rev] = -100 + + weights = masked.to(dtype=logits.dtype) + weights_rev = masked_rev.to(dtype=logits.dtype) + + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights, + ) + loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss( + logits_rev, + labels_rev, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=weights_rev, + ) + + total_loss = loss + loss_rev + accelerator.backward(total_loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g", + global_step, + total_loss.item(), + (loss_sft + loss_sft_rev).item(), + (loss_conf + loss_conf_rev).item(), + lr_scheduler.get_last_lr()[0], + ) + print( + f"step={global_step} loss={total_loss.item():.4f} " + f"sft={(loss_sft + loss_sft_rev).item():.4f} " + f"conf={(loss_conf + loss_conf_rev).item():.4f} " + f"lr={lr_scheduler.get_last_lr()[0]:.6g}" + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eb5068b499cc..7d966452d1a2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -344,6 +344,8 @@ _import_structure["schedulers"].extend( [ "AmusedScheduler", + "BlockRefinementScheduler", + "BlockRefinementSchedulerOutput", "CMStochasticIterativeScheduler", "CogVideoXDDIMScheduler", "CogVideoXDPMScheduler", @@ -580,6 +582,8 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LLaDA2Pipeline", + "LLaDA2PipelineOutput", "LongCatImageEditPipeline", "LongCatImagePipeline", "LTX2ConditionPipeline", @@ -1124,6 +1128,8 @@ from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, + BlockRefinementScheduler, + BlockRefinementSchedulerOutput, CMStochasticIterativeScheduler, CogVideoXDDIMScheduler, CogVideoXDPMScheduler, @@ -1339,6 +1345,8 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LLaDA2Pipeline, + LLaDA2PipelineOutput, LongCatImageEditPipeline, LongCatImagePipeline, LTX2ConditionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b9596f4b7952..3dafb56fdd65 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -285,6 +285,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] + _import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] _import_structure["ltx"] = [ "LTXPipeline", "LTXImageToVideoPipeline", @@ -728,6 +729,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) + from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline from .ltx import ( LTXConditionPipeline, diff --git a/src/diffusers/pipelines/llada2/__init__.py b/src/diffusers/pipelines/llada2/__init__.py new file mode 100644 index 000000000000..45a02e6851e2 --- /dev/null +++ b/src/diffusers/pipelines/llada2/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/llada2/pipeline_llada2.py b/src/diffusers/pipelines/llada2/pipeline_llada2.py new file mode 100644 index 000000000000..d4b037ada151 --- /dev/null +++ b/src/diffusers/pipelines/llada2/pipeline_llada2.py @@ -0,0 +1,491 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from tqdm.auto import tqdm + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import BlockRefinementScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> from diffusers import BlockRefinementScheduler, LLaDA2Pipeline + + >>> model_id = "inclusionAI/LLaDA2.1-mini" + >>> model = AutoModelForCausalLM.from_pretrained( + ... model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" + ... ) + >>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + >>> scheduler = BlockRefinementScheduler() + + >>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + >>> output = pipe(prompt="What is the meaning of life?", gen_length=256) + >>> print(output.texts[0]) + ``` +""" + + +@dataclass +class LLaDA2PipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: list[str] | None = None + + +class LLaDA2Pipeline(DiffusionPipeline): + r""" + Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement. + + This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each + refinement step, it samples candidate tokens for the active block and commits a subset based on confidence. + + The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq, + vocab_size]`. + """ + + model: Any + scheduler: BlockRefinementScheduler + tokenizer: Any + + _callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"] + + def __init__( + self, + model: Any, + scheduler: BlockRefinementScheduler, + tokenizer: Any | None = None, + ): + super().__init__() + self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer) + self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None + self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) if self.tokenizer is not None else None + + @property + def num_timesteps(self): + return self._num_timesteps + + # --- Prompt encoding --- + + def _prepare_input_ids( + self, + *, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + use_chat_template: bool, + add_generation_prompt: bool, + chat_template_kwargs: dict[str, Any] | None, + ) -> torch.LongTensor: + """Convert prompt/messages/input_ids to a [batch, seq] LongTensor.""" + if input_ids is not None: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + if input_ids.ndim != 2: + raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + return input_ids + + if self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + if messages is not None and prompt is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if messages is None and prompt is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + + chat_template_kwargs = chat_template_kwargs or {} + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return encoded["input_ids"] + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return encoded["input_ids"] + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return encoded["input_ids"] + + def check_inputs( + self, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + gen_length: int, + block_length: int, + num_inference_steps: int, + minimal_topk: int, + threshold: float, + sampling_method: str, + output_type: str, + callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None, + callback_on_step_end_tensor_inputs: list[str] | None, + ): + # Input source validation + if prompt is None and messages is None and input_ids is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + if prompt is not None and messages is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if input_ids is not None: + if input_ids.ndim not in (1, 2): + raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + if prompt is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + if messages is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + # Generation parameter validation + if gen_length <= 0: + raise ValueError(f"`gen_length` must be > 0, got {gen_length}.") + if block_length <= 0: + raise ValueError(f"`block_length` must be > 0, got {block_length}.") + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + if minimal_topk <= 0: + raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.") + if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0): + raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.") + if sampling_method not in {"auto", "greedy", "multinomial"}: + raise ValueError( + f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}." + ) + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + + # Callback validation + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + messages: list[dict[str, str]] | None = None, + input_ids: torch.LongTensor | None = None, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + gen_length: int = 2048, + block_length: int = 32, + num_inference_steps: int = 32, + temperature: float = 0.0, + top_p: float | None = None, + top_k: int | None = None, + sampling_method: str = "multinomial", + threshold: float = 0.7, + editing_threshold: float | None = 0.5, + max_post_steps: int = 16, + minimal_topk: int = 1, + eos_early_stop: bool = True, + eos_token_id: int | None = None, + mask_token_id: int | None = None, + generator: torch.Generator | None = None, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> LLaDA2PipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate text with block-wise refinement. + + Args: + prompt (`str` or `List[str]`, *optional*): + Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is + available, the prompt is wrapped in a chat message before tokenization. + messages (`List[Dict[str, str]]`, *optional*): + Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt` + when provided. Requires a tokenizer with `apply_chat_template`. + input_ids (`torch.LongTensor`, *optional*): + Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`. + use_chat_template (`bool`, defaults to `True`): + Whether to wrap the prompt in a chat template. + add_generation_prompt (`bool`, defaults to `True`): + Whether to add the generation prompt when using chat templates. + gen_length (`int`): + Number of tokens to generate. + block_length (`int`): + Block size for refinement. + num_inference_steps (`int`): + Number of refinement steps per block. + temperature (`float`): + Sampling temperature. + top_p (`float`, *optional*): + Nucleus sampling cutoff. + top_k (`int`, *optional*): + Top-k sampling cutoff. + sampling_method (`str`): + Sampling method (`auto`, `greedy`, `multinomial`). + threshold (`float`): + Confidence threshold for committing tokens. + editing_threshold (`float`, *optional*): + Confidence threshold for editing already-committed (non-mask) tokens. When set, after all mask tokens + in a block are resolved, the pipeline continues refining: if the model predicts a different token with + confidence above this threshold, the existing token is replaced. Set to `None` or a negative value to + disable editing. Defaults to `0.5`. + max_post_steps (`int`): + Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only + used when `editing_threshold` is enabled. Defaults to `16`. + minimal_topk (`int`): + Minimum number of tokens to commit per step. + eos_early_stop (`bool`): + Whether to stop after committing EOS in a block. + eos_token_id (`int`, *optional*): + EOS token ID to use for early stopping. + mask_token_id (`int`, *optional*): + Mask token ID to use for the template. + generator (`torch.Generator`, *optional*): + RNG for sampling. + output_type (`str`, defaults to `"text"`): + Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw + token ID sequences only. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`LLaDA2PipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback executed after each refinement step with signature `callback_on_step_end(self, step: int, + timestep: int, callback_kwargs: Dict)`. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`, + `confidence`, `active_block`. + + Examples: + """ + # 1. Check inputs early + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["block_x"] + + self.check_inputs( + prompt=prompt, + messages=messages, + input_ids=input_ids, + gen_length=gen_length, + block_length=block_length, + num_inference_steps=num_inference_steps, + minimal_topk=minimal_topk, + threshold=threshold, + sampling_method=sampling_method, + output_type=output_type, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Prepare input IDs from prompt/messages/input_ids + prompt_ids = self._prepare_input_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=None, + ) + + device = self._execution_device + + if prompt_ids.ndim == 1: + prompt_ids = prompt_ids.unsqueeze(0) + prompt_ids = prompt_ids.to(device=device) + batch_size, prompt_length = prompt_ids.shape + + if eos_token_id is None: + eos_token_id = self.eos_token_id + if mask_token_id is None: + mask_token_id = self.mask_token_id + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).") + + num_inference_steps = min(num_inference_steps, gen_length // minimal_topk) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 3. Build attention mask and position IDs + num_blocks = (prompt_length + gen_length + block_length - 1) // block_length + total_length = num_blocks * block_length + + # 2D attention mask (no padding) — the model handles backend-specific conversion internally. + attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long) + + position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + + # 4. Prepare latents (fully masked sequence) + x = torch.full((batch_size, total_length), mask_token_id, device=device, dtype=torch.long) + if prompt_length > 0: + x[:, :prompt_length] = prompt_ids + + prefill_blocks = prompt_length // block_length + self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0) + + finished = torch.zeros((batch_size,), device=device, dtype=torch.bool) + editing_enabled = editing_threshold is not None and editing_threshold >= 0.0 + global_step = 0 + + # 5. Block-wise refinement loop + block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() + block_progress_bar_config["position"] = 0 + block_progress_bar_config["desc"] = "Blocks" + for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config): + current_window_end = (num_block + 1) * block_length + block_x = x[:, :current_window_end] + block_attn_mask = attn_mask[:, :current_window_end] + block_position_ids = position_ids[:, :current_window_end] + + # Identify which positions in the block are prompt (non-editable). + block_start_pos = num_block * block_length + prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool) + if block_start_pos < prompt_length: + prompt_end_in_block = min(prompt_length - block_start_pos, block_length) + prompt_mask_in_block[:prompt_end_in_block] = True + + post_steps = 0 + step_idx = 0 + should_continue = True + self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps") + progress_bar = self.progress_bar(total=num_inference_steps) + + while should_continue: + block_tokens = block_x[:, -block_length:] + masks_remaining = (block_tokens == mask_token_id).any() + + if not masks_remaining: + post_steps += 1 + + logits = self.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits + block_logits = logits[:, -block_length:, :] + + scheduler_output = self.scheduler.step( + model_output=block_logits, + timestep=step_idx, + sample=block_tokens, + mask_token_id=mask_token_id, + temperature=temperature, + top_p=top_p, + top_k=top_k, + sampling_method=sampling_method, + threshold=threshold, + editing_threshold=editing_threshold, + minimal_topk=minimal_topk, + prompt_mask=prompt_mask_in_block, + generator=generator, + return_dict=True, + ) + + transfer_index = scheduler_output.transfer_index + editing_transfer_index = scheduler_output.editing_transfer_index + final_transfer = transfer_index | editing_transfer_index + + if final_transfer.any(): + block_x[:, -block_length:] = scheduler_output.prev_sample + + if eos_early_stop and eos_token_id is not None: + finished = self.scheduler.check_eos_finished( + cur_x=block_x, + sampled_tokens=scheduler_output.sampled_tokens, + final_transfer=final_transfer, + finished=finished, + eos_token_id=eos_token_id, + mask_token_id=mask_token_id, + prompt_length=prompt_length, + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs) + block_x = callback_outputs.pop("block_x", block_x) + + global_step += 1 + if masks_remaining: + step_idx += 1 + progress_bar.update(1) + + should_continue = self.scheduler.check_block_should_continue( + step_idx=step_idx, + masks_remaining=masks_remaining, + editing_enabled=editing_enabled, + editing_transfer_index=editing_transfer_index, + post_steps=post_steps, + max_post_steps=max_post_steps, + finished=finished, + ) + + progress_bar.close() + x[:, :current_window_end] = block_x + if eos_early_stop and finished.all(): + break + + # 6. Post-process output + generated = x[:, : prompt_length + gen_length] + sequences = generated[:, prompt_length:] + if eos_token_id is not None and batch_size == 1: + eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0] + if len(eos_positions) > 0: + sequences = sequences[:, : int(eos_positions[0].item()) + 1] + + texts = None + if output_type == "text" and self.tokenizer is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + if not return_dict: + return sequences.to(device=device), texts + return LLaDA2PipelineOutput(sequences=sequences.to(device=device), texts=texts) + + +__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index c7101d1b0401..b1f75bed7dc5 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -40,6 +40,7 @@ else: _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] + _import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] @@ -145,6 +146,7 @@ else: from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler + from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler diff --git a/src/diffusers/schedulers/scheduling_block_refinement.py b/src/diffusers/schedulers/scheduling_block_refinement.py new file mode 100644 index 000000000000..5717cee7f8a8 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_block_refinement.py @@ -0,0 +1,459 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class BlockRefinementSchedulerOutput(BaseOutput): + """ + Output class for block refinement scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Updated block tokens after the current refinement step. + transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Boolean mask indicating which tokens were committed (mask-filling). + editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Boolean mask indicating which tokens were edited (non-mask replacement). + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Sampled token IDs from the model logits. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Probabilities of the sampled tokens. + """ + + prev_sample: torch.LongTensor + transfer_index: torch.BoolTensor + editing_transfer_index: torch.BoolTensor + sampled_tokens: torch.LongTensor + sampled_probs: torch.Tensor + + +class BlockRefinementScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for block-wise iterative refinement (commit-by-confidence). + + At each step, the scheduler samples candidate tokens from model logits and commits those with the highest + confidence. The number of tokens to commit per step is determined by evenly distributing the block length across + the number of refinement steps. + + Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a + different token with confidence above `editing_threshold`. + """ + + order = 1 + + @register_to_config + def __init__( + self, + block_length: int = 32, + num_inference_steps: int = 32, + threshold: float = 0.95, + editing_threshold: float | None = None, + minimal_topk: int = 1, + ): + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long) + self._transfer_schedule: torch.LongTensor | None = None + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) + self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to( + device=device if device is not None else "cpu" + ) + + def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor: + """Evenly distribute `block_length` token commits across `num_inference_steps` steps.""" + if num_inference_steps <= 0: + return torch.zeros((0,), dtype=torch.long) + base = block_length // num_inference_steps + remainder = block_length % num_inference_steps + out = torch.full((num_inference_steps,), base, dtype=torch.long) + out[:remainder] += 1 + return out + + # --- SAR sampling utilities --- + + @staticmethod + def _top_p_filtering(logits: torch.Tensor, top_p: float | None) -> torch.Tensor: + """Nucleus (top-p) logit filtering.""" + if top_p is None or top_p >= 1.0: + return logits + if not (0.0 < top_p <= 1.0): + raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.") + + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = sorted_probs.cumsum(dim=-1) + + sorted_indices_to_remove = cumulative_probs > float(top_p) + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min) + filtered = logits.scatter(-1, sorted_indices, sorted_logits) + return filtered + + @staticmethod + def _top_k_filtering(logits: torch.Tensor, top_k: int | None) -> torch.Tensor: + """Top-k logit filtering.""" + if top_k is None or top_k <= 0: + return logits + if top_k >= logits.shape[-1]: + return logits + values, _ = torch.topk(logits, k=top_k, dim=-1) + min_keep = values[..., -1, None] + return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min) + + @staticmethod + def _sample_from_logits( + logits: torch.Tensor, + *, + temperature: float, + top_k: int | None, + top_p: float | None, + generator: torch.Generator | None, + use_multinomial: bool, + ) -> tuple[torch.LongTensor, torch.Tensor]: + """Sample tokens from logits with temperature scaling, top-k, and top-p.""" + if temperature < 0: + raise ValueError(f"`temperature` must be >= 0, got {temperature}.") + + vocab_size = logits.shape[-1] + flat_logits = logits.reshape(-1, vocab_size) + + if temperature == 0.0 or not use_multinomial: + probs = torch.softmax(flat_logits.float(), dim=-1) + token = flat_logits.argmax(dim=-1, keepdim=True) + token_prob = torch.gather(probs, -1, token) + return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + + scaled = flat_logits + if temperature != 1.0: + scaled = flat_logits / temperature + + filtered = BlockRefinementScheduler._top_k_filtering(scaled, top_k=top_k) + filtered = BlockRefinementScheduler._top_p_filtering(filtered, top_p=top_p) + + probs = torch.softmax(filtered.float(), dim=-1) + token = torch.multinomial(probs, num_samples=1, generator=generator) + token_prob = torch.gather(probs, -1, token) + + return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1]) + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + mask_token_id: int, + temperature: float = 0.0, + top_p: float | None = None, + top_k: int | None = None, + sampling_method: str = "auto", + threshold: float | None = None, + editing_threshold: float | None = None, + minimal_topk: int | None = None, + prompt_mask: torch.BoolTensor | None = None, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> ( + BlockRefinementSchedulerOutput + | tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor] + ): + """ + Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing + ones. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`): + Raw logits from the model for the current block. + timestep (`int` or `torch.Tensor`): + Current step index within the block's refinement schedule. + sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Current block token IDs (contains mask tokens for uncommitted positions). + mask_token_id (`int`): + Token ID used for masked positions. + temperature (`float`): + Sampling temperature. + top_p (`float`, *optional*): + Nucleus sampling cutoff. + top_k (`int`, *optional*): + Top-k sampling cutoff. + sampling_method (`str`): + Sampling method (`auto`, `greedy`, `multinomial`). + threshold (`float`, *optional*): + Confidence threshold for committing tokens. Defaults to config value. + editing_threshold (`float`, *optional*): + Confidence threshold for editing non-mask tokens. Defaults to config value. + minimal_topk (`int`, *optional*): + Minimum tokens to commit per step. Defaults to config value. + prompt_mask (`torch.BoolTensor`, *optional*): + Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions. + generator (`torch.Generator`, *optional*): + RNG for sampling. + return_dict (`bool`): + Whether to return a `BlockRefinementSchedulerOutput` or a tuple. + """ + if threshold is None: + threshold = float(self.config.threshold) + if editing_threshold is None: + editing_threshold = self.config.editing_threshold + if minimal_topk is None: + minimal_topk = self.config.minimal_topk + + # Sample from logits + use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and temperature != 0.0) + sampled_tokens, sampled_probs = self._sample_from_logits( + model_output, + temperature=temperature, + top_k=top_k, + top_p=top_p, + generator=generator, + use_multinomial=use_multinomial, + ) + + batch_size, block_length = sample.shape + active_block = sample == mask_token_id + masks_remaining = active_block.any() + + if isinstance(timestep, torch.Tensor): + step_index = int(timestep.item()) + else: + step_index = int(timestep) + + # --- Mask-filling transfer --- + transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool) + if masks_remaining and self._transfer_schedule is not None: + clamped_step = min(step_index, len(self._transfer_schedule) - 1) + num_to_transfer = int(self._transfer_schedule[clamped_step].item()) + + confidence = torch.where( + active_block, + sampled_probs.to(dtype=torch.float32), + torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32), + ) + + for b in range(batch_size): + high_conf = confidence[b] > threshold + if high_conf.sum().item() >= num_to_transfer: + transfer_index[b] = high_conf + else: + k = min(num_to_transfer, int(active_block[b].sum().item())) + if k > 0: + _, idx = torch.topk(confidence[b], k=k) + transfer_index[b, idx] = True + + # --- Editing transfer (non-mask, non-prompt positions) --- + editing_enabled = editing_threshold is not None and editing_threshold >= 0.0 + editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool) + if editing_enabled: + if prompt_mask is None: + prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool) + editable = (~active_block) & (~prompt_mask.unsqueeze(0)) + editing_conf = torch.where( + editable, + sampled_probs.to(dtype=torch.float32), + torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32), + ) + high_conf_edit = editing_conf > float(editing_threshold) + token_changed = sampled_tokens != sample + editing_transfer_index = high_conf_edit & token_changed & editable + + # Apply transfers + final_transfer = transfer_index | editing_transfer_index + prev_sample = sample.clone() + if final_transfer.any(): + prev_sample[final_transfer] = sampled_tokens[final_transfer] + + if not return_dict: + return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs + return BlockRefinementSchedulerOutput( + prev_sample=prev_sample, + transfer_index=transfer_index, + editing_transfer_index=editing_transfer_index, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + @staticmethod + def check_eos_finished( + cur_x: torch.LongTensor, + sampled_tokens: torch.LongTensor, + final_transfer: torch.BoolTensor, + finished: torch.BoolTensor, + eos_token_id: int, + mask_token_id: int, + prompt_length: int, + ) -> torch.BoolTensor: + """ + Update per-batch finished flags when EOS tokens are committed. + + Args: + cur_x (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Current full sequence including all blocks up to the current window. + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Tokens sampled by the scheduler in this step. + final_transfer (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Combined mask of committed and edited positions. + finished (`torch.BoolTensor` of shape `(batch_size,)`): + Current per-batch finished flags. + eos_token_id (`int`): + EOS token ID. + mask_token_id (`int`): + Mask token ID. + prompt_length (`int`): + Number of prompt tokens at the start of the sequence. + + Returns: + `torch.BoolTensor`: Updated finished flags. + """ + batch_size = cur_x.shape[0] + for b in range(batch_size): + if finished[b]: + continue + eos_in_commits = (sampled_tokens[b][final_transfer[b]] == eos_token_id).any().item() + if not eos_in_commits: + continue + eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True) + if len(eos_pos[0]) == 0: + continue + eos_pos = int(eos_pos[0][0].item()) + if prompt_length >= eos_pos: + continue + if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item(): + finished[b] = True + return finished + + def check_block_should_continue( + self, + step_idx: int, + masks_remaining: bool, + editing_enabled: bool, + editing_transfer_index: torch.BoolTensor, + post_steps: int, + max_post_steps: int, + finished: torch.BoolTensor, + ) -> bool: + """ + Determine whether the inner refinement loop should continue for the current block. + + Args: + step_idx (`int`): + Current refinement step index within this block. + masks_remaining (`bool`): + Whether any mask tokens remain in the block. + editing_enabled (`bool`): + Whether editing mode is active. + editing_transfer_index (`torch.BoolTensor`): + Which tokens were edited in this step. + post_steps (`int`): + Number of post-mask editing steps taken so far. + max_post_steps (`int`): + Maximum allowed post-mask editing steps. + finished (`torch.BoolTensor`): + Per-batch finished flags (from EOS detection). + + Returns: + `bool`: `True` if refinement should continue, `False` to break. + """ + if finished.all(): + return False + if not masks_remaining and not editing_enabled: + return False + if not masks_remaining and not editing_transfer_index.any(): + return False + if masks_remaining and step_idx >= self.num_inference_steps: + return False + if not masks_remaining and post_steps > max_post_steps: + return False + return True + + def add_noise( + self, + original_samples: torch.LongTensor, + attention_mask: torch.LongTensor, + *, + prompt_length: int, + block_length: int, + mask_token_id: int, + generator: torch.Generator | None = None, + ) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]: + """ + Apply the forward (noising) process for semi-autoregressive block masking. + + For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with + `mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in + one are the unmasked positions in the other. + + Args: + original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Clean token IDs. + attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Padding mask (1 for valid, 0 for padding). + prompt_length (`int`): + Number of leading prompt tokens to keep unmasked. + block_length (`int`): + Block size for masking. + mask_token_id (`int`): + Token ID to use for masked positions. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + + Returns: + `tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`: + `(noisy, noisy_rev, masked, masked_rev)` — the two complementary noisy sequences and their + corresponding boolean masks. + """ + batch_size, seq_len = original_samples.shape + device = original_samples.device + + noisy = original_samples.clone() + noisy_rev = original_samples.clone() + masked = torch.zeros_like(original_samples, dtype=torch.bool) + masked_rev = torch.zeros_like(original_samples, dtype=torch.bool) + + valid = attention_mask.to(dtype=torch.bool) + for block_start in range(prompt_length, seq_len, block_length): + block_end = min(seq_len, block_start + block_length) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg = seg & valid[:, block_start:block_end] + seg_rev = (~seg) & valid[:, block_start:block_end] + + masked[:, block_start:block_end] = seg + masked_rev[:, block_start:block_end] = seg_rev + + noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy) + noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev) + return noisy, noisy_rev, masked, masked_rev + + +__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"] diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 6c07a30c2ccc..080f852e2490 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -11,6 +11,7 @@ import numpy as np import torch +import torch.nn.functional as F if getattr(torch, "distributed", None) is not None: @@ -109,6 +110,92 @@ def compute_snr(noise_scheduler, timesteps): return snr +def compute_confidence_aware_loss( + logits: torch.Tensor, + labels: torch.Tensor, + *, + lambda_conf: float = 0.0, + temperature: float = 1.0, + per_token_weights: torch.Tensor | None = None, + ignore_index: int = -100, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes a confidence-aware training loss for token classification-style heads. + + This loss combines: + - `loss_sft`: standard supervised cross-entropy on all non-ignored labels. + - `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly. + + Args: + logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`. + labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index` + are excluded from both losses. + lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term. + temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values + sharpen the distribution and change the strength of the confidence gradients. + per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per + token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing. + ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`. + """ + if logits.ndim < 2: + raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.") + if labels.shape != logits.shape[:-1]: + raise ValueError( + f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}." + ) + if temperature <= 0: + raise ValueError(f"`temperature` must be > 0, got {temperature}.") + + valid = labels.ne(ignore_index) + if per_token_weights is None: + weights = torch.ones_like(labels, dtype=logits.dtype) + else: + if per_token_weights.shape != labels.shape: + raise ValueError( + f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}." + ) + weights = per_token_weights.to(dtype=logits.dtype) + + # Supervised CE (optionally weighted). + vocab_size = logits.shape[-1] + per_token_nll = F.cross_entropy( + logits.reshape(-1, vocab_size), + labels.reshape(-1), + reduction="none", + ignore_index=ignore_index, + ).reshape_as(labels) + + denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1) + loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft + + # Confidence loss: penalize entropy only where prediction is already correct. + if lambda_conf == 0.0: + loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype) + return loss_sft, loss_sft, loss_conf + + with torch.no_grad(): + pred = logits.argmax(dim=-1) + correct = valid & pred.eq(labels) + + scaled_logits = logits.float() + if temperature != 1.0: + scaled_logits = scaled_logits / float(temperature) + + probs = torch.softmax(scaled_logits, dim=-1) + eps = torch.finfo(probs.dtype).tiny + log_probs = torch.log(probs.clamp_min(eps)) + entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype) + + denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1) + loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf + + loss = loss_sft + float(lambda_conf) * loss_conf + return loss, loss_sft, loss_conf + + def resolve_interpolation_mode(interpolation_type: str): """ Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index c41410d153c9..fa37388fe75a 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2518,6 +2518,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BlockRefinementScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlockRefinementSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2ec5bc002f41..1e4d14566160 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2222,6 +2222,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LLaDA2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LLaDA2PipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LongCatImageEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/others/test_training.py b/tests/others/test_training.py index 2038a98a813e..d8e86984ef1e 100644 --- a/tests/others/test_training.py +++ b/tests/others/test_training.py @@ -18,7 +18,7 @@ import torch from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel -from diffusers.training_utils import set_seed +from diffusers.training_utils import compute_confidence_aware_loss, set_seed from ..testing_utils import slow @@ -85,3 +85,47 @@ def test_training_step_equality(self): self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5)) self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5)) + + def test_confidence_aware_loss(self): + logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]]) + labels = torch.tensor([[0, 0]]) + weights = torch.tensor([[1.0, 2.0]]) + + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, labels, lambda_conf=0.0, per_token_weights=weights + ) + self.assertTrue(torch.allclose(loss, loss_sft)) + self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf))) + + lambda_conf = 0.25 + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, labels, lambda_conf=lambda_conf, per_token_weights=weights + ) + + # Manual expected values for the small 2-class case. + per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view( + 1, 2 + ) + expected_sft = (per_token_nll * weights).sum() / weights.sum() + + pred = logits.argmax(dim=-1) + correct = pred.eq(labels) + log_probs = torch.log_softmax(logits.float(), dim=-1) + probs = log_probs.exp() + entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype) + expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / ( + weights * correct.to(weights.dtype) + ).sum().clamp_min(1) + + expected = expected_sft + lambda_conf * expected_conf + self.assertTrue(torch.allclose(loss_sft, expected_sft)) + self.assertTrue(torch.allclose(loss_conf, expected_conf)) + self.assertTrue(torch.allclose(loss, expected)) + + # Temperature affects only the confidence term. + loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss( + logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights + ) + self.assertTrue(torch.allclose(loss_sft_t, expected_sft)) + self.assertFalse(torch.allclose(loss_conf_t, expected_conf)) + self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t)) diff --git a/tests/pipelines/llada2/__init__.py b/tests/pipelines/llada2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/llada2/test_llada2.py b/tests/pipelines/llada2/test_llada2.py new file mode 100644 index 000000000000..c3511918fe67 --- /dev/null +++ b/tests/pipelines/llada2/test_llada2.py @@ -0,0 +1,245 @@ +import unittest + +import torch + +from diffusers import BlockRefinementScheduler, LLaDA2Pipeline + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyCausalLM(torch.nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = int(vocab_size) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + + # Make confidence vary with token position so top-k commits are deterministic. + positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1) + token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1) + logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1) + return _DummyModelOutput(logits=logits) + + +def _make_pipeline(tokenizer=None): + model = _DummyCausalLM(vocab_size=32) + scheduler = BlockRefinementScheduler() + return LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + + +class LLaDA2PipelineTest(unittest.TestCase): + def test_pipeline_runs(self): + pipe = _make_pipeline().to("cpu") + + input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long) + out = pipe( + input_ids=input_ids, + use_chat_template=False, + gen_length=24, + block_length=8, + num_inference_steps=8, + temperature=0.0, + threshold=2.0, # force top-k commits + minimal_topk=1, + eos_early_stop=False, + mask_token_id=31, + eos_token_id=None, + output_type="seq", + ) + + self.assertEqual(out.sequences.shape, (2, 24)) + self.assertFalse((out.sequences == 31).any().item()) + + def test_pipeline_return_tuple(self): + pipe = _make_pipeline().to("cpu") + + input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long) + sequences, texts = pipe( + input_ids=input_ids, + use_chat_template=False, + gen_length=16, + block_length=8, + num_inference_steps=4, + temperature=0.0, + threshold=2.0, + minimal_topk=1, + eos_early_stop=False, + mask_token_id=31, + output_type="seq", + return_dict=False, + ) + + self.assertEqual(sequences.shape, (1, 16)) + self.assertIsNone(texts) + + def test_output_type_seq(self): + """output_type='seq' should return sequences but no texts.""" + pipe = _make_pipeline().to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + gen_length=16, + block_length=8, + num_inference_steps=4, + temperature=0.0, + threshold=2.0, + minimal_topk=1, + eos_early_stop=False, + mask_token_id=31, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertEqual(out.sequences.shape, (1, 16)) + self.assertIsNone(out.texts) + + def test_output_type_text_without_tokenizer(self): + """output_type='text' without a tokenizer should return texts=None.""" + pipe = _make_pipeline(tokenizer=None).to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + gen_length=16, + block_length=8, + num_inference_steps=4, + temperature=0.0, + threshold=2.0, + minimal_topk=1, + eos_early_stop=False, + mask_token_id=31, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + def test_output_type_text_with_tokenizer(self): + """output_type='text' with a tokenizer should return decoded texts.""" + tok = type( + "Tok", + (), + { + "eos_token_id": None, + "mask_token_id": 31, + "batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], + }, + )() + pipe = _make_pipeline(tokenizer=tok).to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + gen_length=16, + block_length=8, + num_inference_steps=4, + temperature=0.0, + threshold=2.0, + minimal_topk=1, + eos_early_stop=False, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNotNone(out.texts) + self.assertEqual(len(out.texts), 1) + self.assertTrue(out.texts[0].startswith("decoded_")) + + def test_output_type_invalid_raises(self): + """Invalid output_type should raise ValueError.""" + pipe = _make_pipeline().to("cpu") + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + gen_length=16, + block_length=8, + num_inference_steps=4, + mask_token_id=31, + output_type="invalid", + ) + + def test_prepare_input_ids_from_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + result = pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertTrue(torch.equal(result, ids)) + + def test_prepare_input_ids_from_1d_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([1, 2, 3], dtype=torch.long) + result = pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertEqual(result.shape, (1, 3)) + + def test_prepare_input_ids_no_tokenizer_raises(self): + pipe = _make_pipeline(tokenizer=None) + with self.assertRaises(ValueError): + pipe._prepare_input_ids( + prompt="hello", + messages=None, + input_ids=None, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + + def test_prepare_input_ids_both_prompt_and_messages_raises(self): + pipe = _make_pipeline() + # Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit + pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})() + with self.assertRaises(ValueError): + pipe._prepare_input_ids( + prompt="hello", + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + + def test_prepare_input_ids_neither_raises(self): + pipe = _make_pipeline() + pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})() + with self.assertRaises(ValueError): + pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=None, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_block_refinement.py b/tests/schedulers/test_scheduler_block_refinement.py new file mode 100644 index 000000000000..2e5e404e5f9a --- /dev/null +++ b/tests/schedulers/test_scheduler_block_refinement.py @@ -0,0 +1,470 @@ +import tempfile +import unittest + +import torch + +from diffusers import BlockRefinementScheduler + + +class BlockRefinementSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = { + "block_length": 32, + "num_inference_steps": 8, + "threshold": 0.95, + "editing_threshold": None, + "minimal_topk": 1, + } + config.update(kwargs) + return BlockRefinementScheduler(**config) + + def _make_logits_from_probs(self, target_probs: torch.Tensor, vocab_size: int = 100) -> torch.Tensor: + """Create logits where softmax of the target token has approximately the given probability.""" + batch_size, block_length = target_probs.shape + logits = torch.zeros(batch_size, block_length, vocab_size) + # Set token 0 as the "predicted" token with a logit proportional to desired probability + for b in range(batch_size): + for t in range(block_length): + p = target_probs[b, t].item() + if p > 0: + logits[b, t, t % (vocab_size - 1)] = 10.0 * p + return logits + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + self.assertEqual(scheduler.num_inference_steps, 8) + self.assertEqual(len(scheduler.timesteps), 8) + self.assertEqual(scheduler.timesteps[0].item(), 7) + self.assertEqual(scheduler.timesteps[-1].item(), 0) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + + def test_get_num_transfer_tokens_even(self): + scheduler = self.get_scheduler() + schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8) + self.assertEqual(schedule.sum().item(), 32) + self.assertEqual(len(schedule), 8) + self.assertTrue((schedule == 4).all().item()) + + def test_get_num_transfer_tokens_remainder(self): + scheduler = self.get_scheduler() + schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3) + self.assertEqual(schedule.sum().item(), 10) + self.assertEqual(len(schedule), 3) + self.assertEqual(schedule[0].item(), 4) + self.assertEqual(schedule[1].item(), 3) + self.assertEqual(schedule[2].item(), 3) + + def test_transfer_schedule_created_on_set_timesteps(self): + scheduler = self.get_scheduler(block_length=16) + scheduler.set_timesteps(4) + self.assertIsNotNone(scheduler._transfer_schedule) + self.assertEqual(scheduler._transfer_schedule.sum().item(), 16) + + def test_save_load_config_round_trip(self): + scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2) + with tempfile.TemporaryDirectory() as tmpdir: + scheduler.save_config(tmpdir) + loaded = BlockRefinementScheduler.from_pretrained(tmpdir) + + self.assertEqual(loaded.config.block_length, 64) + self.assertEqual(loaded.config.threshold, 0.8) + self.assertEqual(loaded.config.editing_threshold, 0.5) + self.assertEqual(loaded.config.minimal_topk, 2) + + def test_from_config(self): + scheduler = self.get_scheduler(block_length=16, threshold=0.7) + new_scheduler = BlockRefinementScheduler.from_config(scheduler.config) + self.assertEqual(new_scheduler.config.block_length, 16) + self.assertEqual(new_scheduler.config.threshold, 0.7) + + def test_step_commits_tokens(self): + """Verify that step() commits mask tokens based on confidence.""" + scheduler = self.get_scheduler(block_length=8) + scheduler.set_timesteps(2) + + batch_size, block_length, vocab_size = 1, 8, 32 + mask_id = 31 + + sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long) + # Create logits where confidence decreases with position + logits = torch.zeros(batch_size, block_length, vocab_size) + for i in range(block_length): + logits[0, i, i] = 10.0 - i # decreasing confidence + + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + temperature=0.0, + threshold=0.95, + return_dict=True, + ) + + # With 8 tokens and 2 steps, first step should commit 4 tokens + committed = out.transfer_index[0].sum().item() + self.assertEqual(committed, 4) + + def test_step_no_editing_by_default(self): + """Without editing_threshold, no non-mask tokens should be changed.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + vocab_size = 32 + sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long) + logits = torch.zeros(1, 4, vocab_size) + logits[0, :, 15] = 10.0 # predict token 15 for all positions + + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=31, + temperature=0.0, + editing_threshold=None, + return_dict=True, + ) + + self.assertFalse(out.editing_transfer_index.any().item()) + self.assertFalse(out.transfer_index[0, 0].item()) + self.assertFalse(out.transfer_index[0, 1].item()) + + def test_step_editing_replaces_tokens(self): + """With editing_threshold, non-mask tokens with high confidence and different prediction get replaced.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + vocab_size = 32 + sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long) + logits = torch.zeros(1, 4, vocab_size) + # Token 0: predict 50 (different from 10) with very high logit + logits[0, 0, 15] = 20.0 + # Token 1: predict 20 (same as current) + logits[0, 1, 20] = 20.0 + # Mask tokens + logits[0, 2, 5] = 5.0 + logits[0, 3, 6] = 5.0 + + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=31, + temperature=0.0, + editing_threshold=0.5, + return_dict=True, + ) + + # Token 0 should be edited (different prediction, high confidence) + self.assertTrue(out.editing_transfer_index[0, 0].item()) + # Token 1 should NOT be edited (same prediction) + self.assertFalse(out.editing_transfer_index[0, 1].item()) + + def test_step_prompt_mask_prevents_editing(self): + """Prompt positions should never be edited even with editing enabled.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + vocab_size = 32 + sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long) + logits = torch.zeros(1, 4, vocab_size) + logits[0, :, 15] = 20.0 + prompt_mask = torch.tensor([True, True, False, False]) + + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=31, + temperature=0.0, + editing_threshold=0.5, + prompt_mask=prompt_mask, + return_dict=True, + ) + + self.assertFalse(out.editing_transfer_index[0, 0].item()) + self.assertFalse(out.editing_transfer_index[0, 1].item()) + + def test_step_return_tuple(self): + """Verify tuple output when return_dict=False.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + vocab_size = 32 + sample = torch.full((1, 4), 31, dtype=torch.long) + logits = torch.randn(1, 4, vocab_size) + + result = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=31, + temperature=0.0, + return_dict=False, + ) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 5) + + def test_step_batched(self): + """Verify step works with batch_size > 1.""" + scheduler = self.get_scheduler(block_length=4) + scheduler.set_timesteps(2) + + batch_size, vocab_size = 3, 32 + mask_id = 31 + sample = torch.full((batch_size, 4), mask_id, dtype=torch.long) + logits = torch.randn(batch_size, 4, vocab_size) + + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + temperature=0.0, + return_dict=True, + ) + + self.assertEqual(out.prev_sample.shape, (batch_size, 4)) + self.assertEqual(out.transfer_index.shape, (batch_size, 4)) + + def test_check_block_should_continue_finished(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + finished = torch.tensor([True, True]) + result = scheduler.check_block_should_continue( + step_idx=0, + masks_remaining=True, + editing_enabled=False, + editing_transfer_index=torch.zeros(2, 32, dtype=torch.bool), + post_steps=0, + max_post_steps=16, + finished=finished, + ) + self.assertFalse(result) + + def test_check_block_should_continue_no_masks_no_edits(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + finished = torch.tensor([False]) + result = scheduler.check_block_should_continue( + step_idx=5, + masks_remaining=False, + editing_enabled=True, + editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool), + post_steps=1, + max_post_steps=16, + finished=finished, + ) + self.assertFalse(result) + + def test_check_block_should_continue_steps_exhausted(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + finished = torch.tensor([False]) + result = scheduler.check_block_should_continue( + step_idx=8, + masks_remaining=True, + editing_enabled=False, + editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool), + post_steps=0, + max_post_steps=16, + finished=finished, + ) + self.assertFalse(result) + + def test_check_eos_finished_marks_batch(self): + """When EOS is committed and all tokens before it are unmasked, mark batch as finished.""" + mask_id, eos_id, prompt_length = 99, 2, 2 + # cur_x: [prompt, prompt, token, eos, mask, mask] + cur_x = torch.tensor([[10, 11, 5, eos_id, mask_id, mask_id]], dtype=torch.long) + sampled_tokens = torch.tensor([[0, 0, 0, eos_id]], dtype=torch.long) + final_transfer = torch.tensor([[False, False, False, True]]) + finished = torch.tensor([False]) + + finished = BlockRefinementScheduler.check_eos_finished( + cur_x=cur_x, + sampled_tokens=sampled_tokens, + final_transfer=final_transfer, + finished=finished, + eos_token_id=eos_id, + mask_token_id=mask_id, + prompt_length=prompt_length, + ) + self.assertTrue(finished[0].item()) + + def test_check_eos_finished_ignores_when_masks_before_eos(self): + """If there are still mask tokens between prompt and EOS, don't mark as finished.""" + mask_id, eos_id, prompt_length = 99, 2, 2 + # cur_x: [prompt, prompt, mask, eos] — mask before EOS + cur_x = torch.tensor([[10, 11, mask_id, eos_id]], dtype=torch.long) + sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long) + final_transfer = torch.tensor([[False, True]]) + finished = torch.tensor([False]) + + finished = BlockRefinementScheduler.check_eos_finished( + cur_x=cur_x, + sampled_tokens=sampled_tokens, + final_transfer=final_transfer, + finished=finished, + eos_token_id=eos_id, + mask_token_id=mask_id, + prompt_length=prompt_length, + ) + self.assertFalse(finished[0].item()) + + def test_check_eos_finished_already_finished(self): + """Already-finished batches should stay finished.""" + mask_id, eos_id = 99, 2 + cur_x = torch.tensor([[10, 11, 5, 6]], dtype=torch.long) + sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long) + final_transfer = torch.tensor([[False, False]]) + finished = torch.tensor([True]) + + finished = BlockRefinementScheduler.check_eos_finished( + cur_x=cur_x, + sampled_tokens=sampled_tokens, + final_transfer=final_transfer, + finished=finished, + eos_token_id=eos_id, + mask_token_id=mask_id, + prompt_length=2, + ) + self.assertTrue(finished[0].item()) + + def test_add_noise(self): + scheduler = self.get_scheduler(block_length=4) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + mask_token_id = 99 + + gen = torch.Generator().manual_seed(42) + noisy, noisy_rev, masked, masked_rev = scheduler.add_noise( + input_ids, + attention_mask, + prompt_length=2, + block_length=4, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Prompt positions should never be masked + self.assertFalse(masked[0, 0].item()) + self.assertFalse(masked[0, 1].item()) + self.assertFalse(masked_rev[0, 0].item()) + self.assertFalse(masked_rev[0, 1].item()) + + # Noisy should have mask_token_id where masked is True + self.assertTrue((noisy[masked] == mask_token_id).all().item()) + self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item()) + + # masked and masked_rev should be complementary within valid non-prompt positions + non_prompt = torch.zeros_like(masked) + non_prompt[0, 2:] = True + combined = masked | masked_rev + self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item()) + + +class TestTopPFiltering(unittest.TestCase): + def test_top_p_filtering(self): + logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) + filtered = BlockRefinementScheduler._top_p_filtering(logits, top_p=0.5) + self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any()) + self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any()) + + def test_top_p_filtering_none(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = BlockRefinementScheduler._top_p_filtering(logits, top_p=None) + self.assertTrue(torch.equal(result, logits)) + + def test_top_p_filtering_one(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = BlockRefinementScheduler._top_p_filtering(logits, top_p=1.0) + self.assertTrue(torch.equal(result, logits)) + + +class TestTopKFiltering(unittest.TestCase): + def test_top_k_filtering(self): + logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]]) + filtered = BlockRefinementScheduler._top_k_filtering(logits, top_k=2) + self.assertAlmostEqual(filtered[0, 1].item(), 4.0) + self.assertAlmostEqual(filtered[0, 3].item(), 3.0) + self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min) + self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min) + + def test_top_k_filtering_none(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = BlockRefinementScheduler._top_k_filtering(logits, top_k=None) + self.assertTrue(torch.equal(result, logits)) + + def test_top_k_filtering_zero(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = BlockRefinementScheduler._top_k_filtering(logits, top_k=0) + self.assertTrue(torch.equal(result, logits)) + + def test_top_k_filtering_large_k(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + result = BlockRefinementScheduler._top_k_filtering(logits, top_k=100) + self.assertTrue(torch.equal(result, logits)) + + +class TestSampleFromLogits(unittest.TestCase): + def test_greedy_sampling(self): + logits = torch.tensor([[1.0, 5.0, 2.0]]) + tokens, probs = BlockRefinementScheduler._sample_from_logits( + logits, + temperature=0.0, + top_k=None, + top_p=None, + generator=None, + use_multinomial=False, + ) + self.assertEqual(tokens.item(), 1) + self.assertEqual(tokens.shape, (1,)) + self.assertEqual(probs.shape, (1,)) + + def test_multinomial_sampling(self): + logits = torch.tensor([[0.0, 100.0, -100.0]]) + gen = torch.Generator().manual_seed(42) + tokens, probs = BlockRefinementScheduler._sample_from_logits( + logits, + temperature=1.0, + top_k=None, + top_p=None, + generator=gen, + use_multinomial=True, + ) + self.assertEqual(tokens.item(), 1) + + def test_temperature_scaling(self): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + tokens, _ = BlockRefinementScheduler._sample_from_logits( + logits, + temperature=0.01, + top_k=None, + top_p=None, + generator=None, + use_multinomial=False, + ) + self.assertEqual(tokens.item(), 2) + + def test_negative_temperature_raises(self): + logits = torch.tensor([[1.0, 2.0]]) + with self.assertRaises(ValueError): + BlockRefinementScheduler._sample_from_logits( + logits, + temperature=-1.0, + top_k=None, + top_p=None, + generator=None, + use_multinomial=False, + ) + + +if __name__ == "__main__": + unittest.main() From 7b2c2acb2c5bfc29d376c97278d0b411271a5a1b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 25 Mar 2026 13:19:31 +0100 Subject: [PATCH 079/215] [LLADA2] documentation fixes (#13333) documentation fixes --- docs/source/en/_toctree.yml | 6 +++-- docs/source/en/api/pipelines/llada2.md | 25 ++++++++++++------- .../pipelines/llada2/pipeline_llada2.py | 10 ++++---- .../schedulers/scheduling_block_refinement.py | 7 +++--- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 394d539350d6..caaba0fa5e51 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -580,8 +580,6 @@ title: Latent Diffusion - local: api/pipelines/ledits_pp title: LEDITS++ - - local: api/pipelines/llada2 - title: LLaDA2 - local: api/pipelines/longcat_image title: LongCat-Image - local: api/pipelines/lumina2 @@ -672,6 +670,10 @@ - local: api/pipelines/z_image title: Z-Image title: Image + - sections: + - local: api/pipelines/llada2 + title: LLaDA2 + title: Text - sections: - local: api/pipelines/allegro title: Allegro diff --git a/docs/source/en/api/pipelines/llada2.md b/docs/source/en/api/pipelines/llada2.md index cf0fa0b0d7b6..94555f615c23 100644 --- a/docs/source/en/api/pipelines/llada2.md +++ b/docs/source/en/api/pipelines/llada2.md @@ -26,7 +26,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import BlockRefinementScheduler, LLaDA2Pipeline model_id = "inclusionAI/LLaDA2.1-mini" -model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto") +model = AutoModelForCausalLM.from_pretrained( + model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto" +) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) scheduler = BlockRefinementScheduler() @@ -46,18 +48,21 @@ print(output.texts[0]) ## Callbacks -Callbacks run after each refinement step and can inspect or modify the current tokens. +Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are +included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and +`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the +window. ```py def on_step_end(pipe, step, timestep, callback_kwargs): - cur_x = callback_kwargs["cur_x"] - # Inspect or modify `cur_x` here. - return {"cur_x": cur_x} + block_x = callback_kwargs["block_x"] + # Inspect or modify `block_x` here. + return {"block_x": block_x} out = pipe( prompt="Write a short poem.", callback_on_step_end=on_step_end, - callback_on_step_end_tensor_inputs=["cur_x"], + callback_on_step_end_tensor_inputs=["block_x"], ) ``` @@ -68,11 +73,13 @@ LLaDA2.1 models support two modes: | Mode | `threshold` | `editing_threshold` | `max_post_steps` | |------|-------------|---------------------|------------------| | Quality | 0.7 | 0.5 | 16 | -| Speed | 0.5 | 0.0 | 16 | +| Speed | 0.5 | `None` | 16 | + +Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing. -For LLaDA2.0 models, disable editing by passing `editing_threshold=None`. +For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`. -For all models: `block_length=32`, `temperature=0.0`, `steps=32`. +For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`. ## LLaDA2Pipeline [[autodoc]] LLaDA2Pipeline diff --git a/src/diffusers/pipelines/llada2/pipeline_llada2.py b/src/diffusers/pipelines/llada2/pipeline_llada2.py index d4b037ada151..a6ba6e8ff689 100644 --- a/src/diffusers/pipelines/llada2/pipeline_llada2.py +++ b/src/diffusers/pipelines/llada2/pipeline_llada2.py @@ -273,10 +273,10 @@ def __call__( threshold (`float`): Confidence threshold for committing tokens. editing_threshold (`float`, *optional*): - Confidence threshold for editing already-committed (non-mask) tokens. When set, after all mask tokens - in a block are resolved, the pipeline continues refining: if the model predicts a different token with - confidence above this threshold, the existing token is replaced. Set to `None` or a negative value to - disable editing. Defaults to `0.5`. + Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask + tokens in a block are resolved, the pipeline continues refining: if the model predicts a different + token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a + negative value to disable editing. Defaults to `0.5`. max_post_steps (`int`): Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only used when `editing_threshold` is enabled. Defaults to `16`. @@ -373,7 +373,7 @@ def __call__( self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0) finished = torch.zeros((batch_size,), device=device, dtype=torch.bool) - editing_enabled = editing_threshold is not None and editing_threshold >= 0.0 + editing_enabled = editing_threshold is not None and editing_threshold > 0.0 global_step = 0 # 5. Block-wise refinement loop diff --git a/src/diffusers/schedulers/scheduling_block_refinement.py b/src/diffusers/schedulers/scheduling_block_refinement.py index 5717cee7f8a8..296ad1b6a5fe 100644 --- a/src/diffusers/schedulers/scheduling_block_refinement.py +++ b/src/diffusers/schedulers/scheduling_block_refinement.py @@ -57,7 +57,7 @@ class BlockRefinementScheduler(SchedulerMixin, ConfigMixin): the number of refinement steps. Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a - different token with confidence above `editing_threshold`. + different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing). """ order = 1 @@ -208,7 +208,8 @@ def step( threshold (`float`, *optional*): Confidence threshold for committing tokens. Defaults to config value. editing_threshold (`float`, *optional*): - Confidence threshold for editing non-mask tokens. Defaults to config value. + Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to + config value. minimal_topk (`int`, *optional*): Minimum tokens to commit per step. Defaults to config value. prompt_mask (`torch.BoolTensor`, *optional*): @@ -268,7 +269,7 @@ def step( transfer_index[b, idx] = True # --- Editing transfer (non-mask, non-prompt positions) --- - editing_enabled = editing_threshold is not None and editing_threshold >= 0.0 + editing_enabled = editing_threshold is not None and editing_threshold > 0.0 editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool) if editing_enabled: if prompt_mask is None: From 182228392234aee803f1a22f60a07038cf87d930 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 25 Mar 2026 21:30:06 +0530 Subject: [PATCH 080/215] [ci] claude in ci. (#13297) * claude in ci. * review feedback. --- .ai/review-rules.md | 11 +++++++++ .github/workflows/claude_review.yml | 38 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 .ai/review-rules.md create mode 100644 .github/workflows/claude_review.yml diff --git a/.ai/review-rules.md b/.ai/review-rules.md new file mode 100644 index 000000000000..12efc94c4b61 --- /dev/null +++ b/.ai/review-rules.md @@ -0,0 +1,11 @@ +# PR Review Rules + +Review-specific rules for Claude. Focus on correctness — style is handled by ruff. + +Before reviewing, read and apply the guidelines in: +- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions +- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas +- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities +- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.) + +## Common mistakes (add new rules below this line) diff --git a/.github/workflows/claude_review.yml b/.github/workflows/claude_review.yml new file mode 100644 index 000000000000..e772415d6322 --- /dev/null +++ b/.github/workflows/claude_review.yml @@ -0,0 +1,38 @@ +name: Claude PR Review + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + +permissions: + contents: write + pull-requests: write + issues: read + +jobs: + claude-review: + if: | + ( + github.event_name == 'issue_comment' && + github.event.issue.pull_request && + github.event.issue.state == 'open' && + contains(github.event.comment.body, '@claude') && + (github.event.comment.author_association == 'MEMBER' || + github.event.comment.author_association == 'OWNER' || + github.event.comment.author_association == 'COLLABORATOR') + ) || ( + github.event_name == 'pull_request_review_comment' && + contains(github.event.comment.body, '@claude') && + (github.event.comment.author_association == 'MEMBER' || + github.event.comment.author_association == 'OWNER' || + github.event.comment.author_association == 'COLLABORATOR') + ) + runs-on: ubuntu-latest + steps: + - uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + claude_args: | + --append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/. Do NOT commit changes unless the comment explicitly asks you to using the phrase 'commit this'." From 746fe5c0dbb565ff75592c8ea58e3288f7f6e639 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Wed, 25 Mar 2026 09:31:54 -0700 Subject: [PATCH 081/215] [docs] kernels (#13139) * kernels * feedback --- docs/source/en/optimization/fp16.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md index 941f53604cec..0e427d3a0afb 100644 --- a/docs/source/en/optimization/fp16.md +++ b/docs/source/en/optimization/fp16.md @@ -248,6 +248,24 @@ Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/be The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX. +## Kernels + +[Kernels](https://huggingface.co/docs/kernels/index) is a library for building, distributing, and loading optimized compute kernels on the [Hub](https://huggingface.co/kernels-community). It supports [attention](./attention_backends#set_attention_backend) kernels and custom CUDA kernels for operations like RMSNorm, GEGLU, RoPE, and AdaLN. + +The [Diffusers Pipeline Integration](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/references/diffusers-integration.md) guide shows how to integrate a kernel with the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill. This skill enables an agent, like Claude or Codex, to write custom kernels targeted towards a specific model and your hardware. + +> [!TIP] +> Install the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill to teach an agent how to write a kernel. The [Custom kernels for all from Codex and Claude](https://huggingface.co/blog/custom-cuda-kernels-agent-skills) blog post covers this in more detail. + +For example, a custom RMSNorm kernel (generated by the `add cuda-kernels` skill) with [torch.compile](#torchcompile) speeds up LTX-Video generation 1.43x on an H100. + + + ## Dynamic quantization [Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data. From b367069836ffbef1375ca108d3bb53abc538abb9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 26 Mar 2026 08:48:16 +0530 Subject: [PATCH 082/215] [tests] Tests for conditional pipeline blocks (#13247) * implement test suite for conditional blocks. * remove * another fix. * Revert "another fix." This reverts commit ab07b603abefcb1a0c6137c222327ff129332c2a. --- .../test_conditional_pipeline_blocks.py | 242 ++++++++++++++++++ .../test_modular_pipelines_common.py | 117 --------- .../test_modular_pipelines_custom_blocks.py | 117 ++++++++- 3 files changed, 358 insertions(+), 118 deletions(-) create mode 100644 tests/modular_pipelines/test_conditional_pipeline_blocks.py diff --git a/tests/modular_pipelines/test_conditional_pipeline_blocks.py b/tests/modular_pipelines/test_conditional_pipeline_blocks.py new file mode 100644 index 000000000000..5d9a7fe7d2d3 --- /dev/null +++ b/tests/modular_pipelines/test_conditional_pipeline_blocks.py @@ -0,0 +1,242 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from diffusers.modular_pipelines import ( + AutoPipelineBlocks, + ConditionalPipelineBlocks, + InputParam, + ModularPipelineBlocks, +) + + +class TextToImageBlock(ModularPipelineBlocks): + model_name = "text2img" + + @property + def inputs(self): + return [InputParam(name="prompt")] + + @property + def intermediate_outputs(self): + return [] + + @property + def description(self): + return "text-to-image workflow" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + block_state.workflow = "text2img" + self.set_block_state(state, block_state) + return components, state + + +class ImageToImageBlock(ModularPipelineBlocks): + model_name = "img2img" + + @property + def inputs(self): + return [InputParam(name="prompt"), InputParam(name="image")] + + @property + def intermediate_outputs(self): + return [] + + @property + def description(self): + return "image-to-image workflow" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + block_state.workflow = "img2img" + self.set_block_state(state, block_state) + return components, state + + +class InpaintBlock(ModularPipelineBlocks): + model_name = "inpaint" + + @property + def inputs(self): + return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")] + + @property + def intermediate_outputs(self): + return [] + + @property + def description(self): + return "inpaint workflow" + + def __call__(self, components, state): + block_state = self.get_block_state(state) + block_state.workflow = "inpaint" + self.set_block_state(state, block_state) + return components, state + + +class ConditionalImageBlocks(ConditionalPipelineBlocks): + block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image"] + default_block_name = "text2img" + + @property + def description(self): + return "Conditional image blocks for testing" + + def select_block(self, mask=None, image=None) -> str | None: + if mask is not None: + return "inpaint" + if image is not None: + return "img2img" + return None # falls back to default_block_name + + +class OptionalConditionalBlocks(ConditionalPipelineBlocks): + block_classes = [InpaintBlock, ImageToImageBlock] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask", "image"] + default_block_name = None # no default; block can be skipped + + @property + def description(self): + return "Optional conditional blocks (skippable)" + + def select_block(self, mask=None, image=None) -> str | None: + if mask is not None: + return "inpaint" + if image is not None: + return "img2img" + return None + + +class AutoImageBlocks(AutoPipelineBlocks): + block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image", None] + + @property + def description(self): + return "Auto image blocks for testing" + + +class TestConditionalPipelineBlocksSelectBlock: + def test_select_block_with_mask(self): + blocks = ConditionalImageBlocks() + assert blocks.select_block(mask="something") == "inpaint" + + def test_select_block_with_image(self): + blocks = ConditionalImageBlocks() + assert blocks.select_block(image="something") == "img2img" + + def test_select_block_with_mask_and_image(self): + blocks = ConditionalImageBlocks() + assert blocks.select_block(mask="m", image="i") == "inpaint" + + def test_select_block_no_triggers_returns_none(self): + blocks = ConditionalImageBlocks() + assert blocks.select_block() is None + + def test_select_block_explicit_none_values(self): + blocks = ConditionalImageBlocks() + assert blocks.select_block(mask=None, image=None) is None + + +class TestConditionalPipelineBlocksWorkflowSelection: + def test_default_workflow_when_no_triggers(self): + blocks = ConditionalImageBlocks() + execution = blocks.get_execution_blocks() + assert execution is not None + assert isinstance(execution, TextToImageBlock) + + def test_mask_trigger_selects_inpaint(self): + blocks = ConditionalImageBlocks() + execution = blocks.get_execution_blocks(mask=True) + assert isinstance(execution, InpaintBlock) + + def test_image_trigger_selects_img2img(self): + blocks = ConditionalImageBlocks() + execution = blocks.get_execution_blocks(image=True) + assert isinstance(execution, ImageToImageBlock) + + def test_mask_and_image_selects_inpaint(self): + blocks = ConditionalImageBlocks() + execution = blocks.get_execution_blocks(mask=True, image=True) + assert isinstance(execution, InpaintBlock) + + def test_skippable_block_returns_none(self): + blocks = OptionalConditionalBlocks() + execution = blocks.get_execution_blocks() + assert execution is None + + def test_skippable_block_still_selects_when_triggered(self): + blocks = OptionalConditionalBlocks() + execution = blocks.get_execution_blocks(image=True) + assert isinstance(execution, ImageToImageBlock) + + +class TestAutoPipelineBlocksSelectBlock: + def test_auto_select_mask(self): + blocks = AutoImageBlocks() + assert blocks.select_block(mask="m") == "inpaint" + + def test_auto_select_image(self): + blocks = AutoImageBlocks() + assert blocks.select_block(image="i") == "img2img" + + def test_auto_select_default(self): + blocks = AutoImageBlocks() + # No trigger -> returns None -> falls back to default (text2img) + assert blocks.select_block() is None + + def test_auto_select_priority_order(self): + blocks = AutoImageBlocks() + assert blocks.select_block(mask="m", image="i") == "inpaint" + + +class TestAutoPipelineBlocksWorkflowSelection: + def test_auto_default_workflow(self): + blocks = AutoImageBlocks() + execution = blocks.get_execution_blocks() + assert isinstance(execution, TextToImageBlock) + + def test_auto_mask_workflow(self): + blocks = AutoImageBlocks() + execution = blocks.get_execution_blocks(mask=True) + assert isinstance(execution, InpaintBlock) + + def test_auto_image_workflow(self): + blocks = AutoImageBlocks() + execution = blocks.get_execution_blocks(image=True) + assert isinstance(execution, ImageToImageBlock) + + +class TestConditionalPipelineBlocksStructure: + def test_block_names_accessible(self): + blocks = ConditionalImageBlocks() + sub = dict(blocks.sub_blocks) + assert set(sub.keys()) == {"inpaint", "img2img", "text2img"} + + def test_sub_block_types(self): + blocks = ConditionalImageBlocks() + sub = dict(blocks.sub_blocks) + assert isinstance(sub["inpaint"], InpaintBlock) + assert isinstance(sub["img2img"], ImageToImageBlock) + assert isinstance(sub["text2img"], TextToImageBlock) + + def test_description(self): + blocks = ConditionalImageBlocks() + assert "Conditional" in blocks.description diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 8a65999b2006..8b212c0cbf4e 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -10,11 +10,6 @@ import diffusers from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks from diffusers.guiders import ClassifierFreeGuidance -from diffusers.modular_pipelines import ( - ConditionalPipelineBlocks, - LoopSequentialPipelineBlocks, - SequentialPipelineBlocks, -) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, ConfigSpec, @@ -25,7 +20,6 @@ from diffusers.utils import logging from ..testing_utils import ( - CaptureLogger, backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, @@ -498,117 +492,6 @@ def test_guider_cfg(self, expected_max_diff=1e-2): assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference" -class TestCustomBlockRequirements: - def get_dummy_block_pipe(self): - class DummyBlockOne: - # keep two arbitrary deps so that we can test warnings. - _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} - - class DummyBlockTwo: - # keep two dependencies that will be available during testing. - _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} - - pipe = SequentialPipelineBlocks.from_blocks_dict( - {"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo} - ) - return pipe - - def get_dummy_conditional_block_pipe(self): - class DummyBlockOne: - _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} - - class DummyBlockTwo: - _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} - - class DummyConditionalBlocks(ConditionalPipelineBlocks): - block_classes = [DummyBlockOne, DummyBlockTwo] - block_names = ["block_one", "block_two"] - block_trigger_inputs = [] - - def select_block(self, **kwargs): - return "block_one" - - return DummyConditionalBlocks() - - def get_dummy_loop_block_pipe(self): - class DummyBlockOne: - _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} - - class DummyBlockTwo: - _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} - - return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo}) - - def test_sequential_block_requirements_save_load(self, tmp_path): - pipe = self.get_dummy_block_pipe() - pipe.save_pretrained(str(tmp_path)) - - config_path = tmp_path / "modular_config.json" - - with open(config_path, "r") as f: - config = json.load(f) - - assert "requirements" in config - requirements = config["requirements"] - - expected_requirements = { - "xyz": ">=0.8.0", - "abc": ">=10.0.0", - "transformers": ">=4.44.0", - "diffusers": ">=0.2.0", - } - assert expected_requirements == requirements - - def test_sequential_block_requirements_warnings(self, tmp_path): - pipe = self.get_dummy_block_pipe() - - logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils") - logger.setLevel(30) - - with CaptureLogger(logger) as cap_logger: - pipe.save_pretrained(str(tmp_path)) - - template = "{req} was specified in the requirements but wasn't found in the current environment" - msg_xyz = template.format(req="xyz") - msg_abc = template.format(req="abc") - assert msg_xyz in str(cap_logger.out) - assert msg_abc in str(cap_logger.out) - - def test_conditional_block_requirements_save_load(self, tmp_path): - pipe = self.get_dummy_conditional_block_pipe() - pipe.save_pretrained(str(tmp_path)) - - config_path = tmp_path / "modular_config.json" - with open(config_path, "r") as f: - config = json.load(f) - - assert "requirements" in config - expected_requirements = { - "xyz": ">=0.8.0", - "abc": ">=10.0.0", - "transformers": ">=4.44.0", - "diffusers": ">=0.2.0", - } - assert expected_requirements == config["requirements"] - - def test_loop_block_requirements_save_load(self, tmp_path): - pipe = self.get_dummy_loop_block_pipe() - pipe.save_pretrained(str(tmp_path)) - - config_path = tmp_path / "modular_config.json" - with open(config_path, "r") as f: - config = json.load(f) - - assert "requirements" in config - expected_requirements = { - "xyz": ">=0.8.0", - "abc": ">=10.0.0", - "transformers": ">=4.44.0", - "diffusers": ">=0.2.0", - } - assert expected_requirements == config["requirements"] - - class TestModularModelCardContent: def create_mock_block(self, name="TestBlock", description="Test block description"): class MockBlock: diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 59d6a3e75f55..315e16d7b260 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -24,14 +24,18 @@ from diffusers import FluxTransformer2DModel from diffusers.modular_pipelines import ( ComponentSpec, + ConditionalPipelineBlocks, InputParam, + LoopSequentialPipelineBlocks, ModularPipelineBlocks, OutputParam, PipelineState, + SequentialPipelineBlocks, WanModularPipeline, ) +from diffusers.utils import logging -from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device +from ..testing_utils import CaptureLogger, nightly, require_torch, require_torch_accelerator, slow, torch_device def _create_tiny_model_dir(model_dir): @@ -463,6 +467,117 @@ def test_custom_block_loads_from_hub(self): assert output_prompt.startswith("Modular diffusers + ") +class TestCustomBlockRequirements: + def get_dummy_block_pipe(self): + class DummyBlockOne: + # keep two arbitrary deps so that we can test warnings. + _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} + + class DummyBlockTwo: + # keep two dependencies that will be available during testing. + _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} + + pipe = SequentialPipelineBlocks.from_blocks_dict( + {"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo} + ) + return pipe + + def get_dummy_conditional_block_pipe(self): + class DummyBlockOne: + _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} + + class DummyBlockTwo: + _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} + + class DummyConditionalBlocks(ConditionalPipelineBlocks): + block_classes = [DummyBlockOne, DummyBlockTwo] + block_names = ["block_one", "block_two"] + block_trigger_inputs = [] + + def select_block(self, **kwargs): + return "block_one" + + return DummyConditionalBlocks() + + def get_dummy_loop_block_pipe(self): + class DummyBlockOne: + _requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"} + + class DummyBlockTwo: + _requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"} + + return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo}) + + def test_sequential_block_requirements_save_load(self, tmp_path): + pipe = self.get_dummy_block_pipe() + pipe.save_pretrained(str(tmp_path)) + + config_path = tmp_path / "modular_config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + assert "requirements" in config + requirements = config["requirements"] + + expected_requirements = { + "xyz": ">=0.8.0", + "abc": ">=10.0.0", + "transformers": ">=4.44.0", + "diffusers": ">=0.2.0", + } + assert expected_requirements == requirements + + def test_sequential_block_requirements_warnings(self, tmp_path): + pipe = self.get_dummy_block_pipe() + + logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils") + logger.setLevel(30) + + with CaptureLogger(logger) as cap_logger: + pipe.save_pretrained(str(tmp_path)) + + template = "{req} was specified in the requirements but wasn't found in the current environment" + msg_xyz = template.format(req="xyz") + msg_abc = template.format(req="abc") + assert msg_xyz in str(cap_logger.out) + assert msg_abc in str(cap_logger.out) + + def test_conditional_block_requirements_save_load(self, tmp_path): + pipe = self.get_dummy_conditional_block_pipe() + pipe.save_pretrained(str(tmp_path)) + + config_path = tmp_path / "modular_config.json" + with open(config_path, "r") as f: + config = json.load(f) + + assert "requirements" in config + expected_requirements = { + "xyz": ">=0.8.0", + "abc": ">=10.0.0", + "transformers": ">=4.44.0", + "diffusers": ">=0.2.0", + } + assert expected_requirements == config["requirements"] + + def test_loop_block_requirements_save_load(self, tmp_path): + pipe = self.get_dummy_loop_block_pipe() + pipe.save_pretrained(str(tmp_path)) + + config_path = tmp_path / "modular_config.json" + with open(config_path, "r") as f: + config = json.load(f) + + assert "requirements" in config + expected_requirements = { + "xyz": ">=0.8.0", + "abc": ">=10.0.0", + "transformers": ">=4.44.0", + "diffusers": ">=0.2.0", + } + assert expected_requirements == config["requirements"] + + @slow @nightly @require_torch From f797f175c8efbc340f6b64a59a2d441e758b95c7 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Thu, 26 Mar 2026 15:10:53 +0800 Subject: [PATCH 083/215] avoid hardcode device in flux-control example (#13336) Signed-off-by: Liu, Kaixuan --- examples/flux-control/train_control_flux.py | 2 +- examples/flux-control/train_control_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index c5f93fa2e987..5c817751038d 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -1105,7 +1105,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # text encoding. captions = batch["captions"] - text_encoding_pipeline = text_encoding_pipeline.to("cuda") + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) with torch.no_grad(): prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( captions, prompt_2=None diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index f5d3c822b3ef..f372284d7abc 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -1251,7 +1251,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # text encoding. captions = batch["captions"] - text_encoding_pipeline = text_encoding_pipeline.to("cuda") + text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) with torch.no_grad(): prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( captions, prompt_2=None From b4703862e4add79561df1140c713229a32776857 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 26 Mar 2026 15:39:10 +0530 Subject: [PATCH 084/215] fix claude workflow to include id-token with write. (#13338) --- .github/workflows/claude_review.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/claude_review.yml b/.github/workflows/claude_review.yml index e772415d6322..2df3b47eb1c5 100644 --- a/.github/workflows/claude_review.yml +++ b/.github/workflows/claude_review.yml @@ -10,6 +10,7 @@ permissions: contents: write pull-requests: write issues: read + id-token: write jobs: claude-review: From de3efa879db353d25e6886333a374d43eb7a5742 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Thu, 26 Mar 2026 17:51:29 -0700 Subject: [PATCH 085/215] Update LTX-2 Docs to Cover LTX-2.3 Models (#13337) * Update LTX-2 docs to cover multimodal guidance and prompt enhancement * Apply suggestions from code review Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply reviewer feedback --------- Co-authored-by: Sayak Paul Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/ltx2.md | 166 +++++++++++++++++++++++--- src/diffusers/pipelines/ltx2/utils.py | 149 +++++++++++++++++++++++ 2 files changed, 300 insertions(+), 15 deletions(-) diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index 85b0f9691891..bcddd40e6691 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -18,7 +18,7 @@ LoRA -LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution. +[LTX-2](https://hf.co/papers/2601.03233) is a DiT-based foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution. You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization. @@ -293,6 +293,7 @@ import torch from diffusers import LTX2ConditionPipeline from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT from diffusers.utils import load_image, load_video device = "cuda" @@ -315,19 +316,6 @@ prompt = ( "landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the " "solitude and beauty of a winter drive through a mountainous region." ) -negative_prompt = ( - "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " - "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " - "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " - "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " - "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " - "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " - "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " - "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " - "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " - "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " - "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -) cond_video = load_video( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" @@ -343,7 +331,7 @@ frame_rate = 24.0 video, audio = pipe( conditions=conditions, prompt=prompt, - negative_prompt=negative_prompt, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, width=width, height=height, num_frames=121, @@ -366,6 +354,154 @@ encode_video( Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static. +## Multimodal Guidance + +LTX-2.X pipelines support multimodal guidance. It is composed of three terms, all using a CFG-style update rule: + +1. Classifier-Free Guidance (CFG): standard [CFG](https://huggingface.co/papers/2207.12598) where the perturbed ("weaker") output is generated using the negative prompt. +2. Spatio-Temporal Guidance (STG): [STG](https://huggingface.co/papers/2411.18664) moves away from a perturbed output created from short-cutting self-attention operations and substitutes in the attention values instead. The idea is that this creates sharper videos and better spatiotemporal consistency. +3. Modality Isolation Guidance: moves away from a perturbed output created from disabling cross-modality (audio-to-video and video-to-audio) cross attention. This guidance is more specific to [LTX-2.X](https://huggingface.co/papers/2601.03233) models, with the idea that this produces better consistency between the generated audio and video. + +These are controlled by the `guidance_scale`, `stg_scale`, and `modality_scale` arguments and can be set separately for video and audio. Additionally, for STG the transformer block indices where self-attention is skipped needs to be specified via the `spatio_temporal_guidance_blocks` argument. The LTX-2.X pipelines also support [guidance rescaling](https://huggingface.co/papers/2305.08891) to help reduce over-exposure, which can be a problem when the guidance scales are set to high values. + +```py +import torch +from diffusers import LTX2ImageToVideoPipeline +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT +from diffusers.utils import load_image + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +frame_rate = 24.0 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "dg845/LTX-2.3-Diffusers" + +pipe = LTX2ImageToVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) +pipe.enable_sequential_cpu_offload(device=device) +pipe.vae.enable_tiling() + +prompt = ( + "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in " + "gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs " + "before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small " + "fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly " + "shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a " + "smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the " + "distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a " + "breath-taking, movie-like shot." +) + +image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg", +) + +video, audio = pipe( + image=image, + prompt=prompt, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=30, + guidance_scale=3.0, # Recommended LTX-2.3 guidance parameters + stg_scale=1.0, # Note that 0.0 (not 1.0) means that STG is disabled (all other guidance is disabled at 1.0) + modality_scale=3.0, + guidance_rescale=0.7, + audio_guidance_scale=7.0, # Note that a higher CFG guidance scale is recommended for audio + audio_stg_scale=1.0, + audio_modality_scale=3.0, + audio_guidance_rescale=0.7, + spatio_temporal_guidance_blocks=[28], + use_cross_timestep=True, + generator=generator, + output_type="np", + return_dict=False, +) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_3_i2v_stage_1.mp4", +) +``` + +## Prompt Enhancement + +The LTX-2.X models are sensitive to prompting style. Refer to the [official prompting guide](https://ltx.io/model/model-blog/prompting-guide-for-ltx-2) for recommendations on how to write a good prompt. Using prompt enhancement, where the supplied prompts are enhanced using the pipeline's text encoder (by default a [Gemma 3](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized) model) given a system prompt, can also improve sample quality. The optional `processor` pipeline component needs to be present to use prompt enhancement. Enable prompt enhancement by supplying a `system_prompt` argument: + + +```py +import torch +from transformers import Gemma3Processor +from diffusers import LTX2Pipeline +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT, T2V_DEFAULT_SYSTEM_PROMPT + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +frame_rate = 24.0 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "dg845/LTX-2.3-Diffusers" + +pipe = LTX2Pipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) +pipe.enable_model_cpu_offload(device=device) +pipe.vae.enable_tiling() +if getattr(pipe, "processor", None) is None: + processor = Gemma3Processor.from_pretrained("google/gemma-3-12b-it-qat-q4_0-unquantized") + pipe.processor = processor + +prompt = ( + "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in " + "gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs " + "before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small " + "fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly " + "shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a " + "smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the " + "distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a " + "breath-taking, movie-like shot." +) + +video, audio = pipe( + prompt=prompt, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=30, + guidance_scale=3.0, + stg_scale=1.0, + modality_scale=3.0, + guidance_rescale=0.7, + audio_guidance_scale=7.0, + audio_stg_scale=1.0, + audio_modality_scale=3.0, + audio_guidance_rescale=0.7, + spatio_temporal_guidance_blocks=[28], + use_cross_timestep=True, + system_prompt=T2V_DEFAULT_SYSTEM_PROMPT, + generator=generator, + output_type="np", + return_dict=False, +) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_3_t2v_stage_1.mp4", +) +``` + ## LTX2Pipeline [[autodoc]] LTX2Pipeline diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index f80469817fe6..52d446c46819 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,6 +1,155 @@ +# Copyright 2026 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Pre-trained sigma values for distilled model are taken from # https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] # Reduced schedule for super-resolution stage 2 (subset of distilled values) STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] + + +# Default negative prompt from +# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py#L131-L143 +DEFAULT_NEGATIVE_PROMPT = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) + + +# System prompts for prompt enhancement +# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt#L1 +# Disable line-too-long rule in ruff to keep the prompts exactly the same (e.g. in terms of newlines) +# Supported in ruff>=0.15.0 +# ruff: disable[E501] +T2V_DEFAULT_SYSTEM_PROMPT = """ +You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed +video generation prompt with specific visuals and integrated audio to guide a text-to-video model. + +#### Guidelines +- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, + actions, camera movement, audio). + - If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc. + - For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters. +- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural + movements. +- Maintain chronological flow: use temporal connectors ("as," "then," "while"). +- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). + Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., + "ambient sound is present"). +- Speech (only when requested): + - For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with + voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'"). + - Specify language if not English and accent if relevant. +- Style: Include visual style at the beginning: "Style: