Skip to content

Commit 04d6ddb

Browse files
committed
tests and version guard.
1 parent 76356ea commit 04d6ddb

2 files changed

Lines changed: 62 additions & 0 deletions

File tree

src/diffusers/utils/peft_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None
391391
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
392392
doesn't exist in `peft_state_dict`.
393393
"""
394+
if model_state_dict is None:
395+
return
394396
all_modules = set()
395397
string_to_replace = f"{adapter_name}." if adapter_name else ""
396398

@@ -402,6 +404,7 @@ def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None
402404
all_modules.add(module_name)
403405

404406
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
407+
print(f"{target_modules_set=}")
405408
exclude_modules = list(all_modules - target_modules_set)
406409

407410
return exclude_modules

tests/lora/utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
1516
import inspect
1617
import os
1718
import re
@@ -290,6 +291,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
290291

291292
return modules_to_save
292293

294+
def _get_exclude_modules(self, pipe):
295+
from diffusers.utils.peft_utils import _derive_exclude_modules
296+
297+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
298+
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
299+
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
300+
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
301+
pipe.unload_lora_weights()
302+
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
303+
exclude_modules = _derive_exclude_modules(
304+
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
305+
)
306+
return exclude_modules
307+
293308
def check_if_adapters_added_correctly(
294309
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
295310
):
@@ -2308,6 +2323,50 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23082323
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
23092324
)
23102325

2326+
@require_peft_version_greater("0.13.2")
2327+
def test_lora_exclude_modules(self):
2328+
scheduler_cls = self.scheduler_classes[0]
2329+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
2330+
pipe = self.pipeline_class(**components).to(torch_device)
2331+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2332+
2333+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2334+
self.assertTrue(output_no_lora.shape == self.output_shape)
2335+
2336+
# only supported for `denoiser` now
2337+
pipe_cp = copy.deepcopy(pipe)
2338+
pipe_cp, _ = self.check_if_adapters_added_correctly(
2339+
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2340+
)
2341+
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
2342+
pipe_cp.to("cpu")
2343+
del pipe_cp
2344+
2345+
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
2346+
pipe, _ = self.check_if_adapters_added_correctly(
2347+
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2348+
)
2349+
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
2350+
2351+
with tempfile.TemporaryDirectory() as tmpdir:
2352+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2353+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2354+
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
2355+
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
2356+
pipe.unload_lora_weights()
2357+
pipe.load_lora_weights(tmpdir)
2358+
2359+
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
2360+
2361+
self.assertTrue(
2362+
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
2363+
"LoRA should change outputs.",
2364+
)
2365+
self.assertTrue(
2366+
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
2367+
"Lora outputs should match.",
2368+
)
2369+
23112370
def test_inference_load_delete_load_adapters(self):
23122371
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
23132372
for scheduler_cls in self.scheduler_classes:

0 commit comments

Comments
 (0)