Skip to content

Commit b44e7f5

Browse files
committed
update mellon node config, adding* to required_inputs and required_model_inputs
1 parent 64415ab commit b44e7f5

1 file changed

Lines changed: 19 additions & 56 deletions

File tree

src/diffusers/modular_pipelines/mellon_node_utils.py

Lines changed: 19 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44

55
# Simple typed wrapper for parameter overrides
6-
from dataclasses import asdict, dataclass
6+
from dataclasses import asdict, dataclass, field
77
from typing import Any, Dict, List, Optional, Tuple, Union
88

99
from huggingface_hub import create_repo, hf_hub_download
@@ -244,7 +244,7 @@
244244
"outputs": [
245245
"controlnet",
246246
],
247-
"block_names": ["controlnet_vae_encoder"],
247+
"block_name": "controlnet_vae_encoder",
248248
},
249249
"denoise": {
250250
"inputs": [
@@ -270,7 +270,7 @@
270270
"latents",
271271
"latents_preview",
272272
],
273-
"block_names": ["denoise"],
273+
"block_name": "denoise",
274274
},
275275
"vae_encoder": {
276276
"inputs": [
@@ -284,7 +284,7 @@
284284
"outputs": [
285285
"image_latents",
286286
],
287-
"block_names": ["vae_encoder"],
287+
"block_name": "vae_encoder",
288288
},
289289
"text_encoder": {
290290
"inputs": [
@@ -299,7 +299,7 @@
299299
"outputs": [
300300
"embeddings",
301301
],
302-
"block_names": ["text_encoder"],
302+
"block_name": "text_encoder",
303303
},
304304
"decoder": {
305305
"inputs": [
@@ -311,7 +311,7 @@
311311
"outputs": [
312312
"images",
313313
],
314-
"block_names": ["decode"],
314+
"block_name": "decode",
315315
},
316316
}
317317

@@ -353,21 +353,24 @@ class MellonNodeConfig(PushToHubMixin):
353353
inputs: List[Union[str, MellonParam]]
354354
model_inputs: List[Union[str, MellonParam]]
355355
outputs: List[Union[str, MellonParam]]
356-
blocks_names: list[str]
356+
block_name: str
357357
node_type: str
358+
required_inputs: List[str] = field(default_factory=list)
359+
required_model_inputs: List[str] = field(default_factory=list)
358360
config_name = "mellon_config.json"
359361

360362
def __post_init__(self):
363+
361364
if isinstance(self.inputs, list):
362-
self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS)
365+
self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS, required=self.required_inputs)
363366
if isinstance(self.model_inputs, list):
364-
self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS)
367+
self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS, required=self.required_model_inputs)
365368
if isinstance(self.outputs, list):
366369
self.outputs = self._resolve_params_list(self.outputs, MELLON_OUTPUT_PARAMS)
367370

368371
@staticmethod
369372
def _resolve_params_list(
370-
params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]]
373+
params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]], required: Optional[List[str]] = None
371374
) -> Dict[str, Dict[str, Any]]:
372375
def _resolve_param(
373376
param: Union[str, MellonParam], default_params_map: Dict[str, Dict[str, Any]]
@@ -392,6 +395,10 @@ def _resolve_param(
392395
if name in resolved:
393396
raise ValueError(f"Duplicate param '{name}'")
394397
resolved[name] = cfg
398+
if required is not None:
399+
for name in required:
400+
if name in resolved and not resolved[name]["label"].endswith(" *"):
401+
resolved[name]["label"] = f"{resolved[name]['label']} *"
395402
return resolved
396403

397404
@classmethod
@@ -625,7 +632,7 @@ def to_mellon_dict(self) -> Dict[str, Any]:
625632

626633
return {
627634
"node_type": self.node_type,
628-
"blocks_names": self.blocks_names,
635+
"block_name": self.block_name,
629636
"params": merged_params,
630637
}
631638

@@ -655,50 +662,6 @@ def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig":
655662
inputs=inputs,
656663
model_inputs=model_inputs,
657664
outputs=outputs,
658-
blocks_names=mellon_dict.get("blocks_names", []),
665+
block_name=mellon_dict.get("block_name", None),
659666
node_type=mellon_dict.get("node_type"),
660667
)
661-
662-
# YiYi Notes: not used yet
663-
@classmethod
664-
def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNodeConfig":
665-
"""
666-
Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type,
667-
use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs.
668-
"""
669-
if node_type not in NODE_TYPE_PARAMS_MAP:
670-
raise ValueError(f"Node type {node_type} not supported")
671-
672-
blocks_names = list(blocks.sub_blocks.keys())
673-
674-
default_node_config = NODE_TYPE_PARAMS_MAP[node_type]
675-
inputs_list: List[Union[str, MellonParam]] = default_node_config.get("inputs", [])
676-
model_inputs_list: List[Union[str, MellonParam]] = default_node_config.get("model_inputs", [])
677-
outputs_list: List[Union[str, MellonParam]] = default_node_config.get("outputs", [])
678-
679-
for required_input_name in blocks.required_inputs:
680-
if required_input_name not in inputs_list:
681-
inputs_list.append(
682-
MellonParam(
683-
name=required_input_name, label=required_input_name, type=required_input_name, display="input"
684-
)
685-
)
686-
687-
for component_spec in blocks.expected_components:
688-
if component_spec.name not in model_inputs_list:
689-
model_inputs_list.append(
690-
MellonParam(
691-
name=component_spec.name,
692-
label=component_spec.name,
693-
type="diffusers_auto_model",
694-
display="input",
695-
)
696-
)
697-
698-
return cls(
699-
inputs=inputs_list,
700-
model_inputs=model_inputs_list,
701-
outputs=outputs_list,
702-
blocks_names=blocks_names,
703-
node_type=node_type,
704-
)

0 commit comments

Comments
 (0)