|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
| 15 | +import copy |
15 | 16 | import inspect |
16 | 17 | import os |
17 | 18 | import re |
@@ -290,6 +291,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): |
290 | 291 |
|
291 | 292 | return modules_to_save |
292 | 293 |
|
| 294 | + def _get_exclude_modules(self, pipe): |
| 295 | + from diffusers.utils.peft_utils import _derive_exclude_modules |
| 296 | + |
| 297 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 298 | + denoiser = "unet" if self.unet_kwargs is not None else "transformer" |
| 299 | + modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} |
| 300 | + denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] |
| 301 | + pipe.unload_lora_weights() |
| 302 | + denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() |
| 303 | + exclude_modules = _derive_exclude_modules( |
| 304 | + denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" |
| 305 | + ) |
| 306 | + return exclude_modules |
| 307 | + |
293 | 308 | def check_if_adapters_added_correctly( |
294 | 309 | self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default" |
295 | 310 | ): |
@@ -2308,6 +2323,50 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): |
2308 | 2323 | np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." |
2309 | 2324 | ) |
2310 | 2325 |
|
| 2326 | + @require_peft_version_greater("0.13.2") |
| 2327 | + def test_lora_exclude_modules(self): |
| 2328 | + scheduler_cls = self.scheduler_classes[0] |
| 2329 | + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
| 2330 | + pipe = self.pipeline_class(**components).to(torch_device) |
| 2331 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 2332 | + |
| 2333 | + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2334 | + self.assertTrue(output_no_lora.shape == self.output_shape) |
| 2335 | + |
| 2336 | + # only supported for `denoiser` now |
| 2337 | + pipe_cp = copy.deepcopy(pipe) |
| 2338 | + pipe_cp, _ = self.check_if_adapters_added_correctly( |
| 2339 | + pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config |
| 2340 | + ) |
| 2341 | + denoiser_exclude_modules = self._get_exclude_modules(pipe_cp) |
| 2342 | + pipe_cp.to("cpu") |
| 2343 | + del pipe_cp |
| 2344 | + |
| 2345 | + denoiser_lora_config.exclude_modules = denoiser_exclude_modules |
| 2346 | + pipe, _ = self.check_if_adapters_added_correctly( |
| 2347 | + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config |
| 2348 | + ) |
| 2349 | + output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2350 | + |
| 2351 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 2352 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 2353 | + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
| 2354 | + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) |
| 2355 | + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) |
| 2356 | + pipe.unload_lora_weights() |
| 2357 | + pipe.load_lora_weights(tmpdir) |
| 2358 | + |
| 2359 | + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2360 | + |
| 2361 | + self.assertTrue( |
| 2362 | + not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), |
| 2363 | + "LoRA should change outputs.", |
| 2364 | + ) |
| 2365 | + self.assertTrue( |
| 2366 | + np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), |
| 2367 | + "Lora outputs should match.", |
| 2368 | + ) |
| 2369 | + |
2311 | 2370 | def test_inference_load_delete_load_adapters(self): |
2312 | 2371 | "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." |
2313 | 2372 | for scheduler_cls in self.scheduler_classes: |
|
0 commit comments