Skip to content

Commit 7d2a633

Browse files
committed
style
1 parent cb328d3 commit 7d2a633

28 files changed

Lines changed: 825 additions & 1449 deletions

src/diffusers/__init__.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,8 @@
794794
LayerSkipConfig,
795795
PyramidAttentionBroadcastConfig,
796796
SmoothedEnergyGuidanceConfig,
797-
apply_layer_skip,
798797
apply_faster_cache,
798+
apply_layer_skip,
799799
apply_pyramid_attention_broadcast,
800800
)
801801
from .models import (
@@ -875,6 +875,13 @@
875875
WanTransformer3DModel,
876876
WanVACETransformer3DModel,
877877
)
878+
from .modular_pipelines import (
879+
ComponentsManager,
880+
ComponentSpec,
881+
ModularLoader,
882+
ModularPipeline,
883+
ModularPipelineBlocks,
884+
)
878885
from .optimization import (
879886
get_constant_schedule,
880887
get_constant_schedule_with_warmup,
@@ -907,13 +914,6 @@
907914
ScoreSdeVePipeline,
908915
StableDiffusionMixin,
909916
)
910-
from .modular_pipelines import (
911-
ModularLoader,
912-
ModularPipeline,
913-
ModularPipelineBlocks,
914-
ComponentSpec,
915-
ComponentsManager,
916-
)
917917
from .quantizers import DiffusersQuantizer
918918
from .schedulers import (
919919
AmusedScheduler,
@@ -978,6 +978,10 @@
978978
except OptionalDependencyNotAvailable:
979979
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
980980
else:
981+
from .modular_pipelines import (
982+
StableDiffusionXLAutoPipeline,
983+
StableDiffusionXLModularLoader,
984+
)
981985
from .pipelines import (
982986
AllegroPipeline,
983987
AltDiffusionImg2ImgPipeline,
@@ -1182,10 +1186,6 @@
11821186
WuerstchenDecoderPipeline,
11831187
WuerstchenPriorPipeline,
11841188
)
1185-
from .modular_pipelines import (
1186-
StableDiffusionXLAutoPipeline,
1187-
StableDiffusionXLModularLoader,
1188-
)
11891189

11901190
try:
11911191
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):

