Skip to content

Commit c37faef

Browse files
committed
feat/fix: QLoRA support and NNX Decoder Sharding Fixes
1 parent 7c7f628 commit c37faef

7 files changed

Lines changed: 251 additions & 28 deletions

File tree

src/maxtext/configs/post_train/sft.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ lora:
2727
lora_rank: 0
2828
lora_alpha: 0.0
2929
lora_module_path: ""
30+
# For QLoRA, set lora_weight_qtype (e.g., "nf4") and optionally lora_tile_size.
31+
lora_weight_qtype: null
32+
lora_tile_size: null
3033
# Optional path to LoRA weights to load before training. Ignored if the current run is resumed.
3134
lora_restore_path: ""
3235

src/maxtext/configs/types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1280,9 +1280,19 @@ class LoRA(BaseModel):
12801280
lora_module_path: str = Field(
12811281
"",
12821282
description=(
1283-
"Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'."
1283+
"Regex identifying target NNX modules for LoRA. "
1284+
"Example for standard models: 'decoder/layers/.*(self_attention/(query|out)|mlp/(wi_0|wo))'. "
1285+
"Example for MoE: 'decoder/scanned_blocks/layers.*/.*(MoeBlock_0|shared_experts)/(wi_0|wo)'."
12841286
),
12851287
)
1288+
lora_weight_qtype: str | None = Field(
1289+
None,
1290+
description=("Optional quantization type for QLoRA (e.g., 'nf4'). If set, QLoRA is applied."),
1291+
)
1292+
lora_tile_size: NonNegativeInt | None = Field(
1293+
None,
1294+
description=("Tile size for block-wise quantization. Typically 32 or 64."),
1295+
)
12861296
lora_restore_path: PathStr = Field(
12871297
"",
12881298
description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."),

src/maxtext/layers/nnx_decoders.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@
6464
simple_layer,
6565
)
6666
from maxtext.multimodal import utils as mm_utils
67-
from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding
68-
from maxtext.utils.maxtext_utils_nnx import nnx_ensure_scan_leading_axis
67+
from maxtext.utils import max_logging, max_utils, maxtext_utils, maxtext_utils_nnx, sharding
6968
from maxtext.utils.sharding import create_sharding
7069

7170
# ------------------------------------------------------------------------------
@@ -601,6 +600,8 @@ def _extract_matching_state(template, full):
601600
use_kv = kv_caches_stacked is not None
602601

603602
def layer_fn(carry, scanned_vars):
603+
# Ensure metadata rank matches the sliced values
604+
scanned_vars = maxtext_utils_nnx.nnx_remove_scan_axis(scanned_vars, "layers")
604605

605606
# Unpack the sliced variables for THIS layer
606607
if use_kv:
@@ -670,8 +671,8 @@ def layer_fn(carry, scanned_vars):
670671
# inference with vLLM, parameters do not change and we don't need intermediates.
671672
return current_carry, layers, None
672673
else:
673-
params = nnx_ensure_scan_leading_axis(params, length)
674-
state = nnx_ensure_scan_leading_axis(state, length)
674+
params = maxtext_utils_nnx.nnx_ensure_scan_leading_axis(params, length)
675+
state = maxtext_utils_nnx.nnx_ensure_scan_leading_axis(state, length)
675676

676677
# Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan
677678
# leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop
@@ -691,10 +692,15 @@ def layer_fn(carry, scanned_vars):
691692
final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
692693
returned_kv_stacked = None
693694

694-
if scan_axis != 0:
695-
new_params, new_rest = scanned_state.split(nnx.Param, ...)
696-
new_params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), new_params)
697-
scanned_state = nnx.merge_state(new_params, new_rest)
695+
# Ensure metadata rank matches the stacked values
696+
scanned_state = maxtext_utils_nnx.nnx_add_scan_axis(scanned_state, "layers", 0)
697+
698+
if scan_axis != 0:
699+
new_params, new_rest = scanned_state.split(nnx.Param, ...)
700+
new_params = maxtext_utils_nnx.nnx_sync_moveaxis(new_params, 0, scan_axis)
701+
scanned_state = nnx.merge_state(new_params, new_rest)
702+
703+
returned_kv_stacked = None
698704

