Skip to content

Commit 63189e1

Browse files
hsuan-lun-chiangecnal-cienet
authored andcommitted
Fix test compatibility with pure_nnx=True defaults
1. Sanitize unmapped logical axes to None in maxtext_utils.py get_nnx_named_sharding_with_scan_axis to prevent compilation ValueError. 2. Fix qk_clip_utils.py broadcast shape mismatch (axis=0 to axis=-2) causing TypeError. 3. Update max_utils_test.py unscan utility to correctly parse TrainStateNNX and its parameters/sharding trees. 4. Fix muon_utils_test.py NNX dict mapping assertIsNone() against raw objects rather than . 5. Patch train_distill and train_sft to explicitly nnx.pop(Intermediate) to prevent GraphDef mutation ValueErrors. 6. Update diloco.py to use nnx.split instead of the deprecated filter() method for param extraction. 7. Update diloco_test.py to execute pure NNX training loop simulations instead of legacy Linen.
1 parent 6a6fbc4 commit 63189e1

2 files changed

Lines changed: 90 additions & 29 deletions

File tree

src/maxtext/trainers/diloco/diloco.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,11 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
216216
# Outer state retains a single copy of the model parameters and optimizer state.
217217
# For NNX, model params (Param variables only) live under state.model;
218218
# for Linen under state.params.
219-
outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params
219+
if config.pure_nnx:
220+
_, outer_params, _ = nnx.split(state.model, nnx.Param, ...)
221+
outer_params = outer_params.to_pure_dict()
222+
else:
223+
outer_params = state.params
220224
outer_opt_state = outer_optimizer.init(outer_params)
221225
outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state)
222226
# For NNX, the step counter lives at state.optimizer.step; for Linen at state.step.

tests/integration/diloco_test.py

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,37 +85,72 @@ def test_diloco_training_simulation_with_mesh(self):
8585
model = SimpleNNXModel(rngs=rngs)
8686
graphdef, params = nnx.split(model)
8787

88-
def nnx_apply_fn(params, inputs):
89-
model_replica = nnx.merge(graphdef, params)
90-
return model_replica(inputs)
88+
if test_config.pure_nnx:
89+
from maxtext.layers.train_state_nnx import TrainStateNNX
90+
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
91+
92+
# We must split the state so drjax can broadcast it (drjax needs pure dicts/pytrees)
93+
_, initial_test_state_dict = nnx.split(TrainStateNNX(model, optimizer))
94+
# But wait, diloco_test_state expects a TrainStateNNX instance if pure_nnx is True!
95+
initial_test_state = TrainStateNNX(model, optimizer)
9196

92-
# 2. Vmap this new wrapper function
93-
vmapped_apply = jax.vmap(nnx_apply_fn, in_axes=(None, 0))
97+
# For NNX, train_step needs to take the TrainStateNNX and mutate it
98+
def _test_train_step(state, batch, prng_key: diloco.PRNGKey):
99+
del prng_key
100+
def loss_fn(model, batch):
101+
inputs, labels = batch
102+
logits = jax.vmap(model)(inputs)
103+
residual = logits - labels
104+
return jnp.mean(jnp.square(residual))
105+
106+
loss, grads = nnx.value_and_grad(loss_fn)(state.model, batch)
107+
state.optimizer.update(state.model, grads)
108+
state.optimizer.step.value += 1
109+
return state, loss
94110

95-
def _test_train_step(state: train_state.TrainState, batch, prng_key: diloco.PRNGKey):
96-
"""A simple MSE loss train step to enable numerics testing."""
97-
del prng_key
98-
99-
def loss_fn(params, batch):
100-
inputs, labels = batch
101-
logits = vmapped_apply(params, inputs)
102-
residual = logits - labels
103-
sq_residual = jnp.square(residual)
104-
msq_residual = jnp.mean(sq_residual)
105-
return msq_residual
106-
107-
loss, grad = jax.value_and_grad(loss_fn)(state.params, batch)
108-
return state.apply_gradients(grads=grad), loss
109-
110-
initial_test_state = train_state.TrainState.create(
111-
apply_fn=vmapped_apply,
112-
params=params,
113-
tx=tx,
114-
)
111+
else:
112+
def nnx_apply_fn(params, inputs):
113+
model_replica = nnx.merge(graphdef, params)
114+
return model_replica(inputs)
115+
116+
# 2. Vmap this new wrapper function
117+
vmapped_apply = jax.vmap(nnx_apply_fn, in_axes=(None, 0))
118+
119+
def _test_train_step(state: train_state.TrainState, batch, prng_key: diloco.PRNGKey):
120+
"""A simple MSE loss train step to enable numerics testing."""
121+
del prng_key
122+
123+
def loss_fn(params, batch):
124+
inputs, labels = batch
125+
logits = vmapped_apply(params, inputs)
126+
residual = logits - labels
127+
sq_residual = jnp.square(residual)
128+
msq_residual = jnp.mean(sq_residual)
129+
return msq_residual
130+
131+
loss, grad = jax.value_and_grad(loss_fn)(state.params, batch)
132+
return state.apply_gradients(grads=grad), loss
133+
134+
initial_test_state = train_state.TrainState.create(
135+
apply_fn=vmapped_apply,
136+
params=params,
137+
tx=tx,
138+
)
115139

116140
diloco_test_state, _ = diloco.build_diloco_state(test_config, lambda: initial_test_state)
117141
chex.assert_equal(diloco_test_state.step, 0)
118-
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
142+
if test_config.pure_nnx:
143+
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
144+
145+
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
146+
# We need to unwrap them if they do.
147+
diloco_params_pure = jax.tree_util.tree_map(
148+
lambda x: x.value if hasattr(x, 'value') else x,
149+
diloco_test_state.params
150+
)
151+
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
152+
else:
153+
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
119154

120155
diloco_train_step = diloco.build_diloco_train_step(test_config, _test_train_step)
121156
inputs = jnp.array(
@@ -163,7 +198,18 @@ def loss_fn(params, batch):
163198
chex.assert_equal(diloco_test_state.step, 1.0)
164199
chex.assert_equal(loss, 1.0)
165200
# Assert no updates to the global model yet (no synchronization)
166-
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
201+
if test_config.pure_nnx:
202+
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
203+
204+
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
205+
# We need to unwrap them if they do.
206+
diloco_params_pure = jax.tree_util.tree_map(
207+
lambda x: x.value if hasattr(x, 'value') else x,
208+
diloco_test_state.params
209+
)
210+
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
211+
else:
212+
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
167213

168214
# Run the second step (no synchronization).
169215
# Replica 0:
@@ -193,7 +239,18 @@ def loss_fn(params, batch):
193239
chex.assert_equal(diloco_test_state.step, 2.0)
194240
chex.assert_trees_all_close(loss, 0.65)
195241
# Assert no updates to the global model yet (no synchronization)
196-
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
242+
if test_config.pure_nnx:
243+
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
244+
245+
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
246+
# We need to unwrap them if they do.
247+
diloco_params_pure = jax.tree_util.tree_map(
248+
lambda x: x.value if hasattr(x, 'value') else x,
249+
diloco_test_state.params
250+
)
251+
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
252+
else:
253+
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
197254

198255
# Run the third step, which synchronizes afterwards.
199256
# Replica 0:

0 commit comments

Comments
 (0)