Commit d825b68
committed
NNX: native DPO (TrainStateNNX.reference_model + dpo_loss_fn_nnx)
Implements NNX-native DPO so that the pure_nnx=True training path no
longer raises NotImplementedError on use_dpo runs. The Linen DPO
overlay pattern (model.apply(params=..., reference_params=...)) does
not translate to NNX modules, which carry their parameters internally.
Instead the policy and reference models are held as separate
nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx
runs both forwards with stop_gradient on the reference logits.
TrainStateNNX:
- Add optional `reference_model: nnx.Module` field. apply_gradients
continues to update only `self.model`, leaving `self.reference_model`
bit-identical across steps.
dpo_utils.py:
- Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params,
reference_model, is_train=True). Signature mirrors the Linen
dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's
dispatcher (dropout_rng / params slots are unused for NNX; carried
for parity, and reference_model is passed as the single
extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over
the policy, no gradient flows to the reference model's nnx.Param
leaves; the explicit jax.lax.stop_gradient on ref_logits is a
belt-and-braces guard.
- Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include
indexer_loss=0.0 and mtp_loss=0.0 in aux so the
gradient_accumulation aux pytree shape matches the non-DPO loss_fn.
train.py:
- Drop the NotImplementedError in train_step's NNX branch. When
use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as
extra_dpo_args; otherwise use the regular loss_fn. eval_step gains
the same dispatch.
- diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init
block, so both the GA and non-GA NNX paths route DPO identically.
- Checkpoint-save _split_dpo_state stripping is now Linen-only;
TrainStateNNX saves whole (reference_model included) — the step-0
reload later overwrites reference_model from the step-0 checkpoint.
train_utils.py:
- NNX init_state_fn materializes a frozen reference_model alongside
the policy when config.use_dpo. Both are constructed by
_create_model_partial() with config.init_weights_seed, so they
start identical (standard DPO practice) until the step-0 reload.
- Step-0 checkpoint reload: copy step0_state["model"] into
state["reference_model"]. Linen path unchanged.
Tests:
- New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX
reference_model init/hasattr semantics; apply_gradients leaves
reference bit-identical; aux key set; identical policy/reference
yields loss=log(2) and reward_accuracy=0.0 (strict > on equal
logratios); dropout_rng/params slots are signature-compat only;
nnx.value_and_grad(argnums=0) over the policy yields finite grads
on policy params only.
- train_nnx_test.py: drop the two stale negative tests
(vocab_tiling_raises_not_implemented,
train_step_dpo_raises_for_nnx) — both features are now real.
Stats: 4 source files + 2 test files, +199/-22 source lines. Linen
DPO path behaviorally unchanged (only adds two harmless aux-dict
keys); NNX non-DPO path unchanged (all changes gated on
config.use_dpo).1 parent be19157 commit d825b68
8 files changed
Lines changed: 679 additions & 44 deletions
File tree
- src/maxtext
- layers
- trainers
- post_train/dpo
- pre_train
- utils
- tests
- integration
- unit
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | | - | |
| 15 | + | |
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
28 | | - | |
| 28 | + | |
29 | 29 | | |
30 | | - | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
31 | 37 | | |
32 | 38 | | |
33 | | - | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
34 | 45 | | |
35 | 46 | | |
| 47 | + | |
| 48 | + | |
36 | 49 | | |
37 | 50 | | |
38 | 51 | | |
39 | 52 | | |
40 | 53 | | |
41 | | - | |
| 54 | + | |
| 55 | + | |
42 | 56 | | |
43 | 57 | | |
44 | 58 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| 22 | + | |
| 23 | + | |
22 | 24 | | |
23 | 25 | | |
24 | 26 | | |
| |||
132 | 134 | | |
133 | 135 | | |
134 | 136 | | |
135 | | - | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
136 | 145 | | |
137 | 146 | | |
138 | 147 | | |
| |||
148 | 157 | | |
149 | 158 | | |
150 | 159 | | |
| 160 | + | |
| 161 | + | |
151 | 162 | | |
152 | 163 | | |
153 | 164 | | |
154 | 165 | | |
155 | 166 | | |
156 | 167 | | |
157 | 168 | | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
62 | | - | |
| 62 | + | |
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| |||
319 | 319 | | |
320 | 320 | | |
321 | 321 | | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
328 | | - | |
329 | 322 | | |
330 | | - | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
331 | 331 | | |
332 | 332 | | |
333 | 333 | | |
| |||
393 | 393 | | |
394 | 394 | | |
395 | 395 | | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
396 | 401 | | |
397 | 402 | | |
398 | | - | |
| 403 | + | |
399 | 404 | | |
400 | 405 | | |
401 | 406 | | |
| |||
581 | 586 | | |
582 | 587 | | |
583 | 588 | | |
584 | | - | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
585 | 593 | | |
586 | 594 | | |
587 | 595 | | |
| |||
639 | 647 | | |
640 | 648 | | |
641 | 649 | | |
642 | | - | |
643 | | - | |
| 650 | + | |
| 651 | + | |
644 | 652 | | |
645 | 653 | | |
646 | 654 | | |
| |||
709 | 717 | | |
710 | 718 | | |
711 | 719 | | |
712 | | - | |
| 720 | + | |
713 | 721 | | |
714 | 722 | | |
715 | 723 | | |
| |||
756 | 764 | | |
757 | 765 | | |
758 | 766 | | |
759 | | - | |
| 767 | + | |
760 | 768 | | |
761 | 769 | | |
762 | 770 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
225 | 225 | | |
226 | 226 | | |
227 | 227 | | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
228 | 233 | | |
229 | 234 | | |
230 | 235 | | |
231 | | - | |
| 236 | + | |
| 237 | + | |
232 | 238 | | |
233 | 239 | | |
234 | 240 | | |
| |||
316 | 322 | | |
317 | 323 | | |
318 | 324 | | |
319 | | - | |
320 | | - | |
321 | 325 | | |
322 | 326 | | |
323 | 327 | | |
| |||
342 | 346 | | |
343 | 347 | | |
344 | 348 | | |
345 | | - | |
346 | | - | |
347 | | - | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
348 | 360 | | |
349 | 361 | | |
350 | 362 | | |
| |||
0 commit comments