|
31 | 31 | WanModularPipeline, |
32 | 32 | ) |
33 | 33 |
|
34 | | -from ..testing_utils import nightly, require_torch, slow |
| 34 | +from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device |
| 35 | + |
| 36 | + |
| 37 | +def _create_tiny_model_dir(model_dir): |
| 38 | + TINY_MODEL_CODE = ( |
| 39 | + "import torch\n" |
| 40 | + "from diffusers import ModelMixin, ConfigMixin\n" |
| 41 | + "from diffusers.configuration_utils import register_to_config\n" |
| 42 | + "\n" |
| 43 | + "class TinyModel(ModelMixin, ConfigMixin):\n" |
| 44 | + " @register_to_config\n" |
| 45 | + " def __init__(self, hidden_size=4):\n" |
| 46 | + " super().__init__()\n" |
| 47 | + " self.linear = torch.nn.Linear(hidden_size, hidden_size)\n" |
| 48 | + "\n" |
| 49 | + " def forward(self, x):\n" |
| 50 | + " return self.linear(x)\n" |
| 51 | + ) |
| 52 | + |
| 53 | + with open(os.path.join(model_dir, "modeling.py"), "w") as f: |
| 54 | + f.write(TINY_MODEL_CODE) |
| 55 | + |
| 56 | + config = { |
| 57 | + "_class_name": "TinyModel", |
| 58 | + "_diffusers_version": "0.0.0", |
| 59 | + "auto_map": {"AutoModel": "modeling.TinyModel"}, |
| 60 | + "hidden_size": 4, |
| 61 | + } |
| 62 | + with open(os.path.join(model_dir, "config.json"), "w") as f: |
| 63 | + json.dump(config, f) |
| 64 | + |
| 65 | + torch.save( |
| 66 | + {"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)}, |
| 67 | + os.path.join(model_dir, "diffusion_pytorch_model.bin"), |
| 68 | + ) |
35 | 69 |
|
36 | 70 |
|
37 | 71 | class DummyCustomBlockSimple(ModularPipelineBlocks): |
@@ -341,6 +375,81 @@ def __call__(self, components, state: PipelineState) -> PipelineState: |
341 | 375 | loaded_pipe.update_components(custom_model=custom_model) |
342 | 376 | assert getattr(loaded_pipe, "custom_model", None) is not None |
343 | 377 |
|
| 378 | + def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path): |
| 379 | + """Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel.""" |
| 380 | + from diffusers import AutoModel |
| 381 | + |
| 382 | + model_dir = str(tmp_path / "model") |
| 383 | + os.makedirs(model_dir) |
| 384 | + _create_tiny_model_dir(model_dir) |
| 385 | + |
| 386 | + class DtypeTestBlock(ModularPipelineBlocks): |
| 387 | + @property |
| 388 | + def expected_components(self): |
| 389 | + return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)] |
| 390 | + |
| 391 | + @property |
| 392 | + def inputs(self) -> List[InputParam]: |
| 393 | + return [InputParam("prompt", type_hint=str, required=True)] |
| 394 | + |
| 395 | + @property |
| 396 | + def intermediate_inputs(self) -> List[InputParam]: |
| 397 | + return [] |
| 398 | + |
| 399 | + @property |
| 400 | + def intermediate_outputs(self) -> List[OutputParam]: |
| 401 | + return [OutputParam("output", type_hint=str)] |
| 402 | + |
| 403 | + def __call__(self, components, state: PipelineState) -> PipelineState: |
| 404 | + block_state = self.get_block_state(state) |
| 405 | + block_state.output = "test" |
| 406 | + self.set_block_state(state, block_state) |
| 407 | + return components, state |
| 408 | + |
| 409 | + block = DtypeTestBlock() |
| 410 | + pipe = block.init_pipeline() |
| 411 | + pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True) |
| 412 | + |
| 413 | + assert pipe.model.dtype == torch.float16 |
| 414 | + |
| 415 | + @require_torch_accelerator |
| 416 | + def test_automodel_type_hint_preserves_device(self, tmp_path): |
| 417 | + """Test that ComponentSpec with AutoModel type_hint correctly passes device_map.""" |
| 418 | + from diffusers import AutoModel |
| 419 | + |
| 420 | + model_dir = str(tmp_path / "model") |
| 421 | + os.makedirs(model_dir) |
| 422 | + _create_tiny_model_dir(model_dir) |
| 423 | + |
| 424 | + class DeviceTestBlock(ModularPipelineBlocks): |
| 425 | + @property |
| 426 | + def expected_components(self): |
| 427 | + return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)] |
| 428 | + |
| 429 | + @property |
| 430 | + def inputs(self) -> List[InputParam]: |
| 431 | + return [InputParam("prompt", type_hint=str, required=True)] |
| 432 | + |
| 433 | + @property |
| 434 | + def intermediate_inputs(self) -> List[InputParam]: |
| 435 | + return [] |
| 436 | + |
| 437 | + @property |
| 438 | + def intermediate_outputs(self) -> List[OutputParam]: |
| 439 | + return [OutputParam("output", type_hint=str)] |
| 440 | + |
| 441 | + def __call__(self, components, state: PipelineState) -> PipelineState: |
| 442 | + block_state = self.get_block_state(state) |
| 443 | + block_state.output = "test" |
| 444 | + self.set_block_state(state, block_state) |
| 445 | + return components, state |
| 446 | + |
| 447 | + block = DeviceTestBlock() |
| 448 | + pipe = block.init_pipeline() |
| 449 | + pipe.load_components(device_map=torch_device, trust_remote_code=True) |
| 450 | + |
| 451 | + assert pipe.model.device.type == torch_device |
| 452 | + |
344 | 453 | def test_custom_block_loads_from_hub(self): |
345 | 454 | repo_id = "hf-internal-testing/tiny-modular-diffusers-block" |
346 | 455 | block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) |
|
0 commit comments