5454from maxtext .utils import model_creation_utils
5555
5656# Tunix Imports
57- from tunix .distillation import distillation_trainer
57+ from tunix .sft import peft_trainer
5858from tunix .sft import metrics_logger
5959from tunix .sft import profiler
6060
@@ -174,13 +174,99 @@ def _log_config_details(config: pyconfig.HyperParameters, label: str) -> None:
174174 max_logging .log (f" Checkpoint: { config .load_parameters_path } " )
175175
176176
177- class MaxTextDistillationTrainer (distillation_trainer .DistillationTrainer ):
177+ class ModelBundle (nnx .Module ):
178+ """Wrapper for teacher and student modules."""
179+
180+ def __init__ (self , teacher_model : nnx .Module , student_model : nnx .Module ):
181+ self .teacher_model = teacher_model
182+ self .student_model = student_model
183+
184+ def __call__ (self , * args , ** kwargs ):
185+ raise NotImplementedError ("Use `call_student` or `call_teacher` explicitly." )
186+
187+ def call_student (self , * args , ** kwargs ):
188+ return self .student_model (* args , ** kwargs )
189+
190+ def call_teacher (self , * args , ** kwargs ):
191+ return jax .lax .stop_gradient (self .teacher_model (* args , ** kwargs ))
192+
193+
194+ class MaxTextDistillationTrainer (peft_trainer .PeftTrainer ):
178195 """Custom Trainer to preserve MaxText fields and log Teacher metrics.
179196
180197 This class overrides `_prepare_inputs` to ensure MaxText-specific fields
181198 (positions, segment_ids) are passed to the model.
182199 """
183200
201+ def __init__ (self , model , strategy , optimizer , training_config , ** kwargs ):
202+ super ().__init__ (model = model , optimizer = optimizer , training_config = training_config , ** kwargs )
203+
204+ self .strategy = strategy
205+
206+ # override optimizer to only use student_model.
207+ wrt = nnx .LoRAParam if self ._lora_enabled else nnx .Param
208+ self .optimizer = nnx .Optimizer (model .student_model , optimizer , wrt = wrt )
209+
210+ def _train_step (self , model , optimizer , inputs ):
211+ """Overrides the main JIT block to natively handle ModelBundle module."""
212+
213+ batch = self .gen_model_input_fn (inputs )
214+
215+ def loss_wrapper (student , teacher , batch ):
216+ if "teacher_output" in batch :
217+ teacher_output = batch ["teacher_output" ]
218+ else :
219+ teacher_output = self .strategy .teacher_forward_fn (
220+ model = teacher ,
221+ input_tokens = batch ["input_tokens" ],
222+ positions = batch ["positions" ],
223+ attention_mask = batch .get ("attention_mask" ),
224+ decoder_segment_ids = batch .get ("decoder_segment_ids" ),
225+ cache = None ,
226+ )
227+
228+ teacher_output = jax .tree .map (jax .lax .stop_gradient , teacher_output )
229+
230+ student_output = self .strategy .student_forward_fn (
231+ model = student ,
232+ input_tokens = batch ["input_tokens" ],
233+ positions = batch ["positions" ],
234+ attention_mask = batch .get ("attention_mask" ),
235+ decoder_segment_ids = batch .get ("decoder_segment_ids" ),
236+ cache = None ,
237+ )
238+ labels = self .strategy .labels_fn (batch ["targets" ])
239+ return self .strategy .compute_loss (student_output , teacher_output , labels )
240+
241+ # Because student is the 0th argument, argnums=0 guarantees
242+ # we only compute gradients for the student.
243+ grad_fn = nnx .value_and_grad (
244+ loss_wrapper ,
245+ argnums = 0 ,
246+ has_aux = True ,
247+ )
248+
249+ out , grads = grad_fn (model .student_model , model .teacher_model , batch )
250+
251+ optimizer .update (model .student_model , grads )
252+
253+ return out [0 ], out [1 ] # loss, aux
254+
255+ def _eval_step (self , model , inputs ):
256+ """Evaluation only needs the student."""
257+ inputs = self .gen_model_input_fn (inputs )
258+
259+ student_output = self .strategy .student_forward_fn (
260+ model = model .student_model ,
261+ input_tokens = inputs ["input_tokens" ],
262+ positions = inputs ["positions" ],
263+ attention_mask = inputs .get ("attention_mask" ),
264+ decoder_segment_ids = inputs .get ("decoder_segment_ids" ),
265+ cache = None ,
266+ )
267+ labels = self .strategy .labels_fn (inputs ["targets" ])
268+ return self .strategy .compute_eval_loss (student_output , labels )
269+
184270 def _prepare_inputs (
185271 self , input_data : distillation_utils .MaxTextTrainingInput
186272 ) -> distillation_utils .MaxTextTrainingInput :
@@ -195,22 +281,12 @@ def _prepare_inputs(
195281 Returns:
196282 A new MaxTextTrainingInput containing the Teacher's outputs (logits).
197283 """
198- # 1. Generate inputs dictionary for the Teacher model
199- inputs = self .gen_model_input_fn (input_data )["inputs" ]
200-
201- if self ._mode == metrics_logger .Mode .EVAL :
202- teacher_output = None
203- else :
204- # 2. Run Teacher to get soft targets (logits)
205- # The strategy ensures these are stop_gradient-ed
206- teacher_output = self .strategy .get_teacher_outputs (self .teacher_model , inputs )
207284
208285 # 3. Return extended object so fields are available for Student training step
209286 # pylint: disable=unexpected-keyword-arg
210287 return distillation_utils .MaxTextTrainingInput (
211288 input_tokens = input_data .input_tokens ,
212289 input_mask = input_data .input_mask ,
213- teacher_output = teacher_output ,
214290 positions = input_data .positions ,
215291 decoder_segment_ids = input_data .decoder_segment_ids ,
216292 targets = input_data .targets ,
@@ -380,8 +456,6 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
380456 sft_mode = student_config .use_sft ,
381457 )
382458
383- student_model , teacher_model = strategy .pre_process_models (student_model , teacher_model )
384-
385459 # 4. Optimizer & Config
386460 optimizer = get_distillation_optimizer (student_config , student_config .steps )
387461
@@ -405,7 +479,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
405479 log_dir = student_config .tensorboard_dir , flush_every_n_steps = student_config .log_period
406480 )
407481
408- train_config = distillation_trainer .TrainingConfig (
482+ train_config = peft_trainer .TrainingConfig (
409483 max_steps = student_config .steps ,
410484 eval_every_n_steps = student_config .eval_interval ,
411485 metrics_logging_options = metrics_logging_options ,
@@ -419,10 +493,14 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
419493 max_logging .log ("Initializing Data Iterators via MaxText pipeline..." )
420494 raw_train_iter , raw_eval_iter = input_pipeline_interface .create_data_iterator (student_config , mesh )
421495
496+ teacher_model .eval ()
497+ student_model .train ()
498+
499+ model_bundle = ModelBundle (teacher_model , student_model )
500+
422501 # 6. Initialize Trainer
423502 trainer = MaxTextDistillationTrainer (
424- student_model = student_model ,
425- teacher_model = teacher_model ,
503+ model = model_bundle ,
426504 strategy = strategy ,
427505 optimizer = optimizer ,
428506 training_config = train_config ,
@@ -472,7 +550,10 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
472550 max_logging .log (f"Saving final checkpoint to { student_config .checkpoint_dir } ..." )
473551 try :
474552 saved = trainer .checkpoint_manager .save (
475- trainer .train_steps , trainer .model , save_only_lora_params = getattr (trainer , "_lora_enabled" , False ), force = True
553+ trainer .train_steps ,
554+ trainer .model .student_model ,
555+ save_only_lora_params = getattr (trainer , "_lora_enabled" , False ),
556+ force = True ,
476557 )
477558 if saved :
478559 # Ensure underlying orbax manager finishes writing
0 commit comments