Skip to content

Commit 2d237f8

Browse files
committed
feat: support QLoRA with NNX and Qwix
1 parent 5a7bedb commit 2d237f8

5 files changed

Lines changed: 136 additions & 5 deletions

File tree

src/maxtext/configs/post_train/sft.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ lora:
2727
lora_rank: 0
2828
lora_alpha: 0.0
2929
lora_module_path: ""
30+
# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size.
31+
lora_weight_qtype: null
32+
lora_tile_size: null
3033
# Optional path to LoRA weights to load before training. Ignored if the current run is resumed.
3134
lora_restore_path: ""
3235

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,14 @@ class LoRA(BaseModel):
12061206
"Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'."
12071207
),
12081208
)
1209+
lora_weight_qtype: str | None = Field(
1210+
None,
1211+
description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."),
1212+
)
1213+
lora_tile_size: NonNegativeInt | None = Field(
1214+
None,
1215+
description="Optional tile size for QLoRA (e.g., 128 or 256).",
1216+
)
12091217
lora_restore_path: PathStr = Field(
12101218
"",
12111219
description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."),

src/maxtext/layers/nnx_decoders.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,61 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches
463463

464464
use_kv = kv_caches_stacked is not None
465465

466+
def stash_origin_metadata(x):
467+
is_var = hasattr(x, "get_metadata") and hasattr(x, "replace")
468+
if is_var:
469+
metadata = x.get_metadata()
470+
updates = {'origin_shape': x.value.shape}
471+
for k in ["sharding", "out_sharding", "sharding_names"]:
472+
if k in metadata:
473+
updates[f'origin_{k}'] = metadata[k]
474+
return x.replace(**updates)
475+
return x
476+
477+
params = jax.tree.map(stash_origin_metadata, params)
478+
state = jax.tree.map(stash_origin_metadata, state)
479+
466480
def layer_fn(carry, scanned_vars):
467481
if use_kv:
468482
current_params, current_state, kv_cache_layer = scanned_vars
469483
else:
470484
current_params, current_state = scanned_vars
471485
kv_cache_layer = None
472486

487+
def rank_consistent_spec(spec, shape):
488+
if spec is None:
489+
return None
490+
spec_list = list(spec)
491+
if len(spec_list) > len(shape):
492+
for axis_name in ["layers", "stage"]:
493+
if axis_name in spec_list:
494+
spec_list.remove(axis_name)
495+
if len(spec_list) == len(shape):
496+
break
497+
while len(spec_list) > len(shape):
498+
spec_list.pop(0)
499+
while len(spec_list) < len(shape):
500+
spec_list.insert(0, None)
501+
return jax.sharding.PartitionSpec(*spec_list)
502+
503+
def fix_node_rank(x):
504+
if hasattr(x, "get_metadata") and hasattr(x, "replace") and hasattr(x, "value"):
505+
metadata = x.get_metadata()
506+
updates = {}
507+
for k, axes in metadata.items():
508+
if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)):
509+
spec_obj = jax.sharding.PartitionSpec(*axes) if isinstance(axes, (tuple, list)) else axes
510+
if len(spec_obj) != x.value.ndim:
511+
new_spec = rank_consistent_spec(spec_obj, x.value.shape)
512+
updates[k] = tuple(new_spec) if isinstance(axes, (tuple, list)) else new_spec
513+
if updates:
514+
return x.replace(**updates)
515+
return x
516+
517+
is_nnx_var = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace")
518+
current_params = jax.tree.map(fix_node_rank, current_params, is_leaf=is_nnx_var)
519+
current_state = jax.tree.map(fix_node_rank, current_state, is_leaf=is_nnx_var)
520+
473521
if self.config.parameter_memory_host_offload:
474522
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
475523

@@ -540,8 +588,43 @@ def _ensure_scan_leading_axis(x):
540588
if scan_axis != 0:
541589
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
542590

