@@ -52,20 +52,10 @@ class ClassifierFreeGuidance(BaseGuidance):
5252
5353 Use `use_original_formulation=True` to switch to the original formulation.
5454
55- **Guidance-Distilled Models:**
56-
57- For models with distilled guidance (guidance baked into the model via distillation), set `distilled_guidance_scale`
58- to the desired guidance value. The pipeline will pass this to the model during forward passes. Set to `None` for
59- regular (non-distilled) models.
60-
6155 Args:
6256 guidance_scale (`float`, defaults to `7.5`):
6357 CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
6458 may reduce quality. Typical range: 1.0-20.0.
65- distilled_guidance_scale (`float`, *optional*, defaults to `None`):
66- Guidance scale for distilled models, passed directly to the model during forward pass. If `None`, assumes a
67- regular (non-distilled) model. Allows pipelines to configure different defaults for distilled vs.
68- non-distilled models. Typical range for distilled models: 1.0-8.0.
6959 guidance_rescale (`float`, defaults to `0.0`):
7060 Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
7161 Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
@@ -89,7 +79,6 @@ class ClassifierFreeGuidance(BaseGuidance):
8979 def __init__ (
9080 self ,
9181 guidance_scale : float = 7.5 ,
92- distilled_guidance_scale : Optional [float ] = None ,
9382 guidance_rescale : float = 0.0 ,
9483 use_original_formulation : bool = False ,
9584 start : float = 0.0 ,
@@ -99,20 +88,15 @@ def __init__(
9988 super ().__init__ (start , stop , enabled )
10089
10190 self .guidance_scale = guidance_scale
102- self .distilled_guidance_scale = distilled_guidance_scale
10391 self .guidance_rescale = guidance_rescale
10492 self .use_original_formulation = use_original_formulation
10593
106- def prepare_inputs (
107- self , data : "BlockState" , input_fields : Optional [Dict [str , Union [str , Tuple [str , str ]]]] = None
108- ) -> List ["BlockState" ]:
109- if input_fields is None :
110- input_fields = self ._input_fields
94+ def prepare_inputs (self , data : Dict [str , Tuple [torch .Tensor , torch .Tensor ]]) -> List ["BlockState" ]:
11195
11296 tuple_indices = [0 ] if self .num_conditions == 1 or not self ._is_cfg_enabled () else [0 , 1 ]
11397 data_batches = []
11498 for tuple_idx , input_prediction in zip (tuple_indices , self ._input_predictions ):
115- data_batch = self ._prepare_batch (input_fields , data , tuple_idx , input_prediction )
99+ data_batch = self ._prepare_batch (data , tuple_idx , input_prediction )
116100 data_batches .append (data_batch )
117101 return data_batches
118102
0 commit comments