Skip to content

Commit fc27813

Browse files
committed
feat: enhance LoRA path flexibility and patch qwix
1 parent 4aa6ceb commit fc27813

2 files changed

Lines changed: 223 additions & 14 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 183 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,13 +449,57 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
449449
def layer_fn(carry, scanned_vars):
450450
current_params, current_state = scanned_vars
451451

452+
def rank_consistent_spec(spec, shape):
453+
if spec is None: return None
454+
spec_list = list(spec)
455+
456+
# 1. Remove scanning axes if rank reduction is needed
457+
if len(spec_list) > len(shape):
458+
for axis_name in ["layers", "stage"]:
459+
if axis_name in spec_list:
460+
spec_list.remove(axis_name)
461+
if len(spec_list) == len(shape): break
462+
463+
# 2. If still mismatched, strip from the left (standard JAX rank reduction)
464+
while len(spec_list) > len(shape):
465+
spec_list.pop(0)
466+
467+
# 3. If rank is too small, pad with None
468+
while len(spec_list) < len(shape):
469+
spec_list.insert(0, None)
470+
471+
return jax.sharding.PartitionSpec(*spec_list)
472+
473+
def fix_node_rank(x):
474+
if hasattr(x, "get_metadata") and hasattr(x, "replace") and hasattr(x, "value"):
475+
metadata = x.get_metadata()
476+
updates = {}
477+
for k, axes in metadata.items():
478+
if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)):
479+
# Convert tuple/list to spec for check
480+
spec_obj = jax.sharding.PartitionSpec(*axes) if isinstance(axes, (tuple, list)) else axes
481+
if len(spec_obj) != x.value.ndim:
482+
new_spec = rank_consistent_spec(spec_obj, x.value.shape)
483+
# Keep original type (tuple vs spec)
484+
updates[k] = tuple(new_spec) if isinstance(axes, (tuple, list)) else new_spec
485+
# print(f"[DEBUG] Normalizing metadata key '{k}' from rank {len(spec_obj)} to {len(new_spec)}")
486+
if updates:
487+
return x.replace(**updates)
488+
return x
489+
490+
is_nnx_var = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace")
491+
current_params = jax.tree.map(fix_node_rank, current_params, is_leaf=is_nnx_var)
492+
current_state = jax.tree.map(fix_node_rank, current_state, is_leaf=is_nnx_var)
493+
452494
if self.config.parameter_memory_host_offload:
453495
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
454496

455497
layer = nnx.merge(graphdef, current_params, current_state)
498+
456499
layer_out = layer(carry, *args, **valid_kwargs)
457500
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
458501

502+
# Extract EVERYTHING to capture new parameters
459503
new_graphdef, updated_params, updated_state = nnx.split(layer, nnx.Param, ...)
460504

461505
if dynamic_graph_init:
@@ -466,23 +510,154 @@ def layer_fn(carry, scanned_vars):
466510

467511
return new_carry, (returned_params, updated_state)
468512

