Skip to content

Commit 49ea4d1

Browse files
committed
style
1 parent 58dbe0c commit 49ea4d1

5 files changed

Lines changed: 81 additions & 53 deletions

File tree

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,35 @@ def search_best_candidate(module_sizes, min_memory_offload):
232232

233233

234234
class ComponentsManager:
235-
_available_info_fields = ["model_id", "added_time", "collection", "class_name", "size_gb", "adapters", "has_hook", "execution_device", "ip_adapter"]
236-
235+
_available_info_fields = [
236+
"model_id",
237+
"added_time",
238+
"collection",
239+
"class_name",
240+
"size_gb",
241+
"adapters",
242+
"has_hook",
243+
"execution_device",
244+
"ip_adapter",
245+
]
246+
237247
def __init__(self):
238248
self.components = OrderedDict()
239249
self.added_time = OrderedDict() # Store when components were added
240250
self.collections = OrderedDict() # collection_name -> set of component_names
241251
self.model_hooks = None
242252
self._auto_offload_enabled = False
243253

244-
def _lookup_ids(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None, components: Optional[OrderedDict] = None):
254+
def _lookup_ids(
255+
self,
256+
name: Optional[str] = None,
257+
collection: Optional[str] = None,
258+
load_id: Optional[str] = None,
259+
components: Optional[OrderedDict] = None,
260+
):
245261
"""
246-
Lookup component_ids by name, collection, or load_id. Does not support pattern matching.
247-
Returns a set of component_ids
262+
Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of
263+
component_ids
248264
"""
249265
if components is None:
250266
components = self.components
@@ -318,10 +334,14 @@ def add(self, name, component, collection: Optional[str] = None):
318334
if component_id not in self.collections[collection]:
319335
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
320336
for comp_id in comp_ids_in_collection:
321-
logger.warning(f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}")
337+
logger.warning(
338+
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
339+
)
322340
self.remove(comp_id)
323341
self.collections[collection].add(component_id)
324-
logger.info(f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}")
342+
logger.info(
343+
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
344+
)
325345
else:
326346
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
327347