699705
if dynamic_graph_init:
700706
# If graph changed, we need to merge with the new graphdef.

src/maxtext/utils/lora_utils.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
from typing import Any, Optional
2323

24-
from flax import nnx
24+
from flax import nnx, linen as nn
2525
from flax.linen import partitioning as nn_partitioning
2626
from flax.training import train_state
2727
import jax
@@ -35,7 +35,6 @@
3535
from maxtext.utils import max_logging
3636
from maxtext.utils import max_utils
3737
from maxtext.utils import maxtext_utils
38-
from maxtext.utils import sharding
3938
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
4039

4140

@@ -416,11 +415,18 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid
416415
"rank": mt_config.lora.lora_rank,
417416
"alpha": mt_config.lora.lora_alpha,
418417
"dropout": 0.0,
418+
"weight_qtype": mt_config.lora.lora_weight_qtype,
419+
"tile_size": mt_config.lora.lora_tile_size,
419420
}
421+
# Distinguish between standard LoRA and QLoRA in logs
422+
lora_type = "QLoRA" if mt_config.lora.lora_weight_qtype else "LoRA"
423+
420424
max_logging.log(
421-
f"LoRA configured: module_path={lora_module_path} "
422-
f"rank={mt_config.lora.lora_rank} alpha={mt_config.lora.lora_alpha}"
425+
f"{lora_type} configured: rank={mt_config.lora.lora_rank} alpha={mt_config.lora.lora_alpha} "
426+
f"qtype={mt_config.lora.lora_weight_qtype} tile_size={mt_config.lora.lora_tile_size}"
423427
)
428+
429+
max_logging.log(f"Using lora_module_path: {lora_module_path}")
424430
return qwix.LoraProvider(**lora_kwargs)
425431

426432

@@ -518,13 +524,22 @@ def apply_lora_to_model(
518524

519525
# Use logical_to_mesh_sharding to correctly map logical axes like 'embed'
520526
# to physical mesh axes.
521-
dst_shardings = sharding.logical_to_mesh_sharding(
522-
nnx.get_partition_spec(state), mesh, rules=mt_config.logical_axis_rules
523-
)
524-
525-
from tunix.rl import reshard # pylint: disable=import-outside-toplevel
527+
dst_shardings = nn.logical_to_mesh_sharding(nnx.get_partition_spec(state), mesh, mt_config.logical_axis_rules)
528+
529+
def _safe_reshard(var, sharding_spec):
530+
if not isinstance(var, nnx.Variable) or not isinstance(sharding_spec, jax.sharding.Sharding):
531+
return var
532+
val = var.get_value()
533+
if not isinstance(val, jax.Array):
534+
return var
535+
# make_array_from_callback natively constructs a globally sharded array
536+
# from the local host arrays, bypassing backend-specific device_put issues
537+
# on both Pathways and McJAX.
538+
resharded_val = jax.make_array_from_callback(val.shape, sharding_spec, lambda idx: val[idx])
539+
return var.replace(value=resharded_val)
540+
541+
state = jax.tree_util.tree_map(_safe_reshard, state, dst_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable))
526542

527-
state = reshard.reshard_pytree(state, dst_shardings)
528543
lora_model = nnx.merge(graph_def, state)
529544

530545
_verify_lora_parameters(lora_model, mt_config)

src/maxtext/utils/maxtext_utils_nnx.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
from flax import nnx
2020
import jax
21-
from jax.sharding import Mesh, NamedSharding
21+
import jax.numpy as jnp
22+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2223

2324
from maxtext.utils import max_logging
2425
from maxtext.configs import pyconfig
@@ -187,3 +188,90 @@ def _op(x):
187188
return x
188189