469-
layer_fn_wrapped = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
513+
if dynamic_graph_init:
514+
print(f"[DEBUG] Starting Dynamic Graph Init Loop (length={length})")
515+
curr_carry = x_in
516+
out_params_list = []
517+
out_other_list = []
518+
519+
def _slice_and_unpromote(x, i):
520+
# Resolve physical value and shape
521+
is_var = hasattr(x, "get_metadata") and hasattr(x, "replace")
522+
val = x.value if is_var else x
523+
524+
if not hasattr(val, "shape") or len(val.shape) == 0 or val.shape[0] != length:
525+
return x
526+
527+
# 1. Slice value
528+
sliced_val = val[i]
529+
530+
# 2. Slice logical metadata if it's an NNX variable
531+
if is_var:
532+
metadata = x.get_metadata()
533+
updates = {}
534+
for sharding_key in ["sharding", "out_sharding", "sharding_names"]:
535+
axes = metadata.get(sharding_key)
536+
if isinstance(axes, jax.sharding.PartitionSpec):
537+
spec_list = list(axes)
538+
539+
# Aggressively reduce rank to match sliced_val.ndim
540+
for axis_to_remove in ["layers", "stage"]:
541+
if axis_to_remove in spec_list and len(spec_list) > sliced_val.ndim:
542+
spec_list.remove(axis_to_remove)
543+
544+
while len(spec_list) > sliced_val.ndim:
545+
spec_list.pop(0)
546+
547+
while len(spec_list) < sliced_val.ndim:
548+
spec_list.insert(0, None)
549+
550+
new_spec = jax.sharding.PartitionSpec(*spec_list)
551+
updates[sharding_key] = new_spec
552+
553+
return x.replace(value=sliced_val, **updates)
554+
555+
return sliced_val
556+
557+
def _promote_to_scanned(x):
558+
"""Adds 'layers' axis back to newly created parameters if scanning is enabled."""
559+
if not self.config.scan_layers:
560+
return x
561+
562+
is_nnx_leaf = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace")
563+
if is_nnx_leaf(x):
564+
metadata = x.get_metadata()
565+
updates = {}
566+
# Determine which axis to insert 'layers' into based on config
567+
scan_axis = self.config.param_scan_axis
568+
569+
for sharding_key in ["sharding", "out_sharding", "sharding_names"]:
570+
axes = metadata.get(sharding_key)
571+
if isinstance(axes, jax.sharding.PartitionSpec):
572+
spec_list = list(axes)
573+
if "layers" not in spec_list:
574+
# Insert 'layers' at the correct scan axis position
575+
# Cap at current length to avoid index out of bounds
576+
insert_pos = min(scan_axis, len(spec_list))
577+
spec_list.insert(insert_pos, "layers")
578+
updates[sharding_key] = jax.sharding.PartitionSpec(*spec_list)
579+
580+
if updates:
581+
return x.replace(**updates)
582+
return x
583+
584+
for i in range(length):
585+
# Slice both values AND logical metadata!
586+
is_nnx_leaf = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace")
587+
curr_params = jax.tree.map(lambda x: _slice_and_unpromote(x, i), params, is_leaf=is_nnx_leaf)
588+
curr_state = jax.tree.map(lambda x: _slice_and_unpromote(x, i), state, is_leaf=is_nnx_leaf)
589+
590+
curr_carry, (out_p, out_o) = layer_fn(curr_carry, (curr_params, curr_state))
591+
592+
# Promote ALL parameters back to rank-3 metadata immediately
593+
# This ensures they are ready to be stacked correctly.
594+
out_p = jax.tree.map(_promote_to_scanned, out_p, is_leaf=is_nnx_leaf)
595+
out_o = jax.tree.map(_promote_to_scanned, out_o, is_leaf=is_nnx_leaf)
596+
597+
out_params_list.append(out_p)
598+
out_other_list.append(out_o)
599+
600+
final_carry = curr_carry
601+
scanned_params = jax.tree.map(lambda *args: jnp.stack(args), *out_params_list)
602+
scanned_other = jax.tree.map(lambda *args: jnp.stack(args), *out_other_list)
603+
604+
605+
else:
606+
layer_fn_wrapped = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
470607

471-
def _ensure_scan_leading_axis(x):
472-
if not hasattr(x, "shape") or len(x.shape) == 0:
473-
return jnp.broadcast_to(x, (length,))
474-
return x
608+
def _ensure_scan_leading_axis(x):
609+
if not hasattr(x, "shape") or len(x.shape) == 0:
610+
return jnp.broadcast_to(x, (length,))
611+
return x
475612

476-
params = jax.tree.map(_ensure_scan_leading_axis, params)
477-
state = jax.tree.map(_ensure_scan_leading_axis, state)
613+
params = jax.tree.map(_ensure_scan_leading_axis, params)
614+
state = jax.tree.map(_ensure_scan_leading_axis, state)
478615

479-
final_carry, (scanned_params, scanned_other) = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
616+
final_carry, (scanned_params, scanned_other) = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
480617

481618
if scan_axis != 0:
482619
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
483620

621+
scan_axis = self.config.param_scan_axis
622+
623+
def _force_promote(x):
624+
is_nnx_leaf = hasattr(x, "get_metadata") and hasattr(x, "replace")
625+
if is_nnx_leaf:
626+
metadata = x.get_metadata()
627+
updates = {}
628+
val_ndim = x.value.ndim
629+
for sharding_key in ["sharding", "out_sharding", "sharding_names"]:
630+
axes = metadata.get(sharding_key)
631+
if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)):
632+
l = list(axes)
633+
if len(l) < val_ndim and "layers" not in l:
634+
pos = min(scan_axis, len(l))
635+
l.insert(pos, "layers")
636+
updates[sharding_key] = jax.sharding.PartitionSpec(*l) if isinstance(axes, jax.sharding.PartitionSpec) else tuple(l)
637+
if updates:
638+
return x.replace(**updates)
639+
return x
640+
641+
is_leaf_with_metadata = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace")
642+
scanned_params = jax.tree.map(_force_promote, scanned_params, is_leaf=is_leaf_with_metadata)
643+
scanned_other = jax.tree.map(_force_promote, scanned_other, is_leaf=is_leaf_with_metadata)
644+
484645
if dynamic_graph_init:
646+
# Perform a structural update: merge the new structure with the stacked arrays
485647
out_layers = nnx.merge(updated_graphdef[0], scanned_params, scanned_other)
648+
649+
# We must update the PARENT (self) to point to the new structure.
650+
for attr_name, attr_val in self.__dict__.items():
651+
if attr_val is layers:
652+
setattr(self, attr_name, out_layers)
653+
print(f"[DEBUG] Materialization complete: updated self.{attr_name}")
654+
break
655+
656+
# FORCE NNX to recognize new structural changes by splitting/merging the PARENT
657+
# This updates the underlying GraphDef for the entire Decoder.
658+
g, s = nnx.split(self)
659+
new_self = nnx.merge(g, s)
660+
nnx.update(self, nnx.state(new_self))
486661
else:
487662
nnx.update(layers, nnx.State.merge(scanned_params, scanned_other))
488663
out_layers = layers

