Skip to content

Commit 95fabd1

Browse files
authored
Apply suggestions from code review
1 parent ddae49b commit 95fabd1

10 files changed

Lines changed: 10 additions & 10 deletions

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
8181
if self._step == 0:
8282
if self.adaptive_projected_guidance_momentum is not None:
8383
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
84-
tuple_indices = [0] if self.num_conditions == 1 or not self._is_apg_enabled() else [0, 1]
84+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
8585
data_batches = []
8686
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
8787
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)

src/diffusers/guiders/adaptive_projected_guidance_mix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
9393
if self.adaptive_projected_guidance_momentum is not None:
9494
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
9595
tuple_indices = (
96-
[0] if self.num_conditions == 1 or not self._is_apg_enabled() and not self._is_cfg_enabled() else [0, 1]
96+
[0] if self.num_conditions == 1 else [0, 1]
9797
)
9898
data_batches = []
9999
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):

src/diffusers/guiders/auto_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
134134
registry.remove_hook(name, recurse=True)
135135

136136
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
137-
tuple_indices = [0] if self.num_conditions == 1 or not self._is_ag_enabled() else [0, 1]
137+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
138138
data_batches = []
139139
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
140140
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292
self.use_original_formulation = use_original_formulation
9393

9494
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
95-
tuple_indices = [0] if self.num_conditions == 1 or not self._is_cfg_enabled() else [0, 1]
95+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
9696
data_batches = []
9797
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
9898
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
self.use_original_formulation = use_original_formulation
7979

8080
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
81-
tuple_indices = [0] if self.num_conditions == 1 or not self._is_cfg_enabled() else [0, 1]
81+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
8282
data_batches = []
8383
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
8484
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(
219219
)
220220

221221
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
222-
tuple_indices = [0] if self.num_conditions == 1 or not self._is_fdg_enabled() else [0, 1]
222+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
223223
data_batches = []
224224
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
225225
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
170170

171171
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
172172
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
173-
if self.num_conditions == 1 or not self._is_cfg_enabled() and not self._is_slg_enabled():
173+
if self.num_conditions == 1:
174174
tuple_indices = [0]
175175
input_predictions = ["pred_cond"]
176176
elif self.num_conditions == 2:

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
166166
registry.remove_hook(hook_name, recurse=True)
167167

168168
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
169-
if self.num_conditions == 1 or not self._is_cfg_enabled() and not self._is_slg_enabled():
169+
if self.num_conditions == 1:
170170
tuple_indices = [0]
171171
input_predictions = ["pred_cond"]
172172
elif self.num_conditions == 2:

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def cleanup_models(self, denoiser: torch.nn.Module):
155155
registry.remove_hook(hook_name, recurse=True)
156156

157157
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
158-
if self.num_conditions == 1 or not self._is_cfg_enabled() and not self._is_seg_enabled():
158+
if self.num_conditions == 1:
159159
tuple_indices = [0]
160160
input_predictions = ["pred_cond"]
161161
elif self.num_conditions == 2:

src/diffusers/guiders/tangential_classifier_free_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
self.use_original_formulation = use_original_formulation
6868

6969
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
70-
tuple_indices = [0] if self.num_conditions == 1 or not self._is_tcfg_enabled() else [0, 1]
70+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
7171
data_batches = []
7272
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
7373
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)

0 commit comments

Comments
 (0)