Skip to content

Commit aa36bb9

Browse files
committed
feat: support QLoRA with NNX and Qwix
1 parent 397f319 commit aa36bb9

5 files changed

Lines changed: 196 additions & 8 deletions

File tree

src/maxtext/configs/post_train/sft.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ lora:
2727
lora_rank: 0
2828
lora_alpha: 0.0
2929
lora_module_path: ""
30+
# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size.
31+
lora_weight_qtype: null
32+
lora_tile_size: null
3033
# Optional path to LoRA weights to load before training. Ignored if the current run is resumed.
3134
lora_restore_path: ""
3235

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,14 @@ class LoRA(BaseModel):
12061206
"Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'."
12071207
),
12081208
)
1209+
lora_weight_qtype: str | None = Field(
1210+
None,
1211+
description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."),
1212+
)
1213+
lora_tile_size: NonNegativeInt | None = Field(
1214+
None,
1215+
description="Optional tile size for QLoRA (e.g., 128 or 256).",
1216+
)
12091217
lora_restore_path: PathStr = Field(
12101218
"",
12111219
description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."),

src/maxtext/layers/nnx_decoders.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,63 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches
463463

464464
use_kv = kv_caches_stacked is not None
465465

466+
def stash_origin_metadata(x):
467+
is_var = hasattr(x, "get_metadata") and hasattr(x, "replace")
468+
if is_var:
469+
metadata = x.get_metadata()
470+
updates = {"origin_shape": x.value.shape}
471+
for k in ["sharding", "out_sharding", "sharding_names"]:
472+
if k in metadata:
473+
updates[f"origin_{k}"] = metadata[k]
474+
return x.replace(**updates)
475+
return x
476+
477+
params = jax.tree.map(stash_origin_metadata, params)
478+
state = jax.tree.map(stash_origin_metadata, state)
479+
466480
def layer_fn(carry, scanned_vars):
467481
if use_kv:
468482
current_params, current_state, kv_cache_layer = scanned_vars
469483
else:
470484
current_params, current_state = scanned_vars
471485
kv_cache_layer = None
472486

487+
def rank_consistent_spec(spec, shape):
488+
if spec is None:
489+
return None
490+
spec_list = list(spec)
491+
if len(spec_list) > len(shape):
492+
for axis_name in ["layers", "stage"]:
493+
if axis_name in spec_list:
494+
spec_list.remove(axis_name)
495+
if len(spec_list) == len(shape):
496+
break
497+
while len(spec_list) > len(shape):
498+
spec_list.pop(0)
499+
while len(spec_list) < len(shape):
500+
spec_list.insert(0, None)
501+
return jax.sharding.PartitionSpec(*spec_list)
502+
503+
def fix_node_rank(x):
504+
if hasattr(x, "get_metadata") and hasattr(x, "replace") and hasattr(x, "value"):
505+
metadata = x.get_metadata()
506+
updates = {}
507+
for k, axes in metadata.items():
508+
if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)):
509+
spec_obj = jax.sharding.PartitionSpec(*axes) if isinstance(axes, (tuple, list)) else axes
510+
if len(spec_obj) != x.value.ndim:
511+
new_spec = rank_consistent_spec(spec_obj, x.value.shape)
512+
updates[k] = tuple(new_spec) if isinstance(axes, (tuple, list)) else new_spec
513+
if updates:
514+
return x.replace(**updates)
515+
return x
516+
517+
def is_nnx_var(x):
518+
return hasattr(x, "get_metadata") and hasattr(x, "replace")
519+
520+
current_params = jax.tree.map(fix_node_rank, current_params, is_leaf=is_nnx_var)
521+
current_state = jax.tree.map(fix_node_rank, current_state, is_leaf=is_nnx_var)
522+
473523
if self.config.parameter_memory_host_offload:
474524
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
475525

@@ -540,8 +590,45 @@ def _ensure_scan_leading_axis(x):
540590
if scan_axis != 0:
541591
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
542592

