33import os
44
55# Simple typed wrapper for parameter overrides
6- from dataclasses import asdict , dataclass
6+ from dataclasses import asdict , dataclass , field
77from typing import Any , Dict , List , Optional , Tuple , Union
88
99from huggingface_hub import create_repo , hf_hub_download
244244 "outputs" : [
245245 "controlnet" ,
246246 ],
247- "block_names " : [ "controlnet_vae_encoder" ] ,
247+ "block_name " : "controlnet_vae_encoder" ,
248248 },
249249 "denoise" : {
250250 "inputs" : [
270270 "latents" ,
271271 "latents_preview" ,
272272 ],
273- "block_names " : [ "denoise" ] ,
273+ "block_name " : "denoise" ,
274274 },
275275 "vae_encoder" : {
276276 "inputs" : [
284284 "outputs" : [
285285 "image_latents" ,
286286 ],
287- "block_names " : [ "vae_encoder" ] ,
287+ "block_name " : "vae_encoder" ,
288288 },
289289 "text_encoder" : {
290290 "inputs" : [
299299 "outputs" : [
300300 "embeddings" ,
301301 ],
302- "block_names " : [ "text_encoder" ] ,
302+ "block_name " : "text_encoder" ,
303303 },
304304 "decoder" : {
305305 "inputs" : [
311311 "outputs" : [
312312 "images" ,
313313 ],
314- "block_names " : [ "decode" ] ,
314+ "block_name " : "decode" ,
315315 },
316316}
317317
@@ -353,21 +353,24 @@ class MellonNodeConfig(PushToHubMixin):
353353 inputs : List [Union [str , MellonParam ]]
354354 model_inputs : List [Union [str , MellonParam ]]
355355 outputs : List [Union [str , MellonParam ]]
356- blocks_names : list [ str ]
356+ block_name : str
357357 node_type : str
358+ required_inputs : List [str ] = field (default_factory = list )
359+ required_model_inputs : List [str ] = field (default_factory = list )
358360 config_name = "mellon_config.json"
359361
360362 def __post_init__ (self ):
363+
361364 if isinstance (self .inputs , list ):
362- self .inputs = self ._resolve_params_list (self .inputs , MELLON_INPUT_PARAMS )
365+ self .inputs = self ._resolve_params_list (self .inputs , MELLON_INPUT_PARAMS , required = self . required_inputs )
363366 if isinstance (self .model_inputs , list ):
364- self .model_inputs = self ._resolve_params_list (self .model_inputs , MELLON_MODEL_PARAMS )
367+ self .model_inputs = self ._resolve_params_list (self .model_inputs , MELLON_MODEL_PARAMS , required = self . required_model_inputs )
365368 if isinstance (self .outputs , list ):
366369 self .outputs = self ._resolve_params_list (self .outputs , MELLON_OUTPUT_PARAMS )
367370
368371 @staticmethod
369372 def _resolve_params_list (
370- params : List [Union [str , MellonParam ]], default_map : Dict [str , Dict [str , Any ]]
373+ params : List [Union [str , MellonParam ]], default_map : Dict [str , Dict [str , Any ]], required : Optional [ List [ str ]] = None
371374 ) -> Dict [str , Dict [str , Any ]]:
372375 def _resolve_param (
373376 param : Union [str , MellonParam ], default_params_map : Dict [str , Dict [str , Any ]]
@@ -392,6 +395,10 @@ def _resolve_param(
392395 if name in resolved :
393396 raise ValueError (f"Duplicate param '{ name } '" )
394397 resolved [name ] = cfg
398+ if required is not None :
399+ for name in required :
400+ if name in resolved and not resolved [name ]["label" ].endswith (" *" ):
401+ resolved [name ]["label" ] = f"{ resolved [name ]['label' ]} *"
395402 return resolved
396403
397404 @classmethod
@@ -625,7 +632,7 @@ def to_mellon_dict(self) -> Dict[str, Any]:
625632
626633 return {
627634 "node_type" : self .node_type ,
628- "blocks_names " : self .blocks_names ,
635+ "block_name " : self .block_name ,
629636 "params" : merged_params ,
630637 }
631638
@@ -655,50 +662,6 @@ def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig":
655662 inputs = inputs ,
656663 model_inputs = model_inputs ,
657664 outputs = outputs ,
658- blocks_names = mellon_dict .get ("blocks_names " , [] ),
665+ block_name = mellon_dict .get ("block_name " , None ),
659666 node_type = mellon_dict .get ("node_type" ),
660667 )
661-
662- # YiYi Notes: not used yet
663- @classmethod
664- def from_blocks (cls , blocks : ModularPipelineBlocks , node_type : str ) -> "MellonNodeConfig" :
665- """
666- Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type,
667- use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs.
668- """
669- if node_type not in NODE_TYPE_PARAMS_MAP :
670- raise ValueError (f"Node type { node_type } not supported" )
671-
672- blocks_names = list (blocks .sub_blocks .keys ())
673-
674- default_node_config = NODE_TYPE_PARAMS_MAP [node_type ]
675- inputs_list : List [Union [str , MellonParam ]] = default_node_config .get ("inputs" , [])
676- model_inputs_list : List [Union [str , MellonParam ]] = default_node_config .get ("model_inputs" , [])
677- outputs_list : List [Union [str , MellonParam ]] = default_node_config .get ("outputs" , [])
678-
679- for required_input_name in blocks .required_inputs :
680- if required_input_name not in inputs_list :
681- inputs_list .append (
682- MellonParam (
683- name = required_input_name , label = required_input_name , type = required_input_name , display = "input"
684- )
685- )
686-
687- for component_spec in blocks .expected_components :
688- if component_spec .name not in model_inputs_list :
689- model_inputs_list .append (
690- MellonParam (
691- name = component_spec .name ,
692- label = component_spec .name ,
693- type = "diffusers_auto_model" ,
694- display = "input" ,
695- )
696- )
697-
698- return cls (
699- inputs = inputs_list ,
700- model_inputs = model_inputs_list ,
701- outputs = outputs_list ,
702- blocks_names = blocks_names ,
703- node_type = node_type ,
704- )
0 commit comments