Skip to content

Commit f809ea5

Browse files
committed
NNX: native DPO (TrainStateNNX.reference_model + dpo_loss_fn_nnx)
Implements NNX-native DPO so that the pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs. The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits. TrainStateNNX: - Add optional `reference_model: nnx.Module` field. apply_gradients continues to update only `self.model`, leaving `self.reference_model` bit-identical across steps. dpo_utils.py: - Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True). Signature mirrors the Linen dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; carried for parity, and reference_model is passed as the single extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves; the explicit jax.lax.stop_gradient on ref_logits is a belt-and-braces guard. - Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 and mtp_loss=0.0 in aux so the gradient_accumulation aux pytree shape matches the non-DPO loss_fn. train.py: - Drop the NotImplementedError in train_step's NNX branch. When use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as extra_dpo_args; otherwise use the regular loss_fn. eval_step gains the same dispatch. - diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block, so both the GA and non-GA NNX paths route DPO identically. - Checkpoint-save _split_dpo_state stripping is now Linen-only; TrainStateNNX saves whole (reference_model included) — the step-0 reload later overwrites reference_model from the step-0 checkpoint. train_utils.py: - NNX init_state_fn materializes a frozen reference_model alongside the policy when config.use_dpo. Both are constructed by _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload. - Step-0 checkpoint reload: copy step0_state["model"] into state["reference_model"]. Linen path unchanged. Tests: - New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX reference_model init/hasattr semantics; apply_gradients leaves reference bit-identical; aux key set; identical policy/reference yields loss=log(2) and reward_accuracy=0.0 (strict > on equal logratios); dropout_rng/params slots are signature-compat only; nnx.value_and_grad(argnums=0) over the policy yields finite grads on policy params only. - train_nnx_test.py: drop the two stale negative tests (vocab_tiling_raises_not_implemented, train_step_dpo_raises_for_nnx) — both features are now real. Stats: 4 source files + 2 test files, +199/-22 source lines. Linen DPO path behaviorally unchanged (only adds two harmless aux-dict keys); NNX non-DPO path unchanged (all changes gated on config.use_dpo).
1 parent 3e84637 commit f809ea5

7 files changed

Lines changed: 412 additions & 43 deletions

File tree

src/maxtext/layers/train_state_nnx.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" The NNX Unified TrainState. """
15+
"""The NNX Unified TrainState."""
1616

1717
from typing import Any
1818

