@@ -274,30 +274,45 @@ def wrt_filter(path, x):
274274 # Inherits _shard_optimizer from PeftTrainer.
275275
276276 def _train_step (self , model , optimizer , inputs ):
277- """Overrides the main JIT block to natively handle ModelBundle module."""
277+ """Overrides the main JIT block to natively handle ModelBundle module.
278278
279+ Uses jax.value_and_grad with explicit split/merge to avoid nesting
280+ nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign
281+ conflicting outer_index values and raises:
282+ ValueError: The graph structure of a node added to cached_partial was
283+ mutated inside the transformation.
284+ """
279285 batch = self .gen_model_input_fn (inputs )
286+ student = model .student_model
287+ teacher = model .teacher_model
280288 current_step = model .training_step [...]
281289
282- def loss_wrapper (student , teacher , batch ):
283- if "teacher_output" in batch :
284- teacher_output = batch ["teacher_output" ]
285- else :
286- teacher_output = self .strategy .teacher_forward_fn (
287- model = teacher ,
288- input_tokens = batch ["input_tokens" ],
289- positions = batch ["positions" ],
290- attention_mask = batch .get ("attention_mask" ),
291- decoder_segment_ids = batch .get ("decoder_segment_ids" ),
292- decoder_target_tokens = batch .get ("targets" , None ),
293- decoder_target_mask = batch .get ("targets_segmentation" , None ),
294- cache = None ,
295- )
290+ # Run teacher inference outside of value_and_grad.
291+ # The teacher is frozen (stop_gradient), so its output is a constant
292+ # from the perspective of the student gradient computation.
293+ if "teacher_output" in batch :
294+ teacher_output = batch ["teacher_output" ]
295+ else :
296+ teacher_output = self .strategy .teacher_forward_fn (
297+ model = teacher ,
298+ input_tokens = batch ["input_tokens" ],
299+ positions = batch ["positions" ],
300+ attention_mask = batch .get ("attention_mask" ),
301+ decoder_segment_ids = batch .get ("decoder_segment_ids" ),
302+ decoder_target_tokens = batch .get ("targets" , None ),
303+ decoder_target_mask = batch .get ("targets_segmentation" , None ),
304+ cache = None ,
305+ )
306+ teacher_output = jax .tree .map (jax .lax .stop_gradient , teacher_output )
296307
297- teacher_output = jax .tree .map (jax .lax .stop_gradient , teacher_output )
308+ # Split student into differentiable params and non-differentiable rest.
309+ # Capture graphdef outside of jax.value_and_grad for stable graph tracking.
310+ student_graphdef , diff_params , rest = nnx .split (student , self .wrt_filter , ...)
298311
312+ def loss_wrapper_pure (diff_params , rest ):
313+ local_student = nnx .merge (student_graphdef , diff_params , rest , copy = True )
299314 student_output = self .strategy .student_forward_fn (
300- model = student ,
315+ model = local_student ,
301316 input_tokens = batch ["input_tokens" ],
302317 positions = batch ["positions" ],
303318 attention_mask = batch .get ("attention_mask" ),
@@ -306,29 +321,26 @@ def loss_wrapper(student, teacher, batch):
306321 decoder_target_mask = batch .get ("targets_segmentation" , None ),
307322 cache = None ,
308323 )
309- # we should apply a mask for labels to disable segment-separator tokens
310324 labels = self .strategy .create_labels (batch ["targets" ], targets_segmentation = batch .get ("targets_segmentation" , None ))
311- return self .strategy .compute_loss (student_output , teacher_output , labels , step = current_step )
312-
313- # Because student is the 0th argument, argnums=0 guarantees
314- # we only compute gradients for the student.
315- grad_fn = nnx .value_and_grad (
316- loss_wrapper ,
317- argnums = nnx .DiffState (0 , self .wrt_filter ),
318- has_aux = True ,
319- )
325+ loss , aux = self .strategy .compute_loss (student_output , teacher_output , labels , step = current_step )
326+ # Capture updated non-param state (e.g. RNG counters) from local_student.
327+ _ , _ , new_rest = nnx .split (local_student , self .wrt_filter , ...)
328+ return loss , (aux , new_rest )
320329
321- out , grads = grad_fn (model .student_model , model .teacher_model , batch )
330+ grad_fn = jax .value_and_grad (loss_wrapper_pure , argnums = 0 , has_aux = True )
331+ (loss , (aux , new_rest )), grads = grad_fn (diff_params , rest )
322332
323- model .training_step .set_value (current_step + 1 )
333+ # Propagate updated non-param state back to student.
334+ nnx .update (student , new_rest )
324335
325- tunix_expects_grad_norm = getattr ( self , "_tunix_expects_grad_norm" , True )
336+ optimizer . update ( student , grads )
326337
327- optimizer . update ( model .student_model , grads )
338+ model .training_step . set_value ( current_step + 1 )
328339
340+ tunix_expects_grad_norm = getattr (self , "_tunix_expects_grad_norm" , True )
329341 if tunix_expects_grad_norm :
330- return out [ 0 ], out [ 1 ] , optax .global_norm (grads )
331- return out [ 0 ], out [ 1 ]
342+ return loss , aux , optax .global_norm (grads )
343+ return loss , aux
332344
333345 def _eval_step (self , model , inputs ):
334346 """Evaluation only needs the student."""
0 commit comments