Skip to content

Commit 324eb0a

Browse files
committed
feat: support QLoRA with NNX and Qwix
1 parent 397f319 commit 324eb0a

6 files changed

Lines changed: 263 additions & 18 deletions

File tree

src/maxtext/configs/post_train/sft.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ lora:
2727
lora_rank: 0
2828
lora_alpha: 0.0
2929
lora_module_path: ""
30+
# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size.
31+
lora_weight_qtype: null
32+
lora_tile_size: null
3033
# Optional path to LoRA weights to load before training. Ignored if the current run is resumed.
3134
lora_restore_path: ""
3235

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,14 @@ class LoRA(BaseModel):
12061206
"Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'."
12071207
),
12081208
)
1209+
lora_weight_qtype: str | None = Field(
1210+
None,
1211+
description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."),
1212+
)
1213+
lora_tile_size: NonNegativeInt | None = Field(
1214+
None,
1215+
description="Optional tile size for QLoRA (e.g., 128 or 256).",
1216+
)
12091217
lora_restore_path: PathStr = Field(
12101218
"",
12111219
description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."),

src/maxtext/layers/nnx_decoders.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@
6262
)
6363
from maxtext.multimodal import utils as mm_utils
6464
from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding
65-
from maxtext.utils.sharding import create_sharding
65+
from maxtext.utils.sharding import (
66+
create_sharding,
67+
nnx_ensure_scan_leading_axis,
68+
nnx_reconcile_sharding,
69+
nnx_sync_moveaxis,
70+
)
6671

6772
# ------------------------------------------------------------------------------
6873
# The network: Decoder Definitions
@@ -453,7 +458,7 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches
453458
graphdef, params, state = nnx.split(layers, nnx.Param, ...)
454459

455460
if scan_axis != 0:
456-
params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
461+
params = nnx_sync_moveaxis(params, scan_axis, 0)
457462

458463
sig = inspect.signature(layers.__class__.__call__)
459464
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
@@ -463,7 +468,24 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches
463468

464469
use_kv = kv_caches_stacked is not None
465470

471+
def stash_origin_metadata(x):
472+
is_var = hasattr(x, "get_metadata") and hasattr(x, "replace")
473+
if is_var:
474+
metadata = x.get_metadata()
475+
updates = {"origin_shape": x.value.shape}
476+
for k in ["sharding", "out_sharding", "sharding_names"]:
477+
if k in metadata:
478+
updates[f"origin_{k}"] = metadata[k]
479+
return x.replace(**updates)
480+
return x
481+
482+
params = jax.tree.map(stash_origin_metadata, params)
483+
state = jax.tree.map(stash_origin_metadata, state)
484+
466485
def layer_fn(carry, scanned_vars):
486+
# Ensure metadata rank matches the sliced values
487+
scanned_vars = nnx_reconcile_sharding(scanned_vars, "layers")
488+
467489
if use_kv:
468490
current_params, current_state, kv_cache_layer = scanned_vars
469491
else:
@@ -527,21 +549,57 @@ def layer_fn(carry, scanned_vars):
527549
else:
528550
layer_fn_wrapped = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
529551

530-
def _ensure_scan_leading_axis(x):
531-
if not hasattr(x, "shape") or len(x.shape) == 0:
532-
return jnp.broadcast_to(x, (length,))
533-
return x
534-
535-
params = jax.tree.map(_ensure_scan_leading_axis, params)
536-
state = jax.tree.map(_ensure_scan_leading_axis, state)
552+
params = nnx_ensure_scan_leading_axis(params, length)
553+
state = nnx_ensure_scan_leading_axis(state, length)
537554

538555
final_carry, (scanned_params, scanned_other) = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
539556