591+
def restore_origin_metadata(x):
592+
is_var = hasattr(x, "get_metadata") and hasattr(x, "replace")
593+
if is_var:
594+
metadata = x.get_metadata()
595+
updates = {}
596+
for k in ["sharding", "out_sharding", "sharding_names"]:
597+
origin_key = f'origin_{k}'
598+
if origin_key in metadata:
599+
updates[k] = metadata[origin_key]
600+
else:
601+
axes = metadata.get(k)
602+
if isinstance(axes, (jax.sharding.PartitionSpec, tuple, list)):
603+
spec_list = list(axes)
604+
if "layers" not in spec_list:
605+
pos = min(self.config.param_scan_axis, len(spec_list))
606+
spec_list.insert(pos, "layers")
607+
new_spec = jax.sharding.PartitionSpec(*spec_list)
608+
updates[k] = tuple(new_spec) if isinstance(axes, (tuple, list)) else new_spec
609+
if updates:
610+
return x.replace(**updates)
611+
return x
612+
613+
is_leaf_with_metadata = lambda x: hasattr(x, "get_metadata") and hasattr(x, "replace")
614+
scanned_params = jax.tree.map(restore_origin_metadata, scanned_params, is_leaf=is_leaf_with_metadata)
615+
scanned_other = jax.tree.map(restore_origin_metadata, scanned_other, is_leaf=is_leaf_with_metadata)
616+
543617
if dynamic_graph_init:
544618
out_layers = nnx.merge(updated_graphdef[0], scanned_params, scanned_other)
619+
620+
for attr_name, attr_val in self.__dict__.items():
621+
if attr_val is layers:
622+
setattr(self, attr_name, out_layers)
623+
break
624+
625+
g, s = nnx.split(self)
626+
new_self = nnx.merge(g, s)
627+
nnx.update(self, nnx.state(new_self))
545628
else:
546629
nnx.update(layers, nnx.State.merge(scanned_params, scanned_other))
547630
out_layers = layers

src/maxtext/utils/lora_utils.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,15 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid
412412
"alpha": mt_config.lora.lora_alpha,
413413
"dropout": 0.0,
414414
}
415-
max_logging.log(
416-
f"LoRA configured: module_path={lora_module_path} "
417-
f"rank={mt_config.lora.lora_rank} alpha={mt_config.lora.lora_alpha}"
418-
)
415+
if mt_config.lora.lora_tile_size is not None:
416+
lora_kwargs["tile_size"] = mt_config.lora.lora_tile_size
417+
if mt_config.lora.lora_weight_qtype is not None:
418+
lora_kwargs["weight_qtype"] = mt_config.lora.lora_weight_qtype
419+
420+
lora_type = "QLoRA" if mt_config.lora.lora_weight_qtype else "LoRA"
421+
args_str = " ".join(f"{k}={v}" for k, v in lora_kwargs.items() if k != "dropout")
422+
max_logging.log(f"{lora_type} configured: {args_str}")
423+
419424
return qwix.LoraProvider(**lora_kwargs)
420425

421426

@@ -448,7 +453,7 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
448453
matched_module_paths = []
449454
sample_module_paths = []
450455

451-
for path, _ in nnx.iter_modules(lora_model):
456+
for path, _ in nnx.iter_graph(lora_model):
452457
module_path = "/".join(str(p) for p in path)
453458
if len(sample_module_paths) < 100:
454459
sample_module_paths.append(module_path)
@@ -469,6 +474,34 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
469474
)
470475

471476

477+
def _patch_qwix_for_maxtext(mesh, mt_config):
478+
import qwix._src.flax_util as flax_util
479+
import qwix._src.providers.ptq as ptq
480+
import jax.numpy as jnp
481+
from flax import nnx
482+
483+
# 1. PTQ patch
484+
original_get_intercept_map = ptq.PtqProvider.get_intercept_map
485+
486+
def patched_get_intercept_map(self):
487+
mapping = original_get_intercept_map(self)
488+
489+
def intercept_asarray(a, dtype=None, order=None, **kwargs):
490+
if isinstance(a, nnx.State) and 'array' in a:
491+
a = a['array']
492+
if isinstance(a, nnx.State) and 'qvalue' in a and 'scale' in a:
493+
a = ptq.QArray(qvalue=a['qvalue'].value, scale=a['scale'].value)
494+
495+
if type(a).__name__ in ("WithAux", "QArray"):
496+
return a
497+
return jnp.asarray(a, dtype=dtype, order=order, **kwargs)
498+
499+
mapping["jax.numpy.asarray"] = intercept_asarray
500+
return mapping
501+
502+
ptq.PtqProvider.get_intercept_map = patched_get_intercept_map
503+
504+
472505
def apply_lora_to_model(
473506
model: nnx.Module,
474507
mesh: Optional[jax.sharding.Mesh],
@@ -484,6 +517,8 @@ def apply_lora_to_model(
484517
return model
485518

486519
# Dynamically detect and set LoRA rank before model creation if restoring
520+
521+
_patch_qwix_for_maxtext(mesh, mt_config)
487522

488523
lora_provider = _build_lora_provider(mt_config)
489524

tests/post_training/unit/lora_utils_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def test_build_lora_provider(self):
106106
mock_config.lora.lora_module_path = "custom/path"
107107
mock_config.lora.lora_rank = 8
108108
mock_config.lora.lora_alpha = 16.0
109+
mock_config.lora.lora_tile_size = None
110+
mock_config.lora.lora_weight_qtype = None
109111

110112
with mock.patch("qwix.LoraProvider") as mock_provider:
111113
lora_utils._build_lora_provider(mock_config)

0 commit comments

Comments
 (0)