Skip to content

Commit 5c07324

Browse files
hsuan-lun-chiangecnal-cienet
authored andcommitted
Fix diloco related unit tests
1 parent 63189e1 commit 5c07324

4 files changed

Lines changed: 76 additions & 47 deletions

File tree

src/maxtext/trainers/diloco/diloco.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import optax
3636

3737
from maxtext.configs import pyconfig
38+
from maxtext.layers.train_state_nnx import TrainStateNNX
3839

3940
Batch = Any
4041
Params = PyTree
@@ -157,8 +158,10 @@ def add_diloco_dim(x):
157158
# For NNX, model params (Param variables only) live under abstract_state.model;
158159
# for Linen under abstract_state.params.
159160
if config.pure_nnx:
160-
model_params = abstract_state.model.filter(nnx.Param)
161-
model_params_sharding = state_mesh_shardings.model.filter(nnx.Param)
161+
_, model_params, _ = nnx.split(abstract_state.model, nnx.Param, ...)
162+
model_params = model_params.to_pure_dict()
163+
_, model_params_sharding, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...)
164+
model_params_sharding = model_params_sharding.to_pure_dict()
162165
else:
163166
model_params = abstract_state.params
164167
model_params_sharding = state_mesh_shardings.params
@@ -262,9 +265,11 @@ def synchronize(state):
262265
# state (since last synchronization).
263266
broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh)
264267
# For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params.
265-
inner_model_params = (
266-
nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params
267-
)
268+
if config.pure_nnx:
269+
_, inner_model_params, _ = nnx.split(state.inner_state.model, nnx.Param, ...)
270+
inner_model_params = inner_model_params.to_pure_dict()
271+
else:
272+
inner_model_params = state.inner_state.params
268273
model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params)
269274
# Treat the average delta as the outer optimizer's gradient and apply to
270275
# the global (outer) model params.
@@ -277,15 +282,25 @@ def synchronize(state):
277282
if config.pure_nnx:
278283
# For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state).
279284
def replace_nnx_model_params(s, new_params):
280-
non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param))
281-
new_model = nnx.merge_state(non_param_model, new_params)
282-
# Assign via __setitem__ so nested States are stored as plain dicts (matching
283-
# nnx.state()'s pytree structure). The dict-literal constructor keeps them as
284-
# State objects, which makes jax.lax.cond see mismatched pytree structures.
285-
result = type(s)({})
286-
result["model"] = new_model
287-
result["optimizer"] = s["optimizer"]
288-
return result
285+
s_model = s["model"] if hasattr(s, "keys") else s.model
286+
s_opt = s["optimizer"] if hasattr(s, "keys") else s.optimizer
287+
288+
graphdef, _, non_param_state = nnx.split(s_model, nnx.Param, ...)
289+
new_model = nnx.merge(graphdef, new_params, non_param_state)
290+
291+
if type(s_model).__name__ == "State":
292+
new_model = nnx.state(new_model)
293+
elif isinstance(s_model, dict):
294+
new_model = nnx.to_pure_dict(new_model)
295+
296+
if hasattr(s, "keys"):
297+
leaves, treedef = jax.tree_util.tree_flatten(s)
298+
new_model_leaves, _ = jax.tree_util.tree_flatten(new_model)
299+
N = len(new_model_leaves)
300+
new_leaves = new_model_leaves + leaves[N:]
301+
return jax.tree_util.tree_unflatten(treedef, new_leaves)
302+
else:
303+
return TrainStateNNX(new_model, s_opt)
289304

290305
new_inner_state = drjax.map_fn(
291306
lambda s: replace_nnx_model_params(s, new_outer_params),

src/maxtext/utils/train_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,12 +295,13 @@ def create_train_state_fn():
295295
state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state, mesh=mesh)
296296

