Skip to content

Commit f1a93c7

Browse files
dg845sayakpaul
andauthored
Add Flag to PeftLoraLoaderMixinTests to Enable/Disable Text Encoder LoRA Tests (#12962)
* Improve incorrect LoRA format error message * Add flag in PeftLoraLoaderMixinTests to disable text encoder LoRA tests * Apply changes to LTX2LoraTests * Further improve incorrect LoRA format error msg following review --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 29a930a commit f1a93c7

16 files changed

Lines changed: 67 additions & 304 deletions

src/diffusers/loaders/lora_pipeline.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def load_lora_weights(
214214

215215
is_correct_format = all("lora" in key for key in state_dict.keys())
216216
if not is_correct_format:
217-
raise ValueError("Invalid LoRA checkpoint.")
217+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
218218

219219
self.load_lora_into_unet(
220220
state_dict,
@@ -641,7 +641,7 @@ def load_lora_weights(
641641

642642
is_correct_format = all("lora" in key for key in state_dict.keys())
643643
if not is_correct_format:
644-
raise ValueError("Invalid LoRA checkpoint.")
644+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
645645

646646
self.load_lora_into_unet(
647647
state_dict,
@@ -1081,7 +1081,7 @@ def load_lora_weights(
10811081

10821082
is_correct_format = all("lora" in key for key in state_dict.keys())
10831083
if not is_correct_format:
1084-
raise ValueError("Invalid LoRA checkpoint.")
1084+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
10851085

10861086
self.load_lora_into_transformer(
10871087
state_dict,
@@ -1377,7 +1377,7 @@ def load_lora_weights(
13771377

13781378
is_correct_format = all("lora" in key for key in state_dict.keys())
13791379
if not is_correct_format:
1380-
raise ValueError("Invalid LoRA checkpoint.")
1380+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
13811381

13821382
self.load_lora_into_transformer(
13831383
state_dict,
@@ -1659,7 +1659,7 @@ def load_lora_weights(
16591659
)
16601660

16611661
if not (has_lora_keys or has_norm_keys):
1662-
raise ValueError("Invalid LoRA checkpoint.")
1662+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
16631663

16641664
transformer_lora_state_dict = {
16651665
k: state_dict.get(k)
@@ -2506,7 +2506,7 @@ def load_lora_weights(
25062506

25072507
is_correct_format = all("lora" in key for key in state_dict.keys())
25082508
if not is_correct_format:
2509-
raise ValueError("Invalid LoRA checkpoint.")
2509+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
25102510

25112511
self.load_lora_into_transformer(
25122512
state_dict,
@@ -2703,7 +2703,7 @@ def load_lora_weights(
27032703

27042704
is_correct_format = all("lora" in key for key in state_dict.keys())
27052705
if not is_correct_format:
2706-
raise ValueError("Invalid LoRA checkpoint.")
2706+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
27072707

27082708
self.load_lora_into_transformer(
27092709
state_dict,
@@ -2906,7 +2906,7 @@ def load_lora_weights(
29062906

29072907
is_correct_format = all("lora" in key for key in state_dict.keys())
29082908
if not is_correct_format:
2909-
raise ValueError("Invalid LoRA checkpoint.")
2909+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
29102910

29112911
self.load_lora_into_transformer(
29122912
state_dict,
@@ -3115,7 +3115,7 @@ def load_lora_weights(
31153115

31163116
is_correct_format = all("lora" in key for key in state_dict.keys())
31173117
if not is_correct_format:
3118-
raise ValueError("Invalid LoRA checkpoint.")
3118+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
31193119

31203120
transformer_peft_state_dict = {
31213121
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
@@ -3333,7 +3333,7 @@ def load_lora_weights(
33333333

33343334
is_correct_format = all("lora" in key for key in state_dict.keys())
33353335
if not is_correct_format:
3336-
raise ValueError("Invalid LoRA checkpoint.")
3336+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
33373337

33383338
self.load_lora_into_transformer(
33393339
state_dict,
@@ -3536,7 +3536,7 @@ def load_lora_weights(
35363536

35373537
is_correct_format = all("lora" in key for key in state_dict.keys())
35383538
if not is_correct_format:
3539-
raise ValueError("Invalid LoRA checkpoint.")
3539+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
35403540

35413541
self.load_lora_into_transformer(
35423542
state_dict,
@@ -3740,7 +3740,7 @@ def load_lora_weights(
37403740

37413741
is_correct_format = all("lora" in key for key in state_dict.keys())
37423742
if not is_correct_format:
3743-
raise ValueError("Invalid LoRA checkpoint.")
3743+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
37443744

37453745
self.load_lora_into_transformer(
37463746
state_dict,
@@ -3940,7 +3940,7 @@ def load_lora_weights(
39403940

39413941
is_correct_format = all("lora" in key for key in state_dict.keys())
39423942
if not is_correct_format:
3943-
raise ValueError("Invalid LoRA checkpoint.")
3943+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
39443944

39453945
self.load_lora_into_transformer(
39463946
state_dict,
@@ -4194,7 +4194,7 @@ def load_lora_weights(
41944194
)
41954195
is_correct_format = all("lora" in key for key in state_dict.keys())
41964196
if not is_correct_format:
4197-
raise ValueError("Invalid LoRA checkpoint.")
4197+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
41984198

41994199
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
42004200
if load_into_transformer_2:
@@ -4471,7 +4471,7 @@ def load_lora_weights(
44714471
)
44724472
is_correct_format = all("lora" in key for key in state_dict.keys())
44734473
if not is_correct_format:
4474-
raise ValueError("Invalid LoRA checkpoint.")
4474+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
44754475

44764476
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
44774477
if load_into_transformer_2:
@@ -4691,7 +4691,7 @@ def load_lora_weights(
46914691

46924692
is_correct_format = all("lora" in key for key in state_dict.keys())
46934693
if not is_correct_format:
4694-
raise ValueError("Invalid LoRA checkpoint.")
4694+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
46954695

46964696
self.load_lora_into_transformer(
46974697
state_dict,
@@ -4894,7 +4894,7 @@ def load_lora_weights(
48944894

48954895
is_correct_format = all("lora" in key for key in state_dict.keys())
48964896
if not is_correct_format:
4897-
raise ValueError("Invalid LoRA checkpoint.")
4897+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
48984898

48994899
self.load_lora_into_transformer(
49004900
state_dict,
@@ -5100,7 +5100,7 @@ def load_lora_weights(
51005100

51015101
is_correct_format = all("lora" in key for key in state_dict.keys())
51025102
if not is_correct_format:
5103-
raise ValueError("Invalid LoRA checkpoint.")
5103+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
51045104

51055105
self.load_lora_into_transformer(
51065106
state_dict,
@@ -5306,7 +5306,7 @@ def load_lora_weights(
53065306

53075307
is_correct_format = all("lora" in key for key in state_dict.keys())
53085308
if not is_correct_format:
5309-
raise ValueError("Invalid LoRA checkpoint.")
5309+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
53105310

53115311
self.load_lora_into_transformer(
53125312
state_dict,
@@ -5509,7 +5509,7 @@ def load_lora_weights(
55095509

55105510
is_correct_format = all("lora" in key for key in state_dict.keys())
55115511
if not is_correct_format:
5512-
raise ValueError("Invalid LoRA checkpoint.")
5512+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
55135513

55145514
self.load_lora_into_transformer(
55155515
state_dict,

tests/lora/test_lora_layers_auraflow.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
7676
text_encoder_target_modules = ["q", "k", "v", "o"]
7777
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
7878

79+
supports_text_encoder_loras = False
80+
7981
@property
8082
def output_shape(self):
8183
return (1, 8, 8, 3)
@@ -114,23 +116,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
114116
@unittest.skip("Not supported in AuraFlow.")
115117
def test_modify_padding_mode(self):
116118
pass
117-
118-
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
119-
def test_simple_inference_with_partial_text_lora(self):
120-
pass
121-
122-
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
123-
def test_simple_inference_with_text_lora(self):
124-
pass
125-
126-
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
127-
def test_simple_inference_with_text_lora_and_scale(self):
128-
pass
129-
130-
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
131-
def test_simple_inference_with_text_lora_fused(self):
132-
pass
133-
134-
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
135-
def test_simple_inference_with_text_lora_save_load(self):
136-
pass

tests/lora/test_lora_layers_cogvideox.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
8787

8888
text_encoder_target_modules = ["q", "k", "v", "o"]
8989

90+
supports_text_encoder_loras = False
91+
9092
@property
9193
def output_shape(self):
9294
return (1, 9, 16, 16, 3)
@@ -147,26 +149,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
147149
def test_modify_padding_mode(self):
148150
pass
149151

150-
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
151-
def test_simple_inference_with_partial_text_lora(self):
152-
pass
153-
154-
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
155-
def test_simple_inference_with_text_lora(self):
156-
pass
157-
158-
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
159-
def test_simple_inference_with_text_lora_and_scale(self):
160-
pass
161-
162-
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
163-
def test_simple_inference_with_text_lora_fused(self):
164-
pass
165-
166-
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
167-
def test_simple_inference_with_text_lora_save_load(self):
168-
pass
169-
170152
@unittest.skip("Not supported in CogVideoX.")
171153
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
172154
pass

tests/lora/test_lora_layers_cogview4.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
8585
"text_encoder",
8686
)
8787

88+
supports_text_encoder_loras = False
89+
8890
@property
8991
def output_shape(self):
9092
return (1, 32, 32, 3)
@@ -162,23 +164,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
162164
@unittest.skip("Not supported in CogView4.")
163165
def test_modify_padding_mode(self):
164166
pass
165-
166-
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
167-
def test_simple_inference_with_partial_text_lora(self):
168-
pass
169-
170-
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
171-
def test_simple_inference_with_text_lora(self):
172-
pass
173-
174-
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
175-
def test_simple_inference_with_text_lora_and_scale(self):
176-
pass
177-
178-
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
179-
def test_simple_inference_with_text_lora_fused(self):
180-
pass
181-
182-
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
183-
def test_simple_inference_with_text_lora_save_load(self):
184-
pass

tests/lora/test_lora_layers_flux2.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
6666
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
6767
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
6868

69+
supports_text_encoder_loras = False
70+
6971
@property
7072
def output_shape(self):
7173
return (1, 8, 8, 3)
@@ -146,23 +148,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
146148
@unittest.skip("Not supported in Flux2.")
147149
def test_modify_padding_mode(self):
148150
pass
149-
150-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
151-
def test_simple_inference_with_partial_text_lora(self):
152-
pass
153-
154-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
155-
def test_simple_inference_with_text_lora(self):
156-
pass
157-
158-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
159-
def test_simple_inference_with_text_lora_and_scale(self):
160-
pass
161-
162-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
163-
def test_simple_inference_with_text_lora_fused(self):
164-
pass
165-
166-
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
167-
def test_simple_inference_with_text_lora_save_load(self):
168-
pass

tests/lora/test_lora_layers_hunyuanvideo.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
117117
"text_encoder_2",
118118
)
119119

120+
supports_text_encoder_loras = False
121+
120122
@property
121123
def output_shape(self):
122124
return (1, 9, 32, 32, 3)
@@ -172,26 +174,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
172174
def test_modify_padding_mode(self):
173175
pass
174176

175-
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
176-
def test_simple_inference_with_partial_text_lora(self):
177-
pass
178-
179-
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
180-
def test_simple_inference_with_text_lora(self):
181-
pass
182-
183-
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
184-
def test_simple_inference_with_text_lora_and_scale(self):
185-
pass
186-
187-
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
188-
def test_simple_inference_with_text_lora_fused(self):
189-
pass
190-
191-
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
192-
def test_simple_inference_with_text_lora_save_load(self):
193-
pass
194-
195177

196178
@nightly
197179
@require_torch_accelerator

tests/lora/test_lora_layers_ltx2.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
150150

151151
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
152152

153+
supports_text_encoder_loras = False
154+
153155
@property
154156
def output_shape(self):
155157
return (1, 5, 32, 32, 3)
@@ -267,27 +269,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
267269
@unittest.skip("Not supported in LTX2.")
268270
def test_modify_padding_mode(self):
269271
pass
270-
271-
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
272-
def test_simple_inference_with_partial_text_lora(self):
273-
pass
274-
275-
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
276-
def test_simple_inference_with_text_lora(self):
277-
pass
278-
279-
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
280-
def test_simple_inference_with_text_lora_and_scale(self):
281-
pass
282-
283-
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
284-
def test_simple_inference_with_text_lora_fused(self):
285-
pass
286-
287-
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
288-
def test_simple_inference_with_text_lora_save_load(self):
289-
pass
290-
291-
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
292-
def test_simple_inference_save_pretrained_with_text_lora(self):
293-
pass

0 commit comments

Comments
 (0)