@@ -232,19 +232,35 @@ def search_best_candidate(module_sizes, min_memory_offload):
232232
233233
234234class ComponentsManager :
235- _available_info_fields = ["model_id" , "added_time" , "collection" , "class_name" , "size_gb" , "adapters" , "has_hook" , "execution_device" , "ip_adapter" ]
236-
235+ _available_info_fields = [
236+ "model_id" ,
237+ "added_time" ,
238+ "collection" ,
239+ "class_name" ,
240+ "size_gb" ,
241+ "adapters" ,
242+ "has_hook" ,
243+ "execution_device" ,
244+ "ip_adapter" ,
245+ ]
246+
237247 def __init__ (self ):
238248 self .components = OrderedDict ()
239249 self .added_time = OrderedDict () # Store when components were added
240250 self .collections = OrderedDict () # collection_name -> set of component_names
241251 self .model_hooks = None
242252 self ._auto_offload_enabled = False
243253
244- def _lookup_ids (self , name : Optional [str ] = None , collection : Optional [str ] = None , load_id : Optional [str ] = None , components : Optional [OrderedDict ] = None ):
254+ def _lookup_ids (
255+ self ,
256+ name : Optional [str ] = None ,
257+ collection : Optional [str ] = None ,
258+ load_id : Optional [str ] = None ,
259+ components : Optional [OrderedDict ] = None ,
260+ ):
245261 """
246- Lookup component_ids by name, collection, or load_id. Does not support pattern matching.
247- Returns a set of component_ids
262+ Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of
263+ component_ids
248264 """
249265 if components is None :
250266 components = self .components
@@ -318,10 +334,14 @@ def add(self, name, component, collection: Optional[str] = None):
318334 if component_id not in self .collections [collection ]:
319335 comp_ids_in_collection = self ._lookup_ids (name = name , collection = collection )
320336 for comp_id in comp_ids_in_collection :
321- logger .warning (f"ComponentsManager: removing existing { name } from collection '{ collection } ': { comp_id } " )
337+ logger .warning (
338+ f"ComponentsManager: removing existing { name } from collection '{ collection } ': { comp_id } "
339+ )
322340 self .remove (comp_id )
323341 self .collections [collection ].add (component_id )
324- logger .info (f"ComponentsManager: added component '{ name } ' in collection '{ collection } ': { component_id } " )
342+ logger .info (
343+ f"ComponentsManager: added component '{ name } ' in collection '{ collection } ': { component_id } "
344+ )
325345 else :
326346 logger .info (f"ComponentsManager: added component '{ name } ' as '{ component_id } '" )
327347
@@ -379,40 +399,43 @@ def search_components(
379399 - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
380400 collection: Optional collection to filter by
381401 load_id: Optional load_id to filter by
382- return_dict_with_names: If True, returns a dictionary with component names as keys, throw an error if multiple components with the same name are found
383- If False, returns a dictionary with component IDs as keys
402+ return_dict_with_names:
403+ If True, returns a dictionary with component names as keys, throw an error if
404+ multiple components with the same name are found If False, returns a dictionary
405+ with component IDs as keys
384406
385407 Returns:
386- Dictionary mapping component names to components if return_dict_with_names=True,
387- or a dictionary mapping component IDs to components if return_dict_with_names=False
408+ Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping
409+ component IDs to components if return_dict_with_names=False
388410 """
389411
390412 # select components based on collection and load_id filters
391413 selected_ids = self ._lookup_ids (collection = collection , load_id = load_id )
392414 components = {k : self .components [k ] for k in selected_ids }
393-
415+
394416 def get_return_dict (components , return_dict_with_names ):
395417 """
396- Create a dictionary mapping component names to components if return_dict_with_names=True,
397- or a dictionary mapping component IDs to components if return_dict_with_names=False,
398- throw an error if duplicate component names are found when return_dict_with_names=True
418+ Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary
419+ mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component
420+ names are found when return_dict_with_names=True
399421 """
400422 if return_dict_with_names :
401423 dict_to_return = {}
402424 for comp_id , comp in components .items ():
403425 comp_name = self ._id_to_name (comp_id )
404426 if comp_name in dict_to_return :
405- raise ValueError (f"Duplicate component names found in the search results: { comp_name } , please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" )
427+ raise ValueError (
428+ f"Duplicate component names found in the search results: { comp_name } , please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
429+ )
406430 dict_to_return [comp_name ] = comp
407431 return dict_to_return
408432 else :
409433 return components
410434
411-
412435 # if no names are provided, return the filtered components as it is
413436 if names is None :
414437 return get_return_dict (components , return_dict_with_names )
415-
438+
416439 # if names is not a string, raise an error
417440 elif not isinstance (names , str ):
418441 raise ValueError (f"Invalid type for `names: { type (names )} , only support string" )
@@ -488,9 +511,7 @@ def matches_pattern(component_id, pattern, exact_match=False):
488511 }
489512
490513 if is_not_pattern :
491- logger .info (
492- f"Getting all components except those with base name '{ names } ': { list (matches .keys ())} "
493- )
514+ logger .info (f"Getting all components except those with base name '{ names } ': { list (matches .keys ())} " )
494515 else :
495516 logger .info (f"Getting components with base name '{ names } ': { list (matches .keys ())} " )
496517
@@ -584,8 +605,8 @@ def disable_auto_cpu_offload(self):
584605
585606 # YiYi TODO: (1) add quantization info
586607 def get_model_info (
587- self ,
588- component_id : str ,
608+ self ,
609+ component_id : str ,
589610 fields : Optional [Union [str , List [str ]]] = None ,
590611 ) -> Optional [Dict [str , Any ]]:
591612 """Get comprehensive information about a component.
@@ -603,7 +624,7 @@ def get_model_info(
603624 raise ValueError (f"Component '{ component_id } ' not found in ComponentsManager" )
604625
605626 component = self .components [component_id ]
606-
627+
607628 # Validate fields if specified
608629 if fields is not None :
609630 if isinstance (fields , str ):
@@ -662,7 +683,7 @@ def get_model_info(
662683 return {k : v for k , v in info .items () if k in fields }
663684 else :
664685 return info
665-
686+
666687 # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
667688 def __repr__ (self ):
668689 # Handle empty components case
@@ -820,11 +841,9 @@ def get_one(
820841 load_id : Optional [str ] = None ,
821842 ) -> Any :
822843 """
823- Get a single component by either:
824- (1) searching name (pattern matching), collection, or load_id.
825- (2) passing in a component_id
826- Raises an error if multiple components match or none are found.
827- support pattern matching for name
844+ Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in
845+ a component_id Raises an error if multiple components match or none are found. support pattern matching for
846+ name
828847
829848 Args:
830849 component_id: Optional component ID to get
@@ -841,7 +860,7 @@ def get_one(
841860
842861 if component_id is not None and (name is not None or collection is not None or load_id is not None ):
843862 raise ValueError ("If searching by component_id, do not pass name, collection, or load_id" )
844-
863+
845864 # search by component_id
846865 if component_id is not None :
847866 if component_id not in self .components :
@@ -857,7 +876,6 @@ def get_one(
857876 raise ValueError (f"Multiple components found matching '{ name } ': { list (results .keys ())} " )
858877
859878 return next (iter (results .values ()))
860-
861879
862880 def get_ids (self , names : Union [str , List [str ]] = None , collection : Optional [str ] = None ):
863881 """
@@ -869,7 +887,7 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str]
869887 for name in names :
870888 ids .update (self ._lookup_ids (name = name , collection = collection ))
871889 return list (ids )
872-
890+
873891 def get_components_by_ids (self , ids : List [str ], return_dict_with_names : Optional [bool ] = True ):
874892 """
875893 Get components by a list of IDs.
@@ -881,7 +899,9 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional
881899 for comp_id , comp in components .items ():
882900 comp_name = self ._id_to_name (comp_id )
883901 if comp_name in dict_to_return :
884- raise ValueError (f"Duplicate component names found in the search results: { comp_name } , please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" )
902+ raise ValueError (
903+ f"Duplicate component names found in the search results: { comp_name } , please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
904+ )
885905 dict_to_return [comp_name ] = comp
886906 return dict_to_return
887907 else :
@@ -894,6 +914,7 @@ def get_components_by_names(self, names: List[str], collection: Optional[str] =
894914 ids = self .get_ids (names , collection )
895915 return self .get_components_by_ids (ids )
896916
917+
897918def summarize_dict_by_value_and_parts (d : Dict [str , Any ]) -> Dict [str , Any ]:
898919 """Summarizes a dictionary by finding common prefixes that share the same value.
899920
0 commit comments