297297
# create state_mesh_shardings for the DilocoState
298+
step_mesh = state_mesh_shardings.optimizer.step.mesh if config.pure_nnx else state_mesh_shardings.step.mesh
298299
inner_state_shardings = diloco.add_diloco_to_sharding(state_mesh_shardings)
299300
state_mesh_shardings = diloco.DiLoCoTrainState(
300301
inner_state_shardings,
301-
state_mesh_shardings.params,
302+
state_mesh_shardings_params,
302303
outer_opt_state_sharding,
303-
jax.sharding.NamedSharding(mesh=state_mesh_shardings.step.mesh, spec=jax.sharding.PartitionSpec()),
304+
jax.sharding.NamedSharding(mesh=step_mesh, spec=jax.sharding.PartitionSpec()),
304305
)
305306

306307
# TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal

tests/integration/diloco_test.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import pytest
3131

3232
from maxtext.configs.pyconfig import initialize_pydantic
33+
from maxtext.layers.train_state_nnx import TrainStateNNX
3334
from maxtext.trainers.pre_train.train_compile import main as train_compile_main
3435
from maxtext.trainers.diloco import diloco
3536
from tests.utils.test_helpers import get_test_config_path
@@ -86,51 +87,49 @@ def test_diloco_training_simulation_with_mesh(self):
8687
graphdef, params = nnx.split(model)
8788

8889
if test_config.pure_nnx:
89-
from maxtext.layers.train_state_nnx import TrainStateNNX
9090
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
91-
92-
# We must split the state so drjax can broadcast it (drjax needs pure dicts/pytrees)
93-
_, initial_test_state_dict = nnx.split(TrainStateNNX(model, optimizer))
94-
# But wait, diloco_test_state expects a TrainStateNNX instance if pure_nnx is True!
91+
# diloco_test_state expects a TrainStateNNX instance when pure_nnx is True.
9592
initial_test_state = TrainStateNNX(model, optimizer)
9693

9794
# For NNX, train_step needs to take the TrainStateNNX and mutate it
95+
9896
def _test_train_step(state, batch, prng_key: diloco.PRNGKey):
9997
del prng_key
98+
10099
def loss_fn(model, batch):
101100
inputs, labels = batch
102101
logits = jax.vmap(model)(inputs)
103102
residual = logits - labels
104103
return jnp.mean(jnp.square(residual))
105-
104+
106105
loss, grads = nnx.value_and_grad(loss_fn)(state.model, batch)
107106
state.optimizer.update(state.model, grads)
108-
state.optimizer.step.value += 1
109107
return state, loss
110108

