Skip to content

Commit 3e12970

Browse files
committed
update pipeline, remove true_cfg_scale etc
1 parent 1792aab commit 3e12970

2 files changed

Lines changed: 66 additions & 136 deletions

File tree

src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py

Lines changed: 33 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,7 @@ class HunyuanImagePipeline(DiffusionPipeline):
182182

183183
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
184184
_callback_tensor_inputs = ["latents", "prompt_embeds"]
185-
_optional_components = ["ocr_guider"]
186-
_guider_input_fields = {
187-
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
188-
"encoder_attention_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
189-
"encoder_hidden_states_2": ("prompt_embeds_2", "negative_prompt_embeds_2"),
190-
"encoder_attention_mask_2": ("prompt_embeds_mask_2", "negative_prompt_embeds_mask_2"),
191-
}
185+
_optional_components = ["ocr_guider", "guider"]
192186

193187
def __init__(
194188
self,
@@ -199,7 +193,7 @@ def __init__(
199193
text_encoder_2: T5EncoderModel,
200194
tokenizer_2: ByT5Tokenizer,
201195
transformer: HunyuanImageTransformer2DModel,
202-
guider: AdaptiveProjectedMixGuidance,
196+
guider: Optional[AdaptiveProjectedMixGuidance] = None,
203197
ocr_guider: Optional[AdaptiveProjectedMixGuidance] = None,
204198
):
205199
super().__init__()
@@ -509,8 +503,7 @@ def __call__(
509503
height: Optional[int] = None,
510504
width: Optional[int] = None,
511505
num_inference_steps: int = 50,
512-
true_cfg_scale: Optional[float] = None,
513-
guidance_scale: Optional[float] = None,
506+
distilled_guidance_scale: Optional[float] = 3.25,
514507
sigmas: Optional[List[float]] = None,
515508
num_images_per_prompt: int = 1,
516509
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -538,11 +531,7 @@ def __call__(
538531
instead.
539532
negative_prompt (`str` or `List[str]`, *optional*):
540533
The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
541-
not provided, will use an empty negative prompt. Ignored when not using guidance (i.e., ignored if any
542-
of the following conditions are met:
543-
1. guider is diabled
544-
2. guider.guidance_scale is not greater than `1` and `true_cfg_scale` is not provided,
545-
3. `true_cfg_scale` is not greater than `1`.
534+
not provided, will use an empty negative prompt. Ignored when not using guidance.
546535
).
547536
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
548537
The height in pixels of the generated image. This is set to 1024 by default for the best results.
@@ -555,20 +544,14 @@ def __call__(
555544
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
556545
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
557546
will be used.
558-
true_cfg_scale (`float`, *optional*, defaults to None):
559-
Guidance scale as defined in [Classifier-Free Diffusion
560-
Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
561-
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
562-
setting `true_cfg_scale > 1`. Higher guidance scale encourages to generate images that are closely
563-
linked to the text `prompt`, usually at the expense of lower image quality. If not defined, the default
564-
`guidance_scale` configured in guider will be used.
565-
guidance_scale (`float`, *optional*, defaults to None):
547+
distilled_guidance_scale (`float`, *optional*, defaults to None):
566548
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
567549
where the guidance scale is applied during inference through noise prediction rescaling, guidance
568550
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
569-
is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that
570-
are closely linked to the text `prompt`, usually at the expense of lower image quality. If not defined,
571-
the default `distilled_guidance_scale` configured in guider will be used.
551+
is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate images that
552+
are closely linked to the text `prompt`, usually at the expense of lower image quality.
553+
For guidance distilled models, this parameter is required.
554+
For non-distilled models, this parameter will be ignored.
572555
num_images_per_prompt (`int`, *optional*, defaults to 1):
573556
The number of images to generate per prompt.
574557
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -683,30 +666,17 @@ def __call__(
683666
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
684667
prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
685668

686-
# select guider based on if prompt contains OCR or not
669+
# select guider
687670
if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None:
688671
# prompt contains ocr and pipeline has a guider for ocr
689672
guider = self.ocr_guider
690-
else:
673+
elif self.guider is not None:
691674
guider = self.guider
692-
693-
is_guider_enabled = guider._enabled
694-
695-
# if true_cfg_scale/guidance_scale is provided, override the guidance_scale/distilled_guidance_scale in guider at runtime
696-
guider_kwargs = {}
697-
if true_cfg_scale is not None:
698-
guider_kwargs["guidance_scale"] = true_cfg_scale
699-
if guidance_scale is not None:
700-
guider_kwargs["distilled_guidance_scale"] = guidance_scale
701-
guider = guider.new(**guider_kwargs)
702-
703-
if is_guider_enabled:
704-
guider.enable()
675+
# distilled model does not use guidance method, use default guider with enabled=False
705676
else:
706-
guider.disable()
677+
guider = AdaptiveProjectedMixGuidance(enabled=False)
707678

708-
requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1
709-
if requires_unconditional_embeds:
679+
if guider._enabled and guider.num_conditions > 1:
710680
(
711681
negative_prompt_embeds,
712682
negative_prompt_embeds_mask,
@@ -746,23 +716,13 @@ def __call__(
746716
self._num_timesteps = len(timesteps)
747717

748718
# handle guidance (for guidance-distilled model)
749-
if self.transformer.config.guidance_embeds and not (
750-
hasattr(guider, "distilled_guidance_scale") and guider.distilled_guidance_scale is not None
751-
):
752-
raise ValueError("`guidance_scale` is required for guidance-distilled model.")
753-
elif (
754-
not self.transformer.config.guidance_embeds
755-
and hasattr(guider, "distilled_guidance_scale")
756-
and guider.distilled_guidance_scale is not None
757-
):
758-
logger.warning(
759-
f"`distilled_guidance_scale` {guider.distilled_guidance_scale} is ignored since the model is not guidance-distilled. Please use `true_cfg_scale` instead."
760-
)
719+
if self.transformer.config.guidance_embeds and distilled_guidance_scale is None:
720+
raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
761721

762722
if self.transformer.config.guidance_embeds:
763723
guidance = (
764724
torch.tensor(
765-
[guider.distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
725+
[distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
766726
)
767727
* 1000.0
768728
)
@@ -794,43 +754,39 @@ def __call__(
794754
timestep_r = None
795755

796756
# Step 1: Collect model inputs needed for the guidance method
797-
# The `_guider_input_fields` defines which inputs model needs for conditional/unconditional predictions.
798-
# e.g. {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}
799-
# means both prompt_embeds (conditional) and negative_prompt_embeds (unconditional) as inputs.
800-
guider_inputs = {}
801-
for _, input_names_tuple in self._guider_input_fields.items():
802-
for input_name in input_names_tuple:
803-
guider_inputs[input_name] = locals()[input_name]
757+
# conditional inputs should always be first element in the tuple
758+
guider_inputs = {
759+
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
760+
"encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
761+
"encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
762+
"encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
763+
}
804764

805765
# Step 2: Update guider's internal state for this denoising step
806766
guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
807767

808768
# Step 3: Prepare batched model inputs based on the guidance method
809769
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
810-
# For CFG with _guider_input_fields = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}:
770+
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
771+
# you will get a guider_state with two batches:
811772
# guider_state = [
812773
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
813774
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
814775
# ]
815776
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
816-
guider_state = guider.prepare_inputs(guider_inputs, self._guider_input_fields)
777+
guider_state = guider.prepare_inputs(guider_inputs)
817778
# Step 4: Run the denoiser for each batch
818-
# Each batch represents a different conditioning (conditional, unconditional, etc.).
819-
# We run the model once per batch and store the noise prediction in each batch dict.
820-
# After this loop, continuing the CFG example:
821-
# guider_state = [
822-
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"},
823-
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"},
824-
# ]
779+
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
780+
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
825781
for guider_state_batch in guider_state:
826782
guider.prepare_models(self.transformer)
827783

828784
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
829-
cond_kwargs = guider_state_batch.as_dict()
830-
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields}
785+
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
831786

832-
guider_state_batch_identifier = getattr(guider_state_batch, guider._identifier_key)
833-
with self.transformer.cache_context(guider_state_batch_identifier):
787+
# e.g. "pred_cond"/"pred_uncond"
788+
context_name = getattr(guider_state_batch, guider._identifier_key)
789+
with self.transformer.cache_context(context_name):
834790
# Run denoiser and store noise prediction in this batch
835791
guider_state_batch.noise_pred = self.transformer(
836792
hidden_states=latents,

0 commit comments

Comments
 (0)