@@ -51,9 +51,9 @@ def main():
5151 lora_rank = config .get ("lora_rank" , 16 )
5252 lora_scale = config .get ("lora_scale" , 2.0 )
5353
54- # Try mlx-tune first (has DPO/SimPO support), fall back to manual implementation
54+ # Try mlx-tune first (has DPO/SimPO support with MoE gradient handling)
5555 try :
56- from mlx_tune import SimPOTrainer , DPOTrainer , TrainingArguments
56+ import mlx_tune # noqa: F401 — just check availability
5757 _train_with_mlx_tune (
5858 model_path = model_path ,
5959 dataset_path = dataset_path ,
@@ -92,8 +92,128 @@ def main():
9292
9393
9494def _train_with_mlx_tune (** kwargs ):
95- """Train using the mlx-tune package (preferred, has native SimPO)."""
96- raise ImportError ("mlx-tune integration not yet wired" )
95+ """Train using the mlx-tune package (preferred, has native SimPO for MoE)."""
96+ from mlx_tune import (
97+ FastVisionModel ,
98+ SimPOConfig , SimPOTrainer ,
99+ DPOConfig , DPOTrainer ,
100+ )
101+
102+ model_path = kwargs ["model_path" ]
103+ dataset_path = kwargs ["dataset_path" ]
104+ adapter_out = kwargs ["adapter_out" ]
105+ resume_adapter = kwargs .get ("resume_adapter" )
106+ method = kwargs .get ("method" , "simpo" )
107+ beta = kwargs .get ("beta" , 0.1 )
108+ gamma = kwargs .get ("gamma" , 1.0 )
109+ learning_rate = kwargs .get ("learning_rate" , 1e-6 )
110+ batch_size = kwargs .get ("batch_size" , 1 )
111+ max_steps = kwargs .get ("max_steps" , 500 )
112+ max_seq_len = kwargs .get ("max_seq_len" , 2048 )
113+ lora_rank = kwargs .get ("lora_rank" , 16 )
114+ lora_scale = kwargs .get ("lora_scale" , 2.0 )
115+
116+ sys .stderr .write (f"Loading model via mlx-tune: { model_path } \n " )
117+ sys .stderr .flush ()
118+
119+ # Gemma 4 is treated as VLM in mlx-tune
120+ model , tokenizer = FastVisionModel .from_pretrained (model_path )
121+
122+ # Load SFT adapter weights if provided
123+ if resume_adapter :
124+ import mlx .core as mx
125+ adapter_file = os .path .join (resume_adapter , "adapters.safetensors" )
126+ if os .path .exists (adapter_file ):
127+ sys .stderr .write (f"Loading SFT adapter from { adapter_file } \n " )
128+ weights = mx .load (adapter_file )
129+ model .load_weights (list (weights .items ()))
130+
131+ # Prepare preference dataset from our JSONL format
132+ # mlx-tune expects {"prompt": ..., "chosen": ..., "rejected": ...}
133+ import json as _json
134+ pairs = []
135+ with open (dataset_path ) as f :
136+ for line in f :
137+ line = line .strip ()
138+ if line :
139+ pair = _json .loads (line )
140+ pairs .append ({
141+ "prompt" : pair ["prompt" ],
142+ "chosen" : pair ["chosen" ],
143+ "rejected" : pair ["rejected" ],
144+ })
145+
146+ sys .stderr .write (f"Loaded { len (pairs )} preference pairs\n " )
147+ sys .stderr .write (f"Method: { method } , beta={ beta } , gamma={ gamma } , lr={ learning_rate } \n " )
148+ sys .stderr .flush ()
149+
150+ # Write temp dataset file for mlx-tune
151+ import tempfile
152+ with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.jsonl' , delete = False ) as tf :
153+ for pair in pairs :
154+ tf .write (_json .dumps (pair ) + "\n " )
155+ temp_dataset = tf .name
156+
157+ try :
158+ if method == "simpo" :
159+ config = SimPOConfig (
160+ beta = beta ,
161+ gamma = gamma ,
162+ output_dir = adapter_out ,
163+ learning_rate = learning_rate ,
164+ per_device_train_batch_size = batch_size ,
165+ max_steps = max_steps ,
166+ max_seq_length = max_seq_len ,
167+ logging_steps = 10 ,
168+ save_steps = max_steps , # Save only at end
169+ warmup_steps = min (10 , max_steps // 10 ),
170+ )
171+ trainer = SimPOTrainer (
172+ model = model ,
173+ tokenizer = tokenizer ,
174+ args = config ,
175+ train_dataset = temp_dataset ,
176+ )
177+ else :
178+ config = DPOConfig (
179+ beta = beta ,
180+ output_dir = adapter_out ,
181+ learning_rate = learning_rate ,
182+ per_device_train_batch_size = batch_size ,
183+ max_steps = max_steps ,
184+ max_seq_length = max_seq_len ,
185+ logging_steps = 10 ,
186+ save_steps = max_steps ,
187+ warmup_steps = min (10 , max_steps // 10 ),
188+ )
189+ trainer = DPOTrainer (
190+ model = model ,
191+ tokenizer = tokenizer ,
192+ args = config ,
193+ train_dataset = temp_dataset ,
194+ )
195+
196+ # Train — mlx-tune handles MoE gradient routing correctly
197+ trainer .train ()
198+
199+ sys .stderr .write (f"Training complete. Adapter saved to { adapter_out } \n " )
200+ sys .stderr .flush ()
201+
202+ # Write final progress line for Rust to parse
203+ progress = {
204+ "step" : max_steps ,
205+ "total_steps" : max_steps ,
206+ "loss" : 0.0 ,
207+ "learning_rate" : learning_rate ,
208+ "chosen_reward" : 0.0 ,
209+ "rejected_reward" : 0.0 ,
210+ "reward_margin" : 0.0 ,
211+ }
212+ sys .stdout .write (_json .dumps (progress ) + "\n " )
213+ sys .stdout .flush ()
214+
215+ finally :
216+ os .unlink (temp_dataset )
97217
98218
99219def _train_builtin (** kwargs ):
@@ -137,11 +257,12 @@ def _train_builtin(**kwargs):
137257 lora_config = {"rank" : lora_rank , "scale" : lora_scale , "dropout" : 0.0 }
138258 linear_to_lora_layers (model , 16 , lora_config )
139259
140- # Freeze non-LoRA parameters
141- model .freeze ()
142- for name , param in model .named_parameters ():
143- if "lora" in name .lower ():
144- param .requires_grad = True
260+ # LoRA layers are already trainable from load() with adapter_path.
261+ # Count trainable params for logging.
262+ trainable = model .trainable_parameters ()
263+ n_trainable = sum (p .size for _ , p in nn .utils .tree_flatten (trainable ))
264+ sys .stderr .write (f"Trainable parameters: { n_trainable :,} \n " )
265+ sys .stderr .flush ()
145266
146267 # Load preference dataset
147268 pairs = []
@@ -210,12 +331,11 @@ def _train_builtin(**kwargs):
210331 )
211332 sys .stderr .flush ()
212333
213- # Save adapter
334+ # Save adapter — extract LoRA weights from the parameter tree
214335 os .makedirs (adapter_out , exist_ok = True )
215- # Save only LoRA weights
216- lora_weights = {
217- k : v for k , v in model .parameters ().items () if "lora" in k .lower ()
218- }
336+ lora_weights = {}
337+ for name , param in nn .utils .tree_flatten (model .trainable_parameters ()):
338+ lora_weights [name ] = param
219339 mx .save_safetensors (os .path .join (adapter_out , "adapters.safetensors" ), lora_weights )
220340
221341 sys .stderr .write (f"Adapter saved to { adapter_out } \n " )
@@ -265,14 +385,27 @@ def loss_fn(model):
265385 margin = chosen_avg_logp - rejected_avg_logp
266386 loss = - mx .log (mx .sigmoid (beta * margin ))
267387
268- return loss , {
269- "chosen_reward" : chosen_avg_logp ,
270- "rejected_reward" : rejected_avg_logp ,
271- "margin" : margin ,
272- }
273-
274- # Compute loss and gradients
275- (loss_val , metrics ), grads = nn .value_and_grad (model , lambda m : loss_fn (m ))(model )
388+ return loss
389+
390+ # nn.value_and_grad returns a function that computes (loss, grads)
391+ loss_and_grad_fn = nn .value_and_grad (model , loss_fn )
392+ loss_val , grads = loss_and_grad_fn (model )
393+
394+ # Compute metrics from a separate forward pass (cheap, no grad graph)
395+ chosen_logits = model (chosen_ids [None , :- 1 ])
396+ chosen_lp = - nn .losses .cross_entropy (
397+ chosen_logits .squeeze (0 ), chosen_ids [1 :], reduction = "none"
398+ ).mean ()
399+ rejected_logits = model (rejected_ids [None , :- 1 ])
400+ rejected_lp = - nn .losses .cross_entropy (
401+ rejected_logits .squeeze (0 ), rejected_ids [1 :], reduction = "none"
402+ ).mean ()
403+
404+ metrics = {
405+ "chosen_reward" : chosen_lp ,
406+ "rejected_reward" : rejected_lp ,
407+ "margin" : chosen_lp - rejected_lp ,
408+ }
276409
277410 return loss_val , grads , metrics
278411
0 commit comments