Skip to content

Commit 832676d

Browse files
AlexkkirAlexkkirsayakpaul
authored
Use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING (#13320)
refactor: use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING Co-authored-by: Alexkkir <alexkkir@gmail.coom> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 7bbd96d commit 832676d

1 file changed

Lines changed: 8 additions & 27 deletions

File tree

src/diffusers/loaders/peft.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import inspect
1616
import json
1717
import os
18+
from collections import defaultdict
1819
from functools import partial
1920
from pathlib import Path
2021
from typing import Literal
@@ -44,33 +45,13 @@
4445

4546
logger = logging.get_logger(__name__)
4647

47-
_SET_ADAPTER_SCALE_FN_MAPPING = {
48-
"UNet2DConditionModel": _maybe_expand_lora_scales,
49-
"UNetMotionModel": _maybe_expand_lora_scales,
50-
"SD3Transformer2DModel": lambda model_cls, weights: weights,
51-
"FluxTransformer2DModel": lambda model_cls, weights: weights,
52-
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
53-
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
54-
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
55-
"MochiTransformer3DModel": lambda model_cls, weights: weights,
56-
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
57-
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
58-
"SanaTransformer2DModel": lambda model_cls, weights: weights,
59-
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
60-
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
61-
"WanTransformer3DModel": lambda model_cls, weights: weights,
62-
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
63-
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
64-
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
65-
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
66-
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
67-
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
68-
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
69-
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
70-
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
71-
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
72-
"LTX2TextConnectors": lambda model_cls, weights: weights,
73-
}
48+
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
49+
lambda: (lambda model_cls, weights: weights),
50+
{
51+
"UNet2DConditionModel": _maybe_expand_lora_scales,
52+
"UNetMotionModel": _maybe_expand_lora_scales,
53+
},
54+
)
7455

7556

7657
class PeftAdapterMixin:

0 commit comments

Comments
 (0)