@@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module):
2525
This replaces Linen's TrainState for checkpointing.
2626
2727
Linen TrainState pytree:
28-
{params: {...}, opt_state: {}...}
28+
{"params": {...}, "opt_state": {}...}
2929
TrainStateNNX state pytree:
30-
{“model”: {...}, “optimizer”: {“opt_state”: {...}}
30+
{"model": {...}, "optimizer": {"opt_state": {...}}}
31+
32+
For DPO (Direct Preference Optimization), an optional `reference_model`
33+
carries a frozen copy of the same architecture used to compute reference
34+
log-probabilities. Only `model` is updated by `apply_gradients`; the
35+
reference is held alongside so it is sharded, jit-traced, and checkpointed
36+
with the rest of the train state.
3137
"""
3238

33-
def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None):
39+
def __init__(
40+
self,
41+
model: nnx.Module,
42+
optimizer: nnx.Optimizer | None,
43+
reference_model: nnx.Module | None = None,
44+
):
3445
self.model = model
3546
self.optimizer = optimizer
47+
if reference_model is not None:
48+
self.reference_model = reference_model
3649

3750
def apply_gradients(self, grads: Any):
3851
"""
3952
Mimics the Linen apply_gradients function.
4053
Updates the optimizer state, applies updates to parameters,
41-
and increments the step counter.
54+
and increments the step counter. Only updates `self.model`;
55+
`self.reference_model` (if present) is left untouched.
4256
"""
4357
if self.optimizer is None:
4458
raise RuntimeError(

src/maxtext/trainers/post_train/dpo/dpo_utils.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import jax
2020
import jax.numpy as jnp
2121

22+
from flax import nnx
23+
2224
from maxtext.utils import maxtext_utils
2325

2426

@@ -148,10 +150,147 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t
148150
"total_weights": total_weights,
149151
"moe_lb_loss": moe_lb_loss,
150152
"reward_accuracy": reward_accuracy,
153+
"indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility
154+
"mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility
151155
}
152156
return loss, aux
153157

154158

155159
def _merge_dpo_state(state, reference_params):
156160
"""Merge reference parameters back into DPO state."""
157161
return state.replace(params=dict(state.params, reference_params=reference_params))
162+
163+
164+
# NNX DPO has no split/merge counterpart: the Linen path overlays
165+
# `reference_params` inside `state.params`, so it must be peeled off and
166+
# reattached around `apply_gradients`. The NNX path holds the reference as a
167+
# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already
168+
# only touches `self.model`, so no split/merge is needed.
169+
170+
171+
def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True):
172+
"""NNX DPO loss_fn for both train and eval.
173+
174+
Signature mirrors the Linen `dpo_loss_fn` so it slots into the same
175+
dispatcher in `gradient_accumulation_loss_and_grad`:
176+
`(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)`
177+
178+
Differences from the Linen `dpo_loss_fn`:
179+
* `policy_model` is an `nnx.Module` (carries its own params + RNG state).
180+
* `dropout_rng` and `params` are unused for NNX (kept positional for
181+
signature parity; NNX models manage these internally).
182+
* The 6th arg (the `extra_dpo_args[0]`) is a frozen reference
183+
`nnx.Module`, not a `reference_params` pytree.
184+
* Reference forward is wrapped in `jax.lax.stop_gradient`; combined with
185+
`nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows
186+
to the reference's `nnx.Param` leaves.
187+
188+
Args:
189+
policy_model: Policy `nnx.Module` (the model being trained).
190+
config: Config of parameters.
191+
data: Batch of preference data with `chosen` / `rejected` fields.
192+
dropout_rng: Unused for NNX (kept for signature parity with Linen).
193+
params: Unused for NNX (kept for signature parity with Linen).
194+
reference_model: Frozen reference `nnx.Module` for DPO logratio computation.
195+
is_train: True for train_step and False for eval_step.
196+
197+
Returns:
198+
loss: DPO preference loss + MoE load balance loss (if applicable).
199+
aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss,
200+
total_weights, moe_lb_loss, reward_accuracy.
201+
"""
202+
del dropout_rng, params # unused for NNX
203+
# decimate proportion of data when per_device_batch_size<1
204+
if is_train:
205+
for k, v in data.items():
206+
data[k] = v[: config.micro_batch_size_to_train_on, :]
207+
208+
# for DPO we don't support packed sequences (they shouldn't be present in the first place)
209+
data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32)
210+
data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32)
211+
data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1)
212+
data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1)
213+
214+
# concatenated policy/reference forward pass
215+
inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0)
216+
inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0)
217+
inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0)
218+
219+
logits = policy_model(
220+
decoder_input_tokens=inputs,
221+
decoder_positions=inputs_position,
222+
decoder_segment_ids=inputs_segmentation,
223+
enable_dropout=config.enable_dropout if is_train else False,
224+
)
225+
intermediate_outputs = nnx.state(policy_model, nnx.Intermediate).to_pure_dict()
226+
227+
ref_logits = reference_model(
228+
decoder_input_tokens=inputs,
229+
decoder_positions=inputs_position,
230+
decoder_segment_ids=inputs_segmentation,
231+
enable_dropout=False,
232+
)
233+
ref_logits = jax.lax.stop_gradient(ref_logits)
234+
235+
# extract token ids, segmentation and logits for chosen and rejected sequences
236+
chosen_ids = data["chosen"][..., 1:]
237+
rejected_ids = data["rejected"][..., 1:]
238+
chosen_segmentation = data["chosen_segmentation"][..., 1:]
239+
rejected_segmentation = data["rejected_segmentation"][..., 1:]
240+
n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab]
241+
chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :]
242+
chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :]
243+
244+
# common subsequence and padding mask
245+
common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S]
246+
valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S]
247+
248+
# compute logratios from the sequence-reduced observed token log-probability
249+
chosen_logps_seq = jnp.take_along_axis( # [B, S]
250+
jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1
251+
)[..., 0]
252+
chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B]
253+
chosen_ref_logps_seq = jnp.take_along_axis( # [B, S]
254+
jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1
255+
)[..., 0]
256+
chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B]
257+
chosen_logratios = chosen_logps - chosen_ref_logps # [B]
258+
259+
rejected_logps_seq = jnp.take_along_axis( # [B, S]
260+
jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1
261+
)[..., 0]
262+
rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B]
263+
rejected_ref_logps_seq = jnp.take_along_axis( # [B, S]
264+
jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1
265+
)[..., 0]
266+
rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B]
267+
rejected_logratios = rejected_logps - rejected_ref_logps # [B]
268+
269+
# DPO loss from chosen and rejected logratios
270+
LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta
271+
logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B]
272+
losses = ( # [B]
273+
-jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING)
274+
- jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING
275+
)
276+
total_loss, total_weights = jnp.mean(losses), losses.shape[0]
277+
loss = total_loss
278+
279+
moe_lb_loss = 0.0
280+
if config.num_experts > 1:
281+
moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss")
282+
if moe_lb_losses:
283+
moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses))
284+
loss += moe_lb_loss
285+
reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios)
286+
aux = {
287+
"intermediate_outputs": intermediate_outputs,
288+
"xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility
289+
"dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training
290+
"total_weights": total_weights,
291+
"moe_lb_loss": moe_lb_loss,
292+
"reward_accuracy": reward_accuracy,
293+
"indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility
294+
"mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility
295+
}
296+
return loss, aux

src/maxtext/trainers/pre_train/train.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from maxtext.common.gcloud_stub import vertex_tensorboard_modules
6262
from maxtext.common import metric_logger
6363
from maxtext.common.metric_logger import record_activation_metrics
64-
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
64+
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn, dpo_loss_fn_nnx
6565
from maxtext.utils import exceptions
6666
from maxtext.utils import gcs_utils
6767
from maxtext.utils import max_logging
@@ -320,15 +320,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
320320
params = state.params
321321
ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args
322322
else:
323-
if config.use_dpo:
324-
raise NotImplementedError(
325-
"DPO is not yet supported for NNX modules. DPO requires a reference model "
326-
"stored alongside the policy model (Linen path uses state.params['reference_params']); "
327-
"the NNX TrainState equivalent has not been wired up. As a workaround, set "
328-
"pure_nnx=False for DPO runs."
329-
)
330323
state = nnx.merge(model, state) # reconstruct TrainStateNNX
331-
ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, []
324+
if config.use_dpo:
325+
# NNX DPO: reference_model is a sibling field on TrainStateNNX (set up by
326+
# init_initial_state when config.use_dpo=True). dpo_loss_fn_nnx mirrors
327+
# the Linen dpo_loss_fn signature, so it slots into the same dispatcher
328+
# with reference_model passed as the single extra_dpo_args entry.
329+
ga_fn, ga_model, ga_params, ga_rng, ga_dpo = (dpo_loss_fn_nnx, state.model, None, None, [state.reference_model])
330+
else:
331+
ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, []
332332

333333
# --- Gradient computation ---
334334
if config.gradient_accumulation_steps > 1:
@@ -394,9 +394,14 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
394394
)
395395
nnx.update(state.model, curr_params)
396396

397+
# `ga_fn` and `ga_dpo` were set up earlier (loss_fn vs dpo_loss_fn_nnx;
398+
# ga_dpo carries the frozen reference_model when use_dpo, else empty).
399+
_nnx_loss_fn = ga_fn
400+
_nnx_extra_dpo_args = ga_dpo
401+
397402
def diff_wrapper(param, rest, config, data):
398403
local_model = nnx.merge(model_graphdef, param, rest, copy=True)
399-
loss, aux = loss_fn(local_model, config, data, None, None, is_train=True)
404+
loss, aux = _nnx_loss_fn(local_model, config, data, None, None, *_nnx_extra_dpo_args, is_train=True)
400405
_, _, new_rest = nnx.split(local_model, nnx.Param, ...)
401406
return loss, (aux, new_rest)
402407

@@ -576,7 +581,10 @@ def eval_step(model, config, state, data, dropout_rng=None):
576581
loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats)
577582
else:
578583
state = nnx.merge(model, state) # reconstruct TrainStateNNX
579-
loss, aux = loss_fn(state.model, config, data, None, None, is_train=False)
584+
if config.use_dpo:
585+
loss, aux = dpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False)
586+
else:
587+
loss, aux = loss_fn(state.model, config, data, None, None, is_train=False)
580588

581589
mtp_acceptance_rate = 0.0
582590
if config.mtp_eval_target_module > 0:
@@ -702,7 +710,7 @@ def train_loop(config, recorder, state=None):
702710
step_time_delta = datetime.datetime.now() - last_step_completion
703711
last_step_completion = datetime.datetime.now()
704712

705-
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
713+
state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0]
706714
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step)
707715

708716
if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step):
@@ -746,7 +754,7 @@ def train_loop(config, recorder, state=None):
746754
metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta)
747755

748756
if config.save_checkpoint_on_completion:
749-
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
757+
state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0]
750758
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
751759
if checkpoint_manager is not None:
752760
# in case the last checkpoint_period checkpoint is still in progress

src/maxtext/utils/train_utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,16 @@ def setup_train_loop(config, recorder, devices=None):
225225

226226
if config.pure_nnx:
227227
# For NNX, the train state is wrapped in the TrainStateNNX module.
228+
# When DPO is enabled, also materialize a frozen reference model alongside
229+
# the policy. Both are constructed by `_create_model_partial()` (which uses
230+
# `config.init_weights_seed`), so the reference starts identical to the
231+
# policy — standard DPO practice. The reference is later overwritten by
232+
# the step-0 checkpoint in `setup_post_setup_state` below.
228233
def create_train_state_fn():
229234
model = _create_model_partial()
230235
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
231-
return train_state_nnx.TrainStateNNX(model, optimizer)
236+
reference_model = _create_model_partial() if config.use_dpo else None
237+
return train_state_nnx.TrainStateNNX(model, optimizer, reference_model=reference_model)
232238

233239
init_state_fn = create_train_state_fn
234240
else:
@@ -316,8 +322,6 @@ def create_train_state_fn():
316322
maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params)
317323

318324
if config.use_dpo:
319-
if config.pure_nnx:
320-
raise NotImplementedError("DPO is not supported yet by NNX models.")
321325
abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training)
322326
max_logging.log(
323327
"Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'"
@@ -342,9 +346,17 @@ def create_train_state_fn():
342346
except FileNotFoundError:
343347
step0_restored = None
344348
if step0_restored is not None:
345-
# TODO: For pure_nnx, the dpo state manipulation is different.
346-
reference_params = step0_restored["items"].params["params"]
347-
state = _merge_dpo_state(state, reference_params)
349+
if config.pure_nnx:
350+
# step0_restored["items"] is the flat nnx.State of the step-0 TrainStateNNX
351+
# (typically from a non-DPO pre-training run, so its top-level fields are
352+
# `model` and `optimizer` — no `reference_model`). Copy its `model` substate
353+
# into our current state's `reference_model` slot.
354+
step0_state = step0_restored["items"]
355+
step0_model_substate = step0_state["model"] if "model" in step0_state else step0_state
356+
state["reference_model"] = step0_model_substate
357+
else:
358+
reference_params = step0_restored["items"].params["params"]
359+
state = _merge_dpo_state(state, reference_params)
348360
else:
349361
max_logging.log(
350362
"Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'"

tests/integration/setup_train_loop_nnx_test.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,6 @@ def test_pure_nnx_setup_param_only_split_matches_model(self):
126126

127127
del model
128128

129-
def test_pure_nnx_dpo_raises_not_implemented(self):
130-
"""The use_dpo branch (train_utils.py:319-320) must raise for NNX."""
131-
# use_dpo requires a few prerequisites; the simplest is to set the flag and
132-
# let setup_train_loop reach the NotImplementedError check before the more
133-
# involved DPO path runs.
134-
config = _tiny_nnx_pyconfig(use_dpo=True, packing=False)
135-
with self.assertRaises(NotImplementedError):
136-
setup_train_loop(config, recorder=None)
137-
138129

139130
if __name__ == "__main__":
140131
unittest.main()

0 commit comments

Comments
 (0)