Skip to content

Commit fe529ee

Browse files
Merge pull request #3773 from AI-Hypercomputer:feat/nnx-native-dpo
PiperOrigin-RevId: 921700027
2 parents d11185b + d825b68 commit fe529ee

8 files changed

Lines changed: 680 additions & 44 deletions

File tree

src/maxtext/layers/train_state_nnx.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" The NNX Unified TrainState. """
15+
"""The NNX Unified TrainState."""
1616

1717
from typing import Any
1818

@@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module):
2525
This replaces Linen's TrainState for checkpointing.
2626
2727
Linen TrainState pytree:
28-
{params: {...}, opt_state: {}...}
28+
{"params": {...}, "opt_state": {}...}
2929
TrainStateNNX state pytree:
30-
{“model”: {...}, “optimizer”: {“opt_state”: {...}}
30+
{"model": {...}, "optimizer": {"opt_state": {...}}}
31+
32+
For DPO (Direct Preference Optimization), an optional `reference_model`
33+
carries a frozen copy of the same architecture used to compute reference
34+
log-probabilities. Only `model` is updated by `apply_gradients`; the
35+
reference is held alongside so it is sharded, jit-traced, and checkpointed
36+
with the rest of the train state.
3137
"""
3238

33-
def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None):
39+
def __init__(
40+
self,
41+
model: nnx.Module,
42+
optimizer: nnx.Optimizer | None,
43+
reference_model: nnx.Module | None = None,
44+
):
3445
self.model = model
3546
self.optimizer = optimizer
47+
if reference_model is not None:
48+
self.reference_model = reference_model
3649

3750
def apply_gradients(self, grads: Any):
3851
"""
3952
Mimics the Linen apply_gradients function.
4053
Updates the optimizer state, applies updates to parameters,
41-
and increments the step counter.
54+
and increments the step counter. Only updates `self.model`;
55+
`self.reference_model` (if present) is left untouched.
4256
"""
4357
if self.optimizer is None:
4458
raise RuntimeError(

src/maxtext/trainers/post_train/dpo/dpo_utils.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import jax
2020
import jax.numpy as jnp
2121

22+
from flax import nnx
23+
2224
from maxtext.utils import maxtext_utils
2325

2426

@@ -132,7 +134,14 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t
132134
- jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING
133135
)
134136
total_loss, total_weights = jnp.mean(losses), losses.shape[0]
135-
loss = total_loss
137+
# Under manual gradient accumulation, return the unnormalized sum: the accumulator
138+
# sums per-microbatch grads then divides once by total_weights, so a pre-normalized
139+
# mean would scale the gradient down by an extra microbatch-size factor. Tunix GA
140+
# expects a normalized per-step loss. Mirrors loss_fn in train.py.
141+
if config.gradient_accumulation_steps > 1 and not config.use_tunix_gradient_accumulation:
142+
loss = jnp.sum(losses)
143+
else:
144+
loss = total_loss
136145

137146
moe_lb_loss = 0.0
138147
if config.num_experts > 1:
@@ -148,10 +157,157 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t
148157
"total_weights": total_weights,
149158
"moe_lb_loss": moe_lb_loss,
150159
"reward_accuracy": reward_accuracy,
160+
"indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility
161+
"mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility
151162
}
152163
return loss, aux
153164

154165

155166
def _merge_dpo_state(state, reference_params):
156167
"""Merge reference parameters back into DPO state."""
157168
return state.replace(params=dict(state.params, reference_params=reference_params))
169+
170+
171+
# NNX DPO has no split/merge counterpart: the Linen path overlays
172+
# `reference_params` inside `state.params`, so it must be peeled off and
173+
# reattached around `apply_gradients`. The NNX path holds the reference as a
174+
# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already
175+
# only touches `self.model`, so no split/merge is needed.
176+
177+
178+
def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True):
179+
"""NNX DPO loss_fn for both train and eval.
180+
181+
Signature mirrors the Linen `dpo_loss_fn` so it slots into the same
182+
dispatcher in `gradient_accumulation_loss_and_grad`:
183+
`(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)`
184+
185+
Differences from the Linen `dpo_loss_fn`:
186+
* `policy_model` is an `nnx.Module` (carries its own params + RNG state).
187+
* `dropout_rng` and `params` are unused for NNX (kept positional for
188+
signature parity; NNX models manage these internally).
189+
* The 6th arg (the `extra_dpo_args[0]`) is a frozen reference
190+
`nnx.Module`, not a `reference_params` pytree.
191+
* Reference forward is wrapped in `jax.lax.stop_gradient`; combined with
192+
`nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows
193+
to the reference's `nnx.Param` leaves.
194+
195+
Args:
196+
policy_model: Policy `nnx.Module` (the model being trained).
197+
config: Config of parameters.
198+
data: Batch of preference data with `chosen` / `rejected` fields.
199+
dropout_rng: Unused for NNX (kept for signature parity with Linen).
200+
params: Unused for NNX (kept for signature parity with Linen).
201+
reference_model: Frozen reference `nnx.Module` for DPO logratio computation.
202+
is_train: True for train_step and False for eval_step.
203+
204+
Returns:
205+
loss: DPO preference loss + MoE load balance loss (if applicable).
206+
aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss,
207+
total_weights, moe_lb_loss, reward_accuracy.
208+
"""
209+
del dropout_rng, params # unused for NNX
210+
# decimate proportion of data when per_device_batch_size<1
211+
if is_train:
212+
for k, v in data.items():
213+
data[k] = v[: config.micro_batch_size_to_train_on, :]
214+
215+
# for DPO we don't support packed sequences (they shouldn't be present in the first place)
216+
data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32)
217+
data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32)
218+
data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1)
219+
data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1)
220+
221+
# concatenated policy/reference forward pass
222+
inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0)
223+
inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0)
224+
inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0)
225+
226+
logits = policy_model(
227+
decoder_input_tokens=inputs,
228+
decoder_positions=inputs_position,
229+
decoder_segment_ids=inputs_segmentation,
230+
enable_dropout=config.enable_dropout if is_train else False,
231+
)
232+
# pop (not snapshot) so sown Intermediates don't persist on the model across
233+
# microbatches during gradient accumulation; matches loss_fn in train.py.
234+
intermediates = nnx.pop(policy_model, nnx.Intermediate)
235+
intermediate_outputs = intermediates.to_pure_dict()
236+
237+
ref_logits = reference_model(
238+
decoder_input_tokens=inputs,
239+
decoder_positions=inputs_position,
240+
decoder_segment_ids=inputs_segmentation,
241+
enable_dropout=False,
242+
)
243+
ref_logits = jax.lax.stop_gradient(ref_logits)
244+
245+
# extract token ids, segmentation and logits for chosen and rejected sequences
246+
chosen_ids = data["chosen"][..., 1:]
247+
rejected_ids = data["rejected"][..., 1:]
248+
chosen_segmentation = data["chosen_segmentation"][..., 1:]
249+
rejected_segmentation = data["rejected_segmentation"][..., 1:]
250+
n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab]
251+
chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :]
252+
chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :]
253+
254+
# common subsequence and padding mask
255+
common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S]
256+
valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S]
257+
258+
# compute logratios from the sequence-reduced observed token log-probability
259+
chosen_logps_seq = jnp.take_along_axis( # [B, S]
260+
jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1
261+
)[..., 0]
262+
chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B]
263+
chosen_ref_logps_seq = jnp.take_along_axis( # [B, S]
264+
jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1
265+
)[..., 0]
266+
chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B]
267+
chosen_logratios = chosen_logps - chosen_ref_logps # [B]
268+
269+
rejected_logps_seq = jnp.take_along_axis( # [B, S]
270+
jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1
271+
)[..., 0]
272+
rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B]
273+
rejected_ref_logps_seq = jnp.take_along_axis( # [B, S]
274+
jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1
275+
)[..., 0]
276+
rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B]
277+
rejected_logratios = rejected_logps - rejected_ref_logps # [B]
278+
279+
# DPO loss from chosen and rejected logratios
280+
LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta
281+
logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B]
282+
losses = ( # [B]
283+
-jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING)
284+
- jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING
285+
)
286+
total_loss, total_weights = jnp.mean(losses), losses.shape[0]
287+
# Under manual gradient accumulation, return the unnormalized sum: the accumulator
288+
# sums per-microbatch grads then divides once by total_weights, so a pre-normalized
289+
# mean would scale the gradient down by an extra microbatch-size factor. Tunix GA
290+
# expects a normalized per-step loss. Mirrors loss_fn in train.py.
291+
if config.gradient_accumulation_steps > 1 and not config.use_tunix_gradient_accumulation:
292+
loss = jnp.sum(losses)
293+
else:
294+
loss = total_loss
295+
296+
moe_lb_loss = 0.0
297+
if config.num_experts > 1:
298+
moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss")
299+
if moe_lb_losses:
300+
moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses))
301+
loss += moe_lb_loss
302+
reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios)
303+
aux = {
304+
"intermediate_outputs": intermediate_outputs,
305+
"xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility
306+
"dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training
307+
"total_weights": total_weights,
308+
"moe_lb_loss": moe_lb_loss,
309+
"reward_accuracy": reward_accuracy,
310+
"indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility
311+
"mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility
312+
}
313+
return loss, aux

src/maxtext/trainers/pre_train/train.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from maxtext.common.gcloud_stub import vertex_tensorboard_modules
6060
from maxtext.common import metric_logger
6161
from maxtext.common.metric_logger import record_activation_metrics
62-
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
62+
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn, dpo_loss_fn_nnx
6363
from maxtext.utils import exceptions
6464
from maxtext.utils import gcs_utils
6565
from maxtext.utils import max_logging
@@ -319,15 +319,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
319319
params = state.params
320320
ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args
321321
else:
322-
if config.use_dpo:
323-
raise NotImplementedError(
324-
"DPO is not yet supported for NNX modules. DPO requires a reference model "
325-
"stored alongside the policy model (Linen path uses state.params['reference_params']); "
326-
"the NNX TrainState equivalent has not been wired up. As a workaround, set "
327-
"pure_nnx=False for DPO runs."
328-
)
329322
state = nnx.merge(model, state) # reconstruct TrainStateNNX
330-
ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, []
323+
if config.use_dpo:
324+
# NNX DPO: reference_model is a sibling field on TrainStateNNX (set up by
325+
# init_initial_state when config.use_dpo=True). dpo_loss_fn_nnx mirrors
326+
# the Linen dpo_loss_fn signature, so it slots into the same dispatcher
327+
# with reference_model passed as the single extra_dpo_args entry.
328+
ga_fn, ga_model, ga_params, ga_rng, ga_dpo = (dpo_loss_fn_nnx, state.model, None, None, [state.reference_model])
329+
else:
330+
ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, []
331331

332332
# --- Gradient computation ---
333333
if config.gradient_accumulation_steps > 1:
@@ -393,9 +393,14 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
393393
)
394394
nnx.update(state.model, curr_params)
395395

396+
# `ga_fn` and `ga_dpo` were set up earlier (loss_fn vs dpo_loss_fn_nnx;
397+
# ga_dpo carries the frozen reference_model when use_dpo, else empty).
398+
_nnx_loss_fn = ga_fn
399+
_nnx_extra_dpo_args = ga_dpo
400+
396401
def diff_wrapper(param, rest, config, data):
397402
local_model = nnx.merge(model_graphdef, param, rest, copy=True)
398-
loss, aux = loss_fn(local_model, config, data, None, None, is_train=True)
403+
loss, aux = _nnx_loss_fn(local_model, config, data, None, None, *_nnx_extra_dpo_args, is_train=True)
399404
_, _, new_rest = nnx.split(local_model, nnx.Param, ...)
400405
return loss, (aux, new_rest)
401406

@@ -581,7 +586,10 @@ def eval_step(model, config, state, data, dropout_rng=None):
581586
loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats)
582587
else:
583588
state = nnx.merge(model, state) # reconstruct TrainStateNNX
584-
loss, aux = loss_fn(state.model, config, data, None, None, is_train=False)
589+
if config.use_dpo:
590+
loss, aux = dpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False)
591+
else:
592+
loss, aux = loss_fn(state.model, config, data, None, None, is_train=False)
585593

586594
mtp_acceptance_rate = 0.0
587595
if config.mtp_eval_target_module > 0:
@@ -639,8 +647,8 @@ def train_loop(config, recorder, state=None):
639647
state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"])
640648
jit_model = model
641649
else:
642-
if config.use_dpo:
643-
raise NotImplementedError("DPO is not supported for NNX models.")
650+
# NNX keeps the DPO reference model as a sibling field on TrainStateNNX
651+
# (set up in init_state_fn), so no reference-param merge is needed here.
644652
jit_model, state = nnx.split(state)
645653

646654
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
@@ -709,7 +717,7 @@ def train_loop(config, recorder, state=None):
709717

710718
step_time_delta = datetime.datetime.now() - last_step_completion
711719

712-
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
720+
state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0]
713721
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step)
714722

715723
if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step):
@@ -756,7 +764,7 @@ def train_loop(config, recorder, state=None):
756764
metric_logger_instance.buffer_and_write_metrics(metrics, step, step_time_delta)
757765

758766
if config.save_checkpoint_on_completion:
759-
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
767+
state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0]
760768
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
761769
if checkpoint_manager is not None:
762770
# in case the last checkpoint_period checkpoint is still in progress

src/maxtext/utils/train_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,16 @@ def setup_train_loop(config, recorder, devices=None):
225225

226226
if config.pure_nnx:
227227
# For NNX, the train state is wrapped in the TrainStateNNX module.
228+
# When DPO is enabled, also materialize a frozen reference model alongside
229+
# the policy. Both are constructed by `_create_model_partial()` (which uses
230+
# `config.init_weights_seed`), so the reference starts identical to the
231+
# policy — standard DPO practice. The reference is later overwritten by
232+
# the step-0 checkpoint in `setup_post_setup_state` below.
228233
def create_train_state_fn():
229234
model = _create_model_partial()
230235
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
231-
return train_state_nnx.TrainStateNNX(model, optimizer)
236+
reference_model = _create_model_partial() if config.use_dpo else None
237+
return train_state_nnx.TrainStateNNX(model, optimizer, reference_model=reference_model)
232238

233239
init_state_fn = create_train_state_fn
234240
else:
@@ -316,8 +322,6 @@ def create_train_state_fn():
316322
maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params)
317323

318324
if config.use_dpo:
319-
if config.pure_nnx:
320-
raise NotImplementedError("DPO is not supported yet by NNX models.")
321325
abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training)
322326
max_logging.log(
323327
"Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'"
@@ -342,9 +346,18 @@ def create_train_state_fn():
342346
except FileNotFoundError:
343347
step0_restored = None
344348
if step0_restored is not None:
345-
# TODO: For pure_nnx, the dpo state manipulation is different.
346-
reference_params = step0_restored["items"].params["params"]
347-
state = _merge_dpo_state(state, reference_params)
349+
if config.pure_nnx:
350+
# step0_restored["items"] is the flat nnx.State of the step-0 TrainStateNNX
351+
# (typically from a non-DPO pre-training run, so its top-level fields are
352+
# `model` and `optimizer` — no `reference_model`). Copy its `model` substate
353+
# into our current state's `reference_model` slot.
354+
step0_state = step0_restored["items"]
355+
step0_model_substate = step0_state["model"] if "model" in step0_state else step0_state
356+
if isinstance(state, nnx.State):
357+
state["reference_model"] = step0_model_substate
358+
else:
359+
reference_params = step0_restored["items"].params["params"]
360+
state = _merge_dpo_state(state, reference_params)
348361
else:
349362
max_logging.log(
350363
"Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'"

0 commit comments

Comments
 (0)