@@ -267,30 +267,44 @@ def _safe_shard(x, pspec):
267267 nnx .update (self .optimizer , optimizer_sharded_state )
268268
269269 def _train_step (self , model , optimizer , inputs ):
270- """Overrides the main JIT block to natively handle ModelBundle module."""
270+ """Overrides the main JIT block to natively handle ModelBundle module.
271271
272+ Uses jax.value_and_grad with explicit split/merge to avoid nesting
273+ nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign
274+ conflicting outer_index values and raises:
275+ ValueError: The graph structure of a node added to cached_partial was
276+ mutated inside the transformation.
277+ """
272278 batch = self .gen_model_input_fn (inputs )
273- current_step = model .training_step .value
274-
275- def loss_wrapper (student , teacher , batch ):
276- if "teacher_output" in batch :
277- teacher_output = batch ["teacher_output" ]
278- else :
279- teacher_output = self .strategy .teacher_forward_fn (
280- model = teacher ,
281- input_tokens = batch ["input_tokens" ],
282- positions = batch ["positions" ],
283- attention_mask = batch .get ("attention_mask" ),
284- decoder_segment_ids = batch .get ("decoder_segment_ids" ),
285- decoder_target_tokens = batch .get ("targets" , None ),
286- decoder_target_mask = batch .get ("targets_segmentation" , None ),
287- cache = None ,
288- )
279+ student = model .student_model
280+ teacher = model .teacher_model
281+
282+ # Run teacher inference outside of value_and_grad.
283+ # The teacher is frozen (stop_gradient), so its output is a constant
284+ # from the perspective of the student gradient computation.
285+ if "teacher_output" in batch :
286+ teacher_output = batch ["teacher_output" ]
287+ else :
288+ teacher_output = self .strategy .teacher_forward_fn (
289+ model = teacher ,
290+ input_tokens = batch ["input_tokens" ],
291+ positions = batch ["positions" ],
292+ attention_mask = batch .get ("attention_mask" ),
293+ decoder_segment_ids = batch .get ("decoder_segment_ids" ),
294+ decoder_target_tokens = batch .get ("targets" , None ),
295+ decoder_target_mask = batch .get ("targets_segmentation" , None ),
296+ cache = None ,
297+ )
298+ teacher_output = jax .tree .map (jax .lax .stop_gradient , teacher_output )
289299
290- teacher_output = jax .tree .map (jax .lax .stop_gradient , teacher_output )
300+ # Split student into differentiable params and non-differentiable rest.
301+ # Capture graphdef outside of jax.value_and_grad for stable graph tracking.
302+ student_graphdef , diff_params , rest = nnx .split (student , self .wrt_filter , ...)
291303
304+ def loss_wrapper_pure (diff_params , rest ):
305+ local_student = nnx .merge (student_graphdef , diff_params , rest , copy = True )
292306 student_output = self .strategy .student_forward_fn (
293- model = student ,
307+ model = local_student ,
294308 input_tokens = batch ["input_tokens" ],
295309 positions = batch ["positions" ],
296310 attention_mask = batch .get ("attention_mask" ),
@@ -299,30 +313,27 @@ def loss_wrapper(student, teacher, batch):
299313 decoder_target_mask = batch .get ("targets_segmentation" , None ),
300314 cache = None ,
301315 )
302- # we should apply a mask for labels to disable segment-separator tokens
303316 labels = self .strategy .create_labels (batch ["targets" ], targets_segmentation = batch .get ("targets_segmentation" , None ))
304- return self .strategy .compute_loss (student_output , teacher_output , labels , step = current_step )
305-
306- # Because student is the 0th argument, argnums=0 guarantees
307- # we only compute gradients for the student.
308- grad_fn = nnx .value_and_grad (
309- loss_wrapper ,
310- argnums = nnx .DiffState (0 , self .wrt_filter ),
311- has_aux = True ,
312- )
317+ loss , aux = self .strategy .compute_loss (student_output , teacher_output , labels )
318+ # Capture updated non-param state (e.g. RNG counters) from local_student.
319+ _ , _ , new_rest = nnx .split (local_student , self .wrt_filter , ...)
320+ return loss , (aux , new_rest )
313321
314- out , grads = grad_fn (model .student_model , model .teacher_model , batch )
322+ grad_fn = jax .value_and_grad (loss_wrapper_pure , argnums = 0 , has_aux = True )
323+ (loss , (aux , new_rest )), grads = grad_fn (diff_params , rest )
324+
325+ # Propagate updated non-param state back to student.
326+ nnx .update (student , new_rest )
327+
328+ optimizer .update (student , grads )
315329
316330 # Increment step counter after loss computation
317331 model .training_step .value = current_step + 1
318332
319333 tunix_expects_grad_norm = getattr (self , "_tunix_expects_grad_norm" , True )
320-
321- optimizer .update (model .student_model , grads )
322-
323334 if tunix_expects_grad_norm :
324- return out [ 0 ], out [ 1 ] , optax .global_norm (grads )
325- return out [ 0 ], out [ 1 ]
335+ return loss , aux , optax .global_norm (grads )
336+ return loss , aux
326337
327338 def _eval_step (self , model , inputs ):
328339 """Evaluation only needs the student."""
0 commit comments