189190
return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable))
191+
192+
193+
# ------------------------------------------------------------------------------
194+
# Metadata Synchronization Helpers for NNX Variables
195+
# ------------------------------------------------------------------------------
196+
197+
198+
def nnx_update_sharding_meta(variable, transform_fn):
199+
"""Generic helper to apply a list transformation to all sharding-related metadata."""
200+
if not (hasattr(variable, "get_metadata") and hasattr(variable, "replace")):
201+
return variable
202+
203+
meta = variable.get_metadata()
204+
updates = {}
205+
206+
for key in ["sharding", "out_sharding", "sharding_names"]:
207+
if (val := meta.get(key)) and isinstance(val, (P, tuple, list)):
208+
new_list = list(val)
209+
transformed = transform_fn(new_list)
210+
updates[key] = P(*transformed) if isinstance(val, P) else tuple(transformed)
211+
212+
if updates:
213+
return variable.replace(**updates)
214+
return variable
215+
216+
217+
def nnx_sync_moveaxis(tree, from_axis, to_axis):
218+
"""Moves an axis in both values and sharding metadata of nnx.Variables."""
219+
if from_axis == to_axis:
220+
return tree
221+
222+
def _op(x):
223+
is_var = isinstance(x, nnx.Variable)
224+
val = x.get_value() if is_var else x
225+
if not hasattr(val, "shape"):
226+
return x
227+
228+
new_val = jnp.moveaxis(val, from_axis, to_axis)
229+
if not is_var:
230+
return new_val
231+
232+
def move_fn(l):
233+
if len(l) > max(from_axis, to_axis):
234+
l.insert(to_axis, l.pop(from_axis))
235+
return l
236+
237+
return nnx_update_sharding_meta(x.replace(value=new_val), move_fn)
238+
239+
return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable) or hasattr(x, "shape"))
240+
241+
242+
def nnx_remove_scan_axis(tree, name="layers"):
243+
"""Removes the given scan axis from the PartitionSpec."""
244+
245+
def _op(x):
246+
if not isinstance(x, nnx.Variable):
247+
return x
248+
249+
def remove_fn(l):
250+
if name in l:
251+
l.remove(name)
252+
while len(l) > x.get_value().ndim:
253+
l.pop(0)
254+
return l
255+
256+
return nnx_update_sharding_meta(x, remove_fn)
257+
258+
return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable))
259+
260+
261+
def nnx_add_scan_axis(tree, name="layers", pos=0):
262+
"""Adds the given scan axis to the PartitionSpec at the specified position."""
263+
264+
def _op(x):
265+
if not isinstance(x, nnx.Variable):
266+
return x
267+
268+
def add_fn(l):
269+
if name not in l:
270+
l.insert(pos, name)
271+
while len(l) < x.get_value().ndim:
272+
l.insert(pos, None)
273+
return l
274+
275+
return nnx_update_sharding_meta(x, add_fn)
276+
277+
return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable))

tests/post_training/unit/lora_utils_test.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from maxtext.utils import lora_utils
3030
from maxtext.utils import model_creation_utils
3131
from maxtext.configs import pyconfig
32+
from maxtext.utils import maxtext_utils
33+
from jax.sharding import Mesh
3234
from tests.utils.test_helpers import get_test_config_path
3335

3436
# ---------------------------------------------------------------------------
@@ -104,10 +106,14 @@ def test_build_lora_provider(self):
104106
mock_config.lora.lora_module_path = "custom/path"
105107
mock_config.lora.lora_rank = 8
106108
mock_config.lora.lora_alpha = 16.0
109+
mock_config.lora.lora_weight_qtype = "int8"
110+
mock_config.lora.lora_tile_size = 32
107111

108112
with mock.patch("qwix.LoraProvider") as mock_provider:
109113
lora_utils._build_lora_provider(mock_config)
110-
mock_provider.assert_called_once_with(module_path="custom/path", rank=8, alpha=16.0, dropout=0.0)
114+
mock_provider.assert_called_once_with(
115+
module_path="custom/path", rank=8, alpha=16.0, dropout=0.0, weight_qtype="int8", tile_size=32
116+
)
111117

