@@ -282,6 +282,41 @@ def _dot_general_with_3d(
282282
283283 lora_provider .dot_general = types .MethodType (_dot_general_with_3d , lora_provider )
284284
285+
286+ def _patch_qwix_update_boxed (qwix_flax_util ):
287+ """Patches Qwix flax_util.update_boxed to handle PartitionSpec."""
288+ original_update_boxed = qwix_flax_util .update_boxed
289+
290+ def patched_update_boxed (
291+ boxed ,
292+ * ,
293+ value = None ,
294+ split = None ,
295+ merge = None ,
296+ transpose = None ,
297+ ):
298+ import jax
299+ from flax import nnx
300+
301+ if isinstance (boxed , nnx .Variable ):
302+ if value is not None :
303+ boxed = boxed .replace (value )
304+ shape = boxed .value .shape
305+ metadata = boxed .get_metadata ()
306+ sharding_key = "out_sharding" if "out_sharding" in metadata else "sharding_names"
307+ axes = metadata .get (sharding_key , None )
308+ if isinstance (axes , (list , tuple , jax .sharding .PartitionSpec )):
309+ axes = qwix_flax_util .update_sharding (
310+ axes , shape = shape , split = split , merge = merge , transpose = transpose
311+ )
312+ boxed .set_metadata (sharding_key , axes )
313+ return boxed
314+ return original_update_boxed (
315+ boxed , value = value , split = split , merge = merge , transpose = transpose
316+ )
317+
318+ qwix_flax_util .update_boxed = patched_update_boxed
319+
285320def _prepare_dummy_inputs (mt_config , mesh ):
286321 """Builds dummy decoder inputs used to materialize LoRA parameters."""
287322 batch_size = getattr (mt_config , "per_device_batch_size" , 1 )
@@ -518,6 +553,7 @@ def maybe_apply_lora(model, mesh, mt_config):
518553 lora_provider = _build_lora_provider (mt_config , qwix )
519554
520555 _patch_qwix_dot_general_with_3d (lora_provider , qwix_flax_util , qwix_lora , qwix_ptq , types )
556+ _patch_qwix_update_boxed (qwix_flax_util )
521557
522558 decoder_input_tokens , decoder_positions = _prepare_dummy_inputs (mt_config , mesh )
523559 lora_model = qwix .apply_lora_to_model (
0 commit comments