Skip to content

Commit bbbcdd8

Browse files
yiyixuxuyiyi@huggingface.coyiyi@huggingface.co
authored
[modular]Update model card to include workflow (#13195)
* up * up * update * remove test --------- Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal> Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
1 parent 47e8faf commit bbbcdd8

File tree

4 files changed

+150
-115
lines changed

4 files changed

+150
-115
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,6 +1883,7 @@ def save_pretrained(
18831883
private = kwargs.pop("private", None)
18841884
create_pr = kwargs.pop("create_pr", False)
18851885
token = kwargs.pop("token", None)
1886+
update_model_card = kwargs.pop("update_model_card", False)
18861887
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
18871888

18881889
for component_name, component_spec in self._component_specs.items():
@@ -1957,6 +1958,7 @@ def save_pretrained(
19571958
is_pipeline=True,
19581959
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
19591960
is_modular=True,
1961+
update_model_card=update_model_card,
19601962
)
19611963
model_card = populate_model_card(model_card, tags=card_content["tags"])
19621964
model_card.save(os.path.join(save_directory, "README.md"))

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 124 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,7 @@
5050
5151
{components_description} {configs_section}
5252
53-
## Input/Output Specification
54-
55-
### Inputs {inputs_description}
56-
57-
### Outputs {outputs_description}
53+
{io_specification_section}
5854
"""
5955

6056

@@ -811,6 +807,46 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
811807
return format_params(output_params, "Outputs", indent_level, max_line_length)
812808

813809

810+
def format_params_markdown(params, header="Inputs"):
811+
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
812+
813+
Suitable for model cards rendered on Hugging Face Hub.
814+
815+
Args:
816+
params: list of InputParam or OutputParam objects to format
817+
header: Header text (e.g. "Inputs" or "Outputs")
818+
819+
Returns:
820+
A formatted markdown string, or empty string if params is empty.
821+
"""
822+
if not params:
823+
return ""
824+
825+
def get_type_str(type_hint):
826+
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
827+
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
828+
return " | ".join(type_strs)
829+
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
830+
831+
lines = [f"**{header}:**\n"] if header else []
832+
for param in params:
833+
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
834+
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
835+
param_str = f"- `{name}` (`{type_str}`"
836+
837+
if hasattr(param, "required") and not param.required:
838+
param_str += ", *optional*"
839+
if param.default is not None:
840+
param_str += f", defaults to `{param.default}`"
841+
param_str += ")"
842+
843+
desc = param.description if param.description else "No description provided"
844+
param_str += f": {desc}"
845+
lines.append(param_str)
846+
847+
return "\n".join(lines)
848+
849+
814850
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
815851
"""Format a list of ComponentSpec objects into a readable string representation.
816852
@@ -1067,8 +1103,7 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
10671103
- blocks_description: Detailed architecture of blocks
10681104
- components_description: List of required components
10691105
- configs_section: Configuration parameters section
1070-
- inputs_description: Input parameters specification
1071-
- outputs_description: Output parameters specification
1106+
- io_specification_section: Input/Output specification (per-workflow or unified)
10721107
- trigger_inputs_section: Conditional execution information
10731108
- tags: List of relevant tags for the model card
10741109
"""
@@ -1087,15 +1122,6 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
10871122
if block_desc:
10881123
blocks_desc_parts.append(f" - {block_desc}")
10891124

1090-
# add sub-blocks if any
1091-
if hasattr(block, "sub_blocks") and block.sub_blocks:
1092-
for sub_name, sub_block in block.sub_blocks.items():
1093-
sub_class = sub_block.__class__.__name__
1094-
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
1095-
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
1096-
if sub_desc:
1097-
blocks_desc_parts.append(f" - {sub_desc}")
1098-
10991125
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
11001126

11011127
components = getattr(blocks, "expected_components", [])
@@ -1121,63 +1147,76 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
11211147
if configs_description:
11221148
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
11231149

1124-
inputs = blocks.inputs
1125-
outputs = blocks.outputs
1126-
1127-
# format inputs as markdown list
1128-
inputs_parts = []
1129-
required_inputs = [inp for inp in inputs if inp.required]
1130-
optional_inputs = [inp for inp in inputs if not inp.required]
1131-
1132-
if required_inputs:
1133-
inputs_parts.append("**Required:**\n")
1134-
for inp in required_inputs:
1135-
if hasattr(inp.type_hint, "__name__"):
1136-
type_str = inp.type_hint.__name__
1137-
elif inp.type_hint is not None:
1138-
type_str = str(inp.type_hint).replace("typing.", "")
1139-
else:
1140-
type_str = "Any"
1141-
desc = inp.description or "No description provided"
1142-
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
1150+
# Branch on whether workflows are defined
1151+
has_workflows = getattr(blocks, "_workflow_map", None) is not None
11431152

1144-
if optional_inputs:
1145-
if required_inputs:
1146-
inputs_parts.append("")
1147-
inputs_parts.append("**Optional:**\n")
1148-
for inp in optional_inputs:
1149-
if hasattr(inp.type_hint, "__name__"):
1150-
type_str = inp.type_hint.__name__
1151-
elif inp.type_hint is not None:
1152-
type_str = str(inp.type_hint).replace("typing.", "")
1153-
else:
1154-
type_str = "Any"
1155-
desc = inp.description or "No description provided"
1156-
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
1157-
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
1158-
1159-
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
1160-
1161-
# format outputs as markdown list
1162-
outputs_parts = []
1163-
for out in outputs:
1164-
if hasattr(out.type_hint, "__name__"):
1165-
type_str = out.type_hint.__name__
1166-
elif out.type_hint is not None:
1167-
type_str = str(out.type_hint).replace("typing.", "")
1168-
else:
1169-
type_str = "Any"
1170-
desc = out.description or "No description provided"
1171-
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
1172-
1173-
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
1174-
1175-
trigger_inputs_section = ""
1176-
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
1177-
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
1178-
if trigger_inputs_list:
1179-
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
1180-
trigger_inputs_section = f"""
1153+
if has_workflows:
1154+
workflow_map = blocks._workflow_map
1155+
parts = []
1156+
1157+
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
1158+
# use that as the shared output for all workflows
1159+
blocks_outputs = blocks.outputs
1160+
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
1161+
shared_outputs = (
1162+
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
1163+
)
1164+
1165+
parts.append("## Workflow Input Specification\n")
1166+
1167+
# Per-workflow details: show trigger inputs with full param descriptions
1168+
for wf_name, trigger_inputs in workflow_map.items():
1169+
trigger_input_names = set(trigger_inputs.keys())
1170+
try:
1171+
workflow_blocks = blocks.get_workflow(wf_name)
1172+
except Exception:
1173+
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
1174+
parts.append("*Could not resolve workflow blocks.*\n")
1175+
parts.append("</details>\n")
1176+
continue
1177+
1178+
wf_inputs = workflow_blocks.inputs
1179+
# Show only trigger inputs with full parameter descriptions
1180+
trigger_params = [p for p in wf_inputs if p.name in trigger_input_names]
1181+
1182+
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
1183+
1184+
inputs_str = format_params_markdown(trigger_params, header=None)
1185+
parts.append(inputs_str if inputs_str else "No additional inputs required.")
1186+
parts.append("")
1187+
1188+
parts.append("</details>\n")
1189+
1190+
# Common Inputs & Outputs section (like non-workflow pipelines)
1191+
all_inputs = blocks.inputs
1192+
all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs
1193+
1194+
inputs_str = format_params_markdown(all_inputs, "Inputs")
1195+
outputs_str = format_params_markdown(all_outputs, "Outputs")
1196+
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
1197+
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
1198+
1199+
parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}")
1200+
1201+
io_specification_section = "\n".join(parts)
1202+
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
1203+
trigger_inputs_section = ""
1204+
else:
1205+
# Unified I/O section (original behavior)
1206+
inputs = blocks.inputs
1207+
outputs = blocks.outputs
1208+
inputs_str = format_params_markdown(inputs, "Inputs")
1209+
outputs_str = format_params_markdown(outputs, "Outputs")
1210+
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
1211+
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
1212+
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
1213+
1214+
trigger_inputs_section = ""
1215+
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
1216+
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
1217+
if trigger_inputs_list:
1218+
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
1219+
trigger_inputs_section = f"""
11811220
### Conditional Execution
11821221
11831222
This pipeline contains blocks that are selected at runtime based on inputs:
@@ -1190,7 +1229,18 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
11901229
if hasattr(blocks, "model_name") and blocks.model_name:
11911230
tags.append(blocks.model_name)
11921231

1193-
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
1232+
if has_workflows:
1233+
# Derive tags from workflow names
1234+
workflow_names = set(blocks._workflow_map.keys())
1235+
if any("inpainting" in wf for wf in workflow_names):
1236+
tags.append("inpainting")
1237+
if any("image2image" in wf for wf in workflow_names):
1238+
tags.append("image-to-image")
1239+
if any("controlnet" in wf for wf in workflow_names):
1240+
tags.append("controlnet")
1241+
if any("text2image" in wf for wf in workflow_names):
1242+
tags.append("text-to-image")
1243+
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
11941244
triggers = blocks.trigger_inputs
11951245
if any(t in triggers for t in ["mask", "mask_image"]):
11961246
tags.append("inpainting")
@@ -1218,8 +1268,7 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
12181268
"blocks_description": blocks_description,
12191269
"components_description": components_description,
12201270
"configs_section": configs_section,
1221-
"inputs_description": inputs_description,
1222-
"outputs_description": outputs_description,
1271+
"io_specification_section": io_specification_section,
12231272
"trigger_inputs_section": trigger_inputs_section,
12241273
"tags": tags,
12251274
}

src/diffusers/utils/hub_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def load_or_create_model_card(
107107
widget: list[dict] | None = None,
108108
inference: bool | None = None,
109109
is_modular: bool = False,
110+
update_model_card: bool = False,
110111
) -> ModelCard:
111112
"""
112113
Loads or creates a model card.
@@ -133,6 +134,9 @@ def load_or_create_model_card(
133134
`load_or_create_model_card` from a training script.
134135
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
135136
When True, uses model_description as-is without additional template formatting.
137+
update_model_card: (`bool`, optional): When True, regenerates the model card content even if one
138+
already exists on the remote repo. Existing card metadata (tags, license, etc.) is preserved. Only
139+
supported for modular pipelines (i.e., `is_modular=True`).
136140
"""
137141
if not is_jinja_available():
138142
raise ValueError(
@@ -141,9 +145,17 @@ def load_or_create_model_card(
141145
" To install it, please run `pip install Jinja2`."
142146
)
143147

148+
if update_model_card and not is_modular:
149+
raise ValueError("`update_model_card=True` is only supported for modular pipelines (`is_modular=True`).")
150+
144151
try:
145152
# Check if the model card is present on the remote repo
146153
model_card = ModelCard.load(repo_id_or_path, token=token)
154+
# For modular pipelines, regenerate card content when requested (preserve existing metadata)
155+
if update_model_card and is_modular and model_description is not None:
156+
existing_data = model_card.data
157+
model_card = ModelCard(model_description)
158+
model_card.data = existing_data
147159
except (EntryNotFoundError, RepositoryNotFoundError):
148160
# Otherwise create a model card from template
149161
if from_training:

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,7 @@ def test_basic_model_card_content_structure(self):
483483
"blocks_description",
484484
"components_description",
485485
"configs_section",
486-
"inputs_description",
487-
"outputs_description",
486+
"io_specification_section",
488487
"trigger_inputs_section",
489488
"tags",
490489
]
@@ -581,18 +580,19 @@ def test_inputs_description_required_and_optional(self):
581580
blocks = self.create_mock_blocks(inputs=inputs)
582581
content = generate_modular_model_card_content(blocks)
583582

584-
assert "**Required:**" in content["inputs_description"]
585-
assert "**Optional:**" in content["inputs_description"]
586-
assert "prompt" in content["inputs_description"]
587-
assert "num_steps" in content["inputs_description"]
588-
assert "default: `50`" in content["inputs_description"]
583+
io_section = content["io_specification_section"]
584+
assert "**Inputs:**" in io_section
585+
assert "prompt" in io_section
586+
assert "num_steps" in io_section
587+
assert "*optional*" in io_section
588+
assert "defaults to `50`" in io_section
589589

590590
def test_inputs_description_empty(self):
591591
"""Test handling of pipelines without specific inputs."""
592592
blocks = self.create_mock_blocks(inputs=[])
593593
content = generate_modular_model_card_content(blocks)
594594

595-
assert "No specific inputs defined" in content["inputs_description"]
595+
assert "No specific inputs defined" in content["io_specification_section"]
596596

597597
def test_outputs_description_formatting(self):
598598
"""Test that outputs are correctly formatted."""
@@ -602,15 +602,16 @@ def test_outputs_description_formatting(self):
602602
blocks = self.create_mock_blocks(outputs=outputs)
603603
content = generate_modular_model_card_content(blocks)
604604

605-
assert "images" in content["outputs_description"]
606-
assert "Generated images" in content["outputs_description"]
605+
io_section = content["io_specification_section"]
606+
assert "images" in io_section
607+
assert "Generated images" in io_section
607608

608609
def test_outputs_description_empty(self):
609610
"""Test handling of pipelines without specific outputs."""
610611
blocks = self.create_mock_blocks(outputs=[])
611612
content = generate_modular_model_card_content(blocks)
612613

613-
assert "Standard pipeline outputs" in content["outputs_description"]
614+
assert "Standard pipeline outputs" in content["io_specification_section"]
614615

615616
def test_trigger_inputs_section_with_triggers(self):
616617
"""Test that trigger inputs section is generated when present."""
@@ -628,35 +629,6 @@ def test_trigger_inputs_section_empty(self):
628629

629630
assert content["trigger_inputs_section"] == ""
630631

631-
def test_blocks_description_with_sub_blocks(self):
632-
"""Test that blocks with sub-blocks are correctly described."""
633-
634-
class MockBlockWithSubBlocks:
635-
def __init__(self):
636-
self.__class__.__name__ = "ParentBlock"
637-
self.description = "Parent block"
638-
self.sub_blocks = {
639-
"child1": self.create_child_block("ChildBlock1", "Child 1 description"),
640-
"child2": self.create_child_block("ChildBlock2", "Child 2 description"),
641-
}
642-
643-
def create_child_block(self, name, desc):
644-
class ChildBlock:
645-
def __init__(self):
646-
self.__class__.__name__ = name
647-
self.description = desc
648-
649-
return ChildBlock()
650-
651-
blocks = self.create_mock_blocks()
652-
blocks.sub_blocks["parent"] = MockBlockWithSubBlocks()
653-
654-
content = generate_modular_model_card_content(blocks)
655-
656-
assert "parent" in content["blocks_description"]
657-
assert "child1" in content["blocks_description"]
658-
assert "child2" in content["blocks_description"]
659-
660632
def test_model_description_includes_block_count(self):
661633
"""Test that model description includes the number of blocks."""
662634
blocks = self.create_mock_blocks(num_blocks=5)

0 commit comments

Comments
 (0)