Skip to content

Commit c02c17c

Browse files
authored
[tests] test load_components in modular (#13245)
* test load_components. * fix * fix * u[ * up
1 parent a9855c4 commit c02c17c

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77
import torch
8+
from huggingface_hub import hf_hub_download
89

910
import diffusers
1011
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
@@ -32,6 +33,33 @@
3233
)
3334

3435

36+
def _get_specified_components(path_or_repo_id, cache_dir=None):
37+
if os.path.isdir(path_or_repo_id):
38+
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
39+
else:
40+
try:
41+
config_path = hf_hub_download(
42+
repo_id=path_or_repo_id,
43+
filename="modular_model_index.json",
44+
local_dir=cache_dir,
45+
)
46+
except Exception:
47+
return None
48+
49+
with open(config_path) as f:
50+
config = json.load(f)
51+
52+
components = set()
53+
for k, v in config.items():
54+
if isinstance(v, (str, int, float, bool)):
55+
continue
56+
for entry in v:
57+
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
58+
components.add(k)
59+
break
60+
return components
61+
62+
3563
class ModularPipelineTesterMixin:
3664
"""
3765
It provides a set of common tests for each modular pipeline,
@@ -360,6 +388,39 @@ def test_save_from_pretrained(self, tmp_path):
360388

361389
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
362390

391+
def test_load_expected_components_from_pretrained(self, tmp_path):
392+
pipe = self.get_pipeline()
393+
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
394+
if not expected:
395+
pytest.skip("Skipping test as we couldn't fetch the expected components.")
396+
397+
actual = {
398+
name
399+
for name in pipe.components
400+
if getattr(pipe, name, None) is not None
401+
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
402+
}
403+
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
404+
405+
def test_load_expected_components_from_save_pretrained(self, tmp_path):
406+
pipe = self.get_pipeline()
407+
save_dir = str(tmp_path / "saved-pipeline")
408+
pipe.save_pretrained(save_dir)
409+
410+
expected = _get_specified_components(save_dir)
411+
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
412+
loaded_pipe.load_components(torch_dtype=torch.float32)
413+
414+
actual = {
415+
name
416+
for name in loaded_pipe.components
417+
if getattr(loaded_pipe, name, None) is not None
418+
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
419+
}
420+
assert expected == actual, (
421+
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
422+
)
423+
363424
def test_modular_index_consistency(self, tmp_path):
364425
pipe = self.get_pipeline()
365426
components_spec = pipe._component_specs

0 commit comments

Comments
 (0)