11diff --git a/qwix/_src/flax_util.py b/qwix/_src/flax_util.py
2- index 4ea5d80..7ae95d3 100644
2+ index 4ea5d80..17bb94b 100644
33--- a/qwix/_src/flax_util.py
44+++ b/qwix/_src/flax_util.py
55@@ -308,13 +308,13 @@ def unbox(maybe_boxed: Any) -> Any:
@@ -18,72 +18,54 @@ index 4ea5d80..7ae95d3 100644
1818 """Derives the partition spec from an existing spec.
1919
2020 Args:
21- @@ -330,21 +330,30 @@ def update_sharding(
21+ @@ -330,6 +330,8 @@ def update_sharding(
2222 The updated partition spec.
2323 """
2424 assert bool(split) + bool(merge) + bool(transpose) <= 1
2525+ is_pspec = isinstance(spec, jax.sharding.PartitionSpec)
26- + spec_list = list(spec)
2726+
2827 if split:
29- - spec = [(a, None) if i in split else (a,) for i, a in enumerate(spec)]
30- - spec = sum(spec, ()) # flatten the list of tuples.
31- + spec_list = [
32- + (a, None) if i in split else (a,) for i, a in enumerate(spec_list)
33- + ]
34- + spec_list = list(sum(spec_list, ())) # flatten the list of tuples.
35- elif merge:
36- for i in merge:
37- - spec = spec[: i + 1] + spec[i + 2 :] # pytype: disable=unsupported-operands
38- + spec_list = spec_list[: i + 1] + spec_list[i + 2 :]
39- elif transpose:
40- - spec = tuple(spec[i] if i is not None else None for i in transpose)
41- + spec_list = [
42- + spec_list[i] if i is not None else None for i in transpose
43- + ]
44-
45- if shape:
46- - assert len(shape) == len(spec), f'{shape=} {spec=}'
47- + assert len(shape) == len(spec_list), f'{shape=} {spec_list=}'
28+ spec = [(a, None) if i in split else (a,) for i, a in enumerate(spec)]
29+ spec = sum(spec, ()) # flatten the list of tuples.
30+ @@ -344,6 +346,9 @@ def update_sharding(
4831 # For scales: remove sharding for dimensions of size 1.
49- - spec = tuple(None if d == 1 else a for a, d in zip(spec, shape))
50- + spec_list = [None if d == 1 else a for a, d in zip(spec_list, shape)]
32+ spec = tuple(None if d == 1 else a for a, d in zip(spec, shape))
5133
52- - return spec
5334+ if is_pspec:
54- + return jax.sharding.PartitionSpec(*spec_list)
55- + return tuple(spec_list)
35+ + return jax.sharding.PartitionSpec(*spec)
36+ +
37+ return spec
5638
5739
58- def update_boxed(
59- @@ -380,7 +389,7 @@ def update_boxed(
40+ @@ -380,10 +385,8 @@ def update_boxed(
6041 shape = boxed.unbox().shape
6142 for possible_field in ('names', 'mesh_axes', 'axes_types'):
6243 axes = getattr(boxed, possible_field, None)
6344- if isinstance(axes, (list, tuple)):
45+ - axes = update_sharding(
46+ - axes, shape=shape, split=split, merge=merge, transpose=transpose
47+ - )
6448+ if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
65- axes = update_sharding(
66- axes, shape=shape, split=split, merge=merge, transpose=transpose
67- )
68- @@ -396,11 +405,13 @@ def update_boxed(
49+ + axes = update_sharding(axes, shape=shape, split=split, merge=merge, transpose=transpose)
50+ boxed = dataclasses.replace(boxed, **{possible_field: axes})
51+ elif isinstance(boxed, nnx.Variable):
52+ if value is not None:
53+ @@ -396,10 +399,9 @@ def update_boxed(
6954 else:
7055 sharding_key = 'sharding_names'
7156 axes = metadata.get(sharding_key, None)
7257- if isinstance(axes, (list, tuple)):
7358- axes = update_sharding(
59+ - axes, shape=shape, split=split, merge=merge, transpose=transpose
60+ - )
61+ +
7462+ if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
75- + updated_axes = update_sharding(
76- axes, shape=shape, split=split, merge=merge, transpose=transpose
77- )
78- - boxed.set_metadata(sharding_key, axes)
79- + # Avoid mutating metadata unless sharding actually changed.
80- + if axes != updated_axes:
81- + boxed.set_metadata(sharding_key, updated_axes)
63+ + axes = update_sharding(axes, shape=shape, split=split, merge=merge, transpose=transpose)
64+ boxed.set_metadata(sharding_key, axes)
8265 elif isinstance(boxed, jax.Array): # not boxed.
8366 if value is not None:
84- boxed = value
8567diff --git a/qwix/_src/providers/lora.py b/qwix/_src/providers/lora.py
86- index e98c833..07be623 100644
68+ index e98c833..39ce3ef 100644
8769--- a/qwix/_src/providers/lora.py
8870+++ b/qwix/_src/providers/lora.py
8971@@ -13,6 +13,7 @@
@@ -94,7 +76,7 @@ index e98c833..07be623 100644
9476 import string
9577 from typing import Any, Callable, Collection, Sequence
9678 import warnings
97- @@ -189,29 +190,87 @@ class LoraProvider(ptq.PtqProvider):
79+ @@ -189,29 +190,68 @@ class LoraProvider(ptq.PtqProvider):
9880 if weight_name is None: # rhs is not a weight.
9981 return res
10082
@@ -107,25 +89,6 @@ index e98c833..07be623 100644
10789+ (contract_lhs, contract_rhs) = dimension_numbers[0]
10890+ (batch_lhs, batch_rhs) = dimension_numbers[1]
10991+
110- + if len(rhs.shape) == 2 and not batch_rhs:
111- + # Standard LoRA path for ...a,ab->...b
112- + lora_a, lora_b = _get_or_create_lora_params(
113- + name=weight_name,
114- + rule=rule,
115- + a_shape=(rhs.shape[0], rule.rank),
116- + b_shape=(rule.rank, rhs.shape[1]),
117- + a_sharding_transpose=(0, None),
118- + b_sharding_transpose=(None, 1),
119- + )
120- +
121- + if rule.dropout > 0:
122- + lhs = nnx.Dropout(rule.dropout, deterministic=False)(
123- + lhs, rngs=flax_util.make_rng('dropout')
124- + )
125- +
126- + return res + lhs @ lora_a @ lora_b * (rule.alpha / rule.rank)
127- +
128- + # General LoRA path for N-D kernels and batch dimensions.
12992+ # Identify contracting, batch, and out axes for rhs.
13093+ contract_rhs = tuple(contract_rhs)
13194+ batch_rhs = tuple(batch_rhs)
0 commit comments