Skip to content

Commit 1792aab

Browse files
committed
update guider: remove distilled guidannce scale, simplify prepare_inputs
1 parent ec3290d commit 1792aab

11 files changed

Lines changed: 27 additions & 96 deletions

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,14 @@ def __init__(
7777
self.use_original_formulation = use_original_formulation
7878
self.momentum_buffer = None
7979

80-
def prepare_inputs(
81-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
82-
) -> List["BlockState"]:
83-
if input_fields is None:
84-
input_fields = self._input_fields
85-
80+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
8681
if self._step == 0:
8782
if self.adaptive_projected_guidance_momentum is not None:
8883
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
8984
tuple_indices = [0] if self.num_conditions == 1 or not self._is_apg_enabled() else [0, 1]
9085
data_batches = []
9186
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
92-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
87+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
9388
data_batches.append(data_batch)
9489
return data_batches
9590

src/diffusers/guiders/adaptive_projected_guidance_mix.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(
7575
stop: float = 1.0,
7676
adaptive_projected_guidance_start_step: int = 5,
7777
enabled: bool = True,
78-
distilled_guidance_scale: Optional[float] = None,
7978
):
8079
super().__init__(start, stop, enabled)
8180

@@ -88,13 +87,9 @@ def __init__(
8887
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
8988
self.use_original_formulation = use_original_formulation
9089
self.momentum_buffer = None
91-
self.distilled_guidance_scale = distilled_guidance_scale
9290

9391
def prepare_inputs(
94-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
95-
) -> List["BlockState"]:
96-
if input_fields is None:
97-
input_fields = self._input_fields
92+
self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
9893

9994
if self._step == 0:
10095
if self.adaptive_projected_guidance_momentum is not None:
@@ -104,7 +99,7 @@ def prepare_inputs(
10499
)
105100
data_batches = []
106101
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
107-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
102+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
108103
data_batches.append(data_batch)
109104
return data_batches
110105

src/diffusers/guiders/auto_guidance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
133133
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
134134
registry.remove_hook(name, recurse=True)
135135

136-
def prepare_inputs(
137-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
138-
) -> List["BlockState"]:
139-
if input_fields is None:
140-
input_fields = self._input_fields
141-
136+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
142137
tuple_indices = [0] if self.num_conditions == 1 or not self._is_ag_enabled() else [0, 1]
143138
data_batches = []
144139
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
145-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
140+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
146141
data_batches.append(data_batch)
147142
return data_batches
148143

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,10 @@ class ClassifierFreeGuidance(BaseGuidance):
5252
5353
Use `use_original_formulation=True` to switch to the original formulation.
5454
55-
**Guidance-Distilled Models:**
56-
57-
For models with distilled guidance (guidance baked into the model via distillation), set `distilled_guidance_scale`
58-
to the desired guidance value. The pipeline will pass this to the model during forward passes. Set to `None` for
59-
regular (non-distilled) models.
60-
6155
Args:
6256
guidance_scale (`float`, defaults to `7.5`):
6357
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
6458
may reduce quality. Typical range: 1.0-20.0.
65-
distilled_guidance_scale (`float`, *optional*, defaults to `None`):
66-
Guidance scale for distilled models, passed directly to the model during forward pass. If `None`, assumes a
67-
regular (non-distilled) model. Allows pipelines to configure different defaults for distilled vs.
68-
non-distilled models. Typical range for distilled models: 1.0-8.0.
6959
guidance_rescale (`float`, defaults to `0.0`):
7060
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
7161
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
@@ -89,7 +79,6 @@ class ClassifierFreeGuidance(BaseGuidance):
8979
def __init__(
9080
self,
9181
guidance_scale: float = 7.5,
92-
distilled_guidance_scale: Optional[float] = None,
9382
guidance_rescale: float = 0.0,
9483
use_original_formulation: bool = False,
9584
start: float = 0.0,
@@ -99,20 +88,15 @@ def __init__(
9988
super().__init__(start, stop, enabled)
10089

10190
self.guidance_scale = guidance_scale
102-
self.distilled_guidance_scale = distilled_guidance_scale
10391
self.guidance_rescale = guidance_rescale
10492
self.use_original_formulation = use_original_formulation
10593

106-
def prepare_inputs(
107-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
108-
) -> List["BlockState"]:
109-
if input_fields is None:
110-
input_fields = self._input_fields
94+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
11195

11296
tuple_indices = [0] if self.num_conditions == 1 or not self._is_cfg_enabled() else [0, 1]
11397
data_batches = []
11498
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
115-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
99+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
116100
data_batches.append(data_batch)
117101
return data_batches
118102

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,11 @@ def __init__(
7777
self.guidance_rescale = guidance_rescale
7878
self.use_original_formulation = use_original_formulation
7979

80-
def prepare_inputs(
81-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
82-
) -> List["BlockState"]:
83-
if input_fields is None:
84-
input_fields = self._input_fields
85-
80+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
8681
tuple_indices = [0] if self.num_conditions == 1 or not self._is_cfg_enabled() else [0, 1]
8782
data_batches = []
8883
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
89-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
84+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
9085
data_batches.append(data_batch)
9186
return data_batches
9287

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,16 +218,11 @@ def __init__(
218218
f"({len(self.guidance_scales)})"
219219
)
220220

221-
def prepare_inputs(
222-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
223-
) -> List["BlockState"]:
224-
if input_fields is None:
225-
input_fields = self._input_fields
226-
221+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
227222
tuple_indices = [0] if self.num_conditions == 1 or not self._is_fdg_enabled() else [0, 1]
228223
data_batches = []
229224
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
230-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
225+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
231226
data_batches.append(data_batch)
232227
return data_batches
233228

src/diffusers/guiders/guider_utils.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ def num_conditions(self) -> int:
190190
@classmethod
191191
def _prepare_batch(
192192
cls,
193-
input_fields: Dict[str, Union[str, Tuple[str, str]]],
194-
data: "BlockState",
193+
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
195194
tuple_index: int,
196195
identifier: str,
197196
) -> "BlockState":
@@ -217,24 +216,17 @@ def _prepare_batch(
217216
"""
218217
from ..modular_pipelines.modular_pipeline import BlockState
219218

220-
if isinstance(data, dict):
221-
data = BlockState(**data)
222219

223-
if input_fields is None:
224-
raise ValueError(
225-
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
226-
)
227220
data_batch = {}
228-
for key, value in input_fields.items():
221+
for key, value in data.items():
229222
try:
230-
if isinstance(value, str):
231-
data_batch[key] = getattr(data, value)
223+
if isinstance(value, torch.Tensor):
224+
data_batch[key] = value
232225
elif isinstance(value, tuple):
233-
data_batch[key] = getattr(data, value[tuple_index])
226+
data_batch[key] = value[tuple_index]
234227
else:
235-
# We've already checked that value is a string or a tuple of strings with length 2
236-
pass
237-
except AttributeError:
228+
raise ValueError(f"Invalid value type: {type(value)}")
229+
except ValueError:
238230
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
239231
data_batch[cls._identifier_key] = identifier
240232
return BlockState(**data_batch)

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
169169
registry.remove_hook(hook_name, recurse=True)
170170

171171
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
172-
def prepare_inputs(
173-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
174-
) -> List["BlockState"]:
175-
if input_fields is None:
176-
input_fields = self._input_fields
177-
172+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
178173
if self.num_conditions == 1 or not self._is_cfg_enabled() and not self._is_slg_enabled():
179174
tuple_indices = [0]
180175
input_predictions = ["pred_cond"]
@@ -188,7 +183,7 @@ def prepare_inputs(
188183
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
189184
data_batches = []
190185
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
191-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
186+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
192187
data_batches.append(data_batch)
193188
return data_batches
194189

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
165165
for hook_name in self._skip_layer_hook_names:
166166
registry.remove_hook(hook_name, recurse=True)
167167

168-
def prepare_inputs(
169-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
170-
) -> List["BlockState"]:
171-
if input_fields is None:
172-
input_fields = self._input_fields
173-
168+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
174169
if self.num_conditions == 1 or not self._is_cfg_enabled() and not self._is_slg_enabled():
175170
tuple_indices = [0]
176171
input_predictions = ["pred_cond"]
@@ -184,7 +179,7 @@ def prepare_inputs(
184179
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
185180
data_batches = []
186181
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
187-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
182+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
188183
data_batches.append(data_batch)
189184
return data_batches
190185

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,7 @@ def cleanup_models(self, denoiser: torch.nn.Module):
154154
for hook_name in self._seg_layer_hook_names:
155155
registry.remove_hook(hook_name, recurse=True)
156156

157-
def prepare_inputs(
158-
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
159-
) -> List["BlockState"]:
160-
if input_fields is None:
161-
input_fields = self._input_fields
162-
157+
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
163158
if self.num_conditions == 1 or not self._is_cfg_enabled() and not self._is_seg_enabled():
164159
tuple_indices = [0]
165160
input_predictions = ["pred_cond"]
@@ -173,7 +168,7 @@ def prepare_inputs(
173168
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
174169
data_batches = []
175170
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
176-
data_batch = self._prepare_batch(input_fields, data, tuple_idx, input_prediction)
171+
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
177172
data_batches.append(data_batch)
178173
return data_batches
179174

0 commit comments

Comments
 (0)