diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e9441ef71a31..9d0968a73801 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -463,6 +463,8 @@ "QwenImageLayeredAutoBlocks", "QwenImageLayeredModularPipeline", "QwenImageModularPipeline", + "StableDiffusion3AutoBlocks", + "StableDiffusion3ModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", "Wan22Blocks", @@ -1242,6 +1244,8 @@ QwenImageLayeredAutoBlocks, QwenImageLayeredModularPipeline, QwenImageModularPipeline, + StableDiffusion3AutoBlocks, + StableDiffusion3ModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, Wan22Blocks, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index fd9bd691ca87..ea10761af6ba 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -46,6 +46,7 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] + _import_structure["stable_diffusion_3"] = ["StableDiffusion3AutoBlocks", "StableDiffusion3ModularPipeline"] _import_structure["wan"] = [ "WanBlocks", "Wan22Blocks", @@ -140,6 +141,7 @@ QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) + from .stable_diffusion_3 import StableDiffusion3AutoBlocks, StableDiffusion3ModularPipeline from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import ( Wan22Blocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 9cd2f9f5c6ae..25fc5baa6779 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -119,6 +119,7 @@ def _helios_pyramid_map_fn(config_dict=None): MODULAR_PIPELINE_MAPPING = OrderedDict( [ ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), + ("stable-diffusion-3", _create_default_map_fn("StableDiffusion3ModularPipeline")), ("wan", _wan_map_fn), ("wan-i2v", _wan_i2v_map_fn), ("flux", _create_default_map_fn("FluxModularPipeline")), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py new file mode 100644 index 000000000000..51cb69ed1e8b --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_stable_diffusion_3"] = [ + "StableDiffusion3AutoBlocks" + ] + _import_structure["modular_pipeline"] = ["StableDiffusion3ModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_stable_diffusion_3 import StableDiffusion3AutoBlocks + from .modular_pipeline import StableDiffusion3ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py new file mode 100644 index 000000000000..462f2d93d97d --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -0,0 +1,450 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusion3ModularPipeline + +logger = logging.get_logger(__name__) + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def _get_initial_timesteps_and_optionals( + transformer, + scheduler, + height, + width, + patch_size, + vae_scale_factor, + num_inference_steps, + sigmas, + device, + mu=None, +): + scheduler_kwargs = {} + if scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (height // vae_scale_factor // patch_size) * ( + width // vae_scale_factor // patch_size + ) + mu = calculate_shift( + image_seq_len, + scheduler.config.get("base_image_seq_len", 256), + scheduler.config.get("max_image_seq_len", 4096), + scheduler.config.get("base_shift", 0.5), + scheduler.config.get("max_shift", 1.16), + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + ) + return timesteps, num_inference_steps + + +class StableDiffusion3SetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "num_inference_steps", + default=50, + description="The number of denoising steps.", + ), + InputParam( + "timesteps", + description="Custom timesteps to use for the denoising process.", + ), + InputParam( + "sigmas", description="Custom sigmas to use for the denoising process." + ), + InputParam( + "height", + type_hint=int, + description="The height in pixels of the generated image.", + ), + InputParam( + "width", + type_hint=int, + description="The width in pixels of the generated image.", + ), + InputParam( + "mu", + type_hint=float, + description="The mu value used for dynamic shifting. If not provided, it is dynamically calculated.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "timesteps", + type_hint=torch.Tensor, + description="The timesteps schedule for the denoising process.", + ), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The final number of inference steps.", + ), + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + timesteps, num_inference_steps = _get_initial_timesteps_and_optionals( + components.transformer, + components.scheduler, + block_state.height, + block_state.width, + components.patch_size, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.sigmas, + block_state.device, + getattr(block_state, "mu", None), + ) + + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusion3Img2ImgSetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for img2img inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "num_inference_steps", + default=50, + description="The number of denoising steps.", + ), + InputParam( + "timesteps", + description="Custom timesteps to use for the denoising process.", + ), + InputParam( + "sigmas", description="Custom sigmas to use for the denoising process." + ), + InputParam( + "strength", + default=0.6, + description="Indicates extent to transform the reference image.", + ), + InputParam( + "height", + type_hint=int, + description="The height in pixels of the generated image.", + ), + InputParam( + "width", + type_hint=int, + description="The width in pixels of the generated image.", + ), + InputParam( + "mu", + type_hint=float, + description="The mu value used for dynamic shifting. If not provided, it is dynamically calculated.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "timesteps", + type_hint=torch.Tensor, + description="The timesteps schedule for the denoising process.", + ), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The final number of inference steps.", + ), + ] + + @staticmethod + def get_timesteps(scheduler, num_inference_steps, strength): + init_timestep = min(num_inference_steps * strength, num_inference_steps) + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + timesteps, num_inference_steps = _get_initial_timesteps_and_optionals( + components.transformer, + components.scheduler, + block_state.height, + block_state.width, + components.patch_size, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.sigmas, + block_state.device, + getattr(block_state, "mu", None), + ) + + timesteps, num_inference_steps = self.get_timesteps( + components.scheduler, num_inference_steps, block_state.strength + ) + + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusion3PrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Prepare latents step for Text-to-Image" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "height", + type_hint=int, + description="The height in pixels of the generated image.", + ), + InputParam( + "width", + type_hint=int, + description="The width in pixels of the generated image.", + ), + InputParam( + "latents", + type_hint=torch.Tensor | None, + description="Pre-generated noisy latents to be used as inputs for image generation.", + ), + InputParam( + "num_images_per_prompt", + type_hint=int, + default=1, + description="The number of images to generate per prompt.", + ), + InputParam( + "generator", + description="One or a list of torch generator(s) to make generation deterministic.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="The batch size for latent generation.", + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The data type for the latents.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The prepared latent tensors to be denoised.", + ) + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + if block_state.latents is not None: + block_state.latents = block_state.latents.to( + device=block_state.device, dtype=block_state.dtype + ) + else: + shape = ( + batch_size, + components.num_channels_latents, + int(block_state.height) // components.vae_scale_factor, + int(block_state.width) // components.vae_scale_factor, + ) + block_state.latents = randn_tensor( + shape, + generator=block_state.generator, + device=block_state.device, + dtype=block_state.dtype, + ) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusion3Img2ImgPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to be scaled by the scheduler.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The image latents encoded by the VAE.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps schedule.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "initial_noise", + type_hint=torch.Tensor, + description="The initial noise applied to the image latents.", + ) + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + block_state.initial_noise = block_state.latents + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py new file mode 100644 index 000000000000..079181635e24 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -0,0 +1,80 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + +logger = logging.get_logger(__name__) + + +class StableDiffusion3DecodeStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "output_type", + default="pil", + description="The output format of the generated image (e.g., 'pil', 'pt', 'np').", + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to be decoded.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("images", type_hint=list[PIL.Image.Image] | torch.Tensor)] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + if not block_state.output_type == "latent": + latents = ( + block_state.latents / vae.config.scaling_factor + ) + vae.config.shift_factor + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + else: + block_state.images = block_state.latents + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py new file mode 100644 index 000000000000..f5886d4ac40e --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -0,0 +1,268 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusion3ModularPipeline + +logger = logging.get_logger(__name__) + + +class StableDiffusion3LoopDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", SD3Transformer2DModel), + ] + + @property + def description(self) -> str: + return "Step within the denoising loop that denoises the latents." + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam( + "joint_attention_kwargs", + type_hint=dict, + description="A kwargs dictionary passed along to the AttentionProcessor.", + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process.", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings for guidance.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pooled text embeddings for guidance.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Negative text embeddings for guidance.", + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Negative pooled text embeddings for guidance.", + ), + InputParam( + "guidance_scale", + default=7.0, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance.", + ), + InputParam( + "skip_layer_guidance_scale", + default=2.8, + description="The scale of the guidance for the skipped layers.", + ), + InputParam( + "skip_layer_guidance_stop", + default=0.2, + description="The step fraction at which the guidance for skipped layers stops.", + ), + InputParam( + "skip_layer_guidance_start", + default=0.01, + description="The step fraction at which the guidance for skipped layers starts.", + ), + InputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps.", + ), + ] + + @torch.no_grad() + def __call__( + self, + components: StableDiffusion3ModularPipeline, + block_state: BlockState, + i: int, + t: torch.Tensor, + ) -> PipelineState: + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "pooled_projections": ( + getattr(block_state, "pooled_prompt_embeds", None), + getattr(block_state, "negative_pooled_prompt_embeds", None), + ), + } + + if hasattr(components.guider, "guidance_scale"): + components.guider.guidance_scale = block_state.guidance_scale + if hasattr(components.guider, "skip_layer_guidance_scale"): + components.guider.skip_layer_guidance_scale = ( + block_state.skip_layer_guidance_scale + ) + if hasattr(components.guider, "skip_layer_guidance_start"): + components.guider.skip_layer_guidance_start = ( + block_state.skip_layer_guidance_start + ) + if hasattr(components.guider, "skip_layer_guidance_stop"): + components.guider.skip_layer_guidance_stop = ( + block_state.skip_layer_guidance_stop + ) + + components.guider.set_state( + step=i, num_inference_steps=block_state.num_inference_steps, timestep=t + ) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) + for input_name in guider_inputs.keys() + } + + timestep = t.expand(block_state.latents.shape[0]) + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latents, + timestep=timestep, + joint_attention_kwargs=getattr( + block_state, "joint_attention_kwargs", None + ), + return_dict=False, + **cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + block_state.noise_pred = guider_output.pred + + return components, block_state + + +class StableDiffusion3LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The denoised latent tensors.", + ) + ] + + @torch.no_grad() + def __call__( + self, + components: StableDiffusion3ModularPipeline, + block_state: BlockState, + i: int, + t: torch.Tensor, + ): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class StableDiffusion3DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", SD3Transformer2DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam("timesteps", required=True, type_hint=torch.Tensor), + InputParam("num_inference_steps", required=True, type_hint=int), + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + block_state.num_warmup_steps = max( + len(block_state.timesteps) + - block_state.num_inference_steps * components.scheduler.order, + 0, + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step( + components, block_state, i=i, t=t + ) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps + and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class StableDiffusion3DenoiseStep(StableDiffusion3DenoiseLoopWrapper): + block_classes = [StableDiffusion3LoopDenoiser, StableDiffusion3LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py new file mode 100644 index 000000000000..83b5d592ac55 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -0,0 +1,651 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import SD3LoraLoaderMixin +from ...models import AutoencoderKL +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusion3ModularPipeline + +logger = logging.get_logger(__name__) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + vae: AutoencoderKL, + image: torch.Tensor, + generator: torch.Generator, + sample_mode="sample", +): + if isinstance(generator, list): + image_latents = [ + retrieve_latents( + vae.encode(image[i : i + 1]), + generator=generator[i], + sample_mode=sample_mode, + ) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents( + vae.encode(image), generator=generator, sample_mode=sample_mode + ) + + image_latents = ( + image_latents - vae.config.shift_factor + ) * vae.config.scaling_factor + return image_latents + + +def _get_t5_prompt_embeds( + text_encoder: T5EncoderModel | None, + tokenizer: T5TokenizerFast | None, + prompt: str | list[str] = None, + max_sequence_length: int = 256, + device: torch.device | None = None, + joint_attention_dim: int = 4096, + dtype: torch.dtype | None = None, +): + device = device or ( + text_encoder.device if text_encoder is not None else torch.device("cpu") + ) + dtype = dtype or (text_encoder.dtype if text_encoder is not None else torch.float32) + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if text_encoder is None or tokenizer is None: + return torch.zeros( + (batch_size, max_sequence_length, joint_attention_dim), + device=device, + dtype=dtype, + ) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + +def _get_clip_prompt_embeds( + text_encoder: CLIPTextModelWithProjection | None, + tokenizer: CLIPTokenizer | None, + prompt: str | list[str], + device: torch.device | None = None, + clip_skip: int | None = None, + hidden_size: int = 768, + dtype: torch.dtype | None = None, +): + device = device or ( + text_encoder.device if text_encoder is not None else torch.device("cpu") + ) + dtype = dtype or (text_encoder.dtype if text_encoder is not None else torch.float32) + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if text_encoder is None or tokenizer is None: + prompt_embeds = torch.zeros( + (batch_size, 77, hidden_size), device=device, dtype=dtype + ) + pooled_prompt_embeds = torch.zeros( + (batch_size, hidden_size), device=device, dtype=dtype + ) + return prompt_embeds, pooled_prompt_embeds + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, pooled_prompt_embeds + + +def encode_prompt( + components, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + device: torch.device | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, +): + device = device or components._execution_device + + expected_dtype = None + if components.text_encoder is not None: + expected_dtype = components.text_encoder.dtype + elif components.text_encoder_2 is not None: + expected_dtype = components.text_encoder_2.dtype + elif getattr(components, "transformer", None) is not None: + expected_dtype = components.transformer.dtype + else: + expected_dtype = torch.float32 + + if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin): + components._lora_scale = lora_scale + if components.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = _get_clip_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=prompt, + device=device, + clip_skip=clip_skip, + hidden_size=768, + dtype=expected_dtype, + ) + prompt_2_embed, pooled_prompt_2_embed = _get_clip_prompt_embeds( + components.text_encoder_2, + components.tokenizer_2, + prompt=prompt_2, + device=device, + clip_skip=clip_skip, + hidden_size=1280, + dtype=expected_dtype, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = _get_t5_prompt_embeds( + components.text_encoder_3, + components.tokenizer_3, + prompt=prompt_3, + max_sequence_length=max_sequence_length, + device=device, + joint_attention_dim=( + components.transformer.config.joint_attention_dim + if getattr(components, "transformer", None) is not None + else 4096 + ), + dtype=expected_dtype, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, + (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), + ) + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embed, pooled_prompt_2_embed], dim=-1 + ) + + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + negative_prompt = ( + batch_size * [negative_prompt] + if isinstance(negative_prompt, str) + else negative_prompt + ) + negative_prompt_2 = ( + batch_size * [negative_prompt_2] + if isinstance(negative_prompt_2, str) + else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] + if isinstance(negative_prompt_3, str) + else negative_prompt_3 + ) + + negative_prompt_embed, negative_pooled_prompt_embed = _get_clip_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=negative_prompt, + device=device, + clip_skip=None, + hidden_size=768, + dtype=expected_dtype, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = _get_clip_prompt_embeds( + components.text_encoder_2, + components.tokenizer_2, + prompt=negative_prompt_2, + device=device, + clip_skip=None, + hidden_size=1280, + dtype=expected_dtype, + ) + negative_clip_prompt_embeds = torch.cat( + [negative_prompt_embed, negative_prompt_2_embed], dim=-1 + ) + + t5_negative_prompt_embed = _get_t5_prompt_embeds( + components.text_encoder_3, + components.tokenizer_3, + prompt=negative_prompt_3, + max_sequence_length=max_sequence_length, + device=device, + joint_attention_dim=( + components.transformer.config.joint_attention_dim + if getattr(components, "transformer", None) is not None + else 4096 + ), + dtype=expected_dtype, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + negative_prompt_embeds = torch.cat( + [negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2 + ) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if ( + components.text_encoder is not None + and isinstance(components, SD3LoraLoaderMixin) + and USE_PEFT_BACKEND + ): + unscale_lora_layers(components.text_encoder, lora_scale) + if ( + components.text_encoder_2 is not None + and isinstance(components, SD3LoraLoaderMixin) + and USE_PEFT_BACKEND + ): + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + +class StableDiffusion3ProcessImagesInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Image Preprocess step for SD3." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8, "vae_latent_channels": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", description="The pre-resized image input."), + InputParam( + "image", + description="The input image to be used as the starting point for the image-to-image process.", + ), + InputParam( + "height", description="The height in pixels of the generated image." + ), + InputParam( + "width", description="The width in pixels of the generated image." + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", description="The pre-processed image tensor." + ) + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor, patch_size): + if height is not None and height % (vae_scale_factor * patch_size) != 0: + raise ValueError( + f"Height must be divisible by {vae_scale_factor * patch_size} but is {height}" + ) + + if width is not None and width % (vae_scale_factor * patch_size) != 0: + raise ValueError( + f"Width must be divisible by {vae_scale_factor * patch_size} but is {width}" + ) + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ): + block_state = self.get_block_state(state) + + if block_state.resized_image is None and block_state.image is None: + raise ValueError( + "`resized_image` and `image` cannot be None at the same time" + ) + + if block_state.resized_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + patch_size=components.patch_size, + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + else: + width, height = block_state.resized_image[0].size + image = block_state.resized_image + + block_state.processed_image = components.image_processor.preprocess( + image=image, height=height, width=width + ) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusion3VaeEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + def __init__( + self, + input_name: str = "processed_image", + output_name: str = "image_latents", + sample_mode: str = "sample", + ): + self._image_input_name = input_name + self._image_latents_output_name = output_name + self.sample_mode = sample_mode + super().__init__() + + @property + def description(self) -> str: + return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKL)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + self._image_input_name, + description="The processed image input to be encoded.", + ), + InputParam( + "generator", + description="One or a list of torch generator(s) to make generation deterministic.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + self._image_latents_output_name, + type_hint=torch.Tensor, + description="The latents representing the reference image", + ) + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + image = getattr(block_state, self._image_input_name) + + if image is None: + setattr(block_state, self._image_latents_output_name, None) + else: + device = components._execution_device + dtype = components.vae.dtype + image = image.to(device=device, dtype=dtype) + image_latents = encode_vae_image( + image=image, + vae=components.vae, + generator=block_state.generator, + sample_mode=self.sample_mode, + ) + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusion3TextEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings to guide the image generation for SD3." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec("text_encoder_3", T5EncoderModel), + ComponentSpec("tokenizer_3", T5TokenizerFast), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "prompt", + description="The prompt or prompts to guide the image generation.", + ), + InputParam( + "prompt_2", + description="The prompt or prompts to be sent to tokenizer_2 and text_encoder_2.", + ), + InputParam( + "prompt_3", + description="The prompt or prompts to be sent to tokenizer_3 and text_encoder_3.", + ), + InputParam( + "negative_prompt", + description="The prompt or prompts not to guide the image generation.", + ), + InputParam( + "negative_prompt_2", + description="The prompt or prompts not to guide the image generation for tokenizer_2.", + ), + InputParam( + "negative_prompt_3", + description="The prompt or prompts not to guide the image generation for tokenizer_3.", + ), + InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated text embeddings.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings.", + ), + InputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings.", + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative pooled text embeddings.", + ), + InputParam( + "clip_skip", + type_hint=int, + description="Number of layers to be skipped from CLIP while computing the prompt embeddings.", + ), + InputParam( + "max_sequence_length", + type_hint=int, + default=256, + description="Maximum sequence length to use with the prompt.", + ), + InputParam( + "joint_attention_kwargs", + description="A kwargs dictionary passed along to the AttentionProcessor.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("prompt_embeds", type_hint=torch.Tensor), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + lora_scale = ( + block_state.joint_attention_kwargs.get("scale", None) + if getattr(block_state, "joint_attention_kwargs", None) + else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = encode_prompt( + components=components, + prompt=block_state.prompt, + prompt_2=getattr(block_state, "prompt_2", None), + prompt_3=getattr(block_state, "prompt_3", None), + device=block_state.device, + negative_prompt=getattr(block_state, "negative_prompt", None), + negative_prompt_2=getattr(block_state, "negative_prompt_2", None), + negative_prompt_3=getattr(block_state, "negative_prompt_3", None), + clip_skip=getattr(block_state, "clip_skip", None), + max_sequence_length=getattr(block_state, "max_sequence_length", 256), + lora_scale=lora_scale, + ) + + block_state.prompt_embeds = prompt_embeds + block_state.negative_prompt_embeds = negative_prompt_embeds + block_state.pooled_prompt_embeds = pooled_prompt_embeds + block_state.negative_pooled_prompt_embeds = negative_pooled_prompt_embeds + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py new file mode 100644 index 000000000000..9fc3a21b178e --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -0,0 +1,352 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import InputParam, OutputParam +from .modular_pipeline import StableDiffusion3ModularPipeline + +logger = logging.get_logger(__name__) + + +# Copied from diffusers.modular_pipelines.qwenimage.inputs.repeat_tensor_to_batch_size +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +# Copied from diffusers.modular_pipelines.qwenimage.inputs.calculate_dimension_from_latents +def calculate_dimension_from_latents( + latents: torch.Tensor, vae_scale_factor: int +) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent space dimensions to image space dimensions by multiplying the latent height and width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress images. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + # make sure the latents are not packed + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError( + f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}" + ) + + latent_height, latent_width = latents.shape[-2:] + + height = latent_height * vae_scale_factor + width = latent_width * vae_scale_factor + + return height, width + + +class StableDiffusion3TextInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Text input processing step that standardizes text embeddings for SD3, applying CFG duplication if needed." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "num_images_per_prompt", + default=1, + description="The number of images to generate per prompt.", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings.", + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative pooled text embeddings.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="The batch size for the inference.", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="The expected data type for latents.", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="The processed text embeddings.", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="The processed pooled text embeddings.", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="The processed negative text embeddings.", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + description="The processed negative pooled text embeddings.", + ), + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + prompt_embeds = block_state.prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + prompt_embeds = prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt + ) + pooled_prompt_embeds = pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + + if getattr(block_state, "negative_prompt_embeds", None) is not None: + _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape + negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, + neg_seq_len, + -1, + ) + + negative_pooled_prompt_embeds = ( + block_state.negative_pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt + ) + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + + block_state.negative_prompt_embeds = negative_prompt_embeds + block_state.negative_pooled_prompt_embeds = negative_pooled_prompt_embeds + else: + block_state.negative_prompt_embeds = None + block_state.negative_pooled_prompt_embeds = None + + block_state.prompt_embeds = prompt_embeds + block_state.pooled_prompt_embeds = pooled_prompt_embeds + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusion3AdditionalInputsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + def __init__( + self, + image_latent_inputs: list[str] = ["image_latents"], + additional_batch_inputs: list[str] = [], + ): + self._image_latent_inputs = ( + image_latent_inputs + if isinstance(image_latent_inputs, list) + else [image_latent_inputs] + ) + self._additional_batch_inputs = ( + additional_batch_inputs + if isinstance(additional_batch_inputs, list) + else [additional_batch_inputs] + ) + super().__init__() + + @property + def description(self) -> str: + return "Updates height/width if None, and expands batch size. SD3 does not pack latents on pipeline level." + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam( + "num_images_per_prompt", + default=1, + description="The number of images to generate per prompt.", + ), + InputParam("batch_size", required=True, description="The batch size."), + InputParam( + "height", description="The height in pixels of the generated image." + ), + InputParam( + "width", description="The width in pixels of the generated image." + ), + ] + for name in self._image_latent_inputs + self._additional_batch_inputs: + inputs.append( + InputParam(name, description=f"Latent input {name} to be processed.") + ) + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_height", + type_hint=int, + description="The height of the generated image.", + ), + OutputParam( + "image_width", + type_hint=int, + description="The width of the generated image.", + ), + ] + + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: + block_state = self.get_block_state(state) + + for input_name in self._image_latent_inputs: + tensor = getattr(block_state, input_name) + if tensor is None: + continue + + height, width = calculate_dimension_from_latents( + tensor, components.vae_scale_factor + ) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + setattr(block_state, input_name, tensor) + + for input_name in self._additional_batch_inputs: + tensor = getattr(block_state, input_name) + if tensor is None: + continue + tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + setattr(block_state, input_name, tensor) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py new file mode 100644 index 000000000000..29171e0c64d2 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -0,0 +1,411 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + StableDiffusion3Img2ImgPrepareLatentsStep, + StableDiffusion3Img2ImgSetTimestepsStep, + StableDiffusion3PrepareLatentsStep, + StableDiffusion3SetTimestepsStep, +) +from .decoders import StableDiffusion3DecodeStep +from .denoise import StableDiffusion3DenoiseStep +from .encoders import ( + StableDiffusion3ProcessImagesInputStep, + StableDiffusion3TextEncoderStep, + StableDiffusion3VaeEncoderStep, +) +from .inputs import StableDiffusion3AdditionalInputsStep, StableDiffusion3TextInputStep + +logger = logging.get_logger(__name__) + + +# auto_docstring +class StableDiffusion3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + resized_image (`None`, *optional*): + The pre-resized image input. + image (`None`, *optional*): + The input image to be used as the starting point for the image-to-image process. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + + Outputs: + processed_image (`None`): + The pre-processed image tensor. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "stable-diffusion-3" + block_classes = [ + StableDiffusion3ProcessImagesInputStep(), + StableDiffusion3VaeEncoderStep(), + ] + block_names = ["preprocess", "encode"] + + +# auto_docstring +class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): + """ + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + resized_image (`None`, *optional*): + The pre-resized image input. + image (`None`, *optional*): + The input image to be used as the starting point for the image-to-image process. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + + Outputs: + processed_image (`None`): + The pre-processed image tensor. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "stable-diffusion-3" + block_classes = [StableDiffusion3Img2ImgVaeEncoderStep] + block_names = ["img2img"] + block_trigger_inputs = ["image"] + + +# auto_docstring +class StableDiffusion3T2ICoreDenoiseStep(SequentialPipelineBlocks): + """ + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer + (`SD3Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. + pooled_prompt_embeds (`Tensor`): + Pre-generated pooled text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor | NoneType`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + num_inference_steps (`None`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`None`, *optional*): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed along to the AttentionProcessor. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): + The scale of the guidance for the skipped layers. + skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): + The step fraction at which the guidance for skipped layers stops. + skip_layer_guidance_start (`None`, *optional*, defaults to 0.01): + The step fraction at which the guidance for skipped layers starts. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "stable-diffusion-3" + block_classes = [ + StableDiffusion3TextInputStep(), + StableDiffusion3PrepareLatentsStep(), + StableDiffusion3SetTimestepsStep(), + StableDiffusion3DenoiseStep(), + ] + block_names = ["text_inputs", "prepare_latents", "set_timesteps", "denoise"] + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class StableDiffusion3I2ICoreDenoiseStep(SequentialPipelineBlocks): + """ + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer + (`SD3Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. + pooled_prompt_embeds (`Tensor`): + Pre-generated pooled text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + Latent input image_latents to be processed. + latents (`Tensor | NoneType`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + num_inference_steps (`None`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`None`, *optional*): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + strength (`None`, *optional*, defaults to 0.6): + Indicates extent to transform the reference image. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed along to the AttentionProcessor. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): + The scale of the guidance for the skipped layers. + skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): + The step fraction at which the guidance for skipped layers stops. + skip_layer_guidance_start (`None`, *optional*, defaults to 0.01): + The step fraction at which the guidance for skipped layers starts. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "stable-diffusion-3" + block_classes = [ + StableDiffusion3TextInputStep(), + StableDiffusion3AdditionalInputsStep(), + StableDiffusion3PrepareLatentsStep(), + StableDiffusion3Img2ImgSetTimestepsStep(), + StableDiffusion3Img2ImgPrepareLatentsStep(), + StableDiffusion3DenoiseStep(), + ] + block_names = [ + "text_inputs", + "additional_inputs", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "denoise", + ] + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class StableDiffusion3AutoCoreDenoiseStep(AutoPipelineBlocks): + """ + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer + (`SD3Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. + pooled_prompt_embeds (`Tensor`): + Pre-generated pooled text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + Latent input image_latents to be processed. + latents (`Tensor | NoneType`): + Pre-generated noisy latents to be used as inputs for image generation. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + num_inference_steps (`None`): + The number of denoising steps. + timesteps (`None`): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + strength (`None`, *optional*, defaults to 0.6): + Indicates extent to transform the reference image. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed along to the AttentionProcessor. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): + The scale of the guidance for the skipped layers. + skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): + The step fraction at which the guidance for skipped layers stops. + skip_layer_guidance_start (`None`, *optional*, defaults to 0.01): + The step fraction at which the guidance for skipped layers starts. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "stable-diffusion-3" + block_classes = [ + StableDiffusion3I2ICoreDenoiseStep, + StableDiffusion3T2ICoreDenoiseStep, + ] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", StableDiffusion3TextEncoderStep()), + ("vae_encoder", StableDiffusion3AutoVaeEncoderStep()), + ("denoise", StableDiffusion3AutoCoreDenoiseStep()), + ("decode", StableDiffusion3DecodeStep()), + ] +) + + +# auto_docstring +class StableDiffusion3AutoBlocks(SequentialPipelineBlocks): + """ + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `image`, `prompt` + + Components: + text_encoder (`CLIPTextModelWithProjection`) tokenizer (`CLIPTokenizer`) text_encoder_2 + (`CLIPTextModelWithProjection`) tokenizer_2 (`CLIPTokenizer`) text_encoder_3 (`T5EncoderModel`) tokenizer_3 + (`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler + (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer (`SD3Transformer2DModel`) + + Inputs: + prompt (`None`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`None`, *optional*): + The prompt or prompts to be sent to tokenizer_2 and text_encoder_2. + prompt_3 (`None`, *optional*): + The prompt or prompts to be sent to tokenizer_3 and text_encoder_3. + negative_prompt (`None`, *optional*): + The prompt or prompts not to guide the image generation. + negative_prompt_2 (`None`, *optional*): + The prompt or prompts not to guide the image generation for tokenizer_2. + negative_prompt_3 (`None`, *optional*): + The prompt or prompts not to guide the image generation for tokenizer_3. + prompt_embeds (`Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. + max_sequence_length (`int`, *optional*, defaults to 256): + Maximum sequence length to use with the prompt. + joint_attention_kwargs (`None`, *optional*): + A kwargs dictionary passed along to the AttentionProcessor. + resized_image (`None`, *optional*): + The pre-resized image input. + image (`None`, *optional*): + The input image to be used as the starting point for the image-to-image process. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + image_latents (`None`, *optional*): + Latent input image_latents to be processed. + latents (`Tensor | NoneType`): + Pre-generated noisy latents to be used as inputs for image generation. + num_inference_steps (`None`): + The number of denoising steps. + timesteps (`None`): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + strength (`None`, *optional*, defaults to 0.6): + Indicates extent to transform the reference image. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): + The scale of the guidance for the skipped layers. + skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): + The step fraction at which the guidance for skipped layers stops. + skip_layer_guidance_start (`None`, *optional*, defaults to 0.01): + The step fraction at which the guidance for skipped layers starts. + output_type (`None`, *optional*, defaults to pil): + The output format of the generated image (e.g., 'pil', 'pt', 'np'). + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "stable-diffusion-3" + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + } + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py new file mode 100644 index 000000000000..645ad930b426 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -0,0 +1,70 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + +logger = logging.get_logger(__name__) + + +class StableDiffusion3ModularPipeline( + ModularPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin +): + """ + A ModularPipeline for Stable Diffusion 3. + + >[!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "StableDiffusion3AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.sample_size + return 128 + + @property + def patch_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.patch_size + return 2 + + @property + def tokenizer_max_length(self): + if getattr(self, "tokenizer", None) is not None: + return self.tokenizer.model_max_length + return 77 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.in_channels + return 16 diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index eff798a59051..55edede82b7b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -392,6 +392,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusion3AutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusion3ModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Wan22Blocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/stable_diffusion_3/__init__.py b/tests/modular_pipelines/stable_diffusion_3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py new file mode 100644 index 000000000000..5519303af592 --- /dev/null +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.modular_pipelines.stable_diffusion_3 import ( + StableDiffusion3AutoBlocks, + StableDiffusion3ModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + +SD3_TEXT2IMAGE_WORKFLOWS = { + "text2image": [ + ("text_encoder", "StableDiffusion3TextEncoderStep"), + ("denoise.text_inputs", "StableDiffusion3TextInputStep"), + ("denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), + ("denoise.set_timesteps", "StableDiffusion3SetTimestepsStep"), + ("denoise.denoise", "StableDiffusion3DenoiseStep"), + ("decode", "StableDiffusion3DecodeStep"), + ] +} + + +class TestStableDiffusion3ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = StableDiffusion3ModularPipeline + pipeline_blocks_class = StableDiffusion3AutoBlocks + pretrained_model_name_or_path = "AlanPonnachan/tiny-sd3-modular" + + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + expected_workflow_blocks = SD3_TEXT2IMAGE_WORKFLOWS + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + return { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "output_type": "pt", + } + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + return super().get_pipeline(components_manager, torch_dtype) + + def test_save_from_pretrained(self, tmp_path): + pipes = [] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + base_pipe.save_pretrained(str(tmp_path)) + pipe = self.pipeline_class.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipes.append(pipe) + + image_slices = [] + for p in pipes: + inputs = self.get_dummy_inputs() + image = p(**inputs, output="images") + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_load_expected_components_from_save_pretrained(self, tmp_path): + base_pipe = self.get_pipeline() + base_pipe.save_pretrained(str(tmp_path)) + + pipe = self.pipeline_class.from_pretrained(tmp_path) + pipe.load_components(torch_dtype=torch.float32) + + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +SD3_IMAGE2IMAGE_WORKFLOWS = { + "image2image": [ + ("text_encoder", "StableDiffusion3TextEncoderStep"), + ("vae_encoder.preprocess", "StableDiffusion3ProcessImagesInputStep"), + ("vae_encoder.encode", "StableDiffusion3VaeEncoderStep"), + ("denoise.text_inputs", "StableDiffusion3TextInputStep"), + ("denoise.additional_inputs", "StableDiffusion3AdditionalInputsStep"), + ("denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), + ("denoise.set_timesteps", "StableDiffusion3Img2ImgSetTimestepsStep"), + ( + "denoise.prepare_img2img_latents", + "StableDiffusion3Img2ImgPrepareLatentsStep", + ), + ("denoise.denoise", "StableDiffusion3DenoiseStep"), + ("decode", "StableDiffusion3DecodeStep"), + ] +} + + +class TestStableDiffusion3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = StableDiffusion3ModularPipeline + pipeline_blocks_class = StableDiffusion3AutoBlocks + pretrained_model_name_or_path = "AlanPonnachan/tiny-sd3-modular" + + params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) + batch_params = frozenset(["prompt", "image"]) + expected_workflow_blocks = SD3_IMAGE2IMAGE_WORKFLOWS + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = super().get_pipeline(components_manager, torch_dtype) + pipeline.image_processor = VaeImageProcessor(vae_scale_factor=8) + return pipeline + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 4, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "output_type": "pt", + } + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image)).convert("RGB") + inputs["image"] = init_image + inputs["strength"] = 0.5 + return inputs + + def test_save_from_pretrained(self, tmp_path): + pipes = [] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + base_pipe.save_pretrained(str(tmp_path)) + pipe = self.pipeline_class.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipe.image_processor = VaeImageProcessor(vae_scale_factor=8) + pipes.append(pipe) + + image_slices = [] + for p in pipes: + inputs = self.get_dummy_inputs() + image = p(**inputs, output="images") + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_load_expected_components_from_save_pretrained(self, tmp_path): + base_pipe = self.get_pipeline() + base_pipe.save_pretrained(str(tmp_path)) + + pipe = self.pipeline_class.from_pretrained(tmp_path) + pipe.load_components(torch_dtype=torch.float32) + + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) + + def test_float16_inference(self): + super().test_float16_inference(9e-2)