Skip to content

Commit e9b73fe

Browse files
Move the LoRA test
1 parent e01fb1f commit e9b73fe

2 files changed

Lines changed: 39 additions & 306 deletions

File tree

tests/lora/test_lora_layers_z_image.py

Lines changed: 39 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,36 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import os
1516
import sys
1617
import unittest
1718

18-
import numpy as np
1919
import torch
2020
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
2121

22-
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
22+
from diffusers import (
23+
AutoencoderKL,
24+
FlowMatchEulerDiscreteScheduler,
25+
ZImagePipeline,
26+
ZImageTransformer2DModel,
27+
)
2328

24-
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device
29+
from ..testing_utils import floats_tensor, require_peft_backend
2530

2631

27-
if is_peft_available():
28-
from peft import LoraConfig
32+
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
33+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
34+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
35+
torch.use_deterministic_algorithms(False)
36+
torch.backends.cudnn.deterministic = True
37+
torch.backends.cudnn.benchmark = False
38+
if hasattr(torch.backends, "cuda"):
39+
torch.backends.cuda.matmul.allow_tf32 = False
2940

3041

3142
sys.path.append(".")
3243

33-
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
44+
from .utils import PeftLoraLoaderMixinTests # noqa: E402
3445

3546

3647
@require_peft_backend
@@ -75,11 +86,9 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
7586
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
7687
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
7788

78-
supports_text_encoder_loras = False
79-
8089
@property
8190
def output_shape(self):
82-
return (1, 32, 32, 3)
91+
return (1, 8, 8, 3)
8392

8493
def get_dummy_inputs(self, with_generator=True):
8594
batch_size = 1
@@ -121,139 +130,20 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No
121130
tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id)
122131

123132
transformer = self.transformer_cls(**self.transformer_kwargs)
124-
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
125-
# This can cause NaN data values in our testing environment. Fixating them
126-
# helps prevent that issue.
127-
with torch.no_grad():
128-
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
129-
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
130133
vae = self.vae_cls(**self.vae_kwargs)
131134

132135
if scheduler_cls is None:
133136
scheduler_cls = self.scheduler_cls
134137
scheduler = scheduler_cls(**self.scheduler_kwargs)
135138

136-
rank = 4
137-
lora_alpha = rank if lora_alpha is None else lora_alpha
138-
139-
text_lora_config = LoraConfig(
140-
r=rank,
141-
lora_alpha=lora_alpha,
142-
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
143-
init_lora_weights=False,
144-
use_dora=use_dora,
145-
)
146-
147-
denoiser_lora_config = LoraConfig(
148-
r=rank,
149-
lora_alpha=lora_alpha,
150-
target_modules=self.denoiser_target_modules,
151-
init_lora_weights=False,
152-
use_dora=use_dora,
153-
)
154-
155-
pipeline_components = {
139+
return {
156140
"transformer": transformer,
157141
"vae": vae,
158142
"scheduler": scheduler,
159143
"text_encoder": text_encoder,
160144
"tokenizer": tokenizer,
161145
}
162146

163-
return pipeline_components, text_lora_config, denoiser_lora_config
164-
165-
def test_correct_lora_configs_with_different_ranks(self):
166-
components, _, denoiser_lora_config = self.get_dummy_components()
167-
pipe = self.pipeline_class(**components)
168-
pipe = pipe.to(torch_device)
169-
pipe.set_progress_bar_config(disable=None)
170-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
171-
172-
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
173-
174-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
175-
176-
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
177-
178-
pipe.transformer.delete_adapters("adapter-1")
179-
180-
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
181-
for name, _ in denoiser.named_modules():
182-
if "to_k" in name and "attention" in name and "lora" not in name:
183-
module_name_to_rank_update = name.replace(".base_layer.", ".")
184-
break
185-
186-
# change the rank_pattern
187-
updated_rank = denoiser_lora_config.r * 2
188-
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
189-
190-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
191-
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
192-
193-
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
194-
195-
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
196-
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
197-
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
198-
199-
pipe.transformer.delete_adapters("adapter-1")
200-
201-
# similarly change the alpha_pattern
202-
updated_alpha = denoiser_lora_config.lora_alpha * 2
203-
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
204-
205-
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
206-
self.assertTrue(
207-
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
208-
)
209-
210-
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
211-
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
212-
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
213-
214-
@skip_mps
215-
def test_lora_fuse_nan(self):
216-
components, _, denoiser_lora_config = self.get_dummy_components()
217-
pipe = self.pipeline_class(**components)
218-
pipe = pipe.to(torch_device)
219-
pipe.set_progress_bar_config(disable=None)
220-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
221-
222-
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
223-
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
224-
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
225-
226-
# corrupt one LoRA weight with `inf` values
227-
with torch.no_grad():
228-
possible_tower_names = ["noise_refiner"]
229-
filtered_tower_names = [
230-
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
231-
]
232-
for tower_name in filtered_tower_names:
233-
transformer_tower = getattr(pipe.transformer, tower_name)
234-
transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")
235-
236-
# with `safe_fusing=True` we should see an Error
237-
with self.assertRaises(ValueError):
238-
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
239-
240-
# without we should not see an error, but every image will be black
241-
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
242-
out = pipe(**inputs)[0]
243-
244-
self.assertTrue(np.isnan(out).all())
245-
246-
def test_lora_scale_kwargs_match_fusion(self):
247-
super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2)
248-
249-
@unittest.skip("Needs to be debugged.")
250-
def test_set_adapters_match_attention_kwargs(self):
251-
super().test_set_adapters_match_attention_kwargs()
252-
253-
@unittest.skip("Needs to be debugged.")
254-
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
255-
super().test_simple_inference_with_text_denoiser_lora_and_scale()
256-
257147
@unittest.skip("Not supported in ZImage.")
258148
def test_simple_inference_with_text_denoiser_block_scale(self):
259149
pass
@@ -265,3 +155,23 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
265155
@unittest.skip("Not supported in ZImage.")
266156
def test_modify_padding_mode(self):
267157
pass
158+
159+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
160+
def test_simple_inference_with_partial_text_lora(self):
161+
pass
162+
163+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
164+
def test_simple_inference_with_text_lora(self):
165+
pass
166+
167+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
168+
def test_simple_inference_with_text_lora_and_scale(self):
169+
pass
170+
171+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
172+
def test_simple_inference_with_text_lora_fused(self):
173+
pass
174+
175+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
176+
def test_simple_inference_with_text_lora_save_load(self):
177+
pass

tests/test_lora_layers_z_image.py

Lines changed: 0 additions & 177 deletions
This file was deleted.

0 commit comments

Comments
 (0)