2626from typing import Any , Callable
2727
2828import drjax
29+ from flax import nnx
2930from flax import struct
3031from flax .training import train_state
3132import jax
@@ -153,15 +154,23 @@ def add_diloco_dim(x):
153154 momentum = config .diloco_outer_momentum ,
154155 nesterov = True ,
155156 )
156- outer_opt_state = jax .eval_shape (outer_optimizer .init , abstract_state .params )
157+ # For NNX, model params (Param variables only) live under abstract_state.model;
158+ # for Linen under abstract_state.params.
159+ if config .pure_nnx :
160+ model_params = abstract_state .model .filter (nnx .Param )
161+ model_params_sharding = state_mesh_shardings .model .filter (nnx .Param )
162+ else :
163+ model_params = abstract_state .params
164+ model_params_sharding = state_mesh_shardings .params
165+ outer_opt_state = jax .eval_shape (outer_optimizer .init , model_params )
157166
158167 # Create abstract step
159168 abstract_step = jax .ShapeDtypeStruct ((), jnp .int32 )
160169
161170 # Build abstract DiLoCo state
162171 diloco_state = DiLoCoTrainState (
163172 inner_state = inner_state ,
164- params = abstract_state . params ,
173+ params = model_params ,
165174 outer_opt_state = outer_opt_state ,
166175 step = abstract_step ,
167176 )
@@ -171,12 +180,12 @@ def add_diloco_dim(x):
171180 # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState())
172181 # We shard the momentum trace the same way as the parameters.
173182 outer_opt_state_sharding = (
174- optax .TraceState (trace = state_mesh_shardings . params ),
183+ optax .TraceState (trace = model_params_sharding ),
175184 optax .EmptyState (),
176185 )
177186 diloco_state_shardings = DiLoCoTrainState (
178187 inner_state = inner_state_shardings ,
179- params = state_mesh_shardings . params ,
188+ params = model_params_sharding ,
180189 outer_opt_state = outer_opt_state_sharding ,
181190 step = None ,
182191 )
@@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
205214 # mesh automatically when jax.set_mesh is used.
206215 inner_state = drjax .broadcast (state , mesh = mesh )
207216 # Outer state retains a single copy of the model parameters and optimizer state.
208- outer_params = state .params
217+ # For NNX, model params (Param variables only) live under state.model;
218+ # for Linen under state.params.
219+ outer_params = state .model .filter (nnx .Param ) if config .pure_nnx else state .params
209220 outer_opt_state = outer_optimizer .init (outer_params )
210221 outer_opt_state_sharding = jax .tree_util .tree_map (lambda x : x .sharding , outer_opt_state )
222+ # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step.
223+ step = state .optimizer .step if config .pure_nnx else state .step
211224 return (
212- DiLoCoTrainState (inner_state = inner_state , params = outer_params , outer_opt_state = outer_opt_state , step = state . step ),
225+ DiLoCoTrainState (inner_state = inner_state , params = outer_params , outer_opt_state = outer_opt_state , step = step ),
213226 outer_opt_state_sharding ,
214227 )
215228
@@ -244,7 +257,11 @@ def synchronize(state):
244257 # Calculate the delta between the current replica's state and the global
245258 # state (since last synchronization).
246259 broadcast_outer_params = drjax .broadcast (state .params , mesh = mesh )
247- model_delta = jax .tree .map (lambda x , y : y - x , state .inner_state .params , broadcast_outer_params )
260+ # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params.
261+ inner_model_params = (
262+ nnx .filter_state (state .inner_state .model , nnx .Param ) if config .pure_nnx else state .inner_state .params
263+ )
264+ model_delta = jax .tree .map (lambda x , y : y - x , inner_model_params , broadcast_outer_params )
248265 # Treat the average delta as the outer optimizer's gradient and apply to
249266 # the global (outer) model params.
250267 averaged_pseudo_grad = drjax .reduce_mean (model_delta )
@@ -253,7 +270,27 @@ def synchronize(state):
253270 # Replace inner model params with the new global model params.
254271 # NOTE: inner optimizer state is retained despite the change in parameters,
255272 # see section 6.1 in https://arxiv.org/pdf/2311.08105.
256- new_inner_state = drjax .map_fn (lambda state : state .replace (params = new_outer_params ), state .inner_state , mesh = mesh )
273+ if config .pure_nnx :
274+ # For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state).
275+ def replace_nnx_model_params (s , new_params ):
276+ non_param_model = nnx .filter_state (s .model , nnx .Not (nnx .Param ))
277+ new_model = nnx .merge_state (non_param_model , new_params )
278+ # Build result via __setitem__ so nested States are stored as plain dicts
279+ # internally, matching the pytree structure produced by nnx.state().
280+ # (Passing State objects via the constructor dict literal stores them
281+ # as-is, causing jax.lax.cond to see mismatched pytree structures.)
282+ result = type (s )({})
283+ result ["model" ] = new_model
284+ result ["optimizer" ] = s ["optimizer" ]
285+ return result
286+
287+ new_inner_state = drjax .map_fn (
288+ lambda s : replace_nnx_model_params (s , new_outer_params ),
289+ state .inner_state ,
290+ mesh = mesh ,
291+ )
292+ else :
293+ new_inner_state = drjax .map_fn (lambda s : s .replace (params = new_outer_params ), state .inner_state , mesh = mesh )
257294 return state .replace (
258295 params = new_outer_params ,
259296 outer_opt_state = new_opt_state ,
@@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng):
271308 broadcast_rng = drjax .broadcast (prng , mesh = mesh )
272309 inner_state , metrics = drjax .map_fn (train_step , (state .inner_state , batch , broadcast_rng ), mesh = mesh )
273310 avg_metrics = typed_reduce_mean (metrics )
311+ # For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step.
312+ new_step = inner_state .optimizer .step [0 ] if config .pure_nnx else inner_state .step [0 ]
274313 state = state .replace (
275314 inner_state = inner_state ,
276- step = inner_state . step [ 0 ] ,
315+ step = new_step ,
277316 )
278317 # Either synchronize the model, or no-op, depending on whether the current
279318 # step falls on the synchronization period.
280319 state = jax .lax .cond (
281- inner_state . step [ 0 ] % config .diloco_sync_period == 0 ,
320+ new_step % config .diloco_sync_period == 0 ,
282321 synchronize ,
283322 lambda x : x , # no-op
284323 state ,
0 commit comments