111109
else:
110+
112111
def nnx_apply_fn(params, inputs):
113112
model_replica = nnx.merge(graphdef, params)
114113
return model_replica(inputs)
115-
114+
116115
# 2. Vmap this new wrapper function
117116
vmapped_apply = jax.vmap(nnx_apply_fn, in_axes=(None, 0))
118-
117+
119118
def _test_train_step(state: train_state.TrainState, batch, prng_key: diloco.PRNGKey):
120119
"""A simple MSE loss train step to enable numerics testing."""
121120
del prng_key
122-
121+
123122
def loss_fn(params, batch):
124123
inputs, labels = batch
125124
logits = vmapped_apply(params, inputs)
126125
residual = logits - labels
127126
sq_residual = jnp.square(residual)
128127
msq_residual = jnp.mean(sq_residual)
129128
return msq_residual
130-
129+
131130
loss, grad = jax.value_and_grad(loss_fn)(state.params, batch)
132131
return state.apply_gradients(grads=grad), loss
133-
132+
134133
initial_test_state = train_state.TrainState.create(
135134
apply_fn=vmapped_apply,
136135
params=params,
@@ -141,12 +140,11 @@ def loss_fn(params, batch):
141140
chex.assert_equal(diloco_test_state.step, 0)
142141
if test_config.pure_nnx:
143142
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
144-
143+
145144
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
146145
# We need to unwrap them if they do.
147146
diloco_params_pure = jax.tree_util.tree_map(
148-
lambda x: x.value if hasattr(x, 'value') else x,
149-
diloco_test_state.params
147+
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
150148
)
151149
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
152150
else:
@@ -200,12 +198,11 @@ def loss_fn(params, batch):
200198
# Assert no updates to the global model yet (no synchronization)
201199
if test_config.pure_nnx:
202200
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
203-
201+
204202
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
205203
# We need to unwrap them if they do.
206204
diloco_params_pure = jax.tree_util.tree_map(
207-
lambda x: x.value if hasattr(x, 'value') else x,
208-
diloco_test_state.params
205+
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
209206
)
210207
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
211208
else:
@@ -241,12 +238,11 @@ def loss_fn(params, batch):
241238
# Assert no updates to the global model yet (no synchronization)
242239
if test_config.pure_nnx:
243240
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
244-
241+
245242
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
246243
# We need to unwrap them if they do.
247244
diloco_params_pure = jax.tree_util.tree_map(
248-
lambda x: x.value if hasattr(x, 'value') else x,
249-
diloco_test_state.params
245+
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
250246
)
251247
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
252248
else:
@@ -285,14 +281,31 @@ def loss_fn(params, batch):
285281
chex.assert_trees_all_close(loss, 0.4481)
286282
# Assert that inner and outer parameters are all equal now that
287283
# synchronization has happened.
288-
chex.assert_trees_all_equal(
289-
diloco_test_state.params,
290-
jax.tree.map(lambda arr: arr[0, ...], diloco_test_state.inner_state.params),
291-
)
292-
chex.assert_trees_all_equal(
293-
diloco_test_state.params,
294-
jax.tree.map(lambda arr: arr[1, ...], diloco_test_state.inner_state.params),
295-
)
284+
if test_config.pure_nnx:
285+
_, inner_params, _ = nnx.split(diloco_test_state.inner_state.model, nnx.Param, ...)
286+
inner_params_pure = jax.tree_util.tree_map(
287+
lambda x: x.value if hasattr(x, "value") else x, inner_params.to_pure_dict()
288+
)
289+
diloco_params_pure_3 = jax.tree_util.tree_map(
290+
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
291+
)
292+
chex.assert_trees_all_equal(
293+
diloco_params_pure_3,
294+
jax.tree.map(lambda arr: arr[0, ...], inner_params_pure),
295+
)
296+
chex.assert_trees_all_equal(
297+
diloco_params_pure_3,
298+
jax.tree.map(lambda arr: arr[1, ...], inner_params_pure),
299+
)
300+
else:
301+
chex.assert_trees_all_equal(
302+
diloco_test_state.params,
303+
jax.tree.map(lambda arr: arr[0, ...], diloco_test_state.inner_state.params),
304+
)
305+
chex.assert_trees_all_equal(
306+
diloco_test_state.params,
307+
jax.tree.map(lambda arr: arr[1, ...], diloco_test_state.inner_state.params),
308+
)
296309

297310
# Run the fourth step (no synchronization).
298311
# Replica 0:

tests/unit/max_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ def test_unscan_train_state_params(self):
206206

207207
# Check that the original state is unchanged.
208208

209-
if hasattr(state, 'model'):
209+
if hasattr(state, "model"):
210210
_, params_state, _ = nnx.split(state.model, nnx.Param, ...)
211211
state_decoder_params = params_state.to_pure_dict()["decoder"]
212212
self.assertIn("layers", state_decoder_params)
213213
else:
214214
self.assertIn("layers", state.params["params"]["decoder"])
215215

216-
if hasattr(state, 'model'):
216+
if hasattr(state, "model"):
217217
self.assertNotIn("layers_0", state_decoder_params)
218218
else:
219219
self.assertNotIn("layers_0", state.params["params"]["decoder"])

0 commit comments

Comments
 (0)