|
15 | 15 | import inspect |
16 | 16 | import json |
17 | 17 | import os |
| 18 | +from collections import defaultdict |
18 | 19 | from functools import partial |
19 | 20 | from pathlib import Path |
20 | 21 | from typing import Literal |
|
44 | 45 |
|
45 | 46 | logger = logging.get_logger(__name__) |
46 | 47 |
|
47 | | -_SET_ADAPTER_SCALE_FN_MAPPING = { |
48 | | - "UNet2DConditionModel": _maybe_expand_lora_scales, |
49 | | - "UNetMotionModel": _maybe_expand_lora_scales, |
50 | | - "SD3Transformer2DModel": lambda model_cls, weights: weights, |
51 | | - "FluxTransformer2DModel": lambda model_cls, weights: weights, |
52 | | - "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, |
53 | | - "ConsisIDTransformer3DModel": lambda model_cls, weights: weights, |
54 | | - "HeliosTransformer3DModel": lambda model_cls, weights: weights, |
55 | | - "MochiTransformer3DModel": lambda model_cls, weights: weights, |
56 | | - "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, |
57 | | - "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, |
58 | | - "SanaTransformer2DModel": lambda model_cls, weights: weights, |
59 | | - "AuraFlowTransformer2DModel": lambda model_cls, weights: weights, |
60 | | - "Lumina2Transformer2DModel": lambda model_cls, weights: weights, |
61 | | - "WanTransformer3DModel": lambda model_cls, weights: weights, |
62 | | - "CogView4Transformer2DModel": lambda model_cls, weights: weights, |
63 | | - "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, |
64 | | - "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, |
65 | | - "WanVACETransformer3DModel": lambda model_cls, weights: weights, |
66 | | - "ChromaTransformer2DModel": lambda model_cls, weights: weights, |
67 | | - "ChronoEditTransformer3DModel": lambda model_cls, weights: weights, |
68 | | - "QwenImageTransformer2DModel": lambda model_cls, weights: weights, |
69 | | - "Flux2Transformer2DModel": lambda model_cls, weights: weights, |
70 | | - "ZImageTransformer2DModel": lambda model_cls, weights: weights, |
71 | | - "LTX2VideoTransformer3DModel": lambda model_cls, weights: weights, |
72 | | - "LTX2TextConnectors": lambda model_cls, weights: weights, |
73 | | -} |
| 48 | +_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( |
| 49 | + lambda: (lambda model_cls, weights: weights), |
| 50 | + { |
| 51 | + "UNet2DConditionModel": _maybe_expand_lora_scales, |
| 52 | + "UNetMotionModel": _maybe_expand_lora_scales, |
| 53 | + }, |
| 54 | +) |
74 | 55 |
|
75 | 56 |
|
76 | 57 | class PeftAdapterMixin: |
|
0 commit comments