593+
def restore_origin_metadata(x):
594+
is_var = hasattr(x, "get_metadata") and hasattr(x, "replace")
595+
if is_var:
596+
metadata = x.get_metadata()
597+
updates = {}
598+
for k in ["sharding", "out_sharding", "sharding_names"]:
599+
origin_key = f"origin_{k}"
600+
if origin_key in metadata:
601+
updates[k] = metadata[origin_key]
602+
else:
603+
axes = metadata.get(k)
604+
if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)):
605+
spec_list = list(axes)
606+
if "layers" not in spec_list:
607+
pos = min(self.config.param_scan_axis, len(spec_list))
608+
spec_list.insert(pos, "layers")
609+
new_spec = jax.sharding.PartitionSpec(*spec_list)
610+
updates[k] = tuple(new_spec) if isinstance(axes, (tuple, list)) else new_spec
611+
if updates:
612+
return x.replace(**updates)
613+
return x
614+
615+
def is_leaf_with_metadata(x):
616+
return hasattr(x, "get_metadata") and hasattr(x, "replace")
617+
618+
scanned_params = jax.tree.map(restore_origin_metadata, scanned_params, is_leaf=is_leaf_with_metadata)
619+
scanned_other = jax.tree.map(restore_origin_metadata, scanned_other, is_leaf=is_leaf_with_metadata)
620+
543621
if dynamic_graph_init:
544622
out_layers = nnx.merge(updated_graphdef[0], scanned_params, scanned_other)
623+
624+
for attr_name, attr_val in self.__dict__.items():
625+
if attr_val is layers:
626+
setattr(self, attr_name, out_layers)
627+
break
628+
629+
g, s = nnx.split(self)
630+
new_self = nnx.merge(g, s)
631+
nnx.update(self, nnx.state(new_self))
545632
else:
546633
nnx.update(layers, nnx.State.merge(scanned_params, scanned_other))
547634
out_layers = layers

