|
19 | 19 | import jax |
20 | 20 | import jax.numpy as jnp |
21 | 21 |
|
| 22 | +from flax import nnx |
| 23 | + |
22 | 24 | from maxtext.utils import maxtext_utils |
23 | 25 |
|
24 | 26 |
|
@@ -148,10 +150,147 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t |
148 | 150 | "total_weights": total_weights, |
149 | 151 | "moe_lb_loss": moe_lb_loss, |
150 | 152 | "reward_accuracy": reward_accuracy, |
| 153 | + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility |
| 154 | + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility |
151 | 155 | } |
152 | 156 | return loss, aux |
153 | 157 |
|
154 | 158 |
|
155 | 159 | def _merge_dpo_state(state, reference_params): |
156 | 160 | """Merge reference parameters back into DPO state.""" |
157 | 161 | return state.replace(params=dict(state.params, reference_params=reference_params)) |
| 162 | + |
| 163 | + |
| 164 | +# NNX DPO has no split/merge counterpart: the Linen path overlays |
| 165 | +# `reference_params` inside `state.params`, so it must be peeled off and |
| 166 | +# reattached around `apply_gradients`. The NNX path holds the reference as a |
| 167 | +# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already |
| 168 | +# only touches `self.model`, so no split/merge is needed. |
| 169 | + |
| 170 | + |
| 171 | +def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): |
| 172 | + """NNX DPO loss_fn for both train and eval. |
| 173 | +
|
| 174 | + Signature mirrors the Linen `dpo_loss_fn` so it slots into the same |
| 175 | + dispatcher in `gradient_accumulation_loss_and_grad`: |
| 176 | + `(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)` |
| 177 | +
|
| 178 | + Differences from the Linen `dpo_loss_fn`: |
| 179 | + * `policy_model` is an `nnx.Module` (carries its own params + RNG state). |
| 180 | + * `dropout_rng` and `params` are unused for NNX (kept positional for |
| 181 | + signature parity; NNX models manage these internally). |
| 182 | + * The 6th arg (the `extra_dpo_args[0]`) is a frozen reference |
| 183 | + `nnx.Module`, not a `reference_params` pytree. |
| 184 | + * Reference forward is wrapped in `jax.lax.stop_gradient`; combined with |
| 185 | + `nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows |
| 186 | + to the reference's `nnx.Param` leaves. |
| 187 | +
|
| 188 | + Args: |
| 189 | + policy_model: Policy `nnx.Module` (the model being trained). |
| 190 | + config: Config of parameters. |
| 191 | + data: Batch of preference data with `chosen` / `rejected` fields. |
| 192 | + dropout_rng: Unused for NNX (kept for signature parity with Linen). |
| 193 | + params: Unused for NNX (kept for signature parity with Linen). |
| 194 | + reference_model: Frozen reference `nnx.Module` for DPO logratio computation. |
| 195 | + is_train: True for train_step and False for eval_step. |
| 196 | +
|
| 197 | + Returns: |
| 198 | + loss: DPO preference loss + MoE load balance loss (if applicable). |
| 199 | + aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss, |
| 200 | + total_weights, moe_lb_loss, reward_accuracy. |
| 201 | + """ |
| 202 | + del dropout_rng, params # unused for NNX |
| 203 | + # decimate proportion of data when per_device_batch_size<1 |
| 204 | + if is_train: |
| 205 | + for k, v in data.items(): |
| 206 | + data[k] = v[: config.micro_batch_size_to_train_on, :] |
| 207 | + |
| 208 | + # for DPO we don't support packed sequences (they shouldn't be present in the first place) |
| 209 | + data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32) |
| 210 | + data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32) |
| 211 | + data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1) |
| 212 | + data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1) |
| 213 | + |
| 214 | + # concatenated policy/reference forward pass |
| 215 | + inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0) |
| 216 | + inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0) |
| 217 | + inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0) |
| 218 | + |
| 219 | + logits = policy_model( |
| 220 | + decoder_input_tokens=inputs, |
| 221 | + decoder_positions=inputs_position, |
| 222 | + decoder_segment_ids=inputs_segmentation, |
| 223 | + enable_dropout=config.enable_dropout if is_train else False, |
| 224 | + ) |
| 225 | + intermediate_outputs = nnx.state(policy_model, nnx.Intermediate).to_pure_dict() |
| 226 | + |
| 227 | + ref_logits = reference_model( |
| 228 | + decoder_input_tokens=inputs, |
| 229 | + decoder_positions=inputs_position, |
| 230 | + decoder_segment_ids=inputs_segmentation, |
| 231 | + enable_dropout=False, |
| 232 | + ) |
| 233 | + ref_logits = jax.lax.stop_gradient(ref_logits) |
| 234 | + |
| 235 | + # extract token ids, segmentation and logits for chosen and rejected sequences |
| 236 | + chosen_ids = data["chosen"][..., 1:] |
| 237 | + rejected_ids = data["rejected"][..., 1:] |
| 238 | + chosen_segmentation = data["chosen_segmentation"][..., 1:] |
| 239 | + rejected_segmentation = data["rejected_segmentation"][..., 1:] |
| 240 | + n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab] |
| 241 | + chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :] |
| 242 | + chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :] |
| 243 | + |
| 244 | + # common subsequence and padding mask |
| 245 | + common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S] |
| 246 | + valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S] |
| 247 | + |
| 248 | + # compute logratios from the sequence-reduced observed token log-probability |
| 249 | + chosen_logps_seq = jnp.take_along_axis( # [B, S] |
| 250 | + jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 |
| 251 | + )[..., 0] |
| 252 | + chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B] |
| 253 | + chosen_ref_logps_seq = jnp.take_along_axis( # [B, S] |
| 254 | + jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 |
| 255 | + )[..., 0] |
| 256 | + chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B] |
| 257 | + chosen_logratios = chosen_logps - chosen_ref_logps # [B] |
| 258 | + |
| 259 | + rejected_logps_seq = jnp.take_along_axis( # [B, S] |
| 260 | + jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 |
| 261 | + )[..., 0] |
| 262 | + rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B] |
| 263 | + rejected_ref_logps_seq = jnp.take_along_axis( # [B, S] |
| 264 | + jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 |
| 265 | + )[..., 0] |
| 266 | + rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B] |
| 267 | + rejected_logratios = rejected_logps - rejected_ref_logps # [B] |
| 268 | + |
| 269 | + # DPO loss from chosen and rejected logratios |
| 270 | + LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta |
| 271 | + logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B] |
| 272 | + losses = ( # [B] |
| 273 | + -jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING) |
| 274 | + - jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING |
| 275 | + ) |
| 276 | + total_loss, total_weights = jnp.mean(losses), losses.shape[0] |
| 277 | + loss = total_loss |
| 278 | + |
| 279 | + moe_lb_loss = 0.0 |
| 280 | + if config.num_experts > 1: |
| 281 | + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") |
| 282 | + if moe_lb_losses: |
| 283 | + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) |
| 284 | + loss += moe_lb_loss |
| 285 | + reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios) |
| 286 | + aux = { |
| 287 | + "intermediate_outputs": intermediate_outputs, |
| 288 | + "xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility |
| 289 | + "dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training |
| 290 | + "total_weights": total_weights, |
| 291 | + "moe_lb_loss": moe_lb_loss, |
| 292 | + "reward_accuracy": reward_accuracy, |
| 293 | + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility |
| 294 | + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility |
| 295 | + } |
| 296 | + return loss, aux |
0 commit comments