Skip to content

Commit 4a2833c

Browse files
sayakpaulDN6stevhliu
authored
[Modular] implement requirements validation for custom blocks (#12196)
* feat: implement requirements validation for custom blocks. * up * unify. * up * add tests * Apply suggestions from code review Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * reviewer feedback. * [docs] validation for custom blocks (#13156) validation * move to tmp_path fixture. * propagate to conditional and loopsequential blocks. * up * remove collected tests --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent 1fe688a commit 4a2833c

File tree

5 files changed

+291
-5
lines changed

5 files changed

+291
-5
lines changed

docs/source/en/modular_diffusers/custom_blocks.md

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,4 +332,49 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
332332
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
333333

334334
</hfoption>
335-
</hfoptions>
335+
</hfoptions>
336+
337+
## Dependencies
338+
339+
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
340+
341+
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
342+
343+
```py
344+
from diffusers.modular_pipelines import PipelineBlock
345+
346+
class MyCustomBlock(PipelineBlock):
347+
_requirements = {
348+
"transformers": ">=4.44.0",
349+
"sentencepiece": ">=0.2.0"
350+
}
351+
```
352+
353+
When there are blocks with different requirements, Diffusers merges their requirements.
354+
355+
```py
356+
from diffusers.modular_pipelines import SequentialPipelineBlocks
357+
358+
class BlockA(PipelineBlock):
359+
_requirements = {"transformers": ">=4.44.0"}
360+
# ...
361+
362+
class BlockB(PipelineBlock):
363+
_requirements = {"sentencepiece": ">=0.2.0"}
364+
# ...
365+
366+
pipe = SequentialPipelineBlocks.from_blocks_dict({
367+
"block_a": BlockA,
368+
"block_b": BlockB,
369+
})
370+
```
371+
372+
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
373+
374+
```md
375+
# missing package
376+
xyz-package was specified in the requirements but wasn't found in the current environment.
377+
378+
# version mismatch
379+
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
380+
```

src/diffusers/commands/custom_blocks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ def run(self):
8989
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
9090
# with open(CONFIG, "w") as f:
9191
# json.dump(automap, f)
92-
with open("requirements.txt", "w") as f:
93-
f.write("")
9492

9593
def _choose_block(self, candidates, chosen=None):
9694
for cls, base in candidates:

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
InputParam,
4848
InsertableDict,
4949
OutputParam,
50+
_validate_requirements,
5051
combine_inputs,
5152
combine_outputs,
5253
format_components,
@@ -297,6 +298,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
297298

298299
config_name = "modular_config.json"
299300
model_name = None
301+
_requirements: dict[str, str] | None = None
300302
_workflow_map = None
301303

302304
@classmethod
@@ -411,6 +413,9 @@ def from_pretrained(
411413
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
412414
)
413415

416+
if "requirements" in config and config["requirements"] is not None:
417+
_ = _validate_requirements(config["requirements"])
418+
414419
class_ref = config["auto_map"][cls.__name__]
415420
module_file, class_name = class_ref.split(".")
416421
module_file = module_file + ".py"
@@ -435,8 +440,13 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
435440
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
436441
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
437442
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
438-
439443
self.register_to_config(auto_map=auto_map)
444+
445+
# resolve requirements
446+
requirements = _validate_requirements(getattr(self, "_requirements", None))
447+
if requirements:
448+
self.register_to_config(requirements=requirements)
449+
440450
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
441451
config = dict(self.config)
442452
self._internal_dict = FrozenDict(config)
@@ -658,6 +668,15 @@ def outputs(self) -> list[str]:
658668
combined_outputs = combine_outputs(*named_outputs)
659669
return combined_outputs
660670

671+
@property
672+
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
673+
def _requirements(self) -> dict[str, str]:
674+
requirements = {}
675+
for block_name, block in self.sub_blocks.items():
676+
if getattr(block, "_requirements", None):
677+
requirements[block_name] = block._requirements
678+
return requirements
679+
661680
# used for `__repr__`
662681
def _get_trigger_inputs(self) -> set:
663682
"""
@@ -1247,6 +1266,14 @@ def doc(self):
12471266
expected_configs=self.expected_configs,
12481267
)
12491268

1269+
@property
1270+
def _requirements(self) -> dict[str, str]:
1271+
requirements = {}
1272+
for block_name, block in self.sub_blocks.items():
1273+
if getattr(block, "_requirements", None):
1274+
requirements[block_name] = block._requirements
1275+
return requirements
1276+
12501277

12511278
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
12521279
"""
@@ -1385,6 +1412,15 @@ def intermediate_outputs(self) -> list[str]:
13851412
def outputs(self) -> list[str]:
13861413
return next(reversed(self.sub_blocks.values())).intermediate_outputs
13871414

1415+
@property
1416+
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
1417+
def _requirements(self) -> dict[str, str]:
1418+
requirements = {}
1419+
for block_name, block in self.sub_blocks.items():
1420+
if getattr(block, "_requirements", None):
1421+
requirements[block_name] = block._requirements
1422+
return requirements
1423+
13881424
def __init__(self):
13891425
sub_blocks = InsertableDict()
13901426
for block_name, block in zip(self.block_names, self.block_classes):

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222

2323
import PIL.Image
2424
import torch
25+
from packaging.specifiers import InvalidSpecifier, SpecifierSet
2526

2627
from ..configuration_utils import ConfigMixin, FrozenDict
2728
from ..loaders.single_file_utils import _is_single_file_path_or_url
2829
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
30+
from ..utils.import_utils import _is_package_available
2931

3032

3133
if is_torch_available():
@@ -1020,6 +1022,89 @@ def make_doc_string(
10201022
return output
10211023

10221024

1025+
def _validate_requirements(reqs):
1026+
if reqs is None:
1027+
normalized_reqs = {}
1028+
else:
1029+
if not isinstance(reqs, dict):
1030+
raise ValueError(
1031+
"Requirements must be provided as a dictionary mapping package names to version specifiers."
1032+
)
1033+
normalized_reqs = _normalize_requirements(reqs)
1034+
1035+
if not normalized_reqs:
1036+
return {}
1037+
1038+
final: dict[str, str] = {}
1039+
for req, specified_ver in normalized_reqs.items():
1040+
req_available, req_actual_ver = _is_package_available(req)
1041+
if not req_available:
1042+
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
1043+
1044+
if specified_ver:
1045+
try:
1046+
specifier = SpecifierSet(specified_ver)
1047+
except InvalidSpecifier as err:
1048+
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
1049+
1050+
if req_actual_ver == "N/A":
1051+
logger.warning(
1052+
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
1053+
)
1054+
elif not specifier.contains(req_actual_ver, prereleases=True):
1055+
logger.warning(
1056+
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
1057+
)
1058+
1059+
final[req] = specified_ver
1060+
1061+
return final
1062+
1063+
1064+
def _normalize_requirements(reqs):
1065+
if not reqs:
1066+
return {}
1067+
1068+
normalized: "OrderedDict[str, str]" = OrderedDict()
1069+
1070+
def _accumulate(mapping: dict[str, Any]):
1071+
for pkg, spec in mapping.items():
1072+
if isinstance(spec, dict):
1073+
# This is recursive because blocks are composable. This way, we can merge requirements
1074+
# from multiple blocks.
1075+
_accumulate(spec)
1076+
continue
1077+
1078+
pkg_name = str(pkg).strip()
1079+
if not pkg_name:
1080+
raise ValueError("Requirement package name cannot be empty.")
1081+
1082+
spec_str = "" if spec is None else str(spec).strip()
1083+
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
1084+
spec_str = f"=={spec_str}"
1085+
1086+
existing_spec = normalized.get(pkg_name)
1087+
if existing_spec is not None:
1088+
if not existing_spec and spec_str:
1089+
normalized[pkg_name] = spec_str
1090+
elif existing_spec and spec_str and existing_spec != spec_str:
1091+
try:
1092+
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
1093+
except InvalidSpecifier:
1094+
logger.warning(
1095+
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
1096+
)
1097+
else:
1098+
normalized[pkg_name] = str(combined_spec)
1099+
continue
1100+
1101+
normalized[pkg_name] = spec_str
1102+
1103+
_accumulate(reqs)
1104+
1105+
return normalized
1106+
1107+
10231108
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
10241109
"""
10251110
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import diffusers
1111
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
1212
from diffusers.guiders import ClassifierFreeGuidance
13+
from diffusers.modular_pipelines import (
14+
ConditionalPipelineBlocks,
15+
LoopSequentialPipelineBlocks,
16+
SequentialPipelineBlocks,
17+
)
1318
from diffusers.modular_pipelines.modular_pipeline_utils import (
1419
ComponentSpec,
1520
ConfigSpec,
@@ -19,7 +24,13 @@
1924
)
2025
from diffusers.utils import logging
2126

22-
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
27+
from ..testing_utils import (
28+
CaptureLogger,
29+
backend_empty_cache,
30+
numpy_cosine_similarity_distance,
31+
require_accelerator,
32+
torch_device,
33+
)
2334

2435

2536
class ModularPipelineTesterMixin:
@@ -429,6 +440,117 @@ def test_guider_cfg(self, expected_max_diff=1e-2):
429440
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
430441

431442

443+
class TestCustomBlockRequirements:
444+
def get_dummy_block_pipe(self):
445+
class DummyBlockOne:
446+
# keep two arbitrary deps so that we can test warnings.
447+
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
448+
449+
class DummyBlockTwo:
450+
# keep two dependencies that will be available during testing.
451+
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
452+
453+
pipe = SequentialPipelineBlocks.from_blocks_dict(
454+
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
455+
)
456+
return pipe
457+
458+
def get_dummy_conditional_block_pipe(self):
459+
class DummyBlockOne:
460+
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
461+
462+
class DummyBlockTwo:
463+
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
464+
465+
class DummyConditionalBlocks(ConditionalPipelineBlocks):
466+
block_classes = [DummyBlockOne, DummyBlockTwo]
467+
block_names = ["block_one", "block_two"]
468+
block_trigger_inputs = []
469+
470+
def select_block(self, **kwargs):
471+
return "block_one"
472+
473+
return DummyConditionalBlocks()
474+
475+
def get_dummy_loop_block_pipe(self):
476+
class DummyBlockOne:
477+
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
478+
479+
class DummyBlockTwo:
480+
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
481+
482+
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
483+
484+
def test_sequential_block_requirements_save_load(self, tmp_path):
485+
pipe = self.get_dummy_block_pipe()
486+
pipe.save_pretrained(tmp_path)
487+
488+
config_path = tmp_path / "modular_config.json"
489+
490+
with open(config_path, "r") as f:
491+
config = json.load(f)
492+
493+
assert "requirements" in config
494+
requirements = config["requirements"]
495+
496+
expected_requirements = {
497+
"xyz": ">=0.8.0",
498+
"abc": ">=10.0.0",
499+
"transformers": ">=4.44.0",
500+
"diffusers": ">=0.2.0",
501+
}
502+
assert expected_requirements == requirements
503+
504+
def test_sequential_block_requirements_warnings(self, tmp_path):
505+
pipe = self.get_dummy_block_pipe()
506+
507+
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
508+
logger.setLevel(30)
509+
510+
with CaptureLogger(logger) as cap_logger:
511+
pipe.save_pretrained(tmp_path)
512+
513+
template = "{req} was specified in the requirements but wasn't found in the current environment"
514+
msg_xyz = template.format(req="xyz")
515+
msg_abc = template.format(req="abc")
516+
assert msg_xyz in str(cap_logger.out)
517+
assert msg_abc in str(cap_logger.out)
518+
519+
def test_conditional_block_requirements_save_load(self, tmp_path):
520+
pipe = self.get_dummy_conditional_block_pipe()
521+
pipe.save_pretrained(tmp_path)
522+
523+
config_path = tmp_path / "modular_config.json"
524+
with open(config_path, "r") as f:
525+
config = json.load(f)
526+
527+
assert "requirements" in config
528+
expected_requirements = {
529+
"xyz": ">=0.8.0",
530+
"abc": ">=10.0.0",
531+
"transformers": ">=4.44.0",
532+
"diffusers": ">=0.2.0",
533+
}
534+
assert expected_requirements == config["requirements"]
535+
536+
def test_loop_block_requirements_save_load(self, tmp_path):
537+
pipe = self.get_dummy_loop_block_pipe()
538+
pipe.save_pretrained(tmp_path)
539+
540+
config_path = tmp_path / "modular_config.json"
541+
with open(config_path, "r") as f:
542+
config = json.load(f)
543+
544+
assert "requirements" in config
545+
expected_requirements = {
546+
"xyz": ">=0.8.0",
547+
"abc": ">=10.0.0",
548+
"transformers": ">=4.44.0",
549+
"diffusers": ">=0.2.0",
550+
}
551+
assert expected_requirements == config["requirements"]
552+
553+
432554
class TestModularModelCardContent:
433555
def create_mock_block(self, name="TestBlock", description="Test block description"):
434556
class MockBlock:

0 commit comments

Comments
 (0)