Skip to content

Commit 009bcb5

Browse files
committed
[Modular] Test for catching dtype and device issues with AutoModel type hints (#13287)
* update * update * update
1 parent 37c4139 commit 009bcb5

File tree

1 file changed

+110
-1
lines changed

1 file changed

+110
-1
lines changed

tests/modular_pipelines/test_modular_pipelines_custom_blocks.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,41 @@
3131
WanModularPipeline,
3232
)
3333

34-
from ..testing_utils import nightly, require_torch, slow
34+
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
35+
36+
37+
def _create_tiny_model_dir(model_dir):
38+
TINY_MODEL_CODE = (
39+
"import torch\n"
40+
"from diffusers import ModelMixin, ConfigMixin\n"
41+
"from diffusers.configuration_utils import register_to_config\n"
42+
"\n"
43+
"class TinyModel(ModelMixin, ConfigMixin):\n"
44+
" @register_to_config\n"
45+
" def __init__(self, hidden_size=4):\n"
46+
" super().__init__()\n"
47+
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
48+
"\n"
49+
" def forward(self, x):\n"
50+
" return self.linear(x)\n"
51+
)
52+
53+
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
54+
f.write(TINY_MODEL_CODE)
55+
56+
config = {
57+
"_class_name": "TinyModel",
58+
"_diffusers_version": "0.0.0",
59+
"auto_map": {"AutoModel": "modeling.TinyModel"},
60+
"hidden_size": 4,
61+
}
62+
with open(os.path.join(model_dir, "config.json"), "w") as f:
63+
json.dump(config, f)
64+
65+
torch.save(
66+
{"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)},
67+
os.path.join(model_dir, "diffusion_pytorch_model.bin"),
68+
)
3569

3670

3771
class DummyCustomBlockSimple(ModularPipelineBlocks):
@@ -342,6 +376,81 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
342376
loaded_pipe.update_components(custom_model=custom_model)
343377
assert getattr(loaded_pipe, "custom_model", None) is not None
344378

379+
def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path):
380+
"""Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel."""
381+
from diffusers import AutoModel
382+
383+
model_dir = str(tmp_path / "model")
384+
os.makedirs(model_dir)
385+
_create_tiny_model_dir(model_dir)
386+
387+
class DtypeTestBlock(ModularPipelineBlocks):
388+
@property
389+
def expected_components(self):
390+
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
391+
392+
@property
393+
def inputs(self) -> List[InputParam]:
394+
return [InputParam("prompt", type_hint=str, required=True)]
395+
396+
@property
397+
def intermediate_inputs(self) -> List[InputParam]:
398+
return []
399+
400+
@property
401+
def intermediate_outputs(self) -> List[OutputParam]:
402+
return [OutputParam("output", type_hint=str)]
403+
404+
def __call__(self, components, state: PipelineState) -> PipelineState:
405+
block_state = self.get_block_state(state)
406+
block_state.output = "test"
407+
self.set_block_state(state, block_state)
408+
return components, state
409+
410+
block = DtypeTestBlock()
411+
pipe = block.init_pipeline()
412+
pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True)
413+
414+
assert pipe.model.dtype == torch.float16
415+
416+
@require_torch_accelerator
417+
def test_automodel_type_hint_preserves_device(self, tmp_path):
418+
"""Test that ComponentSpec with AutoModel type_hint correctly passes device_map."""
419+
from diffusers import AutoModel
420+
421+
model_dir = str(tmp_path / "model")
422+
os.makedirs(model_dir)
423+
_create_tiny_model_dir(model_dir)
424+
425+
class DeviceTestBlock(ModularPipelineBlocks):
426+
@property
427+
def expected_components(self):
428+
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
429+
430+
@property
431+
def inputs(self) -> List[InputParam]:
432+
return [InputParam("prompt", type_hint=str, required=True)]
433+
434+
@property
435+
def intermediate_inputs(self) -> List[InputParam]:
436+
return []
437+
438+
@property
439+
def intermediate_outputs(self) -> List[OutputParam]:
440+
return [OutputParam("output", type_hint=str)]
441+
442+
def __call__(self, components, state: PipelineState) -> PipelineState:
443+
block_state = self.get_block_state(state)
444+
block_state.output = "test"
445+
self.set_block_state(state, block_state)
446+
return components, state
447+
448+
block = DeviceTestBlock()
449+
pipe = block.init_pipeline()
450+
pipe.load_components(device_map=torch_device, trust_remote_code=True)
451+
452+
assert pipe.model.device.type == torch_device
453+
345454
def test_custom_block_loads_from_hub(self):
346455
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
347456
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)

0 commit comments

Comments
 (0)