112118
def test_prepare_dummy_inputs(self):
113119
"""Test preparation of dummy inputs for LoRA verification."""
@@ -158,27 +164,36 @@ def test_apply_lora_to_model_adapters_loaded(self):
158164
# If we skip Qwix, it should stay False.
159165
self.assertFalse(lora_utils.is_lora_enabled(result))
160166

161-
def _run_apply_lora_test(self, scan_layers: bool):
162-
"""Helper to run LoRA application test with/without scanned layers."""
167+
def _run_apply_lora_test(self, scan_layers: bool, weight_qtype=None, tile_size=None, mock_multihost: bool = False):
168+
"""Helper to run LoRA application test with/without scanned layers and optional QLoRA."""
163169
# Passing nested dict as 'lora' kwarg to _make_config
164170
cfg = _make_config(
165171
lora={
166172
"enable_lora": True,
167173
"lora_rank": 4,
168174
"lora_alpha": 8.0,
169175
"lora_module_path": ".*mlp/wi_.*",
176+
"lora_weight_qtype": weight_qtype,
177+
"lora_tile_size": tile_size,
170178
},
171179
scan_layers=scan_layers,
172180
)
173181

174182
# Create a real small model using standard creation utils
175-
model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN)
183+
model, mesh = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN)
176184

177185
# Verify model is NOT lora enabled initially
178186
self.assertFalse(lora_utils.is_lora_enabled(model))
179187

180-
# Apply LoRA
181-
lora_model = lora_utils.apply_lora_to_model(model, model.mesh, cfg)
188+
if mock_multihost:
189+
devices_array = maxtext_utils.create_device_mesh(cfg)
190+
dummy_mesh = Mesh(devices_array, cfg.mesh_axes)
191+
192+
# Just verify that apply_lora_to_model runs successfully with the dummy mesh
193+
lora_model = lora_utils.apply_lora_to_model(model, dummy_mesh, cfg)
194+
else:
195+
# Apply LoRA
196+
lora_model = lora_utils.apply_lora_to_model(model, mesh, cfg)
182197

183198
# Verify we can find LoRAParam in the state
184199
_, state = nnx.split(lora_model)
@@ -200,13 +215,27 @@ def _run_apply_lora_test(self, scan_layers: bool):
200215
self.assertGreater(len(jax.tree_util.tree_leaves(opt_state)), 0)
201216

202217
def test_apply_lora_to_model_scan_layers_false(self):
203-
"""Test applying LoRA to model with scan_layers=False."""
218+
"""Test applying standard LoRA to model with scan_layers=False."""
204219
self._run_apply_lora_test(scan_layers=False)
205220

206221
def test_apply_lora_to_model_scan_layers_true(self):
207-
"""Test applying LoRA to model with scan_layers=True."""
222+
"""Test applying standard LoRA to model with scan_layers=True."""
208223
self._run_apply_lora_test(scan_layers=True)
209224

225+
@unittest.skip("Awaiting qwix fix for QLoRA params materialization")
226+
def test_apply_qlora_to_model_scan_layers_false(self):
227+
"""Test applying QLoRA to model with scan_layers=False."""
228+
self._run_apply_lora_test(scan_layers=False, weight_qtype="int8", tile_size=32)
229+
230+
@unittest.skip("Awaiting qwix fix for QLoRA params materialization")
231+
def test_apply_qlora_to_model_scan_layers_true(self):
232+
"""Test applying QLoRA to model with scan_layers=True."""
233+
self._run_apply_lora_test(scan_layers=True, weight_qtype="int8", tile_size=32)
234+
235+
def test_apply_lora_multihost_mock(self):
236+
"""Test applying LoRA with a dummy mesh to trigger the multi-host reshard callback."""
237+
self._run_apply_lora_test(scan_layers=False, mock_multihost=True)
238+
210239
def test_restore_lora_from_path(self):
211240
"""Test restoration of LoRA parameters from a path."""
212241
cfg = _make_config(

0 commit comments

Comments
 (0)