Skip to content

Commit e9f89be

Browse files
committed
feat: patch qwix update_boxed for PartitionSpec handling
1 parent 5793efa commit e9f89be

1 file changed

Lines changed: 36 additions & 0 deletions

File tree

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,41 @@ def _dot_general_with_3d(
282282

283283
lora_provider.dot_general = types.MethodType(_dot_general_with_3d, lora_provider)
284284

285+
286+
def _patch_qwix_update_boxed(qwix_flax_util):
287+
"""Patches Qwix flax_util.update_boxed to handle PartitionSpec."""
288+
original_update_boxed = qwix_flax_util.update_boxed
289+
290+
def patched_update_boxed(
291+
boxed,
292+
*,
293+
value=None,
294+
split=None,
295+
merge=None,
296+
transpose=None,
297+
):
298+
import jax
299+
from flax import nnx
300+
301+
if isinstance(boxed, nnx.Variable):
302+
if value is not None:
303+
boxed = boxed.replace(value)
304+
shape = boxed.value.shape
305+
metadata = boxed.get_metadata()
306+
sharding_key = "out_sharding" if "out_sharding" in metadata else "sharding_names"
307+
axes = metadata.get(sharding_key, None)
308+
if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
309+
axes = qwix_flax_util.update_sharding(
310+
axes, shape=shape, split=split, merge=merge, transpose=transpose
311+
)
312+
boxed.set_metadata(sharding_key, axes)
313+
return boxed
314+
return original_update_boxed(
315+
boxed, value=value, split=split, merge=merge, transpose=transpose
316+
)
317+
318+
qwix_flax_util.update_boxed = patched_update_boxed
319+
285320
def _prepare_dummy_inputs(mt_config, mesh):
286321
"""Builds dummy decoder inputs used to materialize LoRA parameters."""
287322
batch_size = getattr(mt_config, "per_device_batch_size", 1)
@@ -518,6 +553,7 @@ def maybe_apply_lora(model, mesh, mt_config):
518553
lora_provider = _build_lora_provider(mt_config, qwix)
519554

520555
_patch_qwix_dot_general_with_3d(lora_provider, qwix_flax_util, qwix_lora, qwix_ptq, types)
556+
_patch_qwix_update_boxed(qwix_flax_util)
521557

522558
decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(mt_config, mesh)
523559
lora_model = qwix.apply_lora_to_model(

0 commit comments

Comments
 (0)