@@ -85,37 +85,72 @@ def test_diloco_training_simulation_with_mesh(self):
8585 model = SimpleNNXModel (rngs = rngs )
8686 graphdef , params = nnx .split (model )
8787
88- def nnx_apply_fn (params , inputs ):
89- model_replica = nnx .merge (graphdef , params )
90- return model_replica (inputs )
88+ if test_config .pure_nnx :
89+ from maxtext .layers .train_state_nnx import TrainStateNNX
90+ optimizer = nnx .Optimizer (model , tx , wrt = nnx .Param )
91+
92+ # We must split the state so drjax can broadcast it (drjax needs pure dicts/pytrees)
93+ _ , initial_test_state_dict = nnx .split (TrainStateNNX (model , optimizer ))
94+ # But wait, diloco_test_state expects a TrainStateNNX instance if pure_nnx is True!
95+ initial_test_state = TrainStateNNX (model , optimizer )
9196
92- # 2. Vmap this new wrapper function
93- vmapped_apply = jax .vmap (nnx_apply_fn , in_axes = (None , 0 ))
97+ # For NNX, train_step needs to take the TrainStateNNX and mutate it
98+ def _test_train_step (state , batch , prng_key : diloco .PRNGKey ):
99+ del prng_key
100+ def loss_fn (model , batch ):
101+ inputs , labels = batch
102+ logits = jax .vmap (model )(inputs )
103+ residual = logits - labels
104+ return jnp .mean (jnp .square (residual ))
105+
106+ loss , grads = nnx .value_and_grad (loss_fn )(state .model , batch )
107+ state .optimizer .update (state .model , grads )
108+ state .optimizer .step .value += 1
109+ return state , loss
94110
95- def _test_train_step (state : train_state .TrainState , batch , prng_key : diloco .PRNGKey ):
96- """A simple MSE loss train step to enable numerics testing."""
97- del prng_key
98-
99- def loss_fn (params , batch ):
100- inputs , labels = batch
101- logits = vmapped_apply (params , inputs )
102- residual = logits - labels
103- sq_residual = jnp .square (residual )
104- msq_residual = jnp .mean (sq_residual )
105- return msq_residual
106-
107- loss , grad = jax .value_and_grad (loss_fn )(state .params , batch )
108- return state .apply_gradients (grads = grad ), loss
109-
110- initial_test_state = train_state .TrainState .create (
111- apply_fn = vmapped_apply ,
112- params = params ,
113- tx = tx ,
114- )
111+ else :
112+ def nnx_apply_fn (params , inputs ):
113+ model_replica = nnx .merge (graphdef , params )
114+ return model_replica (inputs )
115+
116+ # 2. Vmap this new wrapper function
117+ vmapped_apply = jax .vmap (nnx_apply_fn , in_axes = (None , 0 ))
118+
119+ def _test_train_step (state : train_state .TrainState , batch , prng_key : diloco .PRNGKey ):
120+ """A simple MSE loss train step to enable numerics testing."""
121+ del prng_key
122+
123+ def loss_fn (params , batch ):
124+ inputs , labels = batch
125+ logits = vmapped_apply (params , inputs )
126+ residual = logits - labels
127+ sq_residual = jnp .square (residual )
128+ msq_residual = jnp .mean (sq_residual )
129+ return msq_residual
130+
131+ loss , grad = jax .value_and_grad (loss_fn )(state .params , batch )
132+ return state .apply_gradients (grads = grad ), loss
133+
134+ initial_test_state = train_state .TrainState .create (
135+ apply_fn = vmapped_apply ,
136+ params = params ,
137+ tx = tx ,
138+ )
115139
116140 diloco_test_state , _ = diloco .build_diloco_state (test_config , lambda : initial_test_state )
117141 chex .assert_equal (diloco_test_state .step , 0 )
118- chex .assert_trees_all_equal (diloco_test_state .params , initial_test_state .params )
142+ if test_config .pure_nnx :
143+ _ , params_pure , _ = nnx .split (initial_test_state .model , nnx .Param , ...)
144+
145+ # diloco_test_state.params might contain nnx.Variables instead of pure arrays.
146+ # We need to unwrap them if they do.
147+ diloco_params_pure = jax .tree_util .tree_map (
148+ lambda x : x .value if hasattr (x , 'value' ) else x ,
149+ diloco_test_state .params
150+ )
151+ chex .assert_trees_all_equal (diloco_params_pure , params_pure .to_pure_dict ())
152+ else :
153+ chex .assert_trees_all_equal (diloco_test_state .params , initial_test_state .params )
119154
120155 diloco_train_step = diloco .build_diloco_train_step (test_config , _test_train_step )
121156 inputs = jnp .array (
@@ -163,7 +198,18 @@ def loss_fn(params, batch):
163198 chex .assert_equal (diloco_test_state .step , 1.0 )
164199 chex .assert_equal (loss , 1.0 )
165200 # Assert no updates to the global model yet (no synchronization)
166- chex .assert_trees_all_equal (diloco_test_state .params , initial_test_state .params )
201+ if test_config .pure_nnx :
202+ _ , params_pure , _ = nnx .split (initial_test_state .model , nnx .Param , ...)
203+
204+ # diloco_test_state.params might contain nnx.Variables instead of pure arrays.
205+ # We need to unwrap them if they do.
206+ diloco_params_pure = jax .tree_util .tree_map (
207+ lambda x : x .value if hasattr (x , 'value' ) else x ,
208+ diloco_test_state .params
209+ )
210+ chex .assert_trees_all_equal (diloco_params_pure , params_pure .to_pure_dict ())
211+ else :
212+ chex .assert_trees_all_equal (diloco_test_state .params , initial_test_state .params )
167213
168214 # Run the second step (no synchronization).
169215 # Replica 0:
@@ -193,7 +239,18 @@ def loss_fn(params, batch):
193239 chex .assert_equal (diloco_test_state .step , 2.0 )
194240 chex .assert_trees_all_close (loss , 0.65 )
195241 # Assert no updates to the global model yet (no synchronization)
196- chex .assert_trees_all_equal (diloco_test_state .params , initial_test_state .params )
242+ if test_config .pure_nnx :
243+ _ , params_pure , _ = nnx .split (initial_test_state .model , nnx .Param , ...)
244+
245+ # diloco_test_state.params might contain nnx.Variables instead of pure arrays.
246+ # We need to unwrap them if they do.
247+ diloco_params_pure = jax .tree_util .tree_map (
248+ lambda x : x .value if hasattr (x , 'value' ) else x ,
249+ diloco_test_state .params
250+ )
251+ chex .assert_trees_all_equal (diloco_params_pure , params_pure .to_pure_dict ())
252+ else :
253+ chex .assert_trees_all_equal (diloco_test_state .params , initial_test_state .params )
197254
198255 # Run the third step, which synchronizes afterwards.
199256 # Replica 0:
0 commit comments