3030import pytest
3131
3232from maxtext .configs .pyconfig import initialize_pydantic
33+ from maxtext .layers .train_state_nnx import TrainStateNNX
3334from maxtext .trainers .pre_train .train_compile import main as train_compile_main
3435from maxtext .trainers .diloco import diloco
3536from tests .utils .test_helpers import get_test_config_path
@@ -86,51 +87,49 @@ def test_diloco_training_simulation_with_mesh(self):
8687 graphdef , params = nnx .split (model )
8788
8889 if test_config .pure_nnx :
89- from maxtext .layers .train_state_nnx import TrainStateNNX
9090 optimizer = nnx .Optimizer (model , tx , wrt = nnx .Param )
91-
92- # We must split the state so drjax can broadcast it (drjax needs pure dicts/pytrees)
93- _ , initial_test_state_dict = nnx .split (TrainStateNNX (model , optimizer ))
94- # But wait, diloco_test_state expects a TrainStateNNX instance if pure_nnx is True!
91+ # diloco_test_state expects a TrainStateNNX instance when pure_nnx is True.
9592 initial_test_state = TrainStateNNX (model , optimizer )
9693
9794 # For NNX, train_step needs to take the TrainStateNNX and mutate it
95+
9896 def _test_train_step (state , batch , prng_key : diloco .PRNGKey ):
9997 del prng_key
98+
10099 def loss_fn (model , batch ):
101100 inputs , labels = batch
102101 logits = jax .vmap (model )(inputs )
103102 residual = logits - labels
104103 return jnp .mean (jnp .square (residual ))
105-
104+
106105 loss , grads = nnx .value_and_grad (loss_fn )(state .model , batch )
107106 state .optimizer .update (state .model , grads )
108- state .optimizer .step .value += 1
109107 return state , loss
110108
111109 else :
110+
112111 def nnx_apply_fn (params , inputs ):
113112 model_replica = nnx .merge (graphdef , params )
114113 return model_replica (inputs )
115-
114+
116115 # 2. Vmap this new wrapper function
117116 vmapped_apply = jax .vmap (nnx_apply_fn , in_axes = (None , 0 ))
118-
117+
119118 def _test_train_step (state : train_state .TrainState , batch , prng_key : diloco .PRNGKey ):
120119 """A simple MSE loss train step to enable numerics testing."""
121120 del prng_key
122-
121+
123122 def loss_fn (params , batch ):
124123 inputs , labels = batch
125124 logits = vmapped_apply (params , inputs )
126125 residual = logits - labels
127126 sq_residual = jnp .square (residual )
128127 msq_residual = jnp .mean (sq_residual )
129128 return msq_residual
130-
129+
131130 loss , grad = jax .value_and_grad (loss_fn )(state .params , batch )
132131 return state .apply_gradients (grads = grad ), loss
133-
132+
134133 initial_test_state = train_state .TrainState .create (
135134 apply_fn = vmapped_apply ,
136135 params = params ,
@@ -141,12 +140,11 @@ def loss_fn(params, batch):
141140 chex .assert_equal (diloco_test_state .step , 0 )
142141 if test_config .pure_nnx :
143142 _ , params_pure , _ = nnx .split (initial_test_state .model , nnx .Param , ...)
144-
143+
145144 # diloco_test_state.params might contain nnx.Variables instead of pure arrays.
146145 # We need to unwrap them if they do.
147146 diloco_params_pure = jax .tree_util .tree_map (
148- lambda x : x .value if hasattr (x , 'value' ) else x ,
149- diloco_test_state .params
147+ lambda x : x .value if hasattr (x , "value" ) else x , diloco_test_state .params
150148 )
151149 chex .assert_trees_all_equal (diloco_params_pure , params_pure .to_pure_dict ())
152150 else :
@@ -200,12 +198,11 @@ def loss_fn(params, batch):
200198 # Assert no updates to the global model yet (no synchronization)
201199 if test_config .pure_nnx :
202200 _ , params_pure , _ = nnx .split (initial_test_state .model , nnx .Param , ...)
203-
201+
204202 # diloco_test_state.params might contain nnx.Variables instead of pure arrays.
205203 # We need to unwrap them if they do.
206204 diloco_params_pure = jax .tree_util .tree_map (
207- lambda x : x .value if hasattr (x , 'value' ) else x ,
208- diloco_test_state .params
205+ lambda x : x .value if hasattr (x , "value" ) else x , diloco_test_state .params
209206 )
210207 chex .assert_trees_all_equal (diloco_params_pure , params_pure .to_pure_dict ())
211208 else :
@@ -241,12 +238,11 @@ def loss_fn(params, batch):
241238 # Assert no updates to the global model yet (no synchronization)
242239 if test_config .pure_nnx :
243240 _ , params_pure , _ = nnx .split (initial_test_state .model , nnx .Param , ...)
244-
241+
245242 # diloco_test_state.params might contain nnx.Variables instead of pure arrays.
246243 # We need to unwrap them if they do.
247244 diloco_params_pure = jax .tree_util .tree_map (
248- lambda x : x .value if hasattr (x , 'value' ) else x ,
249- diloco_test_state .params
245+ lambda x : x .value if hasattr (x , "value" ) else x , diloco_test_state .params
250246 )
251247 chex .assert_trees_all_equal (diloco_params_pure , params_pure .to_pure_dict ())
252248 else :
@@ -285,14 +281,31 @@ def loss_fn(params, batch):
285281 chex .assert_trees_all_close (loss , 0.4481 )
286282 # Assert that inner and outer parameters are all equal now that
287283 # synchronization has happened.
288- chex .assert_trees_all_equal (
289- diloco_test_state .params ,
290- jax .tree .map (lambda arr : arr [0 , ...], diloco_test_state .inner_state .params ),
291- )
292- chex .assert_trees_all_equal (
293- diloco_test_state .params ,
294- jax .tree .map (lambda arr : arr [1 , ...], diloco_test_state .inner_state .params ),
295- )
284+ if test_config .pure_nnx :
285+ _ , inner_params , _ = nnx .split (diloco_test_state .inner_state .model , nnx .Param , ...)
286+ inner_params_pure = jax .tree_util .tree_map (
287+ lambda x : x .value if hasattr (x , "value" ) else x , inner_params .to_pure_dict ()
288+ )
289+ diloco_params_pure_3 = jax .tree_util .tree_map (
290+ lambda x : x .value if hasattr (x , "value" ) else x , diloco_test_state .params
291+ )
292+ chex .assert_trees_all_equal (
293+ diloco_params_pure_3 ,
294+ jax .tree .map (lambda arr : arr [0 , ...], inner_params_pure ),
295+ )
296+ chex .assert_trees_all_equal (
297+ diloco_params_pure_3 ,
298+ jax .tree .map (lambda arr : arr [1 , ...], inner_params_pure ),
299+ )
300+ else :
301+ chex .assert_trees_all_equal (
302+ diloco_test_state .params ,
303+ jax .tree .map (lambda arr : arr [0 , ...], diloco_test_state .inner_state .params ),
304+ )
305+ chex .assert_trees_all_equal (
306+ diloco_test_state .params ,
307+ jax .tree .map (lambda arr : arr [1 , ...], diloco_test_state .inner_state .params ),
308+ )
296309
297310 # Run the fourth step (no synchronization).
298311 # Replica 0:
0 commit comments