|
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 | import torch |
| 8 | +from huggingface_hub import hf_hub_download |
8 | 9 |
|
9 | 10 | import diffusers |
10 | 11 | from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks |
|
32 | 33 | ) |
33 | 34 |
|
34 | 35 |
|
| 36 | +def _get_specified_components(path_or_repo_id, cache_dir=None): |
| 37 | + if os.path.isdir(path_or_repo_id): |
| 38 | + config_path = os.path.join(path_or_repo_id, "modular_model_index.json") |
| 39 | + else: |
| 40 | + try: |
| 41 | + config_path = hf_hub_download( |
| 42 | + repo_id=path_or_repo_id, |
| 43 | + filename="modular_model_index.json", |
| 44 | + local_dir=cache_dir, |
| 45 | + ) |
| 46 | + except Exception: |
| 47 | + return None |
| 48 | + |
| 49 | + with open(config_path) as f: |
| 50 | + config = json.load(f) |
| 51 | + |
| 52 | + components = set() |
| 53 | + for k, v in config.items(): |
| 54 | + if isinstance(v, (str, int, float, bool)): |
| 55 | + continue |
| 56 | + for entry in v: |
| 57 | + if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")): |
| 58 | + components.add(k) |
| 59 | + break |
| 60 | + return components |
| 61 | + |
| 62 | + |
35 | 63 | class ModularPipelineTesterMixin: |
36 | 64 | """ |
37 | 65 | It provides a set of common tests for each modular pipeline, |
@@ -360,6 +388,39 @@ def test_save_from_pretrained(self, tmp_path): |
360 | 388 |
|
361 | 389 | assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 |
362 | 390 |
|
| 391 | + def test_load_expected_components_from_pretrained(self, tmp_path): |
| 392 | + pipe = self.get_pipeline() |
| 393 | + expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path) |
| 394 | + if not expected: |
| 395 | + pytest.skip("Skipping test as we couldn't fetch the expected components.") |
| 396 | + |
| 397 | + actual = { |
| 398 | + name |
| 399 | + for name in pipe.components |
| 400 | + if getattr(pipe, name, None) is not None |
| 401 | + and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null") |
| 402 | + } |
| 403 | + assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}" |
| 404 | + |
| 405 | + def test_load_expected_components_from_save_pretrained(self, tmp_path): |
| 406 | + pipe = self.get_pipeline() |
| 407 | + save_dir = str(tmp_path / "saved-pipeline") |
| 408 | + pipe.save_pretrained(save_dir) |
| 409 | + |
| 410 | + expected = _get_specified_components(save_dir) |
| 411 | + loaded_pipe = ModularPipeline.from_pretrained(save_dir) |
| 412 | + loaded_pipe.load_components(torch_dtype=torch.float32) |
| 413 | + |
| 414 | + actual = { |
| 415 | + name |
| 416 | + for name in loaded_pipe.components |
| 417 | + if getattr(loaded_pipe, name, None) is not None |
| 418 | + and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null") |
| 419 | + } |
| 420 | + assert expected == actual, ( |
| 421 | + f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}" |
| 422 | + ) |
| 423 | + |
363 | 424 | def test_modular_index_consistency(self, tmp_path): |
364 | 425 | pipe = self.get_pipeline() |
365 | 426 | components_spec = pipe._component_specs |
|
0 commit comments