1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+ import os
1516import sys
1617import unittest
1718
18- import numpy as np
1919import torch
2020from transformers import Qwen2Tokenizer , Qwen3Config , Qwen3Model
2121
22- from diffusers import AutoencoderKL , FlowMatchEulerDiscreteScheduler , ZImagePipeline , ZImageTransformer2DModel
22+ from diffusers import (
23+ AutoencoderKL ,
24+ FlowMatchEulerDiscreteScheduler ,
25+ ZImagePipeline ,
26+ ZImageTransformer2DModel ,
27+ )
2328
24- from ..testing_utils import floats_tensor , is_peft_available , require_peft_backend , skip_mps , torch_device
29+ from ..testing_utils import floats_tensor , require_peft_backend
2530
2631
27- if is_peft_available ():
28- from peft import LoraConfig
32+ # Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
33+ os .environ ["CUDA_LAUNCH_BLOCKING" ] = "1"
34+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
35+ torch .use_deterministic_algorithms (False )
36+ torch .backends .cudnn .deterministic = True
37+ torch .backends .cudnn .benchmark = False
38+ if hasattr (torch .backends , "cuda" ):
39+ torch .backends .cuda .matmul .allow_tf32 = False
2940
3041
3142sys .path .append ("." )
3243
33- from .utils import PeftLoraLoaderMixinTests , check_if_lora_correctly_set # noqa: E402
44+ from .utils import PeftLoraLoaderMixinTests # noqa: E402
3445
3546
3647@require_peft_backend
@@ -75,11 +86,9 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
7586 text_encoder_cls , text_encoder_id = Qwen3Model , None # Will be created inline
7687 denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
7788
78- supports_text_encoder_loras = False
79-
8089 @property
8190 def output_shape (self ):
82- return (1 , 32 , 32 , 3 )
91+ return (1 , 8 , 8 , 3 )
8392
8493 def get_dummy_inputs (self , with_generator = True ):
8594 batch_size = 1
@@ -121,139 +130,20 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No
121130 tokenizer = Qwen2Tokenizer .from_pretrained (self .tokenizer_id )
122131
123132 transformer = self .transformer_cls (** self .transformer_kwargs )
124- # `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
125- # This can cause NaN data values in our testing environment. Fixating them
126- # helps prevent that issue.
127- with torch .no_grad ():
128- transformer .x_pad_token .copy_ (torch .ones_like (transformer .x_pad_token .data ))
129- transformer .cap_pad_token .copy_ (torch .ones_like (transformer .cap_pad_token .data ))
130133 vae = self .vae_cls (** self .vae_kwargs )
131134
132135 if scheduler_cls is None :
133136 scheduler_cls = self .scheduler_cls
134137 scheduler = scheduler_cls (** self .scheduler_kwargs )
135138
136- rank = 4
137- lora_alpha = rank if lora_alpha is None else lora_alpha
138-
139- text_lora_config = LoraConfig (
140- r = rank ,
141- lora_alpha = lora_alpha ,
142- target_modules = ["q_proj" , "k_proj" , "v_proj" , "o_proj" ],
143- init_lora_weights = False ,
144- use_dora = use_dora ,
145- )
146-
147- denoiser_lora_config = LoraConfig (
148- r = rank ,
149- lora_alpha = lora_alpha ,
150- target_modules = self .denoiser_target_modules ,
151- init_lora_weights = False ,
152- use_dora = use_dora ,
153- )
154-
155- pipeline_components = {
139+ return {
156140 "transformer" : transformer ,
157141 "vae" : vae ,
158142 "scheduler" : scheduler ,
159143 "text_encoder" : text_encoder ,
160144 "tokenizer" : tokenizer ,
161145 }
162146
163- return pipeline_components , text_lora_config , denoiser_lora_config
164-
165- def test_correct_lora_configs_with_different_ranks (self ):
166- components , _ , denoiser_lora_config = self .get_dummy_components ()
167- pipe = self .pipeline_class (** components )
168- pipe = pipe .to (torch_device )
169- pipe .set_progress_bar_config (disable = None )
170- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
171-
172- original_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
173-
174- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
175-
176- lora_output_same_rank = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
177-
178- pipe .transformer .delete_adapters ("adapter-1" )
179-
180- denoiser = pipe .unet if self .unet_kwargs is not None else pipe .transformer
181- for name , _ in denoiser .named_modules ():
182- if "to_k" in name and "attention" in name and "lora" not in name :
183- module_name_to_rank_update = name .replace (".base_layer." , "." )
184- break
185-
186- # change the rank_pattern
187- updated_rank = denoiser_lora_config .r * 2
188- denoiser_lora_config .rank_pattern = {module_name_to_rank_update : updated_rank }
189-
190- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
191- updated_rank_pattern = pipe .transformer .peft_config ["adapter-1" ].rank_pattern
192-
193- self .assertTrue (updated_rank_pattern == {module_name_to_rank_update : updated_rank })
194-
195- lora_output_diff_rank = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
196- self .assertTrue (not np .allclose (original_output , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
197- self .assertTrue (not np .allclose (lora_output_diff_rank , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
198-
199- pipe .transformer .delete_adapters ("adapter-1" )
200-
201- # similarly change the alpha_pattern
202- updated_alpha = denoiser_lora_config .lora_alpha * 2
203- denoiser_lora_config .alpha_pattern = {module_name_to_rank_update : updated_alpha }
204-
205- pipe .transformer .add_adapter (denoiser_lora_config , "adapter-1" )
206- self .assertTrue (
207- pipe .transformer .peft_config ["adapter-1" ].alpha_pattern == {module_name_to_rank_update : updated_alpha }
208- )
209-
210- lora_output_diff_alpha = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
211- self .assertTrue (not np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
212- self .assertTrue (not np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
213-
214- @skip_mps
215- def test_lora_fuse_nan (self ):
216- components , _ , denoiser_lora_config = self .get_dummy_components ()
217- pipe = self .pipeline_class (** components )
218- pipe = pipe .to (torch_device )
219- pipe .set_progress_bar_config (disable = None )
220- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
221-
222- denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
223- denoiser .add_adapter (denoiser_lora_config , "adapter-1" )
224- self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
225-
226- # corrupt one LoRA weight with `inf` values
227- with torch .no_grad ():
228- possible_tower_names = ["noise_refiner" ]
229- filtered_tower_names = [
230- tower_name for tower_name in possible_tower_names if hasattr (pipe .transformer , tower_name )
231- ]
232- for tower_name in filtered_tower_names :
233- transformer_tower = getattr (pipe .transformer , tower_name )
234- transformer_tower [0 ].attention .to_q .lora_A ["adapter-1" ].weight += float ("inf" )
235-
236- # with `safe_fusing=True` we should see an Error
237- with self .assertRaises (ValueError ):
238- pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules , safe_fusing = True )
239-
240- # without we should not see an error, but every image will be black
241- pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules , safe_fusing = False )
242- out = pipe (** inputs )[0 ]
243-
244- self .assertTrue (np .isnan (out ).all ())
245-
246- def test_lora_scale_kwargs_match_fusion (self ):
247- super ().test_lora_scale_kwargs_match_fusion (5e-2 , 5e-2 )
248-
249- @unittest .skip ("Needs to be debugged." )
250- def test_set_adapters_match_attention_kwargs (self ):
251- super ().test_set_adapters_match_attention_kwargs ()
252-
253- @unittest .skip ("Needs to be debugged." )
254- def test_simple_inference_with_text_denoiser_lora_and_scale (self ):
255- super ().test_simple_inference_with_text_denoiser_lora_and_scale ()
256-
257147 @unittest .skip ("Not supported in ZImage." )
258148 def test_simple_inference_with_text_denoiser_block_scale (self ):
259149 pass
@@ -265,3 +155,23 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
265155 @unittest .skip ("Not supported in ZImage." )
266156 def test_modify_padding_mode (self ):
267157 pass
158+
159+ @unittest .skip ("Text encoder LoRA is not supported in ZImage." )
160+ def test_simple_inference_with_partial_text_lora (self ):
161+ pass
162+
163+ @unittest .skip ("Text encoder LoRA is not supported in ZImage." )
164+ def test_simple_inference_with_text_lora (self ):
165+ pass
166+
167+ @unittest .skip ("Text encoder LoRA is not supported in ZImage." )
168+ def test_simple_inference_with_text_lora_and_scale (self ):
169+ pass
170+
171+ @unittest .skip ("Text encoder LoRA is not supported in ZImage." )
172+ def test_simple_inference_with_text_lora_fused (self ):
173+ pass
174+
175+ @unittest .skip ("Text encoder LoRA is not supported in ZImage." )
176+ def test_simple_inference_with_text_lora_save_load (self ):
177+ pass
0 commit comments