557+
# Ensure metadata rank matches the stacked values
558+
scanned_params = nnx_reconcile_sharding(scanned_params, "layers")
559+
scanned_other = nnx_reconcile_sharding(scanned_other, "layers")
560+
540561
if scan_axis != 0:
541-
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
562+
scanned_params = nnx_sync_moveaxis(scanned_params, 0, scan_axis)
563+
564+
def restore_origin_metadata(x):
565+
is_var = hasattr(x, "get_metadata") and hasattr(x, "replace")
566+
if is_var:
567+
metadata = x.get_metadata()
568+
updates = {}
569+
for k in ["sharding", "out_sharding", "sharding_names"]:
570+
origin_key = f"origin_{k}"
571+
if origin_key in metadata:
572+
updates[k] = metadata[origin_key]
573+
else:
574+
axes = metadata.get(k)
575+
if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)):
576+
spec_list = list(axes)
577+
if "layers" not in spec_list:
578+
pos = min(self.config.param_scan_axis, len(spec_list))
579+
spec_list.insert(pos, "layers")
580+
new_spec = jax.sharding.PartitionSpec(*spec_list)
581+
updates[k] = tuple(new_spec) if isinstance(axes, (tuple, list)) else new_spec
582+
if updates:
583+
return x.replace(**updates)
584+
return x
585+
586+
def is_leaf_with_metadata(x):
587+
return hasattr(x, "get_metadata") and hasattr(x, "replace")
588+
589+
scanned_params = jax.tree.map(restore_origin_metadata, scanned_params, is_leaf=is_leaf_with_metadata)
590+
scanned_other = jax.tree.map(restore_origin_metadata, scanned_other, is_leaf=is_leaf_with_metadata)
542591

543592
if dynamic_graph_init:
544593
out_layers = nnx.merge(updated_graphdef[0], scanned_params, scanned_other)
594+
595+
for attr_name, attr_val in self.__dict__.items():
596+
if attr_val is layers:
597+
setattr(self, attr_name, out_layers)
598+
break
599+
600+
g, s = nnx.split(self)
601+
new_self = nnx.merge(g, s)
602+
nnx.update(self, nnx.state(new_self))
545603
else:
546604
nnx.update(layers, nnx.State.merge(scanned_params, scanned_other))
547605
out_layers = layers