src/diffusers/commands/custom_blocks.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
"""
1919

2020
import ast
21-
from argparse import ArgumentParser, Namespace
22-
from pathlib import Path
2321
import importlib.util
2422
import os
23+
from argparse import ArgumentParser, Namespace
24+
from pathlib import Path
25+
2526
from ..utils import logging
2627
from . import BaseDiffusersCLICommand
2728

@@ -57,7 +58,7 @@ def run(self):
5758
# determine the block to be saved.
5859
out = self._get_class_names(self.block_module_name)
5960
classes_found = list({cls for cls, _ in out})
60-
61+
6162
if self.block_class_name is not None:
6263
child_class, parent_class = self._choose_block(out, self.block_class_name)
6364
if child_class is None and parent_class is None:
@@ -125,9 +126,9 @@ def _get_base_name(self, node: ast.expr):
125126
val = self._get_base_name(node.value)
126127
return f"{val}.{node.attr}" if val else node.attr
127128
return None
128-
129+
129130
def _create_automap(self, parent_class, child_class):
130131
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
131132
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
132133
return {"auto_map": auto_map}
133-
134+

src/diffusers/commands/diffusers_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from argparse import ArgumentParser
1717

18+
from .custom_blocks import CustomBlocksCommand
1819
from .env import EnvironmentCommand
1920
from .fp16_safetensors import FP16SafetensorsCommand
20-
from .custom_blocks import CustomBlocksCommand
2121

2222

2323
def main():

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

2020
from .guider_utils import BaseGuidance, rescale_noise_cfg
2121

22+
2223
if TYPE_CHECKING:
2324
from ..modular_pipelines.modular_pipeline import BlockState
2425

@@ -74,10 +75,10 @@ def __init__(
7475
self.momentum_buffer = None
7576

7677
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
77-
78+
7879
if input_fields is None:
7980
input_fields = self._input_fields
80-
81+
8182
if self._step == 0:
8283
if self.adaptive_projected_guidance_momentum is not None:
8384
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
@@ -123,19 +124,19 @@ def num_conditions(self) -> int:
123124
def _is_apg_enabled(self) -> bool:
124125
if not self._enabled:
125126
return False
126-
127+
127128
is_within_range = True
128129
if self._num_inference_steps is not None:
129130
skip_start_step = int(self._start * self._num_inference_steps)
130131
skip_stop_step = int(self._stop * self._num_inference_steps)
131132
is_within_range = skip_start_step <= self._step < skip_stop_step
132-
133+
133134
is_close = False
134135
if self.use_original_formulation:
135136
is_close = math.isclose(self.guidance_scale, 0.0)
136137
else:
137138
is_close = math.isclose(self.guidance_scale, 1.0)
138-
139+
139140
return is_within_range and not is_close
140141

141142

@@ -160,25 +161,25 @@ def normalized_guidance(
160161
):
161162
diff = pred_cond - pred_uncond
162163
dim = [-i for i in range(1, len(diff.shape))]
163-
164+
164165
if momentum_buffer is not None:
165166
momentum_buffer.update(diff)
166167
diff = momentum_buffer.running_average
167-
168+
168169
if norm_threshold > 0:
169170
ones = torch.ones_like(diff)
170171
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
171172
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
172173
diff = diff * scale_factor
173-
174+
174175
v0, v1 = diff.double(), pred_cond.double()
175176
v1 = torch.nn.functional.normalize(v1, dim=dim)
176177
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
177178
v0_orthogonal = v0 - v0_parallel
178179
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
179180
normalized_update = diff_orthogonal + eta * diff_parallel
180-
181+
181182
pred = pred_cond if use_original_formulation else pred_uncond
182183
pred = pred + guidance_scale * normalized_update
183-
184+
184185
return pred

src/diffusers/guiders/auto_guidance.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

2020
from ..hooks import HookRegistry, LayerSkipConfig
2121
from ..hooks.layer_skip import _apply_layer_skip_hook
2222
from .guider_utils import BaseGuidance, rescale_noise_cfg
2323

24+
2425
if TYPE_CHECKING:
2526
from ..modular_pipelines.modular_pipeline import BlockState
2627

@@ -113,18 +114,18 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None:
113114
if self._is_ag_enabled() and self.is_unconditional:
114115
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
115116
_apply_layer_skip_hook(denoiser, config, name=name)
116-
117+
117118
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
118119
if self._is_ag_enabled() and self.is_unconditional:
119120
for name in self._auto_guidance_hook_names:
120121
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
121122
registry.remove_hook(name, recurse=True)
122-
123+
123124
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
124-
125+
125126
if input_fields is None:
126127
input_fields = self._input_fields
127-
128+
128129
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
129130
data_batches = []
130131
for i in range(self.num_conditions):
@@ -144,9 +145,9 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
144145

145146
if self.guidance_rescale > 0.0:
146147
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
147-
148+
148149
return pred, {}
149-
150+
150151
@property
151152
def is_conditional(self) -> bool:
152153
return self._count_prepared == 1
@@ -161,17 +162,17 @@ def num_conditions(self) -> int:
161162
def _is_ag_enabled(self) -> bool:
162163
if not self._enabled:
163164
return False
164-
165+
165166
is_within_range = True
166167
if self._num_inference_steps is not None:
167168
skip_start_step = int(self._start * self._num_inference_steps)
168169
skip_stop_step = int(self._stop * self._num_inference_steps)
169170
is_within_range = skip_start_step <= self._step < skip_stop_step
170-
171+
171172
is_close = False
172173
if self.use_original_formulation:
173174
is_close = math.isclose(self.guidance_scale, 0.0)
174175
else:
175176
is_close = math.isclose(self.guidance_scale, 1.0)
176-
177+
177178
return is_within_range and not is_close

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

2020
from .guider_utils import BaseGuidance, rescale_noise_cfg
2121

22+
2223
if TYPE_CHECKING:
2324
from ..modular_pipelines.modular_pipeline import BlockState
2425

@@ -74,12 +75,12 @@ def __init__(
7475
self.guidance_scale = guidance_scale
7576
self.guidance_rescale = guidance_rescale
7677
self.use_original_formulation = use_original_formulation
77-
78+
7879
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
79-
80+
8081
if input_fields is None:
8182
input_fields = self._input_fields
82-
83+
8384
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
8485
data_batches = []
8586
for i in range(self.num_conditions):
@@ -116,17 +117,17 @@ def num_conditions(self) -> int:
116117
def _is_cfg_enabled(self) -> bool:
117118
if not self._enabled:
118119
return False
119-
120+
120121
is_within_range = True
121122
if self._num_inference_steps is not None:
122123
skip_start_step = int(self._start * self._num_inference_steps)
123124
skip_stop_step = int(self._stop * self._num_inference_steps)
124125
is_within_range = skip_start_step <= self._step < skip_stop_step
125-
126+
126127
is_close = False
127128
if self.use_original_formulation:
128129
is_close = math.isclose(self.guidance_scale, 0.0)
129130
else:
130131
is_close = math.isclose(self.guidance_scale, 1.0)
131-
132+
132133
return is_within_range and not is_close

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

2020
from .guider_utils import BaseGuidance, rescale_noise_cfg
2121

22+
2223
if TYPE_CHECKING:
2324
from ..modular_pipelines.modular_pipeline import BlockState
2425

@@ -72,12 +73,12 @@ def __init__(
7273
self.zero_init_steps = zero_init_steps
7374
self.guidance_rescale = guidance_rescale
7475
self.use_original_formulation = use_original_formulation
75-
76+
7677
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
77-
78+
7879
if input_fields is None:
7980
input_fields = self._input_fields
80-
81+
8182
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
8283
data_batches = []
8384
for i in range(self.num_conditions):
@@ -106,7 +107,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
106107
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
107108

108109
return pred, {}
109-
110+
110111
@property
111112
def is_conditional(self) -> bool:
112113
return self._count_prepared == 1
@@ -121,19 +122,19 @@ def num_conditions(self) -> int:
121122
def _is_cfg_enabled(self) -> bool:
122123
if not self._enabled:
123124
return False
124-
125+
125126
is_within_range = True
126127
if self._num_inference_steps is not None:
127128
skip_start_step = int(self._start * self._num_inference_steps)
128129
skip_stop_step = int(self._stop * self._num_inference_steps)
129130
is_within_range = skip_start_step <= self._step < skip_stop_step
130-
131+
131132
is_close = False
132133
if self.use_original_formulation:
133134
is_close = math.isclose(self.guidance_scale, 0.0)
134135
else:
135136
is_close = math.isclose(self.guidance_scale, 1.0)
136-
137+
137138
return is_within_range and not is_close
138139

139140

0 commit comments

Comments
 (0)