@@ -182,13 +182,7 @@ class HunyuanImagePipeline(DiffusionPipeline):
182182
183183 model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
184184 _callback_tensor_inputs = ["latents" , "prompt_embeds" ]
185- _optional_components = ["ocr_guider" ]
186- _guider_input_fields = {
187- "encoder_hidden_states" : ("prompt_embeds" , "negative_prompt_embeds" ),
188- "encoder_attention_mask" : ("prompt_embeds_mask" , "negative_prompt_embeds_mask" ),
189- "encoder_hidden_states_2" : ("prompt_embeds_2" , "negative_prompt_embeds_2" ),
190- "encoder_attention_mask_2" : ("prompt_embeds_mask_2" , "negative_prompt_embeds_mask_2" ),
191- }
185+ _optional_components = ["ocr_guider" , "guider" ]
192186
193187 def __init__ (
194188 self ,
@@ -199,7 +193,7 @@ def __init__(
199193 text_encoder_2 : T5EncoderModel ,
200194 tokenizer_2 : ByT5Tokenizer ,
201195 transformer : HunyuanImageTransformer2DModel ,
202- guider : AdaptiveProjectedMixGuidance ,
196+ guider : Optional [ AdaptiveProjectedMixGuidance ] = None ,
203197 ocr_guider : Optional [AdaptiveProjectedMixGuidance ] = None ,
204198 ):
205199 super ().__init__ ()
@@ -509,8 +503,7 @@ def __call__(
509503 height : Optional [int ] = None ,
510504 width : Optional [int ] = None ,
511505 num_inference_steps : int = 50 ,
512- true_cfg_scale : Optional [float ] = None ,
513- guidance_scale : Optional [float ] = None ,
506+ distilled_guidance_scale : Optional [float ] = 3.25 ,
514507 sigmas : Optional [List [float ]] = None ,
515508 num_images_per_prompt : int = 1 ,
516509 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
@@ -538,11 +531,7 @@ def __call__(
538531 instead.
539532 negative_prompt (`str` or `List[str]`, *optional*):
540533 The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
541- not provided, will use an empty negative prompt. Ignored when not using guidance (i.e., ignored if any
542- of the following conditions are met:
543- 1. guider is diabled
544- 2. guider.guidance_scale is not greater than `1` and `true_cfg_scale` is not provided,
545- 3. `true_cfg_scale` is not greater than `1`.
534+ not provided, will use an empty negative prompt. Ignored when not using guidance.
546535 ).
547536 height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
548537 The height in pixels of the generated image. This is set to 1024 by default for the best results.
@@ -555,20 +544,14 @@ def __call__(
555544 Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
556545 their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
557546 will be used.
558- true_cfg_scale (`float`, *optional*, defaults to None):
559- Guidance scale as defined in [Classifier-Free Diffusion
560- Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
561- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
562- setting `true_cfg_scale > 1`. Higher guidance scale encourages to generate images that are closely
563- linked to the text `prompt`, usually at the expense of lower image quality. If not defined, the default
564- `guidance_scale` configured in guider will be used.
565- guidance_scale (`float`, *optional*, defaults to None):
547+ distilled_guidance_scale (`float`, *optional*, defaults to None):
566548 A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
567549 where the guidance scale is applied during inference through noise prediction rescaling, guidance
568550 distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
569- is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that
570- are closely linked to the text `prompt`, usually at the expense of lower image quality. If not defined,
571- the default `distilled_guidance_scale` configured in guider will be used.
551+ is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate images that
552+ are closely linked to the text `prompt`, usually at the expense of lower image quality.
553+ For guidance distilled models, this parameter is required.
554+ For non-distilled models, this parameter will be ignored.
572555 num_images_per_prompt (`int`, *optional*, defaults to 1):
573556 The number of images to generate per prompt.
574557 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -683,30 +666,17 @@ def __call__(
683666 prompt_embeds = prompt_embeds .to (self .transformer .dtype )
684667 prompt_embeds_2 = prompt_embeds_2 .to (self .transformer .dtype )
685668
686- # select guider based on if prompt contains OCR or not
669+ # select guider
687670 if not torch .all (prompt_embeds_2 == 0 ) and self .ocr_guider is not None :
688671 # prompt contains ocr and pipeline has a guider for ocr
689672 guider = self .ocr_guider
690- else :
673+ elif self . guider is not None :
691674 guider = self .guider
692-
693- is_guider_enabled = guider ._enabled
694-
695- # if true_cfg_scale/guidance_scale is provided, override the guidance_scale/distilled_guidance_scale in guider at runtime
696- guider_kwargs = {}
697- if true_cfg_scale is not None :
698- guider_kwargs ["guidance_scale" ] = true_cfg_scale
699- if guidance_scale is not None :
700- guider_kwargs ["distilled_guidance_scale" ] = guidance_scale
701- guider = guider .new (** guider_kwargs )
702-
703- if is_guider_enabled :
704- guider .enable ()
675+ # distilled model does not use guidance method, use default guider with enabled=False
705676 else :
706- guider . disable ( )
677+ guider = AdaptiveProjectedMixGuidance ( enabled = False )
707678
708- requires_unconditional_embeds = guider ._enabled and guider .num_conditions > 1
709- if requires_unconditional_embeds :
679+ if guider ._enabled and guider .num_conditions > 1 :
710680 (
711681 negative_prompt_embeds ,
712682 negative_prompt_embeds_mask ,
@@ -746,23 +716,13 @@ def __call__(
746716 self ._num_timesteps = len (timesteps )
747717
748718 # handle guidance (for guidance-distilled model)
749- if self .transformer .config .guidance_embeds and not (
750- hasattr (guider , "distilled_guidance_scale" ) and guider .distilled_guidance_scale is not None
751- ):
752- raise ValueError ("`guidance_scale` is required for guidance-distilled model." )
753- elif (
754- not self .transformer .config .guidance_embeds
755- and hasattr (guider , "distilled_guidance_scale" )
756- and guider .distilled_guidance_scale is not None
757- ):
758- logger .warning (
759- f"`distilled_guidance_scale` { guider .distilled_guidance_scale } is ignored since the model is not guidance-distilled. Please use `true_cfg_scale` instead."
760- )
719+ if self .transformer .config .guidance_embeds and distilled_guidance_scale is None :
720+ raise ValueError ("`distilled_guidance_scale` is required for guidance-distilled model." )
761721
762722 if self .transformer .config .guidance_embeds :
763723 guidance = (
764724 torch .tensor (
765- [guider . distilled_guidance_scale ] * latents .shape [0 ], dtype = self .transformer .dtype , device = device
725+ [distilled_guidance_scale ] * latents .shape [0 ], dtype = self .transformer .dtype , device = device
766726 )
767727 * 1000.0
768728 )
@@ -794,43 +754,39 @@ def __call__(
794754 timestep_r = None
795755
796756 # Step 1: Collect model inputs needed for the guidance method
797- # The `_guider_input_fields` defines which inputs model needs for conditional/unconditional predictions.
798- # e.g. {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}
799- # means both prompt_embeds (conditional) and negative_prompt_embeds (unconditional) as inputs.
800- guider_inputs = {}
801- for _ , input_names_tuple in self . _guider_input_fields . items ():
802- for input_name in input_names_tuple :
803- guider_inputs [ input_name ] = locals ()[ input_name ]
757+ # conditional inputs should always be first element in the tuple
758+ guider_inputs = {
759+ "encoder_hidden_states" : ( prompt_embeds , negative_prompt_embeds ),
760+ "encoder_attention_mask" : ( prompt_embeds_mask , negative_prompt_embeds_mask ),
761+ "encoder_hidden_states_2" : ( prompt_embeds_2 , negative_prompt_embeds_2 ),
762+ "encoder_attention_mask_2" : ( prompt_embeds_mask_2 , negative_prompt_embeds_mask_2 ),
763+ }
804764
805765 # Step 2: Update guider's internal state for this denoising step
806766 guider .set_state (step = i , num_inference_steps = num_inference_steps , timestep = t )
807767
808768 # Step 3: Prepare batched model inputs based on the guidance method
809769 # The guider splits model inputs into separate batches for conditional/unconditional predictions.
810- # For CFG with _guider_input_fields = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}:
770+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
771+ # you will get a guider_state with two batches:
811772 # guider_state = [
812773 # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
813774 # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
814775 # ]
815776 # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
816- guider_state = guider .prepare_inputs (guider_inputs , self . _guider_input_fields )
777+ guider_state = guider .prepare_inputs (guider_inputs )
817778 # Step 4: Run the denoiser for each batch
818- # Each batch represents a different conditioning (conditional, unconditional, etc.).
819- # We run the model once per batch and store the noise prediction in each batch dict.
820- # After this loop, continuing the CFG example:
821- # guider_state = [
822- # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"},
823- # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"},
824- # ]
779+ # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
780+ # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
825781 for guider_state_batch in guider_state :
826782 guider .prepare_models (self .transformer )
827783
828784 # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
829- cond_kwargs = guider_state_batch .as_dict ()
830- cond_kwargs = {k : v for k , v in cond_kwargs .items () if k in self ._guider_input_fields }
785+ cond_kwargs = {input_name : getattr (guider_state_batch , input_name ) for input_name in guider_inputs .keys ()}
831786
832- guider_state_batch_identifier = getattr (guider_state_batch , guider ._identifier_key )
833- with self .transformer .cache_context (guider_state_batch_identifier ):
787+ # e.g. "pred_cond"/"pred_uncond"
788+ context_name = getattr (guider_state_batch , guider ._identifier_key )
789+ with self .transformer .cache_context (context_name ):
834790 # Run denoiser and store noise prediction in this batch
835791 guider_state_batch .noise_pred = self .transformer (
836792 hidden_states = latents ,
0 commit comments