src/maxtext/utils/lora_utils.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Common LoRA utils needed to support LoRA adapters."""
15+
"""Common LoRA utils needed to support LoRA adapters."""
1616
from functools import partial
1717
import json
1818
import os
@@ -385,14 +385,20 @@ def _get_lora_module_path(mt_config: pyconfig.HyperParameters) -> str:
385385
model_name = mt_config.model_name.lower()
386386

387387
# Find the first matching architecture prefix or use 'default'
388-
matched_key = next((k for k in lora_configs if k != "default" and model_name.startswith(k)), "default")
388+
matched_key = next(
389+
(k for k in lora_configs if k != "default" and model_name.startswith(k)),
390+
"default",
391+
)
389392

390393
if matched_key == "default":
391394
max_logging.log(f"Warning: Model '{model_name}' is unverified; falling back to default LoRA path.")
392395
else:
393396
max_logging.log(f"Auto-detected lora_module_path for model '{model_name}' (matched: '{matched_key}')")
394397

395-
raw_path = lora_configs.get(matched_key, "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))")
398+
raw_path = lora_configs.get(
399+
matched_key,
400+
"decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))",
401+
)
396402

397403
# This regex makes the layer index optional, matching both scanned and unscanned layer paths
398404
# (e.g. 'layers/0/mlp/...' vs 'layers/mlp/...').
@@ -412,10 +418,15 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid
412418
"alpha": mt_config.lora.lora_alpha,
413419
"dropout": 0.0,
414420
}
415-
max_logging.log(
416-
f"LoRA configured: module_path={lora_module_path} "
417-
f"rank={mt_config.lora.lora_rank} alpha={mt_config.lora.lora_alpha}"
418-
)
421+
if mt_config.lora.lora_tile_size is not None:
422+
lora_kwargs["tile_size"] = mt_config.lora.lora_tile_size
423+
if mt_config.lora.lora_weight_qtype is not None:
424+
lora_kwargs["weight_qtype"] = mt_config.lora.lora_weight_qtype
425+
426+
lora_type = "QLoRA" if mt_config.lora.lora_weight_qtype else "LoRA"
427+
args_str = " ".join(f"{k}={v}" for k, v in lora_kwargs.items() if k != "dropout")
428+
max_logging.log(f"{lora_type} configured: {args_str}")
429+
419430
return qwix.LoraProvider(**lora_kwargs)
420431

421432

@@ -448,7 +459,7 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
448459
matched_module_paths = []
449460
sample_module_paths = []
450461

451-
for path, _ in nnx.iter_modules(lora_model):
462+
for path, _ in nnx.iter_graph(lora_model):
452463
module_path = "/".join(str(p) for p in path)
453464
if len(sample_module_paths) < 100:
454465
sample_module_paths.append(module_path)
@@ -469,6 +480,81 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
469480
)
470481

471482

483+
def _patch_qwix_for_maxtext(mesh, mt_config):
484+
# pylint: disable=protected-access,import-outside-toplevel,redefined-outer-name,reimported,missing-function-docstring,consider-using-from-import
485+
import qwix._src.flax_util as flax_util
486+
import qwix._src.providers.ptq as ptq
487+
import jax.numpy as jnp
488+
from flax import nnx
489+
490+
# 1. PTQ patch
491+
original_get_intercept_map = ptq.PtqProvider.get_intercept_map
492+
493+
def patched_get_intercept_map(self):
494+
mapping = original_get_intercept_map(self)
495+
496+
def intercept_asarray(a, dtype=None, order=None, **kwargs):
497+
if isinstance(a, nnx.State) and "array" in a:
498+
a = a["array"]
499+
if isinstance(a, nnx.State) and "qvalue" in a and "scale" in a:
500+
a = ptq.QArray(qvalue=a["qvalue"].value, scale=a["scale"].value)
501+
502+
if type(a).__name__ in ("WithAux", "QArray"):
503+
return a
504+
return jnp.asarray(a, dtype=dtype, order=order, **kwargs)
505+
506+
mapping["jax.numpy.asarray"] = intercept_asarray
507+
return mapping
508+
509+
ptq.PtqProvider.get_intercept_map = patched_get_intercept_map
510+
511+
# 2. find_param patch
512+
if not hasattr(flax_util, "_maxtext_find_param_patched"):
513+
514+
def patched_find_param(x, ptq_array_type=None):
515+
module = flax_util.get_current_module()
516+
if module is None:
517+
return None
518+
candidates = {}
519+
if isinstance(module, nnx.Module):
520+
array_types = nnx.Param | ptq_array_type if ptq_array_type else nnx.Param
521+
for name, node in module.__dict__.items():
522+
if isinstance(node, array_types):
523+
candidates[name] = node.value if isinstance(node, nnx.Param) else node
524+
else:
525+
return (
526+
flax_util.find_param.__wrapped__(x, ptq_array_type) if hasattr(flax_util.find_param, "__wrapped__") else None
527+
)
528+
529+
candidates_by_id = {id(c): n for n, c in candidates.items()}
530+
for n, c in candidates.items():
531+
if type(c).__name__ == "WithAux" and hasattr(c, "array"):
532+
candidates_by_id[id(c.array)] = n
533+
534+
if id(x) in candidates_by_id:
535+
return candidates_by_id[id(x)]
536+
537+
if isinstance(x, jax.core.Tracer) and hasattr(x, "parent"):
538+
curr_x = x
539+
while True:
540+
if id(curr_x) in candidates_by_id:
541+
return candidates_by_id[id(curr_x)]
542+
if curr_x.parent and len(curr_x.parent.in_tracers) == 1:
543+
curr_x = curr_x.parent.in_tracers[0]
544+
elif hasattr(curr_x, "get_const") and id(const := curr_x.get_const()) in candidates_by_id:
545+
return candidates_by_id[id(const)]
546+
else:
547+
break
548+
549+
filtered = {n: c for n, c in candidates.items() if hasattr(c, "shape") and c.shape == getattr(x, "shape", None)}
550+
if len(filtered) == 1:
551+
return list(filtered.keys())[0]
552+
return None
553+
554+
flax_util.find_param = patched_find_param
555+
flax_util._maxtext_find_param_patched = True
556+
557+
472558
def apply_lora_to_model(
473559
model: nnx.Module,
474560
mesh: Optional[jax.sharding.Mesh],
@@ -485,6 +571,8 @@ def apply_lora_to_model(
485571

486572
# Dynamically detect and set LoRA rank before model creation if restoring
487573

574+
_patch_qwix_for_maxtext(mesh, mt_config)
575+
488576
lora_provider = _build_lora_provider(mt_config)
489577

490578
model_rngs = getattr(model.decoder, "rngs", None)

src/maxtext/utils/sharding.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from collections.abc import Iterable
2121

2222
import jax
23+
import jax.numpy as jnp
2324
from jax.core import Tracer
2425
from jax.sharding import PartitionSpec as P, NamedSharding, reshard
2526

@@ -670,3 +671,88 @@ def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, sha
670671
# Apply the constraint to the model's current variables. This tells JAX to
671672
# gather the weights into this layout.
672673
return maybe_shard_with_name(variables, physical_constraint_no_fsdp, shard_mode=shard_mode)
674+
675+
676+
# ------------------------------------------------------------------------------
677+
# Metadata Synchronization Helpers for NNX Variables
678+
# ------------------------------------------------------------------------------
679+
680+
681+
def nnx_update_sharding_meta(variable, transform_fn):
682+
"""Generic helper to apply a list transformation to all sharding-related metadata."""
683+
if not (hasattr(variable, "get_metadata") and hasattr(variable, "replace")):
684+
return variable
685+
686+
meta = variable.get_metadata()
687+
updates = {}
688+
689+
for key in ["sharding", "out_sharding", "sharding_names"]:
690+
if (val := meta.get(key)) and isinstance(val, (P, tuple, list)):
691+
new_list = list(val)
692+
transformed = transform_fn(new_list)
693+
updates[key] = P(*transformed) if isinstance(val, P) else tuple(transformed)
694+
695+
return variable.replace(**updates) if updates else variable
696+
697+
698+
def nnx_sync_moveaxis(tree, from_axis, to_axis):
699+
"""Moves an axis in both values and sharding metadata of nnx.Variables."""
700+
if from_axis == to_axis:
701+
return tree
702+
703+
def _op(x):
704+
is_var = hasattr(x, "value") and hasattr(x, "get_metadata")
705+
val = x.value if is_var else x
706+
if not hasattr(val, "shape"):
707+
return x
708+
709+
new_val = jnp.moveaxis(val, from_axis, to_axis)
710+
if not is_var:
711+
return new_val
712+
713+
def move_fn(l):
714+
if len(l) > max(from_axis, to_axis):
715+
l.insert(to_axis, l.pop(from_axis))
716+
return l
717+
718+
return nnx_update_sharding_meta(x.replace(value=new_val), move_fn)
719+
720+
return jax.tree.map(_op, tree, is_leaf=lambda x: hasattr(x, "value") or hasattr(x, "shape"))
721+
722+
723+
def nnx_reconcile_sharding(tree, name="layers"):
724+
"""Reconciles sharding metadata rank with value rank by adding/removing an axis."""
725+
726+
def _op(x):
727+
if not (hasattr(x, "value") and hasattr(x, "get_metadata")):
728+
return x
729+
730+
def reconcile_fn(l):
731+
if len(l) > x.value.ndim: # Sliced: rank decreased
732+
if name in l:
733+
l.remove(name)
734+
while len(l) > x.value.ndim:
735+
l.pop(0)
736+
elif len(l) < x.value.ndim: # Stacked: rank increased
737+
l.insert(0, name) # Assume axis 0 for scan
738+
while len(l) < x.value.ndim:
739+
l.insert(0, None)
740+
return l
741+
742+
return nnx_update_sharding_meta(x, reconcile_fn)
743+
744+
return jax.tree.map(_op, tree, is_leaf=lambda x: hasattr(x, "get_metadata"))
745+
746+
747+
def nnx_ensure_scan_leading_axis(tree, length):
748+
"""Ensures all scanned variables have a leading axis of the given length."""
749+
750+
def _op(x):
751+
is_var = hasattr(x, "value") and hasattr(x, "get_metadata")
752+
val = x.value if is_var else x
753+
if hasattr(val, "shape") and len(val.shape) == 0:
754+
new_val = jnp.broadcast_to(val, (length,))
755+
return x.replace(value=new_val) if is_var else new_val
756+
return x
757+
758+
return jax.tree.map(_op, tree, is_leaf=lambda x: hasattr(x, "value") or hasattr(x, "shape"))

tests/post_training/unit/lora_utils_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def test_build_lora_provider(self):
106106
mock_config.lora.lora_module_path = "custom/path"
107107
mock_config.lora.lora_rank = 8
108108
mock_config.lora.lora_alpha = 16.0
109+
mock_config.lora.lora_tile_size = None
110+
mock_config.lora.lora_weight_qtype = None
109111

110112
with mock.patch("qwix.LoraProvider") as mock_provider:
111113
lora_utils._build_lora_provider(mock_config)

0 commit comments

Comments
 (0)