@@ -379,40 +399,43 @@ def search_components(
379399
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
380400
collection: Optional collection to filter by
381401
load_id: Optional load_id to filter by
382-
return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found
383-
If False, returns a dictionary with component IDs as keys
402+
return_dict_with_names:
403+
If True, returns a dictionary with component names as keys, throw an error if
404+
multiple components with the same name are found If False, returns a dictionary
405+
with component IDs as keys
384406
385407
Returns:
386-
Dictionary mapping component names to components if return_dict_with_names=True,
387-
or a dictionary mapping component IDs to components if return_dict_with_names=False
408+
Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping
409+
component IDs to components if return_dict_with_names=False
388410
"""
389411

390412
# select components based on collection and load_id filters
391413
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
392414
components = {k: self.components[k] for k in selected_ids}
393-
415+
394416
def get_return_dict(components, return_dict_with_names):
395417
"""
396-
Create a dictionary mapping component names to components if return_dict_with_names=True,
397-
or a dictionary mapping component IDs to components if return_dict_with_names=False,
398-
throw an error if duplicate component names are found when return_dict_with_names=True
418+
Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary
419+
mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component
420+
names are found when return_dict_with_names=True
399421
"""
400422
if return_dict_with_names:
401423
dict_to_return = {}
402424
for comp_id, comp in components.items():
403425
comp_name = self._id_to_name(comp_id)
404426
if comp_name in dict_to_return:
405-
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
427+
raise ValueError(
428+
f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
429+
)
406430
dict_to_return[comp_name] = comp
407431
return dict_to_return
408432
else:
409433
return components
410434

411-
412435
# if no names are provided, return the filtered components as it is
413436
if names is None:
414437
return get_return_dict(components, return_dict_with_names)
415-
438+
416439
# if names is not a string, raise an error
417440
elif not isinstance(names, str):
418441
raise ValueError(f"Invalid type for `names: {type(names)}, only support string")
@@ -488,9 +511,7 @@ def matches_pattern(component_id, pattern, exact_match=False):
488511
}
489512

490513
if is_not_pattern:
491-
logger.info(
492-
f"Getting all components except those with base name '{names}': {list(matches.keys())}"
493-
)
514+
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
494515
else:
495516
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
496517

@@ -584,8 +605,8 @@ def disable_auto_cpu_offload(self):
584605

585606
# YiYi TODO: (1) add quantization info
586607
def get_model_info(
587-
self,
588-
component_id: str,
608+
self,
609+
component_id: str,
589610
fields: Optional[Union[str, List[str]]] = None,
590611
) -> Optional[Dict[str, Any]]:
591612
"""Get comprehensive information about a component.
@@ -603,7 +624,7 @@ def get_model_info(
603624
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
604625

605626
component = self.components[component_id]
606-
627+
607628
# Validate fields if specified
608629
if fields is not None:
609630
if isinstance(fields, str):
@@ -662,7 +683,7 @@ def get_model_info(
662683
return {k: v for k, v in info.items() if k in fields}
663684
else:
664685
return info
665-
686+
666687
# YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
667688
def __repr__(self):
668689
# Handle empty components case
@@ -820,11 +841,9 @@ def get_one(
820841
load_id: Optional[str] = None,
821842
) -> Any:
822843
"""
823-
Get a single component by either:
824-
(1) searching name (pattern matching), collection, or load_id.
825-
(2) passing in a component_id
826-
Raises an error if multiple components match or none are found.
827-
support pattern matching for name
844+
Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in
845+
a component_id Raises an error if multiple components match or none are found. support pattern matching for
846+
name
828847
829848
Args:
830849
component_id: Optional component ID to get
@@ -841,7 +860,7 @@ def get_one(
841860

842861
if component_id is not None and (name is not None or collection is not None or load_id is not None):
843862
raise ValueError("If searching by component_id, do not pass name, collection, or load_id")
844-
863+
845864
# search by component_id
846865
if component_id is not None:
847866
if component_id not in self.components:
@@ -857,7 +876,6 @@ def get_one(
857876
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
858877

859878
return next(iter(results.values()))
860-
861879

862880
def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
863881
"""
@@ -869,7 +887,7 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str]
869887
for name in names:
870888
ids.update(self._lookup_ids(name=name, collection=collection))
871889
return list(ids)
872-
890+
873891
def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
874892
"""
875893
Get components by a list of IDs.
@@ -881,7 +899,9 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional
881899
for comp_id, comp in components.items():
882900
comp_name = self._id_to_name(comp_id)
883901
if comp_name in dict_to_return:
884-
raise ValueError(f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys")
902+
raise ValueError(
903+
f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
904+
)
885905
dict_to_return[comp_name] = comp
886906
return dict_to_return
887907
else:
@@ -894,6 +914,7 @@ def get_components_by_names(self, names: List[str], collection: Optional[str] =
894914
ids = self.get_ids(names, collection)
895915
return self.get_components_by_ids(ids)
896916

917+
897918
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
898919
"""Summarizes a dictionary by finding common prefixes that share the same value.
899920

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,23 +1849,30 @@ def dtype(self) -> torch.dtype:
18491849
return module.dtype
18501850

18511851
return torch.float32
1852-
1852+
18531853
@property
18541854
def null_component_names(self) -> List[str]:
18551855
return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None]
1856-
1856+
18571857
@property
18581858
def component_names(self) -> List[str]:
18591859
return list(self.components.keys())
1860-
1860+
18611861
@property
18621862
def pretrained_component_names(self) -> List[str]:
1863-
return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"]
1864-
1863+
return [
1864+
name
1865+
for name in self._component_specs.keys()
1866+
if self._component_specs[name].default_creation_method == "from_pretrained"
1867+
]
1868+
18651869
@property
18661870
def config_component_names(self) -> List[str]:
1867-
return [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_config"]
1868-
1871+
return [
1872+
name
1873+
for name in self._component_specs.keys()
1874+
if self._component_specs[name].default_creation_method == "from_config"
1875+
]
18691876

18701877
@property
18711878
def components(self) -> Dict[str, Any]:
@@ -2430,9 +2437,13 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
24302437
raise ValueError(f"Output '{output}' is not a valid output type")
24312438

24322439
def load_default_components(self, **kwargs):
2433-
names = [name for name in self.loader._component_specs.keys() if self.loader._component_specs[name].default_creation_method == "from_pretrained"]
2440+
names = [
2441+
name
2442+
for name in self.loader._component_specs.keys()
2443+
if self.loader._component_specs[name].default_creation_method == "from_pretrained"
2444+
]
24342445
self.loader.load(names=names, **kwargs)
2435-
2446+
24362447
def load_components(self, names: Union[List[str], str], **kwargs):
24372448
self.loader.load(names=names, **kwargs)
24382449

src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@
2323
else:
2424
_import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
2525
_import_structure["modular_blocks_presets"] = [
26+
"ALL_BLOCKS",
2627
"AUTO_BLOCKS",
2728
"CONTROLNET_BLOCKS",
2829
"IMAGE2IMAGE_BLOCKS",
2930
"INPAINT_BLOCKS",
3031
"IP_ADAPTER_BLOCKS",
31-
"ALL_BLOCKS",
3232
"TEXT2IMAGE_BLOCKS",
3333
"StableDiffusionXLAutoBlocks",
34+
"StableDiffusionXLAutoControlnetStep",
3435
"StableDiffusionXLAutoDecodeStep",
3536
"StableDiffusionXLAutoIPAdapterStep",
3637
"StableDiffusionXLAutoVaeEncoderStep",
37-
"StableDiffusionXLAutoControlnetStep",
3838
]
3939
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
4040

@@ -49,18 +49,18 @@
4949
StableDiffusionXLTextEncoderStep,
5050
)
5151
from .modular_blocks_presets import (
52+
ALL_BLOCKS,
5253
AUTO_BLOCKS,
5354
CONTROLNET_BLOCKS,
5455
IMAGE2IMAGE_BLOCKS,
5556
INPAINT_BLOCKS,
5657
IP_ADAPTER_BLOCKS,
57-
ALL_BLOCKS,
5858
TEXT2IMAGE_BLOCKS,
5959
StableDiffusionXLAutoBlocks,
60+
StableDiffusionXLAutoControlnetStep,
6061
StableDiffusionXLAutoDecodeStep,
6162
StableDiffusionXLAutoIPAdapterStep,
6263
StableDiffusionXLAutoVaeEncoderStep,
63-
StableDiffusionXLAutoControlnetStep,
6464
)
6565
from .modular_loader import StableDiffusionXLModularLoader
6666
else:

src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
7676

7777
@property
7878
def description(self):
79-
return (
80-
"Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
81-
)
79+
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
8280

8381

8482
# before_denoise: text2img

src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,14 @@ class StableDiffusionXLModularLoader(
4444
StableDiffusionXLLoraLoaderMixin,
4545
ModularIPAdapterMixin,
4646
):
47-
4847
@property
4948
def default_height(self):
5049
return self.default_sample_size * self.vae_scale_factor
5150

5251
@property
5352
def default_width(self):
5453
return self.default_sample_size * self.vae_scale_factor
55-
56-
54+
5755
@property
5856
def default_sample_size(self):
5957
default_sample_size = 128

0 commit comments

Comments
 (0)