src/maxtext/utils/lora_utils.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from orbax import checkpoint as ocp
2929
import qwix
3030

31-
from maxtext.common import checkpointing
3231
from maxtext.configs import pyconfig
3332
from maxtext.utils import gcs_utils
3433
from maxtext.utils import max_logging
@@ -391,19 +390,23 @@ def _get_lora_module_path(mt_config: pyconfig.HyperParameters) -> str:
391390

392391
for key, module_path in lora_configs.items():
393392
if key != "default" and model_name.startswith(key):
394-
max_logging.log(f"Auto-detected lora_module_path for model '{model_name}': {module_path}")
395-
return str(module_path)
393+
# Make the layer index optional to support both scanned and non-scanned paths
394+
# e.g., 'decoder/layers/0/mlp' vs 'decoder/layers/mlp'
395+
flexible_path = str(module_path).replace("layers/", "layers/(?:[0-9]+/)?")
396+
max_logging.log(f"Auto-detected lora_module_path for model '{model_name}': {flexible_path}")
397+
return flexible_path
396398

397399
default_path = lora_configs.get(
398400
"default",
399401
"decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))",
400402
)
403+
flexible_default = str(default_path).replace("layers/", "layers/(?:[0-9]+/)?")
401404
max_logging.log(
402405
f"Warning: Model '{model_name}' is not in the list of verified LoRA models. "
403406
"Auto-detection might not work. Please provide an explicit `lora_module_path` in your config if training fails."
404407
)
405-
max_logging.log(f"Falling back to default lora_module_path: {default_path}")
406-
return str(default_path)
408+
max_logging.log(f"Falling back to flexible default lora_module_path: {flexible_default}")
409+
return flexible_default
407410

408411

409412
def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvider:
@@ -433,6 +436,7 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid
433436
f"rank={lora_cfg.lora_rank} alpha={lora_cfg.lora_alpha} "
434437
f"tile_size={lora_cfg.lora_tile_size}"
435438
)
439+
436440
return qwix.LoraProvider(**lora_kwargs)
437441

438442

@@ -465,7 +469,7 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
465469
matched_module_paths = []
466470
sample_module_paths = []
467471

468-
for path, _ in nnx.iter_modules(lora_model):
472+
for path, _ in nnx.iter_graph(lora_model):
469473
module_path = "/".join(str(p) for p in path)
470474
if len(sample_module_paths) < 100:
471475
sample_module_paths.append(module_path)
@@ -486,6 +490,34 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
486490
)
487491

488492

493+
def _patch_qwix_for_maxtext(mesh, mt_config):
494+
import qwix._src.flax_util as flax_util
495+
import qwix._src.providers.ptq as ptq
496+
import jax.numpy as jnp
497+
from flax import nnx
498+
499+
# 1. PTQ patch
500+
original_get_intercept_map = ptq.PtqProvider.get_intercept_map
501+
502+
def patched_get_intercept_map(self):
503+
mapping = original_get_intercept_map(self)
504+
505+
def intercept_asarray(a, dtype=None, order=None, **kwargs):
506+
if isinstance(a, nnx.State) and 'array' in a:
507+
a = a['array']
508+
if isinstance(a, nnx.State) and 'qvalue' in a and 'scale' in a:
509+
a = ptq.QArray(qvalue=a['qvalue'].value, scale=a['scale'].value)
510+
511+
if type(a).__name__ in ("WithAux", "QArray"):
512+
return a
513+
return jnp.asarray(a, dtype=dtype, order=order, **kwargs)
514+
515+
mapping["jax.numpy.asarray"] = intercept_asarray
516+
return mapping
517+
518+
ptq.PtqProvider.get_intercept_map = patched_get_intercept_map
519+
520+
489521
def apply_lora_to_model(
490522
model: nnx.Module,
491523
mesh: Optional[jax.sharding.Mesh],
@@ -501,6 +533,8 @@ def apply_lora_to_model(
501533
if not getattr(lora_cfg, "enable_lora", False):
502534
return model
503535

536+
_patch_qwix_for_maxtext(mesh, mt_config)
537+
504538
lora_provider = _build_lora_provider(mt_config)
505539

506540
model_rngs = getattr(model.decoder, "rngs", None)

0 commit comments

Comments
 (0)