2929from maxtext .utils import lora_utils
3030from maxtext .utils import model_creation_utils
3131from maxtext .configs import pyconfig
32+ from maxtext .utils import maxtext_utils
33+ from jax .sharding import Mesh
3234from tests .utils .test_helpers import get_test_config_path
3335
3436# ---------------------------------------------------------------------------
@@ -104,10 +106,14 @@ def test_build_lora_provider(self):
104106 mock_config .lora .lora_module_path = "custom/path"
105107 mock_config .lora .lora_rank = 8
106108 mock_config .lora .lora_alpha = 16.0
109+ mock_config .lora .lora_weight_qtype = "int8"
110+ mock_config .lora .lora_tile_size = 32
107111
108112 with mock .patch ("qwix.LoraProvider" ) as mock_provider :
109113 lora_utils ._build_lora_provider (mock_config )
110- mock_provider .assert_called_once_with (module_path = "custom/path" , rank = 8 , alpha = 16.0 , dropout = 0.0 )
114+ mock_provider .assert_called_once_with (
115+ module_path = "custom/path" , rank = 8 , alpha = 16.0 , dropout = 0.0 , weight_qtype = "int8" , tile_size = 32
116+ )
111117
112118 def test_prepare_dummy_inputs (self ):
113119 """Test preparation of dummy inputs for LoRA verification."""
@@ -158,27 +164,36 @@ def test_apply_lora_to_model_adapters_loaded(self):
158164 # If we skip Qwix, it should stay False.
159165 self .assertFalse (lora_utils .is_lora_enabled (result ))
160166
161- def _run_apply_lora_test (self , scan_layers : bool ):
162- """Helper to run LoRA application test with/without scanned layers."""
167+ def _run_apply_lora_test (self , scan_layers : bool , weight_qtype = None , tile_size = None , mock_multihost : bool = False ):
168+ """Helper to run LoRA application test with/without scanned layers and optional QLoRA ."""
163169 # Passing nested dict as 'lora' kwarg to _make_config
164170 cfg = _make_config (
165171 lora = {
166172 "enable_lora" : True ,
167173 "lora_rank" : 4 ,
168174 "lora_alpha" : 8.0 ,
169175 "lora_module_path" : ".*mlp/wi_.*" ,
176+ "lora_weight_qtype" : weight_qtype ,
177+ "lora_tile_size" : tile_size ,
170178 },
171179 scan_layers = scan_layers ,
172180 )
173181
174182 # Create a real small model using standard creation utils
175- model , _ = model_creation_utils .from_pretrained (cfg , mesh = None , model_mode = model_creation_utils .MODEL_MODE_TRAIN )
183+ model , mesh = model_creation_utils .from_pretrained (cfg , mesh = None , model_mode = model_creation_utils .MODEL_MODE_TRAIN )
176184
177185 # Verify model is NOT lora enabled initially
178186 self .assertFalse (lora_utils .is_lora_enabled (model ))
179187
180- # Apply LoRA
181- lora_model = lora_utils .apply_lora_to_model (model , model .mesh , cfg )
188+ if mock_multihost :
189+ devices_array = maxtext_utils .create_device_mesh (cfg )
190+ dummy_mesh = Mesh (devices_array , cfg .mesh_axes )
191+
192+ # Just verify that apply_lora_to_model runs successfully with the dummy mesh
193+ lora_model = lora_utils .apply_lora_to_model (model , dummy_mesh , cfg )
194+ else :
195+ # Apply LoRA
196+ lora_model = lora_utils .apply_lora_to_model (model , mesh , cfg )
182197
183198 # Verify we can find LoRAParam in the state
184199 _ , state = nnx .split (lora_model )
@@ -200,13 +215,27 @@ def _run_apply_lora_test(self, scan_layers: bool):
200215 self .assertGreater (len (jax .tree_util .tree_leaves (opt_state )), 0 )
201216
202217 def test_apply_lora_to_model_scan_layers_false (self ):
203- """Test applying LoRA to model with scan_layers=False."""
218+ """Test applying standard LoRA to model with scan_layers=False."""
204219 self ._run_apply_lora_test (scan_layers = False )
205220
206221 def test_apply_lora_to_model_scan_layers_true (self ):
207- """Test applying LoRA to model with scan_layers=True."""
222+ """Test applying standard LoRA to model with scan_layers=True."""
208223 self ._run_apply_lora_test (scan_layers = True )
209224
225+ @unittest .skip ("Awaiting qwix fix for QLoRA params materialization" )
226+ def test_apply_qlora_to_model_scan_layers_false (self ):
227+ """Test applying QLoRA to model with scan_layers=False."""
228+ self ._run_apply_lora_test (scan_layers = False , weight_qtype = "int8" , tile_size = 32 )
229+
230+ @unittest .skip ("Awaiting qwix fix for QLoRA params materialization" )
231+ def test_apply_qlora_to_model_scan_layers_true (self ):
232+ """Test applying QLoRA to model with scan_layers=True."""
233+ self ._run_apply_lora_test (scan_layers = True , weight_qtype = "int8" , tile_size = 32 )
234+
235+ def test_apply_lora_multihost_mock (self ):
236+ """Test applying LoRA with a dummy mesh to trigger the multi-host reshard callback."""
237+ self ._run_apply_lora_test (scan_layers = False , mock_multihost = True )
238+
210239 def test_restore_lora_from_path (self ):
211240 """Test restoration of LoRA parameters from a path."""
212241 cfg = _make_config (
0 commit comments