Skip to content

Commit 32b4cfc

Browse files
authored
[Modular] Test for catching dtype and device issues with AutoModel type hints (#13287)
* update * update * update
1 parent a13e5cf commit 32b4cfc

File tree

1 file changed

+110
-1
lines changed

1 file changed

+110
-1
lines changed

tests/modular_pipelines/test_modular_pipelines_custom_blocks.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,41 @@
3131
WanModularPipeline,
3232
)
3333

34-
from ..testing_utils import nightly, require_torch, slow
34+
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
35+
36+
37+
def _create_tiny_model_dir(model_dir):
38+
TINY_MODEL_CODE = (
39+
"import torch\n"
40+
"from diffusers import ModelMixin, ConfigMixin\n"
41+
"from diffusers.configuration_utils import register_to_config\n"
42+
"\n"
43+
"class TinyModel(ModelMixin, ConfigMixin):\n"
44+
" @register_to_config\n"
45+
" def __init__(self, hidden_size=4):\n"
46+
" super().__init__()\n"
47+
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
48+
"\n"
49+
" def forward(self, x):\n"
50+
" return self.linear(x)\n"
51+
)
52+
53+
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
54+
f.write(TINY_MODEL_CODE)
55+
56+
config = {
57+
"_class_name": "TinyModel",
58+
"_diffusers_version": "0.0.0",
59+
"auto_map": {"AutoModel": "modeling.TinyModel"},
60+
"hidden_size": 4,
61+
}
62+
with open(os.path.join(model_dir, "config.json"), "w") as f:
63+
json.dump(config, f)
64+
65+
torch.save(
66+
{"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)},
67+
os.path.join(model_dir, "diffusion_pytorch_model.bin"),
68+
)
3569

3670

3771
class DummyCustomBlockSimple(ModularPipelineBlocks):
@@ -341,6 +375,81 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
341375
loaded_pipe.update_components(custom_model=custom_model)
342376
assert getattr(loaded_pipe, "custom_model", None) is not None
343377

378+
def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path):
379+
"""Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel."""
380+
from diffusers import AutoModel
381+
382+
model_dir = str(tmp_path / "model")
383+
os.makedirs(model_dir)
384+
_create_tiny_model_dir(model_dir)
385+
386+
class DtypeTestBlock(ModularPipelineBlocks):
387+
@property
388+
def expected_components(self):
389+
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
390+
391+
@property
392+
def inputs(self) -> List[InputParam]:
393+
return [InputParam("prompt", type_hint=str, required=True)]
394+
395+
@property
396+
def intermediate_inputs(self) -> List[InputParam]:
397+
return []
398+
399+
@property
400+
def intermediate_outputs(self) -> List[OutputParam]:
401+
return [OutputParam("output", type_hint=str)]
402+
403+
def __call__(self, components, state: PipelineState) -> PipelineState:
404+
block_state = self.get_block_state(state)
405+
block_state.output = "test"
406+
self.set_block_state(state, block_state)
407+
return components, state
408+
409+
block = DtypeTestBlock()
410+
pipe = block.init_pipeline()
411+
pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True)
412+
413+
assert pipe.model.dtype == torch.float16
414+
415+
@require_torch_accelerator
416+
def test_automodel_type_hint_preserves_device(self, tmp_path):
417+
"""Test that ComponentSpec with AutoModel type_hint correctly passes device_map."""
418+
from diffusers import AutoModel
419+
420+
model_dir = str(tmp_path / "model")
421+
os.makedirs(model_dir)
422+
_create_tiny_model_dir(model_dir)
423+
424+
class DeviceTestBlock(ModularPipelineBlocks):
425+
@property
426+
def expected_components(self):
427+
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
428+
429+
@property
430+
def inputs(self) -> List[InputParam]:
431+
return [InputParam("prompt", type_hint=str, required=True)]
432+
433+
@property
434+
def intermediate_inputs(self) -> List[InputParam]:
435+
return []
436+
437+
@property
438+
def intermediate_outputs(self) -> List[OutputParam]:
439+
return [OutputParam("output", type_hint=str)]
440+
441+
def __call__(self, components, state: PipelineState) -> PipelineState:
442+
block_state = self.get_block_state(state)
443+
block_state.output = "test"
444+
self.set_block_state(state, block_state)
445+
return components, state
446+
447+
block = DeviceTestBlock()
448+
pipe = block.init_pipeline()
449+
pipe.load_components(device_map=torch_device, trust_remote_code=True)
450+
451+
assert pipe.model.device.type == torch_device
452+
344453
def test_custom_block_loads_from_hub(self):
345454
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
346455
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)

0 commit comments

Comments
 (0)