1919import jax
2020import jax .numpy as jnp
2121
22+ from flax import nnx
23+
2224from maxtext .utils import maxtext_utils
2325
2426
@@ -132,7 +134,14 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t
132134 - jax .nn .log_sigmoid (- BETA * logratios_delta ) * LABEL_SMOOTHING
133135 )
134136 total_loss , total_weights = jnp .mean (losses ), losses .shape [0 ]
135- loss = total_loss
137+ # Under manual gradient accumulation, return the unnormalized sum: the accumulator
138+ # sums per-microbatch grads then divides once by total_weights, so a pre-normalized
139+ # mean would scale the gradient down by an extra microbatch-size factor. Tunix GA
140+ # expects a normalized per-step loss. Mirrors loss_fn in train.py.
141+ if config .gradient_accumulation_steps > 1 and not config .use_tunix_gradient_accumulation :
142+ loss = jnp .sum (losses )
143+ else :
144+ loss = total_loss
136145
137146 moe_lb_loss = 0.0
138147 if config .num_experts > 1 :
@@ -148,10 +157,157 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t
148157 "total_weights" : total_weights ,
149158 "moe_lb_loss" : moe_lb_loss ,
150159 "reward_accuracy" : reward_accuracy ,
160+ "indexer_loss" : 0.0 , # for gradient_accumulation aux pytree compatibility
161+ "mtp_loss" : 0.0 , # for gradient_accumulation aux pytree compatibility
151162 }
152163 return loss , aux
153164
154165
155166def _merge_dpo_state (state , reference_params ):
156167 """Merge reference parameters back into DPO state."""
157168 return state .replace (params = dict (state .params , reference_params = reference_params ))
169+
170+
171+ # NNX DPO has no split/merge counterpart: the Linen path overlays
172+ # `reference_params` inside `state.params`, so it must be peeled off and
173+ # reattached around `apply_gradients`. The NNX path holds the reference as a
174+ # sibling field `TrainStateNNX.reference_model`; `apply_gradients` already
175+ # only touches `self.model`, so no split/merge is needed.
176+
177+
178+ def dpo_loss_fn_nnx (policy_model , config , data , dropout_rng , params , reference_model , is_train = True ):
179+ """NNX DPO loss_fn for both train and eval.
180+
181+ Signature mirrors the Linen `dpo_loss_fn` so it slots into the same
182+ dispatcher in `gradient_accumulation_loss_and_grad`:
183+ `(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)`
184+
185+ Differences from the Linen `dpo_loss_fn`:
186+ * `policy_model` is an `nnx.Module` (carries its own params + RNG state).
187+ * `dropout_rng` and `params` are unused for NNX (kept positional for
188+ signature parity; NNX models manage these internally).
189+ * The 6th arg (the `extra_dpo_args[0]`) is a frozen reference
190+ `nnx.Module`, not a `reference_params` pytree.
191+ * Reference forward is wrapped in `jax.lax.stop_gradient`; combined with
192+ `nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows
193+ to the reference's `nnx.Param` leaves.
194+
195+ Args:
196+ policy_model: Policy `nnx.Module` (the model being trained).
197+ config: Config of parameters.
198+ data: Batch of preference data with `chosen` / `rejected` fields.
199+ dropout_rng: Unused for NNX (kept for signature parity with Linen).
200+ params: Unused for NNX (kept for signature parity with Linen).
201+ reference_model: Frozen reference `nnx.Module` for DPO logratio computation.
202+ is_train: True for train_step and False for eval_step.
203+
204+ Returns:
205+ loss: DPO preference loss + MoE load balance loss (if applicable).
206+ aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss,
207+ total_weights, moe_lb_loss, reward_accuracy.
208+ """
209+ del dropout_rng , params # unused for NNX
210+ # decimate proportion of data when per_device_batch_size<1
211+ if is_train :
212+ for k , v in data .items ():
213+ data [k ] = v [: config .micro_batch_size_to_train_on , :]
214+
215+ # for DPO we don't support packed sequences (they shouldn't be present in the first place)
216+ data ["chosen_segmentation" ] = (data ["chosen_segmentation" ] == 1 ).astype (jnp .int32 )
217+ data ["rejected_segmentation" ] = (data ["rejected_segmentation" ] == 1 ).astype (jnp .int32 )
218+ data ["chosen_position" ] = data ["chosen_position" ] * (data ["chosen_segmentation" ] == 1 )
219+ data ["rejected_position" ] = data ["rejected_position" ] * (data ["rejected_segmentation" ] == 1 )
220+
221+ # concatenated policy/reference forward pass
222+ inputs = jnp .concatenate ([data ["chosen" ], data ["rejected" ]], 0 )
223+ inputs_position = jnp .concatenate ([data ["chosen_position" ], data ["rejected_position" ]], 0 )
224+ inputs_segmentation = jnp .concatenate ([data ["chosen_segmentation" ], data ["rejected_segmentation" ]], 0 )
225+
226+ logits = policy_model (
227+ decoder_input_tokens = inputs ,
228+ decoder_positions = inputs_position ,
229+ decoder_segment_ids = inputs_segmentation ,
230+ enable_dropout = config .enable_dropout if is_train else False ,
231+ )
232+ # pop (not snapshot) so sown Intermediates don't persist on the model across
233+ # microbatches during gradient accumulation; matches loss_fn in train.py.
234+ intermediates = nnx .pop (policy_model , nnx .Intermediate )
235+ intermediate_outputs = intermediates .to_pure_dict ()
236+
237+ ref_logits = reference_model (
238+ decoder_input_tokens = inputs ,
239+ decoder_positions = inputs_position ,
240+ decoder_segment_ids = inputs_segmentation ,
241+ enable_dropout = False ,
242+ )
243+ ref_logits = jax .lax .stop_gradient (ref_logits )
244+
245+ # extract token ids, segmentation and logits for chosen and rejected sequences
246+ chosen_ids = data ["chosen" ][..., 1 :]
247+ rejected_ids = data ["rejected" ][..., 1 :]
248+ chosen_segmentation = data ["chosen_segmentation" ][..., 1 :]
249+ rejected_segmentation = data ["rejected_segmentation" ][..., 1 :]
250+ n_logits = logits .shape [- 3 ] // 2 # [B, S, E] - [batch, sequence, embedding/vocab]
251+ chosen_logits , rejected_logits = logits [:n_logits , :, :], logits [n_logits :, :, :]
252+ chosen_ref_logits , rejected_ref_logits = ref_logits [:n_logits , :, :], ref_logits [n_logits :, :, :]
253+
254+ # common subsequence and padding mask
255+ common_prefix_mask = jnp .cumsum (chosen_ids != rejected_ids , axis = - 1 ) == 0 # [B, S]
256+ valid_seq_mask = (chosen_segmentation != 0 ) & (rejected_segmentation != 0 ) & ~ common_prefix_mask # [B, S]
257+
258+ # compute logratios from the sequence-reduced observed token log-probability
259+ chosen_logps_seq = jnp .take_along_axis ( # [B, S]
260+ jax .nn .log_softmax (chosen_logits [..., :- 1 , :], axis = - 1 ), chosen_ids [..., None ], axis = - 1
261+ )[..., 0 ]
262+ chosen_logps = jnp .sum (chosen_logps_seq * valid_seq_mask , axis = - 1 ) # [B]
263+ chosen_ref_logps_seq = jnp .take_along_axis ( # [B, S]
264+ jax .nn .log_softmax (chosen_ref_logits [..., :- 1 , :], axis = - 1 ), chosen_ids [..., None ], axis = - 1
265+ )[..., 0 ]
266+ chosen_ref_logps = jnp .sum (chosen_ref_logps_seq * valid_seq_mask , axis = - 1 ) # [B]
267+ chosen_logratios = chosen_logps - chosen_ref_logps # [B]
268+
269+ rejected_logps_seq = jnp .take_along_axis ( # [B, S]
270+ jax .nn .log_softmax (rejected_logits [..., :- 1 , :], axis = - 1 ), rejected_ids [..., None ], axis = - 1
271+ )[..., 0 ]
272+ rejected_logps = jnp .sum (rejected_logps_seq * valid_seq_mask , axis = - 1 ) # [B]
273+ rejected_ref_logps_seq = jnp .take_along_axis ( # [B, S]
274+ jax .nn .log_softmax (rejected_ref_logits [..., :- 1 , :], axis = - 1 ), rejected_ids [..., None ], axis = - 1
275+ )[..., 0 ]
276+ rejected_ref_logps = jnp .sum (rejected_ref_logps_seq * valid_seq_mask , axis = - 1 ) # [B]
277+ rejected_logratios = rejected_logps - rejected_ref_logps # [B]
278+
279+ # DPO loss from chosen and rejected logratios
280+ LABEL_SMOOTHING , BETA = config .dpo_label_smoothing , config .dpo_beta
281+ logratios_delta = BETA * (chosen_logratios - rejected_logratios ) # [B]
282+ losses = ( # [B]
283+ - jax .nn .log_sigmoid (BETA * logratios_delta ) * (1 - LABEL_SMOOTHING )
284+ - jax .nn .log_sigmoid (- BETA * logratios_delta ) * LABEL_SMOOTHING
285+ )
286+ total_loss , total_weights = jnp .mean (losses ), losses .shape [0 ]
287+ # Under manual gradient accumulation, return the unnormalized sum: the accumulator
288+ # sums per-microbatch grads then divides once by total_weights, so a pre-normalized
289+ # mean would scale the gradient down by an extra microbatch-size factor. Tunix GA
290+ # expects a normalized per-step loss. Mirrors loss_fn in train.py.
291+ if config .gradient_accumulation_steps > 1 and not config .use_tunix_gradient_accumulation :
292+ loss = jnp .sum (losses )
293+ else :
294+ loss = total_loss
295+
296+ moe_lb_loss = 0.0
297+ if config .num_experts > 1 :
298+ moe_lb_losses = maxtext_utils .collect_intermediates_by_suffix (intermediate_outputs , "moe_lb_loss" )
299+ if moe_lb_losses :
300+ moe_lb_loss = jnp .mean (jnp .concatenate (moe_lb_losses ))
301+ loss += moe_lb_loss
302+ reward_accuracy = jnp .mean (chosen_logratios > rejected_logratios )
303+ aux = {
304+ "intermediate_outputs" : intermediate_outputs ,
305+ "xent_sum" : 0.0 , # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility
306+ "dpo_loss" : total_loss , # pure preference loss before MoE lb, analogous to lm_loss in pre-training
307+ "total_weights" : total_weights ,
308+ "moe_lb_loss" : moe_lb_loss ,
309+ "reward_accuracy" : reward_accuracy ,
310+ "indexer_loss" : 0.0 , # for gradient_accumulation aux pytree compatibility
311+ "mtp_loss" : 0.0 , # for gradient_accumulation aux pytree compatibility
312+ }
313+ return loss , aux
0 commit comments