Skip to content

Commit 6798fcd

Browse files
committed
feat: enhance update_sharding for PartitionSpec support
1 parent a82eb1b commit 6798fcd

2 files changed

Lines changed: 40 additions & 92 deletions

File tree

qwix_lora.patch renamed to qwix.patch

Lines changed: 25 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/qwix/_src/flax_util.py b/qwix/_src/flax_util.py
2-
index 4ea5d80..7ae95d3 100644
2+
index 4ea5d80..17bb94b 100644
33
--- a/qwix/_src/flax_util.py
44
+++ b/qwix/_src/flax_util.py
55
@@ -308,13 +308,13 @@ def unbox(maybe_boxed: Any) -> Any:
@@ -18,72 +18,54 @@ index 4ea5d80..7ae95d3 100644
1818
"""Derives the partition spec from an existing spec.
1919

2020
Args:
21-
@@ -330,21 +330,30 @@ def update_sharding(
21+
@@ -330,6 +330,8 @@ def update_sharding(
2222
The updated partition spec.
2323
"""
2424
assert bool(split) + bool(merge) + bool(transpose) <= 1
2525
+ is_pspec = isinstance(spec, jax.sharding.PartitionSpec)
26-
+ spec_list = list(spec)
2726
+
2827
if split:
29-
- spec = [(a, None) if i in split else (a,) for i, a in enumerate(spec)]
30-
- spec = sum(spec, ()) # flatten the list of tuples.
31-
+ spec_list = [
32-
+ (a, None) if i in split else (a,) for i, a in enumerate(spec_list)
33-
+ ]
34-
+ spec_list = list(sum(spec_list, ())) # flatten the list of tuples.
35-
elif merge:
36-
for i in merge:
37-
- spec = spec[: i + 1] + spec[i + 2 :] # pytype: disable=unsupported-operands
38-
+ spec_list = spec_list[: i + 1] + spec_list[i + 2 :]
39-
elif transpose:
40-
- spec = tuple(spec[i] if i is not None else None for i in transpose)
41-
+ spec_list = [
42-
+ spec_list[i] if i is not None else None for i in transpose
43-
+ ]
44-
45-
if shape:
46-
- assert len(shape) == len(spec), f'{shape=} {spec=}'
47-
+ assert len(shape) == len(spec_list), f'{shape=} {spec_list=}'
28+
spec = [(a, None) if i in split else (a,) for i, a in enumerate(spec)]
29+
spec = sum(spec, ()) # flatten the list of tuples.
30+
@@ -344,6 +346,9 @@ def update_sharding(
4831
# For scales: remove sharding for dimensions of size 1.
49-
- spec = tuple(None if d == 1 else a for a, d in zip(spec, shape))
50-
+ spec_list = [None if d == 1 else a for a, d in zip(spec_list, shape)]
32+
spec = tuple(None if d == 1 else a for a, d in zip(spec, shape))
5133

52-
- return spec
5334
+ if is_pspec:
54-
+ return jax.sharding.PartitionSpec(*spec_list)
55-
+ return tuple(spec_list)
35+
+ return jax.sharding.PartitionSpec(*spec)
36+
+
37+
return spec
5638

5739

58-
def update_boxed(
59-
@@ -380,7 +389,7 @@ def update_boxed(
40+
@@ -380,10 +385,8 @@ def update_boxed(
6041
shape = boxed.unbox().shape
6142
for possible_field in ('names', 'mesh_axes', 'axes_types'):
6243
axes = getattr(boxed, possible_field, None)
6344
- if isinstance(axes, (list, tuple)):
45+
- axes = update_sharding(
46+
- axes, shape=shape, split=split, merge=merge, transpose=transpose
47+
- )
6448
+ if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
65-
axes = update_sharding(
66-
axes, shape=shape, split=split, merge=merge, transpose=transpose
67-
)
68-
@@ -396,11 +405,13 @@ def update_boxed(
49+
+ axes = update_sharding(axes, shape=shape, split=split, merge=merge, transpose=transpose)
50+
boxed = dataclasses.replace(boxed, **{possible_field: axes})
51+
elif isinstance(boxed, nnx.Variable):
52+
if value is not None:
53+
@@ -396,10 +399,9 @@ def update_boxed(
6954
else:
7055
sharding_key = 'sharding_names'
7156
axes = metadata.get(sharding_key, None)
7257
- if isinstance(axes, (list, tuple)):
7358
- axes = update_sharding(
59+
- axes, shape=shape, split=split, merge=merge, transpose=transpose
60+
- )
61+
+
7462
+ if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
75-
+ updated_axes = update_sharding(
76-
axes, shape=shape, split=split, merge=merge, transpose=transpose
77-
)
78-
- boxed.set_metadata(sharding_key, axes)
79-
+ # Avoid mutating metadata unless sharding actually changed.
80-
+ if axes != updated_axes:
81-
+ boxed.set_metadata(sharding_key, updated_axes)
63+
+ axes = update_sharding(axes, shape=shape, split=split, merge=merge, transpose=transpose)
64+
boxed.set_metadata(sharding_key, axes)
8265
elif isinstance(boxed, jax.Array): # not boxed.
8366
if value is not None:
84-
boxed = value
8567
diff --git a/qwix/_src/providers/lora.py b/qwix/_src/providers/lora.py
86-
index e98c833..07be623 100644
68+
index e98c833..39ce3ef 100644
8769
--- a/qwix/_src/providers/lora.py
8870
+++ b/qwix/_src/providers/lora.py
8971
@@ -13,6 +13,7 @@
@@ -94,7 +76,7 @@ index e98c833..07be623 100644
9476
import string
9577
from typing import Any, Callable, Collection, Sequence
9678
import warnings
97-
@@ -189,29 +190,87 @@ class LoraProvider(ptq.PtqProvider):
79+
@@ -189,29 +190,68 @@ class LoraProvider(ptq.PtqProvider):
9880
if weight_name is None: # rhs is not a weight.
9981
return res
10082

@@ -107,25 +89,6 @@ index e98c833..07be623 100644
10789
+ (contract_lhs, contract_rhs) = dimension_numbers[0]
10890
+ (batch_lhs, batch_rhs) = dimension_numbers[1]
10991
+
110-
+ if len(rhs.shape) == 2 and not batch_rhs:
111-
+ # Standard LoRA path for ...a,ab->...b
112-
+ lora_a, lora_b = _get_or_create_lora_params(
113-
+ name=weight_name,
114-
+ rule=rule,
115-
+ a_shape=(rhs.shape[0], rule.rank),
116-
+ b_shape=(rule.rank, rhs.shape[1]),
117-
+ a_sharding_transpose=(0, None),
118-
+ b_sharding_transpose=(None, 1),
119-
+ )
120-
+
121-
+ if rule.dropout > 0:
122-
+ lhs = nnx.Dropout(rule.dropout, deterministic=False)(
123-
+ lhs, rngs=flax_util.make_rng('dropout')
124-
+ )
125-
+
126-
+ return res + lhs @ lora_a @ lora_b * (rule.alpha / rule.rank)
127-
+
128-
+ # General LoRA path for N-D kernels and batch dimensions.
12992
+ # Identify contracting, batch, and out axes for rhs.
13093
+ contract_rhs = tuple(contract_rhs)
13194
+ batch_rhs = tuple(batch_rhs)

src/maxtext/utils/lora_utils.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -418,77 +418,62 @@ def _patch_nnx_decoder_apply_layers_sequentially(model: nnx.Module) -> None:
418418
"""Patches the NNX decoder's _apply_layers_sequentially to include Qwix specific logic."""
419419

420420
def _apply_layers_sequentially_with_qwix(self, layers, x_in, *args, length: int, **kwargs):
421-
"""Runs the layer stack using nnx.scan with Qwix specific graph init and VJP downcasting."""
421+
"""Runs the layer stack using nnx.scan with Qwix specific graph init."""
422422
policy = self.get_remat_policy()
423423
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
424-
graphdef, params, state = nnx.split(
425-
layers, nnx.Param, ...
426-
) # state: the mutable state we carry (KV cache, RNGs, etc.)
424+
graphdef, params, state = nnx.split(layers, nnx.Param, ...)
427425

428426
scan_axis = self.config.param_scan_axis
429427
if scan_axis != 0:
430-
# Move scan_axis to 0 so scan can iterate over it
431-
params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
428+
params = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
432429

433-
layer_cls = layers.__class__
434-
sig = inspect.signature(layer_cls.__call__)
430+
sig = inspect.signature(layers.__class__.__call__)
435431
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
436-
# During Qwix init (disable_quant_stats_update=True), params may be lazily
437-
# created and the layer graphdef can grow. Allow graphdef refresh in that
438-
# phase only. Keep scanned training path static for remat purity.
432+
439433
dynamic_graph_init = bool(getattr(self, "disable_quant_stats_update", False))
440434
updated_graphdef = [graphdef]
441435

442436
def layer_fn(carry, scanned_vars):
443-
# Unpack the sliced variables for THIS layer
444437
current_params, current_state = scanned_vars
445438

446439
if self.config.parameter_memory_host_offload:
447-
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
440+
current_params = jax.tree_util.tree_map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
448441

449-
# Merge using the SLICED state
450442
layer = nnx.merge(graphdef, current_params, current_state)
451-
452-
# Run the layer (Filter kwargs if using the solution from previous turn)
453443
layer_out = layer(carry, *args, **valid_kwargs)
454-
455444
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
456445

457-
# Qwix init: return updated params so graphdef can grow.
458-
# In normal training, keep params unchanged to avoid extra memory use.
459446
new_graphdef, updated_params, updated_state = nnx.split(layer, nnx.Param, ...)
447+
460448
if dynamic_graph_init:
461449
updated_graphdef[0] = new_graphdef
462450
returned_params = updated_params
463451
else:
464452
returned_params = current_params
453+
465454
return new_carry, (returned_params, updated_state)
466455

467456
layer_fn_wrapped = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
468457

469458
def _ensure_scan_leading_axis(x):
470-
# Promote scalars for scan axis compatibility.
471-
if not hasattr(x, "shape"):
472-
return x
473-
if len(x.shape) == 0:
459+
if not hasattr(x, "shape") or len(x.shape) == 0:
474460
return jnp.broadcast_to(x, (length,))
475461
return x
476462

477-
params = jax.tree.map(_ensure_scan_leading_axis, params)
478-
state = jax.tree.map(_ensure_scan_leading_axis, state)
463+
params = jax.tree_util.tree_map(_ensure_scan_leading_axis, params)
464+
state = jax.tree_util.tree_map(_ensure_scan_leading_axis, state)
479465

480466
final_carry, (scanned_params, scanned_other) = jax.lax.scan(layer_fn_wrapped, x_in, (params, state))
481467

482468
if scan_axis != 0:
483-
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
469+
scanned_params = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
484470

485471
if dynamic_graph_init:
486472
return final_carry, nnx.merge(updated_graphdef[0], scanned_params, scanned_other)
487-
else:
488-
nnx.update(layers, nnx.State.merge(scanned_params, scanned_other))
489-
return final_carry, layers
473+
474+
nnx.update(layers, nnx.State.merge(scanned_params, scanned_other))
475+
return final_carry, layers
490476

491-
# IMPORTANT: Patch the class so nnx.merge doesn't lose the patch
492477
model.decoder.__class__._apply_layers_sequentially = _apply_layers_sequentially_with_qwix # pylint: disable=protected-access
493478

494479

0 commit comments

Comments
 (0)