src/maxtext/utils/lora_utils.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Common LoRA utils needed to support LoRA adapters."""
15+
"""Common LoRA utils needed to support LoRA adapters."""
1616
from functools import partial
1717
import json
1818
import os
@@ -385,14 +385,20 @@ def _get_lora_module_path(mt_config: pyconfig.HyperParameters) -> str:
385385
model_name = mt_config.model_name.lower()
386386

387387
# Find the first matching architecture prefix or use 'default'
388-
matched_key = next((k for k in lora_configs if k != "default" and model_name.startswith(k)), "default")
388+
matched_key = next(
389+
(k for k in lora_configs if k != "default" and model_name.startswith(k)),
390+
"default",
391+
)
389392

390393
if matched_key == "default":
391394
max_logging.log(f"Warning: Model '{model_name}' is unverified; falling back to default LoRA path.")
392395
else:
393396
max_logging.log(f"Auto-detected lora_module_path for model '{model_name}' (matched: '{matched_key}')")
394397

395-
raw_path = lora_configs.get(matched_key, "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))")
398+
raw_path = lora_configs.get(
399+
matched_key,
400+
"decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))",
401+
)
396402

397403
# This regex makes the layer index optional, matching both scanned and unscanned layer paths
398404
# (e.g. 'layers/0/mlp/...' vs 'layers/mlp/...').
@@ -412,10 +418,15 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid
412418
"alpha": mt_config.lora.lora_alpha,
413419
"dropout": 0.0,
414420
}
415-
max_logging.log(
416-
f"LoRA configured: module_path={lora_module_path} "
417-
f"rank={mt_config.lora.lora_rank} alpha={mt_config.lora.lora_alpha}"
418-
)
421+
if mt_config.lora.lora_tile_size is not None:
422+
lora_kwargs["tile_size"] = mt_config.lora.lora_tile_size
423+
if mt_config.lora.lora_weight_qtype is not None:
424+
lora_kwargs["weight_qtype"] = mt_config.lora.lora_weight_qtype
425+
426+
lora_type = "QLoRA" if mt_config.lora.lora_weight_qtype else "LoRA"
427+
args_str = " ".join(f"{k}={v}" for k, v in lora_kwargs.items() if k != "dropout")
428+
max_logging.log(f"{lora_type} configured: {args_str}")
429+
419430
return qwix.LoraProvider(**lora_kwargs)
420431

421432

@@ -448,7 +459,7 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
448459
matched_module_paths = []
449460
sample_module_paths = []
450461

451-
for path, _ in nnx.iter_modules(lora_model):
462+
for path, _ in nnx.iter_graph(lora_model):
452463
module_path = "/".join(str(p) for p in path)
453464
if len(sample_module_paths) < 100:
454465
sample_module_paths.append(module_path)
@@ -469,6 +480,81 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
469480
)
470481

471482

483+
def _patch_qwix_for_maxtext(mesh, mt_config):
484+
# pylint: disable=protected-access,import-outside-toplevel,redefined-outer-name,reimported,missing-function-docstring,consider-using-from-import
485+
import qwix._src.flax_util as flax_util
486+
import qwix._src.providers.ptq as ptq
487+
import jax.numpy as jnp
488+
from flax import nnx
489+
490+
# 1. PTQ patch
491+
original_get_intercept_map = ptq.PtqProvider.get_intercept_map
492+
493+
def patched_get_intercept_map(self):
494+
mapping = original_get_intercept_map(self)
495+
496+
def intercept_asarray(a, dtype=None, order=None, **kwargs):
497+
if isinstance(a, nnx.State) and "array" in a:
498+
a = a["array"]
499+
if isinstance(a, nnx.State) and "qvalue" in a and "scale" in a:
500+
a = ptq.QArray(qvalue=a["qvalue"].value, scale=a["scale"].value)
501+
502+
if type(a).__name__ in ("WithAux", "QArray"):
503+
return a
504+
return jnp.asarray(a, dtype=dtype, order=order, **kwargs)
505+
506+
mapping["jax.numpy.asarray"] = intercept_asarray
507+
return mapping
508+
509+
ptq.PtqProvider.get_intercept_map = patched_get_intercept_map
510+
511+
# 2. find_param patch
512+
if not hasattr(flax_util, "_maxtext_find_param_patched"):
513+
514+
def patched_find_param(x, ptq_array_type=None):
515+
module = flax_util.get_current_module()
516+
if module is None:
517+
return None
518+
candidates = {}
519+
if isinstance(module, nnx.Module):
520+
array_types = nnx.Param | ptq_array_type if ptq_array_type else nnx.Param
521+
for name, node in module.__dict__.items():
522+
if isinstance(node, array_types):
523+
candidates[name] = node.value if isinstance(node, nnx.Param) else node
524+
else:
525+
return (
526+
flax_util.find_param.__wrapped__(x, ptq_array_type) if hasattr(flax_util.find_param, "__wrapped__") else None
527+
)
528+
529+
candidates_by_id = {id(c): n for n, c in candidates.items()}
530+
for n, c in candidates.items():
531+
if type(c).__name__ == "WithAux" and hasattr(c, "array"):
532+
candidates_by_id[id(c.array)] = n
533+
534+
if id(x) in candidates_by_id:
535+
return candidates_by_id[id(x)]
536+
537+
if isinstance(x, jax.core.Tracer) and hasattr(x, "parent"):
538+
curr_x = x
539+
while True:
540+
if id(curr_x) in candidates_by_id:
541+
return candidates_by_id[id(curr_x)]
542+
if curr_x.parent and len(curr_x.parent.in_tracers) == 1:
543+
curr_x = curr_x.parent.in_tracers[0]
544+
elif hasattr(curr_x, "get_const") and id(const := curr_x.get_const()) in candidates_by_id:
545+
return candidates_by_id[id(const)]
546+
else:
547+
break
548+
549+
filtered = {n: c for n, c in candidates.items() if hasattr(c, "shape") and c.shape == getattr(x, "shape", None)}
550+
if len(filtered) == 1:
551+
return list(filtered.keys())[0]
552+
return None
553+
554+
flax_util.find_param = patched_find_param
555+
flax_util._maxtext_find_param_patched = True
556+
557+
472558
def apply_lora_to_model(
473559
model: nnx.Module,
474560
mesh: Optional[jax.sharding.Mesh],
@@ -485,6 +571,8 @@ def apply_lora_to_model(
485571

486572
# Dynamically detect and set LoRA rank before model creation if restoring
487573

574+
_patch_qwix_for_maxtext(mesh, mt_config)
575+
488576
lora_provider = _build_lora_provider(mt_config)
489577

490578
model_rngs = getattr(model.decoder, "rngs", None)

tests/post_training/unit/lora_utils_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def test_build_lora_provider(self):
106106
mock_config.lora.lora_module_path = "custom/path"
107107
mock_config.lora.lora_rank = 8
108108
mock_config.lora.lora_alpha = 16.0
109+
mock_config.lora.lora_tile_size = None
110+
mock_config.lora.lora_weight_qtype = None
109111

110112
with mock.patch("qwix.LoraProvider") as mock_provider:
111113
lora_utils._build_lora_provider(mock_config)

0 commit comments

Comments
 (0)