|
39 | 39 | is_torch_version, |
40 | 40 | require_peft_backend, |
41 | 41 | require_peft_version_greater, |
| 42 | + require_torch_accelerator, |
42 | 43 | require_transformers_version_greater, |
43 | 44 | skip_mps, |
44 | 45 | torch_device, |
@@ -2372,3 +2373,73 @@ def test_inference_load_delete_load_adapters(self): |
2372 | 2373 | pipe.load_lora_weights(tmpdirname) |
2373 | 2374 | output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] |
2374 | 2375 | self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) |
| 2376 | + |
| 2377 | + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): |
| 2378 | + from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook |
| 2379 | + |
| 2380 | + onload_device = torch_device |
| 2381 | + offload_device = torch.device("cpu") |
| 2382 | + |
| 2383 | + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) |
| 2384 | + pipe = self.pipeline_class(**components) |
| 2385 | + pipe = pipe.to(torch_device) |
| 2386 | + pipe.set_progress_bar_config(disable=None) |
| 2387 | + |
| 2388 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 2389 | + denoiser.add_adapter(denoiser_lora_config) |
| 2390 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 2391 | + |
| 2392 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 2393 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 2394 | + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
| 2395 | + self.pipeline_class.save_lora_weights( |
| 2396 | + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts |
| 2397 | + ) |
| 2398 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) |
| 2399 | + |
| 2400 | + components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) |
| 2401 | + pipe = self.pipeline_class(**components) |
| 2402 | + pipe = pipe.to(torch_device) |
| 2403 | + pipe.set_progress_bar_config(disable=None) |
| 2404 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 2405 | + |
| 2406 | + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) |
| 2407 | + check_if_lora_correctly_set(denoiser) |
| 2408 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 2409 | + |
| 2410 | + # Test group offloading with load_lora_weights |
| 2411 | + denoiser.enable_group_offload( |
| 2412 | + onload_device=onload_device, |
| 2413 | + offload_device=offload_device, |
| 2414 | + offload_type=offload_type, |
| 2415 | + num_blocks_per_group=1, |
| 2416 | + use_stream=use_stream, |
| 2417 | + ) |
| 2418 | + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) |
| 2419 | + self.assertTrue(group_offload_hook_1 is not None) |
| 2420 | + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2421 | + |
| 2422 | + # Test group offloading after removing the lora |
| 2423 | + pipe.unload_lora_weights() |
| 2424 | + group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) |
| 2425 | + self.assertTrue(group_offload_hook_2 is not None) |
| 2426 | + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 |
| 2427 | + |
| 2428 | + # Add the lora again and check if group offloading works |
| 2429 | + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) |
| 2430 | + check_if_lora_correctly_set(denoiser) |
| 2431 | + group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) |
| 2432 | + self.assertTrue(group_offload_hook_3 is not None) |
| 2433 | + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2434 | + |
| 2435 | + self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3)) |
| 2436 | + |
| 2437 | + @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) |
| 2438 | + @require_torch_accelerator |
| 2439 | + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): |
| 2440 | + for cls in inspect.getmro(self.__class__): |
| 2441 | + if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: |
| 2442 | + # Skip this test if it is overwritten by child class. We need to do this because parameterized |
| 2443 | + # materializes the test methods on invocation which cannot be overridden. |
| 2444 | + return |
| 2445 | + self._test_group_offloading_inference_denoiser(offload_type, use_stream) |
0 commit comments